[
  {
    "path": ".gitignore",
    "content": ".idea\n__pycache__/\n*.DS_Store\n*.pth\n*.pt\n*.mp4\n*.npy\nvis_results/\ncheckpoints/\nlogs/\nslurm_logs/\nsubmit*\nlogs*\n\n/running\n/datasets\n/env.sh\n/eular_log\n/outputs\n/wandb"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\" style=\"line-height:1.2; margin:0; padding:0;\">\n<h1 style=\"margin-bottom:0em;\">Multi-View 3D Point Tracking</h1>\n\n<a href=\"https://arxiv.org/abs/2508.21060\"><img src=\"https://img.shields.io/badge/arXiv-2508.21060-b31b1b\" alt=\"arXiv\"></a>\n<a href=\"https://ethz-vlg.github.io/mvtracker/\"><img src=\"https://img.shields.io/badge/Project%20Page-009688?logo=internetcomputer&logoColor=white\" alt=\"Project Page\"></a>\n<a href=\"https://ethz-vlg.github.io/mvtracker/#qualitative-visualization\"><img src=\"https://img.shields.io/badge/Interactive%20Results-673ab7?logo=apachespark&logoColor=white\" alt=\"Interactive Results\"></a>\n[![](https://img.shields.io/badge/🤗%20Demo-Coming%20soon…-ffcc00)](#)\n<br>\n[**Frano Rajič**](https://m43.github.io/)<sup>1</sup> · \n[**Haofei Xu**](https://haofeixu.github.io/)<sup>1</sup> · \n[**Marko Mihajlovic**](https://markomih.github.io/)<sup>1</sup> · \n[**Siyuan Li**](https://siyuanliii.github.io/)<sup>1</sup> · \n[**Irem Demir**](https://github.com/iremddemir)<sup>1</sup>  \n[**Emircan Gündoğdu**](https://github.com/emircangun)<sup>1</sup> · \n[**Lei Ke**](https://www.kelei.site/)<sup>2</sup> · \n[**Sergey Prokudin**](https://vlg.inf.ethz.ch/team/Dr-Sergey-Prokudin.html)<sup>1,3</sup> · \n[**Marc Pollefeys**](https://people.inf.ethz.ch/marc.pollefeys/)<sup>1,4</sup> · \n[**Siyu Tang**](https://vlg.inf.ethz.ch/team/Prof-Dr-Siyu-Tang.html)<sup>1</sup>\n<br>\n<sup>1</sup>[ETH Zürich](https://vlg.inf.ethz.ch/) &emsp;\n<sup>2</sup>[Carnegie Mellon University](https://www.cmu.edu/) &emsp;\n<sup>3</sup>[Balgrist University Hospital](https://www.balgrist.ch/) &emsp;\n<sup>4</sup>[Microsoft](https://www.microsoft.com/)\n</div>\n\n<p float=\"left\">\n  <img alt=\"selfcap\" src=\"https://github.com/user-attachments/assets/b502d193-c37c-43be-af6c-653b5de7597e\" width=\"48%\" /> \n  <img alt=\"dexycb\" src=\"https://github.com/user-attachments/assets/d14d4c6c-152e-4040-b29b-3da4b7e8b913\" width=\"48%\" /> \n  <img alt=\"4d-dress-stretching\" src=\"https://github.com/user-attachments/assets/f3eabdda-59e1-4032-b345-c4603ea86fc0\" width=\"48%\" />\n  <img alt=\"4d-dress-avatarmove\" src=\"https://github.com/user-attachments/assets/3fef9924-84ad-4295-95e2-5b82ae7c3053\" width=\"48%\" />\n</p>\n\nMVTracker is the first **data-driven multi-view 3D point tracker** for tracking arbitrary 3D points across multiple cameras. It fuses multi-view features into a unified 3D feature point cloud, within which it leverages kNN-based correlation to capture spatiotemporal relationships across views. A transformer then iteratively refines the point tracks, handling occlusions and adapting to varying camera setups without per-sequence optimization.\n\n\n## Updates\n\n- <ins>August 28, 2025</ins>: Public release.\n\n\n## Quick Start\n\nThis repo was validated on **Python 3.10.12**, **PyTorch 2.3.0** (CUDA 12.1), **cuDNN 8903**, and **gcc 11.3.0**. If you want a fresh minimal environment that runs the Hub demo and `demo.py`:\n```bash\nconda create -n 3dpt python=3.10.12 -y\nconda activate 3dpt\nconda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y\npip install -r https://raw.githubusercontent.com/ethz-vlg/mvtracker/refs/heads/main/requirements.txt\n\n# Optional, speeds up the model\npip install --upgrade --no-build-isolation flash-attn==2.5.8  # Speeds up attention\npip install \"git+https://github.com/ethz-vlg/pointcept.git@2082918#subdirectory=libs/pointops\"  # Speeds up kNN search; may require gcc 11.3.0: conda install -c conda-forge gcc_linux-64=11.3.0 gxx_linux-64=11.3.0 gcc=11.3.0 gxx=11.3.0\n```\n\nWith the minimal dependencies in place, you can try MVTracker directly via **PyTorch Hub**:\n```python\nimport torch\nimport numpy as np\nfrom huggingface_hub import hf_hub_download\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nmvtracker = torch.hub.load(\"ethz-vlg/mvtracker\", \"mvtracker\", pretrained=True, device=device)\n\n# Example input from demo sample (downloaded automatically)\nsample = np.load(hf_hub_download(\"ethz-vlg/mvtracker\", \"data_sample.npz\"))\nrgbs = torch.from_numpy(sample[\"rgbs\"]).float()\ndepths = torch.from_numpy(sample[\"depths\"]).float()\nintrs = torch.from_numpy(sample[\"intrs\"]).float()\nextrs = torch.from_numpy(sample[\"extrs\"]).float()\nquery_points = torch.from_numpy(sample[\"query_points\"]).float()\n\nwith torch.no_grad():\n    results = mvtracker(\n        rgbs=rgbs[None].to(device) / 255.0,\n        depths=depths[None].to(device),\n        intrs=intrs[None].to(device),\n        extrs=extrs[None].to(device),\n        query_points_3d=query_points[None].to(device),\n    )\n\npred_tracks = results[\"traj_e\"].cpu()  # [T,N,3]\npred_vis = results[\"vis_e\"].cpu()      # [T,N]\nprint(pred_tracks.shape, pred_vis.shape)\n```\n\nAlternatively, you can run our interactive demo:\n\n```bash\npython demo.py --rerun save --lightweight\n```\n\nBy default this saves a lightweight `.rrd` recording (e.g., `mvtracker_demo.rrd`) that you can open in any Rerun viewer. The simplest option is to drag and drop the file into the [online viewer](https://app.rerun.io/version/0.21.0). For the best experience, you can also install Rerun locally (`pip install rerun-sdk==0.21.0; rerun`). Results can be explored interactively in the viewer with WASD/QE navigation, mouse rotation and zoom, and timeline playback controls.\n\n<details>\n<summary>[Interactive viewer on a cluster or with GUI support - click to expand]</summary>\n  \nIf you are working on a cluster, you can stream results directly to your laptop by forwarding a port (`ssh -R 9876:localhost:9876 user@cluster`) and then running the demo in streaming mode (`python demo.py --rerun stream`), which sends live data into your local Rerun instance. If you are running the demo locally with GUI support, you can automatically spawn a Rerun window (`python demo.py --rerun spawn`).\n\n</details>\n\n\n## Installation\n\nYou can use a pretrained model directly via **PyTorch Hub** (see Quick Start above), or clone this repo if you want to run our demo, evaluation, or training. We recommend using **PyTorch with CUDA** for best performance. CPU-only runs are possible but very slow.\n\n```bash\ngit clone https://github.com/ethz-vlg/mvtracker.git\ncd mvtracker\n```\n\nTo extend the conda environment from the Quick Start to support training and evaluation, install the full requirements by running `pip install -r requirements.full.txt`. Baselines based on SpatialTracker V1 also require cupy:\n```bash\npip install tensorflow==2.12.1 tensorflow-datasets tensorflow-graphics tensorboard\npip install cupy-cuda12x==12.2.0\npython -m cupyx.tools.install_library --cuda 12.x --library cutensor\npython -m cupyx.tools.install_library --cuda 12.x --library nccl\npython -m cupyx.tools.install_library --cuda 12.x --library cudnn\n```\n\n\n## Datasets\n\nTo benchmark multi-view 3D point tracking, we provide preprocessed versions of three datasets:\n\n- **MV-Kubric**: a synthetic training dataset adapted from single-view Kubric into a multi-view setting.  \n- **Panoptic Studio**: evaluation benchmark with real-world activities such as basketball, juggling, and toy play (10 sequences).  \n- **DexYCB**: evaluation benchmark with real-world hand–object interactions (10 sequences).  \n\n<details>\n<summary>[Downloading our preprocessed datasets - click to expand]</summary>\n  \nYou can download and extract them as (~72 GB after extraction):\n\n```bash\n# MV-Kubric (simulated + DUSt3R depths)\nwget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/kubric-multiview--test.tar.gz -P datasets/\nwget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/kubric-multiview--test--dust3r-depth.tar.gz -P datasets/\ntar -xvzf datasets/kubric-multiview--test.tar.gz -C datasets/\ntar -xvzf datasets/kubric-multiview--test--dust3r-depth.tar.gz -C datasets/\nrm datasets/kubric-multiview*.tar.gz\n\n# Panoptic Studio (optimization-based depth from Dynamic3DGS)\nwget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/panoptic-multiview.tar.gz -P datasets/\ntar -xvzf datasets/panoptic-multiview.tar.gz -C datasets/\nrm datasets/panoptic-multiview.tar.gz\n\n# DexYCB (Kinect + DUSt3R depths)\nwget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/dex-ycb-multiview.tar.gz -P datasets/\nwget https://huggingface.co/datasets/ethz-vlg/mv3dpt-datasets/resolve/main/dex-ycb-multiview--dust3r-depth.tar.gz -P datasets/\ntar -xvzf datasets/dex-ycb-multiview.tar.gz -C datasets/\ntar -xvzf datasets/dex-ycb-multiview--dust3r-depth.tar.gz -C datasets/\nrm datasets/dex-ycb-multiview*.tar.gz\n\n# $ du -sch datasets/*\n# 31G     kubric-multiview\n# 13G     panoptic-multiview\n# 29G     dex-ycb-multiview\n# 72G     total\n```\n\n</details>\n\n\n<details>\n<summary>[Regenerating datasets from scratch - click to expand]</summary>\n  \nIf you wish to regenerate datasets from scratch, we provide scripts with docstrings that explain usage and list the commands we used. For licensing and usage terms, please refer to the original datasets. \n- MV-Kubric data for training and testing can be generated with [ethz-vlg/kubric](https://github.com/ethz-vlg/kubric/blob/multiview-point-tracking/challenges/point_tracking_3d/worker.py).\n- DexYCB can be downloaded and labels regenerated using [`scripts/dex_ycb_to_neus_format.py`](./scripts/dex_ycb_to_neus_format.py); note that we have created labels for 10 sequences, but DexYCB is much larger and more labels could be produced if needed.\n- Panoptic Studio can be downloaded and labels regenerated using [`scripts/panoptic_studio_preprocessing.py`](./scripts/panoptic_studio_preprocessing.py).\n- DUSt3R depths can be produced for any dataset with [`scripts/estimate_depth_with_duster.py`](./scripts/estimate_depth_with_duster.py).\n- For unlabeled datasets used only in qualitative experiments, we provide the following preprocessing scripts: [4D-Dress](./scripts/4ddress_preprocessing.py), [Hi4D](./scripts/hi4d_preprocessing.py), [EgoExo4D](./scripts/egoexo4d_preprocessing.py), and [SelfCap](./scripts/selfcap_preprocessing.py).  \n\n</details>\n\nFor quick testing, we also release a small **demo sample** (~200 MB):\n\n```bash\npython demo.py --random_query_points\n```\n\nOur generic loader [`GenericSceneDataset`](./mvtracker/datasets/generic_scene_dataset.py) supports adding new datasets. It can compute depths on the fly with [DUSt3R](https://github.com/naver/dust3r), [VGGT](https://vgg-t.github.io), [MonoFusion](https://imnotprepared.github.io/research/25_DSR/index.html), or [MoGe-2](https://github.com/microsoft/MoGe), and can also estimate camera poses with VGGT.  \n\n\n\n## Evaluation\n\nEvaluation is driven by Hydra configs. See [`mvtracker/cli/eval.py`](./mvtracker/cli/eval.py) and [`configs/eval.yaml`](./configs/eval.yaml) for details.\n\nTo evaluate MVTracker with our best model, first download the checkpoint from [Hugging Face](https://huggingface.co/ethz-vlg/mvtracker):\n\n```bash\nwget https://huggingface.co/ethz-vlg/mvtracker/resolve/main/mvtracker_200000_june2025.pth -P checkpoints/\n```\n\nThen run:\n\n```bash\npython -m mvtracker.cli.eval \\\n  experiment_path=logs/mvtracker \\\n  model=mvtracker \\\n  datasets.eval.names=[kubric-multiview-v3-views0123] \\\n  restore_ckpt_path=checkpoints/mvtracker_200000_june2025.pth\n\n# Expected result:\n# {\n#   \"eval_kubric-multiview-v3-views0123/model__ate_visible__dynamic-static-mean\": 5.07,\n#   \"eval_kubric-multiview-v3-views0123/model__average_jaccard__dynamic-static-mean\": 81.42,\n#   \"eval_kubric-multiview-v3-views0123/model__average_pts_within_thresh__dynamic-static-mean\": 90.00\n# }\n```\n\nTo evaluate a baseline, e.g. CoTracker3-Online (auto-downloaded checkpoint), run:\n\n```bash\npython -m mvtracker.cli.eval experiment_path=logs/cotracker3-online model=cotracker3_online\n\n# Expected result:\n# {\n#   \"eval_panoptic-multiview-views1_7_14_20/model__average_jaccard__any\": 74.56\n# }\n```\n\nFor more baselines and dataset setups (e.g. varying camera counts, camera subsets, etc.), see [`scripts/slurm/eval.sh`](./scripts/slurm/eval.sh) for the commands used in our experiments.\n\n<details>\n<summary>[Details on evaluation parameters - click to expand]</summary>\n  \nThe evaluation datasets are specified with `datasets.eval.names`. Each name is parsed by the dataset `from_name()` factory (see e.g. [`DexYCBMultiViewDataset.from_name`](./mvtracker/datasets/dexycb_multiview_dataset.py)), which supports modifiers such as `-views`, `-duster`, `-novelviews`, `-removehand`, `-2dpt`, or `-cached`. This makes it easy to select subsets of cameras, enable different depth sources, or ensure deterministic track sampling. The main labeled benchmarks are:\n- **Kubric (synthetic)** — e.g. `kubric-multiview-v3-views0123`  \n- **Panoptic Studio (real)** — e.g. `panoptic-multiview-views1_7_14_20`  \n- **DexYCB (real)** — e.g. `dex-ycb-multiview-views0123`  \n\nFor reproducibility of our main results, we also provide *cached* variants of each benchmark, which freeze track selection exactly as used in our paper. Without `-cached`, random seeding ensures reproducibility, but cached versions guarantee identical tracks across environments. The following cached variants are included in the released datasets:\n- `kubric-multiview-v3-views0123-cached`  \n- `kubric-multiview-v3-duster0123-cached`  \n- `panoptic-multiview-views1_7_14_20-cached`  \n- `panoptic-multiview-views27_16_14_8-cached`  \n- `panoptic-multiview-views1_4_7_11-cached`  \n- `dex-ycb-multiview-views0123-cached`  \n- `dex-ycb-multiview-duster0123-cached`  \n\n</details>\n\n## Training\n\nTo run a small overfitting test that fits into 24 GB GPU RAM:\n\n```bash\npython -m mvtracker.cli.train +experiment=mvtracker_overfit_mini\n```\n\nFor a full-scale MVTracker on an 80 GB GPU:\n\n```bash\npython -m mvtracker.cli.train +experiment=mvtracker_overfit\n```\n\n## Practical Considerations\n\n<details>\n<summary>[Scene normalization - click to expand]</summary>\n\nPerformance depends strongly on scene normalization. MVTracker was trained on Kubric with randomized but bounded scales and camera setups. At test time, scenes with very different scales, rotations, or translations must be aligned to this distribution. Our generic loader provides an automatic normalization that assumes the ground plane is parallel to the XY plane. This automatic normalization worked reasonably well for 4D-Dress, Hi4D, EgoExo4D, and SelfCap. For Panoptic and DexYCB, we applied manual similarity transforms, which are encoded in the respective dataloaders. Robust, general-purpose normalization remains an open challenge.  \n\n</details>\n\n\n<details>\n<summary>[Challenges and future directions - click to expand]</summary>\n\nThe central challenge in multi-view 3D point tracking is 4D reconstruction: obtaining depth maps that are accurate, temporally consistent, and available in real time, especially under sparse-view setups. MVTracker performs well when sensor depth and camera calibration are provided, but in settings where both must be estimated, errors in reconstruction quickly make tracking unreliable. While learned motion priors help tolerate moderate noise, they cannot replace a robust reconstruction backbone. We believe progress will hinge on methods that jointly solve depth estimation and tracking for mutual refinement, or large-scale foundation models for 4D reconstruction and tracking that fully leverage data and compute. We hope the community will direct future efforts toward this goal.\n\n\n</details>\n\n\n## Acknowledgements\n\nOur code builds upon and was inspired by many prior works, including [SpaTracker](https://github.com/henry123-boy/SpaTracker), [CoTracker](https://github.com/facebookresearch/co-tracker), and [DUSt3R](https://github.com/naver/dust3r). We thank the authors for releasing their code and pretrained models. We are also grateful to maintainers of [Rerun](https://rerun.io) for their helpful visualization toolkit.\n\n## Citation\n\nIf you find our repository useful, please consider giving it a star ⭐ and citing our work:\n```bibtex\n@inproceedings{rajic2025mvtracker,\n  title     = {Multi-View 3D Point Tracking},\n  author    = {Raji{\\v{c}}, Frano and Xu, Haofei and Mihajlovic, Marko and Li, Siyuan and Demir, Irem and G{\\\"u}ndo{\\u{g}}du, Emircan and Ke, Lei and Prokudin, Sergey and Pollefeys, Marc and Tang, Siyu},\n  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},\n  year      = {2025}\n}\n```\n"
  },
  {
    "path": "configs/eval.yaml",
    "content": "defaults:\n  - train\n  - _self_\n\nmodes:\n  eval_only: true\n\ntrainer:\n  precision: 32-true\n\n# Optional overrides specific to evaluation runs\ndatasets:\n  eval:\n    names: [ \"panoptic-multiview-views1_7_14_20\" ]\n    max_seq_len: 1000\nevaluation:\n  consume_model_stats: false        # whether to report model stats (which can slow down the forward pass)\n  evaluator:\n    rerun_viz_indices: null\n    forward_pass_log_indices: null\n    mp4_track_viz_indices: null\n\n#    rerun_viz_indices: [ 0,1,2 ]\n#    forward_pass_log_indices: [ 0,1,2 ]\n#    mp4_track_viz_indices: [ 0,1,2 ]\n\n#    rerun_viz_indices: [ 0,3,27, 2,23 ]\n#    forward_pass_log_indices: null\n#    mp4_track_viz_indices: [ 0,3,27, 2,23 ]\n\n#    rerun_viz_indices: [ 0, 7 ]\n#    forward_pass_log_indices: null\n#    mp4_track_viz_indices: [ 0, 7 ]\n\n#    rerun_viz_indices: [ 0, 5 ]\n#    forward_pass_log_indices: null\n#    mp4_track_viz_indices: [ 0, 5 ]\n\n#    rerun_viz_indices: [ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29 ]\n#    forward_pass_log_indices: [ 0,1,2,3,4 ]\n#    mp4_track_viz_indices: [ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29 ]\n"
  },
  {
    "path": "configs/experiment/mvtracker.yaml",
    "content": "# @package _global_\ndefaults:\n  - override /model: mvtracker\n\nexperiment_path: ./logs/mvtracker\n"
  },
  {
    "path": "configs/experiment/mvtracker_overfit.yaml",
    "content": "# @package _global_\ndefaults:\n  - override /model: mvtracker\n\nexperiment_path: ./logs/debug/mvtracker-overfit\n\ndatasets:\n  root: ./datasets\n  train:\n    name: kubric-multiview-v3-views0123-training\n    batch_size: 1\n    sequence_len: 24\n    traj_per_sample: 512\n    num_workers: 4\n  eval:\n    names: [kubric-multiview-v3-views0123-overfit-on-training]\n    num_workers: 2\n    max_seq_len: 1000\n\ntrainer:\n  num_steps: 1500\n  eval_freq: 500\n  viz_freq: 500\n  save_ckpt_freq: 500\n  augment_train_iters: false\n\naugmentations:\n  probability: 1.0\n  rgb: false\n  depth: false\n  cropping: true\n  variable_trajpersample: false\n  scene_transform: false\n  camera_params_noise: false\n  variable_depth_type: false\n  variable_num_views: false\n\nmodes:\n  tune_per_scene: true\n  dont_validate_at_start: true\n  do_initial_static_pretrain: false\n  pretrain_only: false\n  eval_only: false\n  debug: false\n"
  },
  {
    "path": "configs/experiment/mvtracker_overfit_mini.yaml",
    "content": "# @package _global_\ndefaults:\n  - mvtracker_overfit\n\nexperiment_path: ./logs/debug/mvtracker-overfit-mini\n\ndatasets:\n  train:\n    traj_per_sample: 8\n\nmodel:\n  fmaps_dim: 32\n"
  },
  {
    "path": "configs/model/copycat.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.copycat.CopyCat\n"
  },
  {
    "path": "configs/model/cotracker1_offline.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper\n    model_name: cotracker2v1\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/cotracker1_online.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper\n    model_name: cotracker2v1_online\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/cotracker2_offline.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper\n    model_name: cotracker2\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/cotracker2_online.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper\n    model_name: cotracker2_online\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/cotracker3_offline.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOfflineWrapper\n    model_name: cotracker3_offline\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/cotracker3_online.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.CoTrackerOnlineWrapper\n    model_name: cotracker3_online\n    grid_size: 10\n"
  },
  {
    "path": "configs/model/default.yaml",
    "content": "# @package _global_\nmodel:\n  _target_: ???\n\ntrainer:\n  train_iters: 4\n\nevaluation:\n  eval_iters: 4\n  interp_shape: null\n\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.9\n      grid_size: 0\n      n_grids_per_view: 1\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n    dex_ycb:\n      visibility_threshold: 0.9\n      grid_size: 0\n      n_grids_per_view: 1\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n    panoptic:\n      visibility_threshold: 0.9\n      grid_size: 0\n      n_grids_per_view: 1\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n    tapvid2d-davis:\n      visibility_threshold: 0.9\n      grid_size: 0\n      n_grids_per_view: 1\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n    generic:\n      visibility_threshold: 0.9\n      grid_size: 0\n      n_grids_per_view: 1\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n"
  },
  {
    "path": "configs/model/delta.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.DELTAWrapper\n    ckpt: checkpoints/densetrack3d.pth\n    upsample_factor: 4\n    grid_size: 20\n    return_2d_track: false\n"
  },
  {
    "path": "configs/model/locotrack.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.LocoTrackWrapper\n    model_size: base\n\nevaluation:\n  interp_shape: [ 256, 256 ]\n"
  },
  {
    "path": "configs/model/mvtracker.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.mvtracker.mvtracker.MVTracker\n  sliding_window_len: 12\n  stride: 4\n  normalize_scene_in_fwd_pass: false\n  fmaps_dim: 128\n  add_space_attn: true\n  num_heads: 6\n  hidden_size: 256\n  space_depth: 6\n  time_depth: 6\n  num_virtual_tracks: 64\n  use_flash_attention: true\n  corr_n_groups: 1\n  corr_n_levels: 4\n  corr_neighbors: 16\n  corr_add_neighbor_offset: true\n  corr_add_neighbor_xyz: false\n  corr_filter_invalid_depth: false   # slower, but would make sure points with invalid depth are not considered in corr\n\nevaluation:\n  interp_shape: [ 384, 512 ]\n\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.5\n      grid_size: 4\n      local_grid_size: 18\n    dex_ycb:\n      visibility_threshold: 0.01\n      grid_size: 4\n      local_grid_size: 18\n    panoptic:\n      visibility_threshold: 0.01\n      grid_size: 6\n      local_grid_size: 18\n    tapvid2d-davis:\n      visibility_threshold: 0.01\n      grid_size: 6\n      n_grids_per_view: 6\n      local_grid_size: 0\n      local_extent: 50\n      sift_size: 0\n      num_uniformly_sampled_pts: 0\n    generic:\n      visibility_threshold: 0.01\n      grid_size: 4\n      local_grid_size: 18\n\ntrainer:\n  precision: bf16-mixed\n"
  },
  {
    "path": "configs/model/scenetracker.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.SceneTrackerWrapper\n    ckpt: checkpoints/scenetracker-odyssey-200k.pth\n    return_2d_track: false\n\nevaluation:\n  interp_shape: [ 384, 512 ]\n"
  },
  {
    "path": "configs/model/spatialtrackerv2.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.SpaTrackerV2Wrapper\n    model_type: online  # or offline, whichever is better on a specific dataset\n    vo_points: 756\n\nevaluation:\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.01\n    dex_ycb:\n      visibility_threshold: 0.01\n    panoptic:\n      visibility_threshold: 0.01\n"
  },
  {
    "path": "configs/model/spatracker_monocular.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.spatracker.spatracker_monocular.SpaTrackerMultiViewAdapter\n\n  sliding_window_len: 12\n  stride: 4\n  add_space_attn: true\n  num_heads: 8\n  hidden_size: 384\n  space_depth: 6\n  time_depth: 6\n  triplane_zres: 128\n\nevaluation:\n  interp_shape: [ 512, 512 ]    # This checkpoint was trained on 512x512 Kubric sequences\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.5\n      grid_size: 4\n      local_grid_size: 18\n    dex_ycb:\n      visibility_threshold: 0.5\n      grid_size: 0\n      local_grid_size: 18\n    panoptic:\n      visibility_threshold: 0.5\n      grid_size: 4\n      local_grid_size: 18\n\n#restore_ckpt_path: checkpoints/spatracker_monocular_trained-on-kubric-depth_069800.pth\n#restore_ckpt_path: checkpoints/spatracker_monocular_trained-on-duster-depth_090800.pth\n"
  },
  {
    "path": "configs/model/spatracker_monocular_pretrained.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.spatracker.spatracker_monocular.SpaTrackerMultiViewAdapter\n\n  sliding_window_len: 12\n  stride: 4\n  add_space_attn: true\n  num_heads: 8\n  hidden_size: 384\n  space_depth: 6\n  time_depth: 6\n  triplane_zres: 128\n\nevaluation:\n  interp_shape: [ 384, 512 ]\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.9\n      grid_size: 4\n      local_grid_size: 18\n    dex_ycb:\n      visibility_threshold: 0.9\n      grid_size: 4\n      local_grid_size: 18\n    panoptic:\n      visibility_threshold: 0.9\n      grid_size: 4\n      local_grid_size: 18\n\n#restore_ckpt_path: checkpoints/spatracker_monocular_original-authors-ckpt.pth\n"
  },
  {
    "path": "configs/model/spatracker_multiview.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.spatracker.spatracker_multiview.MultiViewSpaTracker\n\n  sliding_window_len: 12\n  stride: 4\n  add_space_attn: true\n  use_3d_pos_embed: true\n  remove_zeromlpflow: true\n  concat_triplane_features: true\n  num_heads: 8\n  hidden_size: 384\n  space_depth: 6\n  time_depth: 6\n  fmaps_dim: 128\n  triplane_xres: 128\n  triplane_yres: 128\n  triplane_zres: 128\n\nevaluation:\n  interp_shape: [ 512, 512 ]    # This checkpoint was trained on 512x512 Kubric sequences\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.5\n      grid_size: 4\n      local_grid_size: 18\n    dex_ycb:\n      visibility_threshold: 0.01\n      grid_size: 4\n      local_grid_size: 18\n    panoptic:\n      visibility_threshold: 0.01\n      grid_size: 4\n      local_grid_size: 18\n\n#restore_ckpt_path: checkpoints/spatracker_multiview_trained-on-kubric-depth_100000.pth\n#model:\n#  triplane_xres: 128\n#  triplane_yres: 128\n#  triplane_zres: 128\n\n#restore_ckpt_path: checkpoints/spatracker_multiview_trained-on-duster-depth_100000.pth\n#model:\n#  triplane_xres: 256\n#  triplane_yres: 256\n#  triplane_zres: 128\n"
  },
  {
    "path": "configs/model/tapip3d.yaml",
    "content": "# @package _global_\ndefaults:\n  - default\n\nmodel:\n  _target_: mvtracker.models.core.monocular_baselines.MonocularToMultiViewAdapter\n  model:\n    _target_: mvtracker.models.core.monocular_baselines.TAPIP3DWrapper\n    ckpt: checkpoints/tapip3d_final.pth\n    num_iters: 6\n    grid_size: 8\n    resolution_factor: 1 # --> [ 384, 512 ]\n#    resolution_factor: 2 # --> [ 543, 724 ]\n\nevaluation:\n  interp_shape: [ 384, 512 ] # --> resolution_factor = 1\n#  interp_shape: [ 543, 724 ] # --> resolution_factor = 2\n  predictor_settings:\n    kubric:\n      visibility_threshold: 0.01\n    dex_ycb:\n      visibility_threshold: 0.01\n    panoptic:\n      visibility_threshold: 0.01\n"
  },
  {
    "path": "configs/train.yaml",
    "content": "defaults:\n  - _self_\n  - model: mvtracker\n\nexperiment_path: ???                # where to store checkpoints, visualizations, etc.\nrestore_ckpt_path: null             # resume from checkpoint\n\n# === Datasets ===\ndatasets:\n  root: ./datasets\n  train:\n    name: kubric-multiview-v3-training\n    batch_size: 1\n    sequence_len: 24                # frames per sequence\n    traj_per_sample: 2048           # number of 3D points/trajectories per sample\n    max_videos: null                # takes all training videos by default\n    kubric_max_depth: 24\n    num_workers: 8\n  eval:\n    names:\n      - panoptic-multiview-views1_7_14_20\n      - kubric-multiview-v3-overfit-on-training\n      - kubric-multiview-v3-views0123\n      - kubric-multiview-v3-duster0123\n      - dex-ycb-multiview\n      - dex-ycb-multiview-duster0123\n    num_workers: 4\n    max_seq_len: 1000\n\n# === Trainer Settings ===\ntrainer:\n  num_steps: 200000\n  eval_freq: 10000\n  viz_freq: 10000\n  save_ckpt_freq: 500\n\n  lr: 0.0005\n  gamma: 0.8\n  wdecay: 0.00001\n  anneal_strategy: linear\n  grad_clip: 1.0\n  precision: 16-mixed               # training precision (e.g., 16-mixed, bf16-mixed or 32-true)\n  visibility_loss_weight: 0.1\n\n  augment_train_iters: false\n  augment_train_iters_warmup: 2000\n\n# === Evaluation Settings ===\nevaluation:\n  consume_model_stats: false        # whether to report model stats (which can slow down the forward pass)\n  evaluator:\n    _target_: mvtracker.evaluation.evaluator_3dpt.Evaluator\n    rerun_viz_indices: null\n    forward_pass_log_indices: null\n    mp4_track_viz_indices: [0]\n\n# === Execution Modes ===\nmodes:\n  debug: false                      # enable for quick iteration\n  tune_per_scene: false             # overfit to single scene (debugging)\n  validate_at_start: false          # run eval before train starts\n  do_initial_static_pretrain: false # run static-only phase first\n  pretrain_only: false              # stop after static pretraining\n  eval_only: false                  # skip training, just run evaluation\n\n  debugging_hotfix_datapoint_path: null  # path to a dumped datapoint (no need to set debug flag)\n\n# === Reproducibility ===\nreproducibility:\n  # Note that reproducibility will not work if\n  # floating point precision is set to 16-mixed,\n  # but with 32 it will. Note also that the number\n  # of data loading workers (num_workers) might\n  # affect reproducibility as well. The number of\n  # GPUs surely affects reproducibility.\n  seed: 36\n  deterministic: false              # speeds up training at expense of determinism\n\n# === Augmentations ===\naugmentations:\n  probability: 0.8\n\n  rgb: true\n  depth: true\n  variable_depth_type: true\n  variable_num_views: true\n\n  cropping: true\n  cropping_size: [384, 512]\n  variable_vggt_crop_size: false\n  keep_principal_point_centered: false\n\n  variable_trajpersample: true\n\n  scene_transform: true\n  camera_params_noise: true\n  normalize_scene_following_vggt: false\n\n# === Logging ===\nlogging:\n  log_wandb: false\n  wandb_project: mvtracker-ablation\n  tags: [\"kubric\", \"3dpt\", \"multiview\"]\n\n# === Extras ===\nextras:\n  print_config: true                     # pretty print config tree at the start\n  ignore_warnings: false                 # disable python warnings if they annoy you\n  enable_faulthandler_traceback: false   # enable traceback dump on timeout for debugging of main process hanging\n  faulthandler_traceback_timeout: 600    # timeout in seconds before dumping traceback (e.g. 600 = 10 min)\n\n# === Hydra Settings ===\nhydra:\n  run:\n    dir: ${experiment_path}\n"
  },
  {
    "path": "demo.py",
    "content": "import argparse\nimport os\nimport warnings\n\nimport numpy as np\nimport rerun as rr  # pip install rerun-sdk==0.21.0\nimport torch\nfrom huggingface_hub import hf_hub_download\n\nfrom mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun, log_tracks_to_rerun\n\n\ndef main():\n    p = argparse.ArgumentParser()\n    p.add_argument(\n        \"--rerun\",\n        choices=[\"save\", \"spawn\", \"stream\"],\n        default=\"save\",\n        help=(\n            \"Whether to save recording to disk, spawn a new Rerun instance, or stream to an existing one. \"\n            \"If 'spawn', make sure a rerun window can be spawned in your environment. \"\n            \"If 'stream', make sure a rerun instance is running at port 9876. \"\n            \"If 'save', the recording will be saved to a `.rrd` file that can be drag-and-dropped into \"\n            \"a running rerun viewer, including the online viewer at https://app.rerun.io/version/0.21.0. \"\n            \"For the online viewer, you want to create low memory-usage recordings with --lightweight.\"\n        ),\n    )\n    p.add_argument(\n        \"--lightweight\",\n        action=\"store_true\",\n        help=(\n            \"Use lightweight rerun logging (less memory usage). This is recommended if you want to \"\n            \"view the recording in the online Rerun viewer at https://app.rerun.io/version/0.21.0.\"\n        ),\n    )\n    p.add_argument(\n        \"--random_query_points\",\n        action=\"store_true\",\n        help=\"Use random query points instead of demo ones.\",\n    )\n    p.add_argument(\n        \"--rrd\",\n        default=\"mvtracker_demo.rrd\",\n        help=(\n            \"Path to save a .rrd file if `--rerun save` is used. \"\n            \"Note that rerun prefers recordings to have a .rrd suffix.\"\n        ),\n    )\n    args = p.parse_args()\n    np.random.seed(72)\n    torch.manual_seed(72)\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    # Load MVTracker predictor\n    mvtracker = torch.hub.load(\"ethz-vlg/mvtracker\", \"mvtracker\", pretrained=True, device=device)\n\n    # Download demo sample from Hugging Face Hub\n    sample_path = hf_hub_download(\n        repo_id=\"ethz-vlg/mvtracker\",\n        filename=\"data_sample.npz\",\n        token=os.getenv(\"HF_TOKEN\"),\n        repo_type=\"model\",\n    )\n    sample = np.load(sample_path)\n\n    rgbs = torch.from_numpy(sample[\"rgbs\"]).float()\n    depths = torch.from_numpy(sample[\"depths\"]).float()\n    intrs = torch.from_numpy(sample[\"intrs\"]).float()\n    extrs = torch.from_numpy(sample[\"extrs\"]).float()\n    query_points = torch.from_numpy(sample[\"query_points\"]).float()\n\n    # Optionally, sample random queries in a cylinder of radius 12, height [-1, +10] and replace the demo queries\n    if args.random_query_points:\n        from mvtracker.models.core.model_utils import init_pointcloud_from_rgbd\n        num_queries = 512\n        t0 = 0\n        xy_radius = 12.0\n        z_min, z_max = -1.0, 10.0\n        xyz, _ = init_pointcloud_from_rgbd(\n            fmaps=rgbs[None],  # [1,V,T,1,H,W], uint8 0–255\n            depths=depths[None],  # [1,V,T,1,H,W]\n            intrs=intrs[None],  # [1,V,T,3,3]\n            extrs=extrs[None],  # [1,V,T,3,4]\n            stride=1,\n            level=0,\n        )\n        pts = xyz[t0]  # [V*H*W, 3] at t=0\n        assert pts.numel() > 0, \"No valid depth points to sample queries from.\"\n\n        r2 = pts[:, 0] ** 2 + pts[:, 1] ** 2\n        mask = (r2 <= xy_radius ** 2) & (pts[:, 2] >= z_min) & (pts[:, 2] <= z_max)\n        pool = pts[mask]\n        assert pool.shape[0] > 0, \"Cylinder mask removed all points; increase radius or z-range.\"\n\n        idx = torch.randperm(pool.shape[0])[:num_queries]\n        pts = pool[idx]\n        ts = torch.full((pts.shape[0], 1), float(t0), device=pts.device)\n        query_points = torch.cat([ts, pts], dim=1).float()  # (N,4): (t,x,y,z)\n        print(f\"Sampled {pts.shape[0]} queries from depth at t={t0} within r<={xy_radius}, z∈[{z_min},{z_max}].\")\n\n    # Run prediction\n    torch.set_float32_matmul_precision(\"high\")\n    amp_dtype = torch.bfloat16 if (device == \"cuda\" and torch.cuda.get_device_capability()[0] >= 8) else torch.float16\n    with torch.no_grad(), torch.cuda.amp.autocast(enabled=device == \"cuda\", dtype=amp_dtype):\n        results = mvtracker(\n            rgbs=rgbs[None].to(device) / 255.0,\n            depths=depths[None].to(device),\n            intrs=intrs[None].to(device),\n            extrs=extrs[None].to(device),\n            query_points_3d=query_points[None].to(device),\n        )\n    pred_tracks = results[\"traj_e\"].cpu()  # [T,N,3]\n    pred_vis = results[\"vis_e\"].cpu()  # [T,N]\n\n    # Visualize results\n    rr.init(\"3dpt\", recording_id=\"v0.16\")\n    if args.rerun == \"stream\":\n        rr.connect_tcp()\n    elif args.rerun == \"spawn\":\n        rr.spawn()\n    log_pointclouds_to_rerun(\n        dataset_name=\"demo\",\n        datapoint_idx=0,\n        rgbs=rgbs[None],\n        depths=depths[None],\n        intrs=intrs[None],\n        extrs=extrs[None],\n        depths_conf=None,\n        conf_thrs=[5.0],\n        log_only_confident_pc=False,\n        radii=-2.45,\n        fps=12,\n        bbox_crop=None,\n        sphere_radius_crop=12.0,\n        sphere_center_crop=np.array([0, 0, 0]),\n        log_rgb_image=False,\n        log_depthmap_as_image_v1=False,\n        log_depthmap_as_image_v2=False,\n        log_camera_frustrum=True,\n        log_rgb_pointcloud=True,\n    )\n    log_tracks_to_rerun(\n        dataset_name=\"demo\",\n        datapoint_idx=0,\n        predictor_name=\"MVTracker\",\n        gt_trajectories_3d_worldspace=None,\n        gt_visibilities_any_view=None,\n        query_points_3d=query_points[None],\n        pred_trajectories=pred_tracks,\n        pred_visibilities=pred_vis,\n        per_track_results=None,\n        radii_scale=1.0,\n        fps=12,\n        sphere_radius_crop=12.0,\n        sphere_center_crop=np.array([0, 0, 0]),\n        log_per_interval_results=False,\n        max_tracks_to_log=100 if args.lightweight else None,\n        track_batch_size=50,\n        method_id=None,\n        color_per_method_id=None,\n        memory_lightweight_logging=args.lightweight,\n    )\n    if args.rerun == \"save\":\n        rr.save(args.rrd)\n        print(f\"Saved Rerun recording to: {os.path.abspath(args.rrd)}\")\n\n\nif __name__ == \"__main__\":\n    warnings.filterwarnings(\"ignore\", message=\".*DtypeTensor constructors are no longer.*\", module=\"pointops.query\")\n    warnings.filterwarnings(\"ignore\", message=\".*Plan failed with a cudnnException.*\", module=\"torch.nn.modules.conv\")\n    main()\n"
  },
  {
    "path": "hubconf.py",
    "content": "# Copyright (c) ETH VLG.\n# Licensed under the terms in the LICENSE file at the root of this repo.\n\nfrom pathlib import Path\nimport os\nimport torch\n\n_WEIGHTS = {\n    \"mvtracker_main\": \"hf://ethz-vlg/mvtracker::mvtracker_200000_june2025.pth\",\n    \"mvtracker_cleandepth\": \"hf://ethz-vlg/mvtracker::mvtracker_200000_june2025_cleandepth.pth\",\n}\n\n\ndef _load_ckpt(spec: str):\n    if spec.startswith(\"http\"):\n        return torch.hub.load_state_dict_from_url(spec, map_location=\"cpu\")\n    if spec.startswith(\"hf://\"):\n        from huggingface_hub import hf_hub_download\n        repo_id, filename = spec[len(\"hf://\"):].split(\"::\", 1)\n        path = hf_hub_download(repo_id=repo_id, filename=filename, token=os.getenv(\"HF_TOKEN\"))\n        return torch.load(path, map_location=\"cpu\")\n    path = Path(spec).expanduser().resolve()\n    return torch.load(str(path), map_location=\"cpu\")\n\n\ndef _extract_model_state(sd):\n    \"\"\"\n    Accept:\n      - plain state dict\n      - {'state_dict': ...}\n      - {'model': ..., 'optimizer': ..., 'scheduler': ..., 'total_steps': ...}\n    Returns a clean model state_dict.\n    \"\"\"\n    if isinstance(sd, dict):\n        if \"state_dict\" in sd and isinstance(sd[\"state_dict\"], dict):\n            sd = sd[\"state_dict\"]\n        elif \"model\" in sd and isinstance(sd[\"model\"], dict):\n            sd = sd[\"model\"]\n    # Strip optional \"model.\" prefix\n    sd = {k.replace(\"model.\", \"\", 1): v for k, v in sd.items()}\n    return sd\n\n\ndef _build_model(**overrides):\n    from mvtracker.models.core.mvtracker.mvtracker import MVTracker\n    cfg = dict(\n        sliding_window_len=12,\n        stride=4,\n        normalize_scene_in_fwd_pass=False,\n        fmaps_dim=128,\n        add_space_attn=True,\n        num_heads=6,\n        hidden_size=256,\n        space_depth=6,\n        time_depth=6,\n        num_virtual_tracks=64,\n        use_flash_attention=True,\n        corr_n_groups=1,\n        corr_n_levels=4,\n        corr_neighbors=16,\n        corr_add_neighbor_offset=True,\n        corr_add_neighbor_xyz=False,\n        corr_filter_invalid_depth=False,\n    )\n    cfg.update(overrides)\n    return MVTracker(**cfg)\n\n\ndef _load_into(model, checkpoint_key: str):\n    raw = _load_ckpt(_WEIGHTS[checkpoint_key])\n    sd = _extract_model_state(raw)\n    missing, unexpected = model.load_state_dict(sd, strict=False)\n    if unexpected:\n        raise RuntimeError(f\"Unexpected keys in state_dict: {unexpected}\")\n    return model\n\n\ndef mvtracker_model(*,\n                    pretrained: bool = False,\n                    device: str = \"cuda\",\n                    checkpoint: str = \"mvtracker_main\",\n                    **model_kwargs):\n    \"\"\"\n    Return a bare MVTracker nn.Module.\n\n    - pretrained=False: random init with model_kwargs.\n    - pretrained=True : load from _WEIGHTS[checkpoint], then .eval().\n    \"\"\"\n    model = _build_model(**model_kwargs).to(device)\n    if pretrained:\n        model = _load_into(model, checkpoint)\n        model.eval()\n    return model\n\n\ndef mvtracker_predictor(*,\n                        pretrained: bool = True,\n                        device: str = \"cuda\",\n                        checkpoint: str = \"mvtracker_main\",\n                        model_kwargs: dict | None = None,\n                        predictor_kwargs: dict | None = None):\n    \"\"\"\n    Return EvaluationPredictor wrapped around MVTracker.\n\n    Pass model configuration via `model_kwargs={...}` (matches MVTracker.__init__).\n    Pass predictor configuration via `predictor_kwargs={...}`:\n      - interp_shape, visibility_threshold, grid_size, n_grids_per_view,\n        local_grid_size, local_extent, sift_size, num_uniformly_sampled_pts, n_iters\n    \"\"\"\n    from mvtracker.models.evaluation_predictor_3dpt import EvaluationPredictor\n\n    model_kwargs = {} if model_kwargs is None else dict(model_kwargs)\n    predictor_kwargs = {} if predictor_kwargs is None else dict(predictor_kwargs)\n\n    predictor_defaults = dict(\n        interp_shape=(384, 512),\n        visibility_threshold=0.5,\n        grid_size=4,\n        n_grids_per_view=1,\n        local_grid_size=18,\n        local_extent=50,\n        sift_size=0,\n        num_uniformly_sampled_pts=0,\n        n_iters=6,\n    )\n    pk = {**predictor_defaults, **predictor_kwargs}\n\n    model = mvtracker_model(pretrained=pretrained, device=device, checkpoint=checkpoint, **model_kwargs)\n    return EvaluationPredictor(multiview_model=model, **pk)\n\n\ndef mvtracker(pretrained: bool = True, device: str = \"cuda\"):\n    \"\"\"Default public endpoint: predictor with main checkpoint.\"\"\"\n    return mvtracker_predictor(pretrained=pretrained, device=device, checkpoint=\"mvtracker_main\")\n\n\ndef mvtracker_cleandepth(pretrained: bool = True, device: str = \"cuda\"):\n    \"\"\"Predictor with 'clean depth only' checkpoint.\"\"\"\n    return mvtracker_predictor(pretrained=pretrained, device=device, checkpoint=\"mvtracker_cleandepth\")"
  },
  {
    "path": "mvtracker/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/cli/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/cli/eval.py",
    "content": "import hydra\nfrom omegaconf import DictConfig\n\nfrom mvtracker.cli.train import main as train_main\n\n\n@hydra.main(version_base=\"1.3\", config_path=\"../../configs\", config_name=\"eval\")\ndef main(cfg: DictConfig):\n    train_main(cfg)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "mvtracker/cli/train.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\n\ntorch.set_float32_matmul_precision('high')\n\nfrom lightning.fabric.wrappers import _unwrap_objects\nfrom mvtracker.datasets.generic_scene_dataset import GenericSceneDataset\n\nfrom torch.utils.tensorboard import SummaryWriter\nimport gpustat\nimport json\nimport threading\nimport warnings\nfrom pathlib import Path\n\nimport hydra\nimport numpy as np\nimport pandas as pd\nimport torch.optim as optim\nimport wandb\nfrom lightning.fabric import Fabric\nfrom lightning.fabric.utilities import AttributeDict\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch import nn\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\nimport signal, sys\n\nfrom mvtracker.datasets import KubricMultiViewDataset\nfrom mvtracker.datasets import TapVidDataset\nfrom mvtracker.datasets import kubric_multiview_dataset\nfrom mvtracker.datasets.dexycb_multiview_dataset import DexYCBMultiViewDataset\nfrom mvtracker.datasets.panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset\nfrom mvtracker.datasets.utils import collate_fn, dataclass_to_cuda_\nfrom mvtracker.models.core.losses import balanced_ce_loss, sequence_loss_3d\nfrom mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z, pixel_xy_and_camera_z_to_world_space\nfrom mvtracker.models.evaluation_predictor_3dpt import EvaluationPredictor as EvaluationPredictor3D\nfrom mvtracker.utils.visualizer_mp4 import MultiViewVisualizer, Visualizer\nfrom mvtracker.cli.utils import extras\nfrom mvtracker.cli.utils.helpers import maybe_close_wandb\n\nimport logging\nimport os\n\nimport torch\nimport time\nfrom collections import deque\nfrom torchdata.stateful_dataloader import StatefulDataLoader\n\n\ndef fetch_optimizer(trainer_cfg, model):\n    \"\"\"Create the optimizer and learning rate scheduler\"\"\"\n    optimizer = optim.AdamW(model.parameters(), lr=trainer_cfg.lr, weight_decay=trainer_cfg.wdecay)\n    if trainer_cfg.anneal_strategy in [\"linear\", \"cos\"]:\n        scheduler = optim.lr_scheduler.OneCycleLR(\n            optimizer,\n            trainer_cfg.lr,\n            trainer_cfg.num_steps + 100,\n            pct_start=0.05,\n            cycle_momentum=False,\n            anneal_strategy=trainer_cfg.anneal_strategy,\n        )\n    elif trainer_cfg.anneal_strategy == \"restarts\":\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n            optimizer,\n            T_0=5000,\n            T_mult=1,\n            eta_min=trainer_cfg.lr / 1000,\n        )\n\n    return optimizer, scheduler\n\n\ndef forward_batch_multi_view(batch, model, cfg, step, train_iters, gamma, save_debug_logs=False, debug_logs_path=''):\n    # Per view data\n    rgbs = batch.video\n    depths = batch.videodepth\n    image_features = batch.feats\n    intrs = batch.intrs\n    extrs = batch.extrs\n    gt_trajectories_2d_pixelspace_w_z_cameraspace = batch.trajectory\n    gt_visibilities_per_view = batch.visibility\n    query_points_3d = batch.query_points_3d\n\n    # Non-per-view data\n    gt_trajectories_3d_worldspace = batch.trajectory_3d\n    valid_tracks_per_frame = batch.valid\n    track_upscaling_factor = batch.track_upscaling_factor\n\n    batch_size, num_views, num_frames, _, height, width = rgbs.shape\n    num_points = gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[3]\n\n    # Assert shapes of per-view data\n    assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n    assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n    assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n    assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n    assert gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == (batch_size, num_views, num_frames, num_points, 3)\n    assert gt_visibilities_per_view.shape == (batch_size, num_views, num_frames, num_points)\n\n    # Assert shapes of non-per-view data\n    assert query_points_3d.shape == (batch_size, num_points, 4)\n    assert gt_trajectories_3d_worldspace.shape == (batch_size, num_frames, num_points, 3)\n    assert valid_tracks_per_frame.shape == (batch_size, num_frames, num_points)\n\n    gt_visibilities_any_view = gt_visibilities_per_view.any(dim=1)\n    assert gt_visibilities_any_view.any(dim=1).all(), \"All points should be visible at in least one frame.\"\n\n    for batch_idx in range(batch_size):\n        for point_idx in range(num_points):\n            t = query_points_3d[batch_idx, point_idx, 0].long().item()\n            valid_tracks_per_frame[batch_idx, :t, point_idx] = False\n\n    # Run the model\n    results = model(\n        rgbs=rgbs,\n        depths=depths,\n        image_features=image_features,\n        query_points=query_points_3d,\n        iters=train_iters,\n        is_train=True,\n        intrs=intrs,\n        extrs=extrs,\n        save_debug_logs=save_debug_logs,\n        debug_logs_path=debug_logs_path,\n    )\n    pred_trajectories = results[\"traj_e\"]\n    pred_visibilities = results[\"vis_e\"]\n    vis_predictions = results[\"train_data\"][\"vis_predictions\"]\n    coord_predictions = results[\"train_data\"][\"coord_predictions\"]\n    p_idx_end_list = results[\"train_data\"][\"p_idx_end_list\"]\n    sort_inds = results[\"train_data\"][\"sort_inds\"]\n\n    # Prepare the ground truth for the loss functions,\n    # which expect the data to be in the sliding-window\n    vis_gts = []\n    traj_gts = []\n    valids_gts = []\n    query_points_t_min = query_points_3d[:, :, 0].long().min()\n    for i, wind_p_idx_end in enumerate(p_idx_end_list):\n        gt_visibilities_any_view_sorted = gt_visibilities_any_view[:, :, sort_inds]\n        gt_trajectories_3d_worldspace_sorted = gt_trajectories_3d_worldspace[:, :, sort_inds]\n        valid_tracks_per_frame_sorted = valid_tracks_per_frame[:, :, sort_inds]\n        ind = query_points_t_min + i * (cfg.model.sliding_window_len // 2)\n        vis_gts.append(gt_visibilities_any_view_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end])\n        traj_gts.append(\n            gt_trajectories_3d_worldspace_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end])\n        valids_gts.append(valid_tracks_per_frame_sorted[:, ind: ind + cfg.model.sliding_window_len, :wind_p_idx_end])\n\n    # Compute the losses\n    logging.info(f\"[DEBUG] \"\n                 f\"{step=} \"\n                 f\"{track_upscaling_factor=} \"\n                 f\"{coord_predictions[0][0][0, 0, 0]=} \"\n                 f\"{coord_predictions[-1][0][0, 0, 0]=} \"\n                 f\"{vis_predictions[0][0, 0, 0]=} \"\n                 f\"{vis_predictions[-1][0, 0, 0]=}\")\n    xyz_loss = sequence_loss_3d(coord_predictions, traj_gts, vis_gts, valids_gts, gamma) * track_upscaling_factor\n    vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)\n\n    # Compute 3DPT metrics\n    # eval_3dpt_results_dict = evaluate_3dpt(\n    #     gt_tracks=gt_trajectories_3d_worldspace[0].cpu().numpy(),\n    #     gt_visibilities=gt_visibilities_any_view[0].cpu().numpy(),\n    #     pred_tracks=pred_trajectories[0].detach().cpu().numpy(),\n    #     pred_visibilities=(pred_visibilities[0] > 0.5).detach().cpu().numpy(),\n    #     evaluation_setting=\"kubric-multiview\",\n    #     track_upscaling_factor=track_upscaling_factor,\n    #     prefix=\"train_3dpt\",\n    #     verbose=False,\n    #     query_points=query_points_3d[0].cpu().numpy(),\n    # )\n\n    # Invert the intrinsics and extrinsics matrices\n    intrs_inv = torch.inverse(intrs.float())\n    extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n    extrs_square[:, :, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float())\n\n    # Project the predictions to pixel space\n    pred_trajectories = pred_trajectories[0].detach()\n    pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([\n        torch.cat(world_space_to_pixel_xy_and_camera_z(\n            world_xyz=pred_trajectories,\n            intrs=intrs[0, view_idx],\n            extrs=extrs[0, view_idx],\n        ), dim=-1)\n        for view_idx in range(num_views)\n    ], dim=0)\n    for view_idx in range(num_views):\n        pred_trajectories_reproduced = pixel_xy_and_camera_z_to_world_space(\n            pixel_xy=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, :2],\n            camera_z=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, 2:],\n            intrs_inv=intrs_inv[0, view_idx],\n            extrs_inv=extrs_inv[0, view_idx],\n        )\n        if not torch.allclose(pred_trajectories_reproduced, pred_trajectories, atol=1):\n            warnings.warn(f\"Reprojection of the predicted trajectories failed: \"\n                          f\"view_idx={view_idx}, \"\n                          f\"max_diff={torch.max(torch.abs(pred_trajectories_reproduced - pred_trajectories))}\")\n\n    logging.info(\n        f\"{step=}, \"\n        f\"seq={batch.seq_name}, \"\n        f\"{xyz_loss.item()=}, \"\n        f\"{vis_loss.item()=}, \"\n    )\n\n    output = {\n        \"flow\": {\n            \"loss\": xyz_loss * 1.0,\n            \"predictions\": pred_trajectories_pixel_xy_camera_z_per_view,\n            \"predictions_worldspace\": pred_trajectories,\n        },\n        \"visibility\": {\n            \"loss\": vis_loss * cfg.trainer.visibility_loss_weight,\n            \"predictions\": pred_visibilities[0].detach(),\n        },\n        # \"metrics\": {\n        #     k: v\n        #     for k, v in eval_3dpt_results_dict.items()\n        #     if \"per_track\" not in k\n        # },\n    }\n    return output\n\n\ndef run_test_eval(cfg, evaluator, model, dataloaders, writer, step):\n    if len(dataloaders) == 0:\n        return\n\n    logging.info(f\"Eval – GPU usage A: {gpustat.new_query()}\")\n\n    log_dir = cfg.experiment_path\n    model.eval()\n    for ds_name, dataloader in dataloaders:\n        if ds_name.startswith(\"kubric\"):\n            predictor_settings = cfg.evaluation.predictor_settings[\"kubric\"]\n        elif ds_name.startswith(\"dex-ycb\"):\n            predictor_settings = cfg.evaluation.predictor_settings[\"dex_ycb\"]\n        elif ds_name.startswith(\"panoptic\"):\n            predictor_settings = cfg.evaluation.predictor_settings[\"panoptic\"]\n        elif ds_name.startswith(\"tapvid2d-davis\"):\n            predictor_settings = cfg.evaluation.predictor_settings[\"tapvid2d-davis\"]\n        else:\n            predictor_settings = cfg.evaluation.predictor_settings[\"generic\"]\n            logging.info(f\"Using generic predictor settings for dataset with name {ds_name}\")\n\n        predictor = EvaluationPredictor3D(\n            multiview_model=model,\n            interp_shape=cfg.evaluation.interp_shape,\n            single_point=\"single\" in ds_name,\n            n_iters=cfg.evaluation.eval_iters,\n            **predictor_settings\n        )\n\n        log_dir_ds = os.path.join(log_dir, f\"eval_{ds_name}\")\n        os.makedirs(log_dir_ds, exist_ok=True)\n\n        if cfg.evaluation.consume_model_stats and hasattr(model, \"init_stats\"):\n            model.init_stats()\n        metrics = evaluator.evaluate_sequence(\n            model=predictor,\n            test_dataloader=dataloader,\n            dataset_name=ds_name,\n            writer=writer,\n            step=step,\n            log_dir=log_dir_ds,\n        )\n        if cfg.evaluation.consume_model_stats and hasattr(model, \"consume_stats\"):\n            model.consume_stats()\n\n        metrics_to_log = {\n            k: np.nanmean([v[k] for v in metrics.values() if k in v]).round(2)\n            for k in metrics[0].keys()\n        }\n        for k, v in metrics_to_log.items():\n            writer.add_scalar(k, v, step)\n\n        with pd.option_context(\n                'display.max_rows', None,\n                'display.max_columns', None,\n                'display.max_colwidth', None,\n                'display.width', None,\n        ):\n            logging.info(f\"Per-sequence Metrics for {ds_name}: {pd.DataFrame(metrics)}\")\n            logging.info(f\"Average metrics for {ds_name}: {json.dumps(metrics_to_log, indent=4)}\")\n\n        # Save metrics to csv\n        if log_dir_ds is not None:\n            df = pd.DataFrame(metrics)\n            df = df.T\n            assert df.map(lambda x: (len(x) == 1) if isinstance(x, np.ndarray) else True).all().all()\n            df = df.map(lambda x: x[0] if isinstance(x, np.ndarray) or isinstance(x, list) else x)\n            df.to_csv(f\"{log_dir_ds}/step-{step}_metrics.csv\")\n\n            df = pd.DataFrame(metrics_to_log, index=[\"score\"])\n            df = df.T\n            df.to_csv(f\"{log_dir_ds}/step-{step}_metrics_avg.csv\")\n            logging.info(f\"Saved metrics to {log_dir_ds}/step-{step}_metrics_avg.csv\")\n        # logging.info(f\"Eval – GPU usage (after {ds_name}): {gpustat.new_query()}\")\n\n    # logging.info(f\"Eval – GPU usage B: {gpustat.new_query()}\")\n    del predictor\n    del metrics\n    # logging.info(f\"Eval – GPU usage C: {gpustat.new_query()}\")\n    torch.cuda.empty_cache()\n    # logging.info(f\"Eval – GPU usage D: {gpustat.new_query()}\")\n\n    model.train()\n\n\ndef augment_train_iters(train_iters: int, current_step: int, warmup_steps: int = 1000) -> int:\n    \"\"\"\n    Adaptive iteration scheduler with warmup:\n    - During warmup_steps: always return 1\n    - After warmup:\n        - 10% chance: return 1\n        - 15% chance: return random int in [2, train_iters - 1]\n        - 75% chance: return train_iters\n    \"\"\"\n    if current_step < warmup_steps or train_iters <= 1:\n        return 1\n\n    rng = torch.Generator().manual_seed(current_step)\n    p = torch.rand(1, generator=rng).item()\n\n    if p < 0.10:\n        return 1\n    elif p < 0.25 and train_iters > 2:\n        mid_candidates = list(range(2, train_iters))\n        idx = torch.randint(len(mid_candidates), (1,), generator=rng).item()\n        return mid_candidates[idx]\n    else:\n        return train_iters\n\n\n@hydra.main(version_base=\"1.3\", config_path=\"../../configs\", config_name=\"train.yaml\")\n@maybe_close_wandb\ndef main(cfg: DictConfig):\n    \"\"\"Main entry point for training.\n\n    :param cfg: DictConfig configuration composed by Hydra.\n    :return: Optional[float] with optimized metric value.\n    \"\"\"\n    extras(cfg)\n    Path(cfg.experiment_path).mkdir(exist_ok=True, parents=True)\n\n    num_nodes = int(os.environ.get(\"SLURM_JOB_NUM_NODES\", 1))\n    devices = int(os.environ.get(\"SLURM_GPUS_PER_NODE\", torch.cuda.device_count()))\n    logging.info(f\"SLURM job num nodes: {num_nodes}\")\n    logging.info(f\"SLURM tasks per node (devices): {devices}\")\n\n    from lightning.fabric.strategies import DDPStrategy\n    fabric = Fabric(\n        num_nodes=num_nodes,\n        devices=devices,\n        precision=cfg.trainer.precision,\n        strategy=DDPStrategy(find_unused_parameters=True),\n    )\n    fabric.launch()\n    fabric.seed_everything(cfg.reproducibility.seed, workers=True)\n    if cfg.reproducibility.deterministic:\n        torch.use_deterministic_algorithms(True)\n        torch.backends.cudnn.benchmark = False\n        torch.backends.cudnn.deterministic = True\n        torch.autograd.set_detect_anomaly(True)\n\n    if cfg.logging.get(\"log_wandb\", False) and fabric.global_rank == 0:\n        exp_name = cfg.experiment_path.replace(\"./logs/\", \"\").replace(\"/\", \"_\").replace(\"\\\\\", \"_\")\n        wandb.init(\n            project=cfg.logging.wandb_project,\n            name=exp_name,\n            tags=cfg.logging.get(\"tags\", []),\n            config=OmegaConf.to_container(cfg, resolve=True),\n            sync_tensorboard=True,\n        )\n\n    original_numpy = torch.Tensor.numpy\n\n    def patched_numpy(self, *args, **kwargs):\n        if self.dtype == torch.bfloat16:\n            return original_numpy(self.float(), *args, **kwargs)\n        return original_numpy(self, *args, **kwargs)\n\n    torch.Tensor.numpy = patched_numpy\n\n    eval_dataloaders = []\n    for dataset_name in cfg.datasets.eval.names:\n        if dataset_name.startswith(\"tapvid2d-davis-\"):\n            eval_dataset = TapVidDataset.from_name(dataset_name, cfg.datasets.root)\n        elif dataset_name.startswith(\"kubric-multiview-v3-25views\"):\n            kubric_kwargs = {\n                \"data_root\": os.path.join(cfg.datasets.root, \"kubric_multiview_003\", \"kubric_25_view\"),\n                \"seq_len\": 24,\n                \"traj_per_sample\": 200,\n                \"seed\": 72,\n                \"sample_vis_1st_frame\": True,\n                \"tune_per_scene\": False,\n                \"max_videos\": 30,\n                \"use_duster_depths\": False,\n                \"duster_views\": None,\n                \"clean_duster_depths\": False,\n                \"views_to_return\": list(range(20)),\n                \"novel_views\": list(range(20, 25)),\n                \"num_views\": -1,\n                \"depth_noise_std\": 0,\n            }\n            eval_dataset = KubricMultiViewDataset(**kubric_kwargs)\n        elif dataset_name.startswith(\"kubric-multiview-v3\"):\n            eval_dataset = KubricMultiViewDataset.from_name(dataset_name, cfg.datasets.root, cfg)\n        elif dataset_name.startswith(\"panoptic-multiview\"):\n            eval_dataset = PanopticStudioMultiViewDataset.from_name(dataset_name, cfg.datasets.root)\n        elif dataset_name.startswith(\"dex-ycb-multiview\"):\n            eval_dataset = DexYCBMultiViewDataset.from_name(dataset_name, cfg.datasets.root)\n        elif dataset_name == \"egoexo4d\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/egoexo4d-processed/maxframes-300_downsample-1_downscale-512/\",\n                drop_first_n_frames=44,\n            )\n        elif dataset_name == \"4d-dress\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/4d-dress-processed-resized-512-selection\",\n                use_duster_depths=False,\n            )\n        elif dataset_name == \"hi4d\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/hi4d-processed-resized-512\",\n                use_duster_depths=False,\n                use_vggt_depths_with_aligned_cameras=True,\n            )\n        elif dataset_name == \"selfcap-v1\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-10_downscale-512/\",\n                drop_first_n_frames=72,\n            )\n        elif dataset_name == \"selfcap-v2\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-True_startframe-90_maxframes-256_downsample-10_downscale-512/\",\n                drop_first_n_frames=72,\n            )\n        elif dataset_name == \"selfcap-v3\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-20_downscale-512/\",\n                drop_first_n_frames=36,\n            )\n        elif dataset_name == \"selfcap-v4\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-30_downscale-512/\",\n                drop_first_n_frames=24,\n            )\n        elif dataset_name == \"selfcap-v5\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-256_downsample-5_downscale-512/\",\n                drop_first_n_frames=144,\n            )\n        elif dataset_name == \"selfcap-v6\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-8-seq-False_startframe-90_maxframes-2560_downsample-10_downscale-512/\",\n                drop_first_n_frames=44,\n            )\n        elif dataset_name == \"selfcap-v7\":\n            eval_dataset = GenericSceneDataset(\n                dataset_dir=\"datasets/selfcap-processed/numcams-4-seq-False_startframe-90_maxframes-256_downsample-10_downscale-512/\",\n                drop_first_n_frames=72,\n            )\n        else:\n            raise ValueError(f\"Dataset {dataset_name} not supported for evaluation.\")\n        eval_dataloader = torch.utils.data.DataLoader(\n            eval_dataset,\n            batch_size=1,\n            shuffle=False,\n            num_workers=cfg.datasets.eval.num_workers,\n            collate_fn=collate_fn,\n        )\n        eval_dataloaders.append((dataset_name, eval_dataloader))\n\n    # # Let each rank handle a subset of the evaluation dataloaders\n    # eval_dataloaders_for_rank = []\n    # for idx, (dset_name, dset_loader) in enumerate(eval_dataloaders):\n    #     if (idx % fabric.world_size) == fabric.global_rank:\n    #         eval_dataloaders_for_rank.append((dset_name, fabric.setup_dataloaders(dset_loader)))\n    # eval_dataloaders = eval_dataloaders_for_rank\n\n    train_viz_save_dir = os.path.join(cfg.experiment_path, f\"train_{cfg.datasets.train.name}\")\n    os.makedirs(train_viz_save_dir, exist_ok=True)\n    visualizer = MultiViewVisualizer(\n        save_dir=train_viz_save_dir,\n        pad_value=16,\n        fps=12,\n        show_first_frame=0,\n        tracks_leave_trace=0,\n    )\n\n    evaluator = hydra.utils.instantiate(cfg.evaluation.evaluator)\n\n    if cfg.modes.do_initial_static_pretrain and not cfg.modes.eval_only:\n        pretraining_datasets = [\n            kubric_multiview_dataset.KubricMultiViewDataset(\n                data_root=os.path.join(cfg.datasets.root, \"kubric_multiview_003\", \"train\"),\n                traj_per_sample=cfg.datasets.train.traj_per_sample,\n                ratio_dynamic=0.1,\n                ratio_very_dynamic=0.0,\n                num_views=4,\n                enable_cropping_augs=cfg.augmentations.cropping,\n\n                seq_len=seq_len,\n                static_cropping=static_cropping,\n                max_videos=max_videos,\n            )\n            for seq_len, static_cropping, max_videos in [\n                (12, True, 500),\n                (18, True, 500),\n                (24, True, 1000),\n                (24, False, 2000),\n            ]\n        ]\n        pretraining_dataset = torch.utils.data.ConcatDataset(pretraining_datasets)\n        pretraining_dataloader = StatefulDataLoader(\n            pretraining_dataset,\n            batch_size=cfg.datasets.train.batch_size,\n            shuffle=False,\n            num_workers=cfg.datasets.train.num_workers,\n            pin_memory=True,\n            pin_memory_device=\"cuda\",\n            collate_fn=collate_fn,\n            drop_last=True,\n            in_order=cfg.reproducibility.deterministic,\n        )\n        pretraining_dataloader = fabric.setup_dataloaders(pretraining_dataloader)\n    else:\n        pretraining_dataloader = None\n\n    if cfg.modes.eval_only:\n        train_dataset = None\n    elif cfg.datasets.train.name.startswith(\"kubric-multiview-v3\"):\n        train_dataset = KubricMultiViewDataset.from_name(cfg.datasets.train.name, cfg.datasets.root, cfg, fabric)\n    else:\n        raise ValueError(f\"Dataset {cfg.datasets.train.name} not supported for training\")\n\n    if not cfg.modes.eval_only:\n        train_loader = StatefulDataLoader(\n            train_dataset,\n            batch_size=cfg.datasets.train.batch_size,\n            shuffle=True,\n            num_workers=cfg.datasets.train.num_workers,\n            pin_memory=True,\n            collate_fn=collate_fn,\n            drop_last=True,\n            prefetch_factor=4 if cfg.datasets.train.num_workers > 0 else None,\n            in_order=cfg.reproducibility.deterministic,\n        )\n        # eval_dataloaders += [(\"kubric-multiview-v3-training\", train_loader)]\n        train_loader = fabric.setup_dataloaders(train_loader)\n        logging.info(f\"LEN TRAIN LOADER={len(train_loader)}\")\n        num_epochs = cfg.trainer.num_steps // len(train_loader) + 1 + (1 if cfg.modes.do_initial_static_pretrain else 0)\n        if cfg.modes.do_initial_static_pretrain:\n            cfg.trainer.num_steps += len(pretraining_dataloader)\n    else:\n        train_loader = None\n        num_epochs = None\n\n    epoch = -1\n    total_steps = 0\n\n    model: nn.Module = hydra.utils.instantiate(cfg.model)\n    model.cuda()\n    optimizer, scheduler = fetch_optimizer(cfg.trainer, model)\n    model, optimizer = fabric.setup(model, optimizer)\n\n    folder_ckpts = [\n        f\n        for f in os.listdir(cfg.experiment_path)\n        if f.endswith(\".pth\")\n           and not os.path.isdir(f)\n           and not \"final\" in f\n           and not \"unwrap_model\" in f\n           and not \"unwrap_module\" in f\n    ]\n    logging.info(f\"Found {len(folder_ckpts)} checkpoints: {folder_ckpts}\")\n    if len(folder_ckpts) > 0:\n        # We can load this checkpoint directly since we have saved it during training\n        ckpt_name = sorted(folder_ckpts)[-1]\n        experiment_path = os.path.join(cfg.experiment_path, ckpt_name)\n        state = AttributeDict(\n            model=model,\n            optimizer=optimizer,\n            scheduler=scheduler,\n            total_steps=total_steps,\n        )\n        logging.info(f\"Total steps before loading checkpoint: {total_steps}\")\n        fabric.load(experiment_path, state)\n        total_steps = state.total_steps  # Integers are immutable, so they cannot be changed inplace\n        if train_loader is not None:\n            epoch = total_steps // len(train_loader) - 1\n        logging.info(f\"Loaded checkpoint {experiment_path} (total_steps={total_steps})\")\n        logging.info(f\"Total steps after loading checkpoint: {total_steps}\")\n\n    elif cfg.restore_ckpt_path is not None:\n        restore_ckpt_path = cfg.restore_ckpt_path\n        assert restore_ckpt_path.endswith(\".pth\")\n        logging.info(f\"Restoring pre-trained weights from {os.path.abspath(restore_ckpt_path)}\")\n        training_ckpt = \"total_steps\" in torch.load(restore_ckpt_path)\n        if training_ckpt:\n            # Loading a checkpoint saved by fabric during training\n            logging.info(\"Trying to load as a training checkpoint...\")\n            state = AttributeDict(model=model)\n            try:\n                fabric.load(restore_ckpt_path, state, strict=True)\n            except RuntimeError as e:\n                logging.warning(f\"Failed to load weights with from {restore_ckpt_path} with strict=True: {e}. \"\n                                f\"Trying again with strict=False.\")\n                fabric.load(restore_ckpt_path, state, strict=False)\n            logging.info(f\"Loaded checkpoint {restore_ckpt_path}\")\n        else:\n            fabric.load_raw(restore_ckpt_path, model)\n\n    tb_writer = SummaryWriter(log_dir=os.path.join(cfg.experiment_path, f\"runs_{fabric.global_rank}\"))\n    if cfg.modes.eval_only or cfg.modes.validate_at_start:\n        run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps - 1)\n        fabric.barrier()\n        if cfg.modes.eval_only:\n            return\n\n    total_durations = deque()\n    dataloader_durations = deque()\n    fwd_durations = deque()\n    sync_durations = deque()\n    bwd_durations = deque()\n    timing_log_freq = 100\n\n    def handle_sigterm(signum, frame):\n        logging.error(f\"Signal {signum} received, saving checkpoint and exiting...\")\n        ckpt_iter = \"0\" * (6 - len(str(total_steps))) + str(total_steps)\n        save_path = Path(f\"{cfg.experiment_path}/model_{ckpt_iter}.pth\")\n        state = AttributeDict(\n            model=model,\n            optimizer=optimizer,\n            scheduler=scheduler,\n            total_steps=total_steps + 1,\n        )\n        fabric.save(save_path, state)\n        logging.info(f\"Saved checkpoint to {save_path}. Waiting for all ranks to finish...\")\n        fabric.barrier()\n        logging.info(f\"Calling sys.exit(0) now.\")\n        sys.exit(0)\n\n    signal.signal(signal.SIGUSR1, handle_sigterm)\n    signal.signal(signal.SIGTERM, handle_sigterm)\n    logging.info(f\"Registered signal handlers for SIGUSR1 and SIGTERM.\")\n\n    model.train()\n    should_keep_training = True if cfg.trainer.num_steps > 0 else False\n    total_batches_loaded = 0\n    total_batches_failed = 0\n    if fabric.global_rank == 0:\n        tqdm_total_steps = tqdm(\n            total=cfg.trainer.num_steps,\n            desc=f\"Total Training Progress (rank={fabric.global_rank})\",\n            unit=\"batch\",\n            initial=total_steps,\n            position=0,\n        )\n    threads = []\n    had_run_pretraining_epoch = cfg.modes.do_initial_static_pretrain and total_steps > len(pretraining_dataloader)\n    logging.info(f\"{total_steps=}, {epoch=}/{num_epochs}, {had_run_pretraining_epoch=}\")\n    while should_keep_training:\n        epoch += 1\n        i_batch = -1\n\n        if cfg.modes.do_initial_static_pretrain and not had_run_pretraining_epoch:\n            had_run_pretraining_epoch = True\n            data_iter = iter(pretraining_dataloader)\n            n_batches = len(pretraining_dataloader)\n        else:\n            data_iter = iter(train_loader)\n            n_batches = len(train_loader)\n        if fabric.global_rank == 0:\n            tqdm_epoch = tqdm(total=n_batches, desc=f\"Epoch {epoch + 1}/{num_epochs}\", unit=\"batch\", position=1)\n\n        while i_batch < n_batches:\n            start_time_1 = time.time()\n            logging.info(f\"Gonna load batch {i_batch + 1}/{n_batches} (rank={fabric.global_rank})\")\n            try:\n                batch = next(data_iter)\n            except StopIteration:\n                data_iter = iter(train_loader)\n                n_batches = len(train_loader)\n                batch = next(data_iter)\n\n            batch, gotit = batch\n            total_batches_loaded += 1\n\n            if cfg.modes.debugging_hotfix_datapoint_path is not None:\n                logging.info(f\"Debugging hotfix: loading batch from {cfg.modes.debugging_hotfix_datapoint_path}\")\n                batch = torch.load(cfg.modes.debugging_hotfix_datapoint_path, map_location=\"cuda:0\")\n                logging.info(f\"Debugging hotfix: loaded batch {batch.seq_name} \"\n                             f\"with {len(batch.video)} views and {batch.video.shape[2]} frames\")\n\n            if not all(gotit):\n                total_batches_failed += 1\n                logging.info(f\"batch is None: \"\n                             f\"failed {total_batches_failed} / {total_batches_loaded} \"\n                             f\"({total_batches_failed / total_batches_loaded * 100:.2f}%) batches\")\n                continue\n\n            i_batch += 1\n            dataclass_to_cuda_(batch)\n            assert model.training\n\n            start_time_2 = time.time()\n            dataloader_duration = start_time_2 - start_time_1\n            logging.info(f\"Datapoint: {batch.seq_name} (Waited for {dataloader_duration:>5.2f}s)\")\n\n            train_iters = cfg.trainer.train_iters\n            if cfg.trainer.augment_train_iters:\n                train_iters = augment_train_iters(train_iters, total_steps, cfg.trainer.augment_train_iters_warmup)\n            optimizer.zero_grad()\n\n            try:\n                output = forward_batch_multi_view(\n                    batch=batch,\n                    model=model,\n                    cfg=cfg,\n                    step=total_steps,\n                    train_iters=train_iters,\n                    gamma=cfg.trainer.gamma,\n                    save_debug_logs=(\n                            ((total_steps % cfg.trainer.viz_freq) == (cfg.trainer.viz_freq - 1))\n                            or (total_steps in [0, 10, 100, cfg.trainer.num_steps - 1])\n                    ),\n                    debug_logs_path=os.path.join(\n                        cfg.experiment_path,\n                        f'forward_pass__train_step-{total_steps}_global_rank-{fabric.global_rank}'\n                    ),\n                )\n            except Exception as e:\n                logging.critical(f\"Forward pass crashed at step {total_steps}: {e}\")\n\n                # Save current checkpoint\n                save_path = Path(f\"{cfg.experiment_path}/test_{total_steps:06d}.pth\")\n                state = AttributeDict(\n                    model=model,\n                    optimizer=optimizer,\n                    scheduler=scheduler,\n                    total_steps=total_steps + 1,\n                )\n                fabric._strategy.checkpoint_io.save_checkpoint(\n                    checkpoint=fabric._strategy._convert_stateful_objects_in_state(_unwrap_objects(state), filter={}),\n                    path=save_path,\n                )\n                logging.info(f\"Saved crash checkpoint to {save_path}\")\n\n                # Save the batch\n                batch_path = Path(f\"{cfg.experiment_path}/crash_batch_step_{total_steps:06d}.pt\")\n                try:\n                    torch.save(batch, batch_path)\n                    logging.info(f\"Saved crashing batch to {batch_path}\")\n                except Exception as batch_exc:\n                    logging.error(f\"Failed to save crashing batch as .pt: {batch_exc}\")\n\n                raise  # re-raise to crash the job after saving artifacts\n\n            loss = torch.tensor(0.0).cuda()\n            for k, v in output.items():\n                if k == \"metrics\":\n                    for metric_name, metric_value in v.items():\n                        tb_writer.add_scalar(metric_name, metric_value, total_steps)\n                elif \"loss\" in v:\n                    loss += v[\"loss\"]\n                    tb_writer.add_scalar(f\"live_{k}_loss\", v[\"loss\"].item(), total_steps)\n                else:\n                    raise ValueError(f\"Unknown key {k} in output\")\n\n            start_time_3 = time.time()\n            fwd_duration = start_time_3 - start_time_2\n\n            fabric.barrier()\n\n            start_time_4 = time.time()\n            sync_duration = start_time_4 - start_time_3\n\n            fabric.backward(loss)\n            # Log a limited number of grad + optimizer state pairs, also log current learning rate\n            if (total_steps <= 10) or (total_steps % cfg.trainer.viz_freq == 0):\n                log_limit = 5\n                logged = 0\n                prefix = f\"[DEBUG] [RANK={fabric.global_rank:03d}]\"\n                logging.info(f\"{prefix} RNG seed: {torch.initial_seed()}\")\n                logging.info(f\"{prefix} Step={total_steps} – Gradients and Optimizer State\")\n                for name, param in model.named_parameters():\n                    if param.grad is not None and param in optimizer.state:\n                        state = optimizer.state[param]\n                        exp_avg_norm = state['exp_avg'].norm().item() if 'exp_avg' in state else float('nan')\n                        exp_avg_sq_norm = state['exp_avg_sq'].norm().item() if 'exp_avg_sq' in state else float('nan')\n                        grad_norm = param.grad.norm().item()\n                        logging.info(\n                            f\"{prefix} Param: {name:<60s} | \"\n                            f\"grad_norm={grad_norm:>14.9f} | \"\n                            f\"exp_avg_norm={exp_avg_norm:>14.9f} | \"\n                            f\"exp_avg_sq_norm={exp_avg_sq_norm:>14.9f}\"\n                        )\n                        logged += 1\n                        if logged >= log_limit:\n                            break\n                for name, param in model.named_parameters():\n                    if param.grad_fn:\n                        print(f\"{prefix} {name} grad_fn: {param.grad_fn}\")\n                logging.info(f\"{prefix} LR at step {total_steps}: {scheduler.get_last_lr()}\")\n            fabric.clip_gradients(model, optimizer, clip_val=cfg.trainer.grad_clip)\n            optimizer.step()\n            scheduler.step()\n\n            start_time_5 = time.time()\n            bwd_duration = start_time_5 - start_time_4\n\n            if fabric.global_rank == 0:\n                if (total_steps % cfg.trainer.viz_freq == 0) or (\n                        total_steps == cfg.trainer.num_steps - 1) or total_steps in [0, 10, 100]:\n                    logging.info(f\"Creating training viz logs (rank: {fabric.global_rank}, step: {total_steps})\")\n                    video = batch.video.clone().cpu()\n                    video_depth = batch.videodepth.clone().cpu()\n                    gt_viz, vector_colors = visualizer.visualize(\n                        video=video,\n                        video_depth=video_depth,\n                        tracks=batch.trajectory.clone().cpu(),\n                        visibility=batch.visibility.clone().cpu(),\n                        query_frame=batch.query_points_3d[..., 0].long().clone().cpu(),\n                        filename=\"train_gt_traj\",\n                        writer=tb_writer,\n                        step=total_steps,\n                        save_video=False,\n                    )\n                    pred_viz, _ = visualizer.visualize(\n                        video=video,\n                        video_depth=video_depth,\n                        tracks=output[\"flow\"][\"predictions\"][None].cpu(),\n                        visibility=(output[\"visibility\"][\"predictions\"][None] > 0.5).cpu(),\n                        query_frame=batch.query_points_3d[..., 0].long().clone().cpu(),\n                        filename=\"train_pred_traj\",\n                        writer=tb_writer,\n                        step=total_steps,\n                        save_video=False,\n                    )\n                    viz = torch.cat([gt_viz[..., :gt_viz.shape[-1] // 2], pred_viz], dim=-1)\n                    thread = threading.Thread(\n                        target=Visualizer.save_video,\n                        args=(viz, visualizer.save_dir, f\"train\", tb_writer, visualizer.fps, total_steps)\n                    )\n                    thread.start()\n                    threads.append(thread)\n\n                if len(output) > 1:\n                    tb_writer.add_scalar(f\"live_total_loss\", loss.item(), total_steps)\n                tb_writer.add_scalar(f\"learning_rate\", optimizer.param_groups[0][\"lr\"], total_steps)\n\n            if total_steps % cfg.trainer.save_ckpt_freq == 0:\n                ckpt_iter = \"0\" * (6 - len(str(total_steps))) + str(total_steps)\n                save_path = Path(f\"{cfg.experiment_path}/model_{ckpt_iter}.pth\")\n                logging.info(f\"Saving file {save_path}\")\n                state = AttributeDict(\n                    model=model,\n                    optimizer=optimizer,\n                    scheduler=scheduler,\n                    total_steps=total_steps + 1,\n                )\n                fabric.save(save_path, state)\n\n            if total_steps % cfg.trainer.eval_freq == 0 and total_steps > 1:\n                run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps)\n                fabric.barrier()\n\n            total_steps += 1\n            if fabric.global_rank == 0:\n                tqdm_epoch.update(1)\n                tqdm_total_steps.update(1)\n                tqdm_epoch.set_postfix(\n                    loss=loss.item(),\n                    lr=optimizer.param_groups[0][\"lr\"],\n                    train_iters=cfg.trainer.train_iters,\n                    gamma=cfg.trainer.gamma,\n                    seq_name=batch.seq_name,\n                )\n\n            total_duration = time.time() - start_time_1\n            logging.info(\n                f\"[timing:{total_steps:06d}] \"\n                f\"Total: {total_duration:>6.2f}s | \"\n                f\"Data: {dataloader_duration:>6.2f}s | \"\n                f\"Fwd: {fwd_duration:>6.2f}s | \"\n                f\"Sync: {sync_duration:>6.2f}s | \"\n                f\"Bwd: {bwd_duration:>6.2f}s | \"\n            )\n            if fabric.global_rank == 0:\n                dataloader_durations.append(dataloader_duration)\n                fwd_durations.append(fwd_duration)\n                sync_durations.append(sync_duration)\n                bwd_durations.append(bwd_duration)\n                total_durations.append(total_duration)\n\n                tb_writer.add_scalar(f\"timing/step\", total_duration, total_steps)\n                tb_writer.add_scalar(f\"timing/only_fwd\", fwd_durations[-1], total_steps)\n                tb_writer.add_scalar(f\"timing/only_sync\", sync_durations[-1], total_steps)\n                tb_writer.add_scalar(f\"timing/only_bwd\", bwd_durations[-1], total_steps)\n                tb_writer.add_scalar(f\"timing/only_dataloader\", dataloader_duration, total_steps)\n\n                if len(total_durations) >= timing_log_freq:\n                    total_durations_np = np.array(total_durations)\n                    fwd_durations_np = np.array(fwd_durations)\n                    sync_durations_np = np.array(sync_durations)\n                    bwd_durations_np = np.array(bwd_durations)\n                    dataloader_durations_np = np.array(dataloader_durations)\n\n                    total_duration_mean = np.mean(total_durations_np)\n                    fwd_duration_mean = np.mean(fwd_durations_np)\n                    sync_duration_mean = np.mean(sync_durations_np)\n                    bwd_duration_mean = np.mean(bwd_durations_np)\n                    dataloader_duration_mean = np.mean(dataloader_durations_np)\n\n                    total_duration_median = np.median(total_durations_np)\n                    fwd_duration_median = np.median(fwd_durations_np)\n                    sync_duration_median = np.median(sync_durations_np)\n                    bwd_duration_median = np.median(bwd_durations_np)\n                    dataloader_duration_median = np.median(dataloader_durations_np)\n\n                    total_duration_std = np.std(total_durations_np)\n                    fwd_duration_std = np.std(fwd_durations_np)\n                    sync_duration_std = np.std(sync_durations_np)\n                    bwd_duration_std = np.std(bwd_durations_np)\n                    dataloader_duration_std = np.std(dataloader_durations_np)\n\n                    tb_writer.add_scalar(\"timing/step_mean\", total_duration_mean, total_steps)\n                    tb_writer.add_scalar(\"timing/step_median\", total_duration_median, total_steps)\n                    tb_writer.add_scalar(\"timing/only_fwd_mean\", fwd_duration_mean, total_steps)\n                    tb_writer.add_scalar(\"timing/only_fwd_median\", fwd_duration_median, total_steps)\n                    tb_writer.add_scalar(\"timing/only_sync_mean\", sync_duration_mean, total_steps)\n                    tb_writer.add_scalar(\"timing/only_sync_median\", sync_duration_median, total_steps)\n                    tb_writer.add_scalar(\"timing/only_bwd_mean\", bwd_duration_mean, total_steps)\n                    tb_writer.add_scalar(\"timing/only_bwd_median\", bwd_duration_median, total_steps)\n                    tb_writer.add_scalar(\"timing/only_dataloader_mean\", dataloader_duration_mean, total_steps)\n                    tb_writer.add_scalar(\"timing/only_dataloader_median\", dataloader_duration_median, total_steps)\n\n                    logging.info(\n                        f\"[timing:total] \"\n                        f\"Mean: {total_duration_mean:>6.2f}s | \"\n                        f\"Median: {total_duration_median:>6.2f}s | \"\n                        f\"Std: {total_duration_std:6.2f}s\"\n                    )\n                    logging.info(\n                        f\"[timing:fwd]   \"\n                        f\"Mean: {fwd_duration_mean:>6.2f}s | \"\n                        f\"Median: {fwd_duration_median:>6.2f}s | \"\n                        f\"Std: {fwd_duration_std:6.2f}s\"\n                    )\n                    logging.info(\n                        f\"[timing:sync]  \"\n                        f\"Mean: {sync_duration_mean:>6.2f}s | \"\n                        f\"Median: {sync_duration_median:>6.2f}s | \"\n                        f\"Std: {sync_duration_std:6.2f}s\"\n                    )\n                    logging.info(\n                        f\"[timing:bwd]   \"\n                        f\"Mean: {bwd_duration_mean:>6.2f}s | \"\n                        f\"Median: {bwd_duration_median:>6.2f}s | \"\n                        f\"Std: {bwd_duration_std:6.2f}s\"\n                    )\n                    logging.info(\n                        f\"[timing:datal] \"\n                        f\"Mean: {dataloader_duration_mean:>6.2f}s | \"\n                        f\"Median: {dataloader_duration_median:>6.2f}s | \"\n                        f\"Std: {dataloader_duration_std:6.2f}s\"\n                    )\n\n                    total_durations.clear()\n                    fwd_durations.clear()\n                    sync_durations.clear()\n                    bwd_durations.clear()\n                    dataloader_durations.clear()\n\n            if total_steps > cfg.trainer.num_steps:\n                should_keep_training = False\n                break\n\n        if fabric.global_rank == 0:\n            tqdm_epoch.close()\n\n    if fabric.global_rank == 0:\n        tqdm_total_steps.close()\n    logging.info(\"FINISHED TRAINING\")\n\n    save_path = f\"{cfg.experiment_path}/model_final.pth\"\n    logging.info(f\"Saving file {save_path}\")\n    state = AttributeDict(\n        model=model,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        total_steps=total_steps,\n    )\n    fabric.save(save_path, state)\n    run_test_eval(cfg, evaluator, model, eval_dataloaders, tb_writer, total_steps)\n    for thread in threads:\n        thread.join()\n    tb_writer.flush()\n    tb_writer.close()\n    fabric.barrier()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "mvtracker/cli/utils/__init__.py",
    "content": "from .pylogger import RankedLogger\nfrom .rich_utils import enforce_tags, print_config_tree\nfrom .helpers import extras, get_metric_value, task_wrapper\n"
  },
  {
    "path": "mvtracker/cli/utils/helpers.py",
    "content": "import faulthandler\nimport warnings\nfrom functools import wraps\nfrom importlib.util import find_spec\nfrom typing import Any, Callable, Dict, Optional, Tuple\n\nimport wandb\nfrom omegaconf import DictConfig\n\nfrom mvtracker.cli.utils import pylogger, rich_utils\n\nlog = pylogger.RankedLogger(__name__, rank_zero_only=True)\n\n\ndef extras(cfg: DictConfig) -> None:\n    \"\"\"Applies optional utilities before the task is started.\n\n    Utilities:\n        - Ignoring python warnings\n        - Setting tags from command line\n        - Rich config printing\n\n    :param cfg: A DictConfig object containing the config tree.\n    \"\"\"\n    # return if no `extras` config\n    if not cfg.get(\"extras\"):\n        log.warning(\"Extras config not found! <cfg.extras=null>\")\n        return\n\n    # disable python warnings\n    if cfg.extras.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <cfg.extras.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    # prompt user to input tags from command line if none are provided in the config\n    if cfg.extras.get(\"enforce_tags\"):\n        log.info(\"Enforcing tags! <cfg.extras.enforce_tags=True>\")\n        rich_utils.enforce_tags(cfg, save_to_file=True)\n\n    # pretty print config tree using Rich library\n    if cfg.extras.get(\"print_config\"):\n        log.info(\"Printing config tree with Rich! <cfg.extras.print_config=True>\")\n        rich_utils.print_config_tree(cfg, print_order=None, resolve=True, save_to_file=True)\n\n    if cfg.extras.get(\"enable_faulthandler_traceback\"):\n        log.info(\"Enabling faulthandler timeouts!\")\n        faulthandler.dump_traceback_later(timeout=cfg.extras.faulthandler_traceback_timeout, repeat=True)\n\n\ndef task_wrapper(task_func: Callable) -> Callable:\n    \"\"\"Optional decorator that controls the failure behavior when executing the task function.\n\n    This wrapper can be used to:\n        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)\n        - save the exception to a `.log` file\n        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)\n        - etc. (adjust depending on your needs)\n\n    Example:\n    ```\n    @utils.task_wrapper\n    def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n        ...\n        return metric_dict, object_dict\n    ```\n\n    :param task_func: The task function to be wrapped.\n\n    :return: The wrapped task function.\n    \"\"\"\n\n    def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n        # execute the task\n        try:\n            metric_dict, object_dict = task_func(cfg=cfg)\n\n        # things to do if exception occurs\n        except Exception as ex:\n            # save exception to `.log` file\n            log.exception(\"\")\n\n            # some hyperparameter combinations might be invalid or cause out-of-memory errors\n            # so when using hparam search plugins like Optuna, you might want to disable\n            # raising the below exception to avoid multirun failure\n            raise ex\n\n        # things to always do after either success or exception\n        finally:\n            # display output dir path in terminal\n            log.info(f\"Output dir: {cfg.paths.output_dir}\")\n\n            # always close wandb run (even if exception occurs so multirun won't fail)\n            if find_spec(\"wandb\"):  # check if wandb is installed\n                import wandb\n\n                if wandb.run:\n                    log.info(\"Closing wandb!\")\n                    wandb.finish()\n\n        return metric_dict, object_dict\n\n    return wrap\n\n\ndef get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:\n    \"\"\"Safely retrieves value of the metric logged in LightningModule.\n\n    :param metric_dict: A dict containing metric values.\n    :param metric_name: If provided, the name of the metric to retrieve.\n    :return: If a metric name was provided, the value of the metric.\n    \"\"\"\n    if not metric_name:\n        log.info(\"Metric name is None! Skipping metric value retrieval...\")\n        return None\n\n    if metric_name not in metric_dict:\n        raise Exception(\n            f\"Metric value not found! <metric_name={metric_name}>\\n\"\n            \"Make sure metric name logged in LightningModule is correct!\\n\"\n            \"Make sure `optimized_metric` name in `hparams_search` config is correct!\"\n        )\n\n    metric_value = metric_dict[metric_name].item()\n    log.info(f\"Retrieved metric value! <{metric_name}={metric_value}>\")\n\n    return metric_value\n\n\ndef maybe_close_wandb(fn: Callable) -> Callable:\n    @wraps(fn)\n    def wrapper(cfg, *args, **kwargs):\n        try:\n            return fn(cfg, *args, **kwargs)\n        finally:\n            if wandb.run is not None:\n                wandb.finish()\n\n    return wrapper\n"
  },
  {
    "path": "mvtracker/cli/utils/pylogger.py",
    "content": "import logging\nfrom typing import Mapping, Optional\n\nfrom lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only\n\n\nclass RankedLogger(logging.LoggerAdapter):\n    \"\"\"A multi-GPU-friendly python command line logger.\"\"\"\n\n    def __init__(\n            self,\n            name: str = __name__,\n            rank_zero_only: bool = False,\n            extra: Optional[Mapping[str, object]] = None,\n    ) -> None:\n        \"\"\"Initializes a multi-GPU-friendly python command line logger that logs on all processes\n        with their rank prefixed in the log message.\n\n        :param name: The name of the logger. Default is ``__name__``.\n        :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.\n        :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.\n        \"\"\"\n        logger = logging.getLogger(name)\n        super().__init__(logger=logger, extra=extra)\n        self.rank_zero_only = rank_zero_only\n\n    def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:\n        \"\"\"Delegate a log call to the underlying logger, after prefixing its message with the rank\n        of the process it's being logged from. If `'rank'` is provided, then the log will only\n        occur on that rank/process.\n\n        :param level: The level to log at. Look at `logging.__init__.py` for more information.\n        :param msg: The message to log.\n        :param rank: The rank to log at.\n        :param args: Additional args to pass to the underlying logging function.\n        :param kwargs: Any additional keyword args to pass to the underlying logging function.\n        \"\"\"\n        if self.isEnabledFor(level):\n            msg, kwargs = self.process(msg, kwargs)\n            current_rank = getattr(rank_zero_only, \"rank\", None)\n            if current_rank is None:\n                raise RuntimeError(\"The `rank_zero_only.rank` needs to be set before use\")\n            msg = rank_prefixed_message(msg, current_rank)\n            if self.rank_zero_only:\n                if current_rank == 0:\n                    self.logger.log(level, msg, *args, **kwargs)\n            else:\n                if rank is None:\n                    self.logger.log(level, msg, *args, **kwargs)\n                elif current_rank == rank:\n                    self.logger.log(level, msg, *args, **kwargs)\n"
  },
  {
    "path": "mvtracker/cli/utils/rich_utils.py",
    "content": "from pathlib import Path\nfrom typing import Sequence, Optional\n\nimport rich\nimport rich.syntax\nimport rich.tree\nfrom hydra.core.hydra_config import HydraConfig\nfrom lightning_utilities.core.rank_zero import rank_zero_only\nfrom omegaconf import DictConfig, OmegaConf, open_dict\nfrom rich.prompt import Prompt\n\nfrom mvtracker.cli.utils import pylogger\n\nlog = pylogger.RankedLogger(__name__, rank_zero_only=True)\n\n\n@rank_zero_only\ndef print_config_tree(\n        cfg: DictConfig,\n        print_order: Optional[Sequence[str]] = (\n                \"experiment_paths\",\n                \"model\",\n                \"predictor_settings\",\n        ),\n        resolve: bool = False,\n        save_to_file: bool = False,\n) -> None:\n    \"\"\"Prints the contents of a DictConfig as a tree structure using the Rich library.\n\n    :param cfg: A DictConfig composed by Hydra.\n    :param print_order: Determines in what order config components are printed.\n    :param resolve: Whether to resolve reference fields of DictConfig.\n    :param save_to_file: Whether to export config to the hydra output folder.\n    \"\"\"\n    style = \"italic cyan\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    queue = []\n\n    # add fields from `print_order` to queue\n    if print_order is not None:\n        for field in print_order:\n            queue.append(field) if field in cfg else log.warning(\n                f\"Field '{field}' not found in config. Skipping '{field}' config printing...\"\n            )\n\n    # add all the other fields to queue (not specified in `print_order`)\n    for field in cfg:\n        if field not in queue:\n            queue.append(field)\n\n    # generate config tree from queue\n    for field in queue:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_group = cfg[field]\n        if isinstance(config_group, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)\n        else:\n            branch_content = str(config_group)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    # print config tree\n    rich.print(tree)\n\n    # save config tree to file\n    if save_to_file:\n        with open(Path(HydraConfig.get().runtime.output_dir, \"config_tree.log\"), \"w\") as file:\n            rich.print(tree, file=file)\n\n\n@rank_zero_only\ndef enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:\n    \"\"\"Prompts user to input tags from command line if no tags are provided in config.\n\n    :param cfg: A DictConfig composed by Hydra.\n    :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.\n    \"\"\"\n    if not cfg.get(\"tags\"):\n        if \"id\" in HydraConfig().cfg.hydra.job:\n            raise ValueError(\"Specify tags before launching a multirun!\")\n\n        log.warning(\"No tags provided in config. Prompting user to input tags...\")\n        tags = Prompt.ask(\"Enter a list of comma separated tags\", default=\"dev\")\n        tags = [t.strip() for t in tags.split(\",\") if t != \"\"]\n\n        with open_dict(cfg):\n            cfg.tags = tags\n\n        log.info(f\"Tags: {cfg.tags}\")\n\n    if save_to_file:\n        with open(Path(cfg.paths.output_dir, \"tags.log\"), \"w\") as file:\n            rich.print(cfg.tags, file=file)\n"
  },
  {
    "path": "mvtracker/datasets/__init__.py",
    "content": "from .dexycb_multiview_dataset import DexYCBMultiViewDataset\nfrom .kubric_multiview_dataset import KubricMultiViewDataset\nfrom .panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset\nfrom .tap_vid_datasets import TapVidDataset\n"
  },
  {
    "path": "mvtracker/datasets/dexycb_multiview_dataset.py",
    "content": "import logging\nimport os\nimport pathlib\nimport re\nimport time\nimport warnings\n\nimport cv2\nimport matplotlib\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn.functional as F\nfrom scipy.spatial.transform import Rotation as R\nfrom torch.utils.data import Dataset\n\nfrom mvtracker.datasets.utils import Datapoint, transform_scene\n\n\nclass DexYCBMultiViewDataset(Dataset):\n\n    @staticmethod\n    def from_name(dataset_name: str, dataset_root: str):\n        \"\"\"\n        Examples of datasets supported by this factory method:\n        - \"dex-ycb-multiview\",\n        - \"dex-ycb-multiview-single\",\n        - \"dex-ycb-multiview-removehand\",\n        - \"dex-ycb-multiview-duster0123\",\n        - \"dex-ycb-multiview-duster0123cleaned\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123-novelviews45\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-single\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-2dpt-single\",\n        - \"dex-ycb-multiview-duster0123cleaned-views0123-novelviews45-removehand-2dpt-single-cached\",\n        \"\"\"\n        # Parse the dataset name, chunk by chunk\n        non_parsed = dataset_name.replace(\"dex-ycb-multiview\", \"\", 1)\n\n        if non_parsed.startswith(\"-duster\"):\n            match = re.match(r\"-duster(\\d+)(cleaned)?\", non_parsed)\n            assert match is not None\n            duster_views = list(map(int, match.group(1)))\n            use_duster = True\n            use_duster_cleaned = match.group(2) is not None\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            use_duster = False\n            use_duster_cleaned = False\n            duster_views = None\n\n        if non_parsed.startswith(\"-views\"):\n            match = re.match(r\"-views(\\d+)\", non_parsed)\n            assert match is not None\n            views = list(map(int, match.group(1)))\n            if duster_views is not None:\n                assert all(v in duster_views for v in views)\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            views = duster_views\n\n        if non_parsed.startswith(\"-novelviews\"):\n            match = re.match(r\"-novelviews(\\d+)\", non_parsed)\n            assert match is not None\n            novel_views = list(map(int, match.group(1)))\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            novel_views = None\n\n        if non_parsed.startswith(\"-removehand\"):\n            remove_hand = True\n            non_parsed = non_parsed.replace(\"-removehand\", \"\", 1)\n        else:\n            remove_hand = False\n\n        if non_parsed.startswith(\"-single\"):\n            single_point = True\n            non_parsed = non_parsed.replace(\"-single\", \"\", 1)\n        else:\n            single_point = False\n\n        if non_parsed.startswith(\"-2dpt\"):\n            eval_2dpt = True\n            non_parsed = non_parsed.replace(\"-2dpt\", \"\", 1)\n        else:\n            eval_2dpt = False\n\n        if non_parsed.startswith(\"-cached\"):\n            use_cached_tracks = True\n            non_parsed = non_parsed.replace(\"-cached\", \"\", 1)\n        else:\n            use_cached_tracks = False\n\n        assert non_parsed == \"\", f\"Unparsed part of the dataset name: {non_parsed}\"\n\n        if views is None and duster_views is None:\n            views = [0, 1, 2, 3]  # Make the legacy \"dex-ycb-multiview\" name take the first 4 views (not all 8)\n\n        return DexYCBMultiViewDataset(\n            data_root=os.path.join(dataset_root, \"dex-ycb-multiview\"),\n            views_to_return=views,\n            novel_views=novel_views,\n            remove_hand=remove_hand,\n            use_duster_depths=use_duster,\n            duster_views=duster_views,\n            clean_duster_depths=use_duster_cleaned,\n            traj_per_sample=384,\n            seed=72,\n            max_videos=10,\n            perform_sanity_checks=False,\n            use_cached_tracks=use_cached_tracks,\n        )\n\n    def __init__(\n            self,\n            data_root,\n            remove_hand=False,\n            views_to_return=None,\n            novel_views=None,\n            use_duster_depths=False,\n            clean_duster_depths=False,\n            duster_views=None,\n            traj_per_sample=768,\n            seed=None,\n            max_videos=None,\n            perform_sanity_checks=False,\n            use_cached_tracks=False,\n    ):\n        super().__init__()\n        self.data_root = data_root\n        self.remove_hand = remove_hand\n        self.views_to_return = views_to_return\n        self.novel_views = novel_views\n        self.use_duster_depths = use_duster_depths\n        self.clean_duster_depths = clean_duster_depths\n        self.duster_views = duster_views\n        self.traj_per_sample = traj_per_sample\n        self.seed = seed\n        self.perform_sanity_checks = perform_sanity_checks\n        self.use_cached_tracks = use_cached_tracks\n        self.cache_name = self._cache_key()\n        self.seq_names = self._get_sequence_names(max_videos)\n        self.getitem_calls = 0\n\n    def _get_sequence_names(self, max_videos):\n        \"\"\"\n        Fetch all valid sequence names from the dataset root.\n\n        Args:\n            max_videos (int): Limit the number of sequences to load.\n\n        Returns:\n            List[str]: Sorted list of valid sequence names.\n        \"\"\"\n        seq_names = [\n            fname\n            for fname in os.listdir(self.data_root)\n            if os.path.isdir(os.path.join(self.data_root, fname))\n               and not fname.startswith(\".\")\n               and not fname.startswith(\"_\")\n        ]\n        seq_names = sorted(seq_names)\n        valid_seqs = []\n\n        for seq_name in seq_names:\n            scene_path = os.path.join(self.data_root, seq_name)\n            view_folders = [\n                d for d in os.listdir(scene_path)\n                if os.path.isdir(os.path.join(scene_path, d)) and d.startswith(\"view_\")\n            ]\n            if not view_folders:\n                warnings.warn(f\"Skipping {scene_path} because it has no views.\")\n                continue\n\n            valid_seqs.append(seq_name)\n\n        if max_videos is not None:\n            valid_seqs = valid_seqs[:max_videos]\n\n        print(f\"Using {len(valid_seqs)} videos from {self.data_root}\")\n        return valid_seqs\n\n    def _cache_key(self):\n        name = f\"cachedtracks--seed{self.seed}\"\n        if self.views_to_return is not None:\n            name += f\"-views{'_'.join(map(str, self.views_to_return))}\"\n        if self.traj_per_sample is not None:\n            name += f\"-n{self.traj_per_sample}\"\n        if self.remove_hand:\n            name += \"-removehand\"\n        return name + \"--v1\"  # bump this if you change the selection policy\n\n    def __len__(self):\n        return len(self.seq_names)\n\n    def __getitem__(self, index):\n        start_time = time.time()\n        sample = self._getitem_helper(index)\n\n        self.getitem_calls += 1\n        if self.getitem_calls < 10:\n            print(f\"Loading {index:>06d} took  {time.time() - start_time:.3f} sec. Getitem calls: {self.getitem_calls}\")\n\n        return sample, True\n\n    def _getitem_helper(self, index):\n        \"\"\"\n        Helper function to load a single sample.\n\n        Args:\n            index (int): Index of the sample to load.\n\n        Returns:\n            CoTrackerData, bool: Sample data and success flag.\n        \"\"\"\n        if self.seed is None:\n            seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()\n        else:\n            seed = self.seed\n        rnd_torch = torch.Generator().manual_seed(seed)\n        rnd_np = np.random.RandomState(seed=seed)\n\n        datapoint_path = os.path.join(self.data_root, self.seq_names[index])\n\n        views = {}\n        view_folders = sorted([f for f in os.listdir(datapoint_path) if f.startswith(\"view_\")])\n        if self.views_to_return is not None:\n            views_to_return = self.views_to_return\n        else:\n            views_to_return = sorted(list(range(len(view_folders))))\n        views_to_load = views_to_return.copy()\n        if self.novel_views is not None:\n            views_to_load = list(set(views_to_load + self.novel_views))\n        for v in views_to_load:\n            view_path = os.path.join(datapoint_path, view_folders[v])\n\n            # Load RGB images\n            rgb_folder = os.path.join(view_path, \"rgb\")\n            rgb_files = sorted(os.listdir(rgb_folder))\n            rgb_images = [cv2.imread(os.path.join(rgb_folder, f))[:, :, ::-1] for f in rgb_files]\n\n            # Load depth maps\n            depth_folder = os.path.join(view_path, \"depth\")\n            depth_files = sorted(os.listdir(depth_folder))\n            depth_images = [cv2.imread(os.path.join(depth_folder, f), cv2.IMREAD_ANYDEPTH) for f in depth_files]\n\n            # Load camera parameters\n            camera_params_file = os.path.join(view_path, \"intrinsics_extrinsics.npz\")\n            params = np.load(camera_params_file)\n            intrinsics = params[\"intrinsics\"][:3, :3]  # Extract K\n            extrinsics = params[\"extrinsics\"][:3, :]  # Extract R|t (world to camera)\n\n            views[v] = {\n                \"rgb\": np.stack(rgb_images),\n                \"depth\": np.stack(depth_images),\n                \"intrinsics\": intrinsics,\n                \"extrinsics\": extrinsics,\n            }\n\n        rgbs = np.stack([views[v][\"rgb\"] for v in views_to_return])\n        n_views, n_frames, h, w, _ = rgbs.shape\n        depths = np.stack([views[v][\"depth\"] for v in views_to_return])[..., None].astype(np.float32) / 1000\n        intrs = np.stack([views[v][\"intrinsics\"] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1)\n        extrs = np.stack([views[v][\"extrinsics\"] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1)\n\n        # Load novel views if they exist\n        novel_rgbs = None\n        novel_intrs = None\n        novel_extrs = None\n        if self.novel_views is not None:\n            novel_rgbs = np.stack([views[v][\"rgb\"] for v in self.novel_views])\n            novel_intrs = np.stack([views[v][\"intrinsics\"] for v in self.novel_views])[:, None, :, :].repeat(n_frames,\n                                                                                                             axis=1)\n            novel_extrs = np.stack([views[v][\"extrinsics\"] for v in self.novel_views])[:, None, :, :].repeat(n_frames,\n                                                                                                             axis=1)\n\n        # Load Duster's features and estimated depths if they exist\n        duster_views = self.duster_views if self.duster_views is not None else views_to_return\n        duster_views_str = ''.join(str(v) for v in duster_views)\n        duster_root = pathlib.Path(datapoint_path) / f'duster-views-{duster_views_str}'\n        if self.use_duster_depths:\n            assert duster_root.exists() and (duster_root / f\"3d_model__{n_frames - 1:05d}__scene.npz\").exists(), \\\n                f\"Duster root {duster_root} does not exist.\"\n\n        feats = None\n        feat_dim = None\n        feat_stride = None\n        depth_confs = None\n        if duster_root.exists() and (duster_root / f\"3d_model__{n_frames - 1:05d}__scene.npz\").exists():\n            duster_depths = []\n            duster_confs = []\n            duster_feats = []\n            for frame_idx in range(n_frames):\n                scene = np.load(duster_root / f\"3d_model__{frame_idx:05d}__scene.npz\")\n                duster_depth = torch.from_numpy(scene[\"depths\"])\n                duster_conf = torch.from_numpy(scene[\"confs\"])\n                duster_msk = torch.from_numpy(scene[\"cleaned_mask\"])\n\n                if self.clean_duster_depths:\n                    duster_depth = duster_depth * duster_msk\n\n                duster_depth = F.interpolate(duster_depth[:, None], (h, w), mode='nearest')\n                duster_depths.append(duster_depth[:, 0, :, :, None])\n\n                duster_conf = F.interpolate(duster_conf[:, None], (h, w), mode='nearest')\n                duster_confs.append(duster_conf[:, 0, :, :, None])\n\n                if \"feats\" in scene:\n                    duster_feats.append(torch.from_numpy(scene[\"feats\"]))\n\n            duster_depths = torch.stack(duster_depths, dim=1).numpy()\n            duster_confs = torch.stack(duster_confs, dim=1).numpy()\n            if duster_feats:\n                feats = torch.stack(duster_feats, dim=1).numpy()\n\n            # Extract the correct views\n            assert duster_depths.shape[0] == len(duster_views)\n            duster_depths = duster_depths[[duster_views.index(v) for v in views_to_return]]\n            duster_confs = duster_confs[[duster_views.index(v) for v in views_to_return]]\n            if feats is not None:\n                assert feats.shape[0] == len(duster_views)\n                feats = feats[[duster_views.index(v) for v in views_to_return]]\n\n            # Reshape the features\n            if feats is not None:\n                assert feats.ndim == 4\n                assert feats.shape[0] == n_views\n                assert feats.shape[1] == n_frames\n                feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int)\n                feat_dim = feats.shape[3]\n                feats = feats.reshape(n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n\n            # Replace the depths with the Duster depths, if configured so\n            if self.use_duster_depths:\n                depths = duster_depths\n                depth_confs = duster_confs\n\n        tracks_3d_file = os.path.join(datapoint_path, \"tracks_3d.npz\")\n        tracks_3d_data = np.load(tracks_3d_file, allow_pickle=True)\n        traj3d_world = tracks_3d_data[\"tracks_3d\"]\n        traj2d = tracks_3d_data[\"tracks_2d\"][views_to_return]\n        traj2d_w_z = np.concatenate((traj2d, tracks_3d_data[\"tracks_2d_z\"][views_to_return][:, :, :, None]), axis=-1)\n        visibility = tracks_3d_data[\"tracks_2d_visibilities\"][views_to_return]\n\n        # Label the trajectories according to: 0: hand, 1: moving ycb object, 2: static ycb objects\n        object_id_to_name = tracks_3d_data[\"object_id_to_name\"].item()\n        traj_object_id = tracks_3d_data[\"object_ids\"]\n        for object_name in object_id_to_name.values():\n            assert object_name == \"mano-right-hand\" or object_name.startswith(\"ycb\")\n        avg_movement_per_object_id = {}\n        for object_id in np.unique(traj_object_id):\n            object_mask = traj_object_id == object_id\n            object_traj = traj3d_world[:, object_mask]\n            avg_movement_per_object_id[object_id] = np.linalg.norm(object_traj[1:] - object_traj[:-1], axis=-1).mean()\n        hand_id = {v: k for k, v in object_id_to_name.items()}[\"mano-right-hand\"]\n        dynamic_ycb_object_ids = [k for k, v in avg_movement_per_object_id.items() if v >= 1e-4 and k != hand_id]\n        assert len(dynamic_ycb_object_ids) == 1\n        dynamic_ycb_object_id = dynamic_ycb_object_ids[0]\n        static_ycb_object_ids = [k for k, v in avg_movement_per_object_id.items() if v < 1e-4 and k != hand_id]\n        assert 1 + 1 + len(static_ycb_object_ids) == len(object_id_to_name)\n        # remap object ids to 0: hand, 1: dynamic ycb object, 2: static ycb objects\n        traj_object_id = (\n                0 * (traj_object_id == hand_id) +\n                1 * (traj_object_id == dynamic_ycb_object_id) +\n                2 * np.isin(traj_object_id, static_ycb_object_ids)\n        )\n\n        if self.remove_hand:\n            traj3d_world = traj3d_world[:, traj_object_id > 0]\n            traj2d = traj2d[:, :, traj_object_id > 0]\n            traj2d_w_z = traj2d_w_z[:, :, traj_object_id > 0]\n            visibility = visibility[:, :, traj_object_id > 0]\n            traj_object_id = traj_object_id[traj_object_id > 0]\n\n        n_tracks = traj3d_world.shape[1]\n        assert rgbs.shape == (n_views, n_frames, h, w, 3)\n        assert depths.shape == (n_views, n_frames, h, w, 1)\n        assert depth_confs is None or depth_confs.shape == (n_views, n_frames, h, w, 1)\n        assert feats is None or feats.shape == (n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n        assert intrs.shape == (n_views, n_frames, 3, 3)\n        assert extrs.shape == (n_views, n_frames, 3, 4)\n        assert traj2d.shape == (n_views, n_frames, n_tracks, 2)\n        assert visibility.shape == (n_views, n_frames, n_tracks)\n        assert traj3d_world.shape == (n_frames, n_tracks, 3)\n        assert traj_object_id.shape == (n_tracks,)\n\n        if novel_rgbs is not None:\n            assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3)\n            assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3)\n            assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4)\n\n        # Make sure our intrinsics and extrinsics work correctly\n        point_3d_world = traj3d_world\n        point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1)\n        point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo)\n        if self.perform_sanity_checks:\n            point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera)\n            point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:]\n            point_2d_pixel_gt = traj2d\n\n            point_2d_pixel_no_nan = np.nan_to_num(point_2d_pixel, nan=0)\n            point_2d_pixel_gt_no_nan = np.nan_to_num(point_2d_pixel_gt, nan=0)\n\n            assert np.allclose(point_2d_pixel_no_nan[0, :, 0, :], point_2d_pixel_no_nan[0, :, 0, :],\n                               atol=1), f\"Proj. failed\"\n            assert np.allclose(point_2d_pixel_gt_no_nan, point_2d_pixel_gt_no_nan, atol=1), f\"Point projection failed\"\n\n            assert np.allclose(point_3d_camera[..., 2:], traj2d_w_z[..., -1:], atol=1)\n\n        # Convert everything to torch tensors\n        rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float()\n        depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float()\n        depth_confs = torch.from_numpy(depth_confs).permute(0, 1, 4, 2, 3).float() if depth_confs is not None else None\n        feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None\n        intrs = torch.from_numpy(intrs).float()\n        extrs = torch.from_numpy(extrs).float()\n        traj2d = torch.from_numpy(traj2d)\n        traj2d_w_z = torch.from_numpy(traj2d_w_z)\n        traj3d_world = torch.from_numpy(traj3d_world)\n        traj_object_id = torch.from_numpy(traj_object_id)\n        visibility = torch.from_numpy(visibility)\n        if novel_rgbs is not None:\n            novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float()\n            novel_intrs = torch.from_numpy(novel_intrs).float()\n            novel_extrs = torch.from_numpy(novel_extrs).float()\n\n        # Track selection\n        cache_root = os.path.join(self.data_root, self.seq_names[index], \"cache\")\n        os.makedirs(cache_root, exist_ok=True)\n        cache_file = os.path.join(cache_root, f\"{self.cache_name}.npz\")\n\n        # Check if we can use cached tracks\n        use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file)\n        if use_cache:\n            cache = np.load(cache_file)\n            inds_sampled = torch.from_numpy(cache[\"track_indices\"])\n            traj2d_w_z = torch.from_numpy(cache[\"traj2d_w_z\"])\n            traj3d_world = torch.from_numpy(cache[\"traj3d_world\"])\n            traj_object_id = torch.from_numpy(cache[\"traj_object_id\"])\n            visibility = torch.from_numpy(cache[\"visibility\"])\n            valids = torch.from_numpy(cache[\"valids\"])\n            query_points = torch.from_numpy(cache[\"query_points\"])\n\n        # Otherwise, sample the tracks and create query points\n        else:\n            # Force query points on hand to appear later\n            # This avoids querying when the GT hand reconstruction is severely lacking\n            # Identify tracks that are invisible in the first frame across all views (as they are probably on the hand)\n            invisible_at_first_frame = visibility[:, 0, :] == 0\n            invisible_at_first_frame = invisible_at_first_frame.unsqueeze(1).expand(-1, 5, -1)\n            # Set visibility to 0 for the first 5 frames where the first frame was invisible\n            visibility[:, 0:5, :] *= ~invisible_at_first_frame  # Keep visible ones, set others to 0\n\n            # Sample the points to track\n            visible_for_at_least_two_frames = visibility.any(0).sum(0) >= 2\n            hectic_visibility = ((visibility[:, :-1] & ~visibility[:, 1:]).sum(0) >= 3).any(0)\n            valid_tracks = visible_for_at_least_two_frames & ~hectic_visibility\n            valid_tracks = valid_tracks.nonzero(as_tuple=False)[:, 0]\n\n            point_inds = torch.randperm(len(valid_tracks), generator=rnd_torch)\n            traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds)\n            assert len(point_inds) >= traj_per_sample\n            point_inds = point_inds[:traj_per_sample]\n            inds_sampled = valid_tracks[point_inds]\n\n            n_tracks = len(inds_sampled)\n            traj2d = traj2d[:, :, inds_sampled].float()\n            traj2d_w_z = traj2d_w_z[:, :, inds_sampled].float()\n            traj3d_world = traj3d_world[:, inds_sampled].float()\n            traj_object_id = traj_object_id[inds_sampled]\n            visibility = visibility[:, :, inds_sampled]\n\n            valids = ~torch.isnan(traj2d).any(dim=-1).any(dim=0)\n\n            # Create the query points\n            gt_visibilities_any_view = visibility.any(dim=0)\n            assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), \"All points should be visible in least two frames.\"\n            last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values\n            assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all()\n            gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False\n            assert (gt_visibilities_any_view.sum(dim=0) >= 1).all()\n\n            n_non_first_point_appearance_queries = n_tracks // 4\n            n_first_point_appearance_queries = n_tracks - n_non_first_point_appearance_queries\n\n            first_point_appearances = torch.argmax(\n                gt_visibilities_any_view[..., -n_first_point_appearance_queries:].float(), dim=0)\n            non_first_point_appearances = first_point_appearances.new_zeros((n_non_first_point_appearance_queries,))\n            for track_idx in range(n_tracks)[:n_non_first_point_appearance_queries]:\n                # Randomly take a timestep where the point is visible\n                non_zero_timesteps = torch.nonzero(gt_visibilities_any_view[:, track_idx] == 1)\n                random_timestep = non_zero_timesteps[rnd_np.randint(len(non_zero_timesteps))].item()\n                non_first_point_appearances[track_idx] = random_timestep\n\n            query_points_t = torch.cat([non_first_point_appearances, first_point_appearances], dim=0)\n            query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)]\n            query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1)\n            assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all()\n\n            # Replace nans with zeros\n            traj2d[torch.isnan(traj2d)] = 0\n            traj2d_w_z[torch.isnan(traj2d_w_z)] = 0\n            traj3d_world[torch.isnan(traj3d_world)] = 0\n            assert torch.isnan(visibility).sum() == 0\n\n            # Cache the selected tracks and query points\n            if self.use_cached_tracks:\n                logging.warn(f\"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}\")\n                np.savez_compressed(\n                    cache_file,\n                    track_indices=inds_sampled.numpy(),\n                    traj2d_w_z=traj2d_w_z.numpy(),\n                    traj3d_world=traj3d_world.numpy(),\n                    traj_object_id=traj_object_id.numpy(),\n                    visibility=visibility.numpy(),\n                    valids=valids.numpy(),\n                    query_points=query_points.numpy(),\n                )\n\n        # Normalize the scene to be similar to Kubric's scene\n        scale = 6\n        rot_x = R.from_euler('x', 220, degrees=True).as_matrix()\n        rot_y = R.from_euler('y', 3, degrees=True).as_matrix()\n        rot_z = R.from_euler('z', -30, degrees=True).as_matrix()\n        rot = torch.from_numpy(rot_z @ rot_y @ rot_x)\n        translation = torch.tensor([0.0, 0.0, 0.5], dtype=torch.float32)\n        (\n            depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n        ) = transform_scene(scale, rot, translation, depths, extrs, query_points, traj3d_world, traj2d_w_z)\n        novel_extrs_trans = transform_scene(scale, rot, translation, None, novel_extrs, None, None, None)[1]\n\n        # rerun_viz_scene(\"nane/scene__no_transform/\", rgbs, depths, intrs, extrs, traj3d_world, 0.1)\n        # rerun_viz_scene(\"nane/scene_transformed/\", rgbs, depths_trans, intrs, extrs_trans, traj3d_world_trans, 1)\n\n        # # Use the auto scene normalization of generic scenes\n        # from mvtracker.datasets.generic_scene_dataset import compute_auto_scene_normalization\n        # scale, rot, translation = compute_auto_scene_normalization(depths, torch.ones_like(depths) * 100, extrs_trans, intrs)\n        # scale = scale * T[0, 0].item()\n        # print(f\"{scale=}\")\n        # (depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n        # ) = transform_scene(scale, rot, translation, depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans)\n        # _, novel_extrs_trans, _, _, _ = transform_scene(scale, rot, translation, None, novel_extrs_trans, None, None, None)\n        # 82.7 91.1 --> 80.8 89.1\n\n        segs = torch.ones((n_frames, 1, h, w))  # Dummy segmentation masks\n        datapoint = Datapoint(\n            video=rgbs,\n            videodepth=depths_trans,\n            videodepthconf=depth_confs.float() if depth_confs is not None else None,\n            feats=feats,\n            segmentation=segs,\n            trajectory=traj2d_w_z_trans,\n            trajectory_3d=traj3d_world_trans,\n            trajectory_category=traj_object_id,\n            visibility=visibility,\n            valid=valids,\n            seq_name=self.seq_names[index],\n            intrs=intrs,\n            extrs=extrs_trans,\n            query_points=None,\n            query_points_3d=query_points_trans,\n            track_upscaling_factor=1 / scale,\n\n            novel_video=novel_rgbs,\n            novel_intrs=novel_intrs,\n            novel_extrs=novel_extrs_trans,\n        )\n        return datapoint\n\n\ndef rerun_viz_scene(entity_prefix, rgbs, depths, intrs, extrs, tracks, radii_scale,\n                    viz_camera=False, viz_point_cloud=True, fps=12):\n    import rerun as rr\n\n    # Initialize Rerun\n    rr.init(f\"3dpt\", recording_id=\"v0.16\")\n    rr.connect_tcp()\n\n    V, T, _, H, W = rgbs.shape\n    _, N, _ = tracks.shape\n    assert rgbs.shape == (V, T, 3, H, W)\n    assert depths.shape == (V, T, 1, H, W)\n    assert intrs.shape == (V, T, 3, 3)\n    assert extrs.shape == (V, T, 3, 4)\n    assert tracks.shape == (T, N, 3)\n\n    # Compute inverse intrinsics and extrinsics\n    intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n    extrs_square = torch.eye(4).to(extrs.device).repeat(V, T, 1, 1)\n    extrs_square[:, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n    assert intrs_inv.shape == (V, T, 3, 3)\n    assert extrs_inv.shape == (V, T, 4, 4)\n\n    for v in range(V):  # Iterate over views\n        for t in range(T):  # Iterate over frames\n            rr.set_time_seconds(\"frame\", t / fps)\n\n            # Log RGB image\n            rgb_image = rgbs[v, t].permute(1, 2, 0).cpu().numpy()\n            if viz_camera:\n                rr.log(f\"{entity_prefix}image/view-{v}/rgb\", rr.Image(rgb_image))\n\n            # Log Depth map\n            depth_map = depths[v, t, 0].cpu().numpy()\n            if viz_camera:\n                rr.log(f\"{entity_prefix}image/view-{v}/depth\", rr.DepthImage(depth_map, point_fill_ratio=0.2))\n\n            # Log Camera\n            K = intrs[v, t].cpu().numpy()\n            world_T_cam = np.eye(4)\n            world_T_cam[:3, :3] = extrs_inv[v, t, :3, :3].cpu().numpy()\n            world_T_cam[:3, 3] = extrs_inv[v, t, :3, 3].cpu().numpy()\n            if viz_camera:\n                rr.log(f\"{entity_prefix}image/view-{v}\", rr.Pinhole(image_from_camera=K, width=W, height=H))\n                rr.log(f\"{entity_prefix}image/view-{v}\",\n                       rr.Transform3D(translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3]))\n\n            # Generate and log point cloud colored by RGB values\n            # Compute 3D points from depth map\n            y, x = np.indices((H, W))\n            homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n            depth_values = depth_map.ravel()\n            cam_coords = (intrs_inv[v, t].cpu().numpy() @ homo_pixel_coords) * depth_values\n            cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n            world_coords = (world_T_cam @ cam_coords)[:3].T\n\n            # Filter out points with zero depth\n            valid_mask = depth_values > 0\n            world_coords = world_coords[valid_mask]\n            rgb_colors = rgb_image.reshape(-1, 3)[valid_mask].astype(np.uint8)\n\n            # Log the point cloud\n            if viz_point_cloud:\n                rr.log(f\"{entity_prefix}point_cloud/view-{v}\",\n                       rr.Points3D(world_coords, colors=rgb_colors, radii=0.02 * radii_scale))\n\n    # Log 3D tracks\n    x = tracks[0, :, 0]\n    c = (x - x.min()) / (x.max() - x.min() + 1e-8)\n    colors = (matplotlib.colormaps[\"gist_rainbow\"](c)[:, :3] * 255).astype(np.uint8)\n    for t in range(T):\n        rr.set_time_seconds(\"frame\", t / fps)\n        rr.log(\n            f\"{entity_prefix}tracks/points\",\n            rr.Points3D(positions=tracks[t], colors=colors, radii=0.01 * radii_scale),\n        )\n        if t > 0:\n            strips = np.concatenate(\n                [np.stack([tracks[:t, n], tracks[1:t + 1, n]], axis=-2) for n in range(N)],\n                axis=0,\n            )\n            strip_colors = np.concatenate(\n                [np.repeat(colors[n][None], t, axis=0) for n in range(N)],\n                axis=0,\n            )\n            rr.log(\n                f\"{entity_prefix}tracks/lines\",\n                rr.LineStrips3D(strips=strips, colors=strip_colors, radii=0.005 * radii_scale),\n            )\n"
  },
  {
    "path": "mvtracker/datasets/generic_scene_dataset.py",
    "content": "import logging\nimport os\nimport pickle\nimport sys\nfrom contextlib import ExitStack\nfrom typing import Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom torch.nn.functional import interpolate\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms as TF\nfrom tqdm import tqdm\n\nfrom mvtracker.datasets.utils import Datapoint, transform_scene, align_umeyama, apply_sim3_to_extrinsics\n\n\nclass GenericSceneDataset(Dataset):\n    def __init__(\n            self,\n            dataset_dir,\n\n            use_duster_depths=True,\n            use_vggt_depths_with_aligned_cameras=False,\n            use_vggt_depths_with_raw_cameras=False,\n            use_monofusion_depths=False,\n            use_moge2_depths=False,\n\n            skip_depth_computation_if_cached=True,\n\n            drop_first_n_frames=0,\n\n            scene_normalization_mode=\"auto\",  # \"auto\" | \"manual\" | \"none\"\n            scene_normalization_auto_conf_thresh=4.8,\n            scene_normalization_auto_target_radius=6.3,\n            scene_normalization_auto_rescale_by_camera_radius=True,\n            scene_normalization_manual_scale=None,  # Optional float\n            scene_normalization_manual_rotation=None,  # Optional 3x3 torch.Tensor rotation matrix\n            scene_normalization_manual_translation=None,  # Optional 3D torch.Tensor post-scale translation vector\n            # E.g., the manual transform that translates up by 1.4 units and scales 2.5 times (was good for EgoExo4D):\n            #   scale = 2.5\n            #   translate_x = 0\n            #   translate_y = 0\n            #   translate_z = 1.4 * scale\n            #   T = torch.tensor([\n            #       [scale, 0.0, 0.0, translate_x],\n            #       [0.0, scale, 0.0, translate_y],\n            #       [0.0, 0.0, scale, translate_z],\n            #       [0.0, 0.0, 0.0, 1.0],\n            #   ], dtype=torch.float32)\n\n            stream_viz_to_rerun=False,\n    ):\n        self.dataset_dir = dataset_dir\n\n        self.use_duster_depths = use_duster_depths\n        self.use_vggt_depths_with_aligned_cameras = use_vggt_depths_with_aligned_cameras\n        self.use_vggt_depths_with_raw_cameras = use_vggt_depths_with_raw_cameras\n        self.use_monofusion_depths = use_monofusion_depths\n        self.use_moge2_depths = use_moge2_depths\n        # --- Assert exclusive depth-source configuration ---\n        # Exactly 0 or 1 of these should be True. (0 => fall back to pkl/dust3r.)\n        depth_flags = (int(self.use_duster_depths)\n                       + int(self.use_vggt_depths_with_aligned_cameras)\n                       + int(self.use_vggt_depths_with_raw_cameras)\n                       + int(self.use_monofusion_depths)\n                       + int(self.use_moge2_depths))\n        assert depth_flags <= 1, (\n            \"Misconfigured dataset: choose at most one depth source among \"\n            \"`use_monofusion_depths`, `use_moge2_depths`, `use_duster_depths`.\"\n        )\n\n        self.skip_depth_computation_if_cached = skip_depth_computation_if_cached\n        self.drop_first_n_frames = drop_first_n_frames\n\n        self.scene_normalization_mode = scene_normalization_mode\n        self.scene_normalization_auto_conf_thresh = scene_normalization_auto_conf_thresh\n        self.scene_normalization_auto_target_radius = scene_normalization_auto_target_radius\n        self.scene_normalization_auto_rescale_by_camera_radius = scene_normalization_auto_rescale_by_camera_radius\n        self.scene_normalization_manual_scale = scene_normalization_manual_scale\n        self.scene_normalization_manual_rotation = scene_normalization_manual_rotation\n        self.scene_normalization_manual_translation = scene_normalization_manual_translation\n\n        self.stream_viz_to_rerun = stream_viz_to_rerun\n\n        self.seq_names = sorted([\n            f.replace(\".pkl\", \"\")\n            for f in os.listdir(dataset_dir)\n            if f.endswith(\".pkl\")\n        ])\n        assert self.seq_names, f\"No sequences found in {dataset_dir}\"\n\n    def __len__(self):\n        return len(self.seq_names)\n\n    def __getitem__(self, idx):\n        seq_name = self.seq_names[idx]\n        pkl_path = os.path.join(self.dataset_dir, f\"{seq_name}.pkl\")\n        with open(pkl_path, \"rb\") as f:\n            data = pickle.load(f)\n\n        ego_cam = data.get(\"ego_cam_name\", None)\n        rgbs_dict = data[\"rgbs\"]\n        intrs_dict = data[\"intrs\"]\n        extrs_dict = data[\"extrs\"]\n        depths_dict = data.get(\"depths\", None)\n\n        if ego_cam:\n            rgbs_dict.pop(ego_cam)\n            intrs_dict.pop(ego_cam)\n            extrs_dict.pop(ego_cam)\n            if depths_dict is not None:\n                depths_dict.pop(ego_cam)\n\n        cam_names = sorted(rgbs_dict.keys())\n        n_views = len(cam_names)\n        n_frames, _, H, W = rgbs_dict[cam_names[0]].shape\n\n        rgbs = torch.stack([torch.from_numpy(rgbs_dict[cam]) for cam in cam_names])  # [V, T, 3, H, W]\n        intrs = torch.stack([torch.from_numpy(intrs_dict[cam]) for cam in cam_names])  # [V, 3, 3]\n        intrs = intrs[:, None].expand(-1, n_frames, -1, -1)  # [V, T, 3, 3]\n\n        extr_list = []\n        for cam in cam_names:\n            e = extrs_dict[cam]\n            if e.ndim == 2:\n                e = np.broadcast_to(e[None, ...], (n_frames, 3, 4))\n            extr_list.append(torch.from_numpy(e.copy()))\n        extrs = torch.stack(extr_list)  # [V, T, 3, 4]\n\n        # ------- Depth selection & caching -------\n        if self.use_duster_depths:\n            depth_root = os.path.join(self.dataset_dir, f\"duster_depths__{seq_name}\")\n            if not os.path.exists(os.path.join(depth_root, f\"3d_model__{n_frames - 1:05d}__scene.npz\")):\n                if \"../duster\" not in sys.path:\n                    sys.path.insert(0, \"../duster\")\n                from scripts.egoexo4d_preprocessing import main_estimate_duster_depth\n                pkl_path = os.path.join(self.dataset_dir, f\"{seq_name}.pkl\")\n\n                # Re-enable autograd locally (overrides any surrounding no_grad/inference_mode)\n                with ExitStack() as stack:\n                    stack.enter_context(torch.inference_mode(False))\n                    stack.enter_context(torch.enable_grad())\n                    main_estimate_duster_depth(pkl_path, depth_root, self.skip_depth_computation_if_cached)\n            duster_depths, duster_confs = [], []\n            for t in range(n_frames):\n                scene_path = os.path.join(depth_root, f\"3d_model__{t:05d}__scene.npz\")\n                scene = np.load(scene_path)\n                d = torch.from_numpy(scene[\"depths\"])  # [V, H', W']\n                d = interpolate(d[:, None], size=(H, W), mode=\"nearest\")  # [V, 1, H, W]\n                duster_depths.append(d)\n                c = torch.from_numpy(scene[\"confs\"])\n                c = interpolate(c[:, None], size=(H, W), mode=\"nearest\")\n                duster_confs.append(c)\n            depths = torch.stack(duster_depths, dim=1)  # [V, T, 1, H, W]\n            depth_confs = torch.stack(duster_confs, dim=1)\n\n        elif self.use_vggt_depths_with_aligned_cameras:\n            depths, depth_confs, intrs, extrs = _ensure_vggt_aligned_cache_and_load(\n                rgbs=rgbs,\n                seq_name=seq_name,\n                dataset_root=self.dataset_dir,\n                extrs_gt=extrs,  # your current GT world->cam\n                vggt_cache_subdir=\"vggt_cache\",\n                skip_if_cached=self.skip_depth_computation_if_cached,\n                model_id=\"facebook/VGGT-1B\",\n            )\n\n        elif self.use_vggt_depths_with_raw_cameras:\n            # Only use VGGT’s own (raw) cameras and depths\n            depths, depth_confs, intrs, extrs = _ensure_vggt_raw_cache_and_load(\n                rgbs=rgbs,\n                seq_name=seq_name,\n                dataset_root=self.dataset_dir,\n                vggt_cache_subdir=\"vggt_cache\",\n                skip_if_cached=self.skip_depth_computation_if_cached,\n                model_id=\"facebook/VGGT-1B\",\n            )\n\n        elif self.use_monofusion_depths:\n            # MonoFusion (Dust3r + FG/BG-heuristic + MoGE-2) with caching\n            final_depths, final_confs = _ensure_monofusion_cache_and_load(\n                rgbs=rgbs,\n                seq_name=seq_name,\n                dataset_root=self.dataset_dir,\n                monofusion_cache_subdir=\"monofusion_cache\",\n                skip_if_cached=self.skip_depth_computation_if_cached,\n            )\n            depths = final_depths\n            depth_confs = final_confs\n\n        elif self.use_moge2_depths:\n            # Raw MoGe-2 (metric) with caching\n            depths, depth_confs = _ensure_moge2_cache_and_load(\n                rgbs=rgbs,\n                seq_name=seq_name,\n                dataset_root=self.dataset_dir,\n                moge2_cache_subdir=\"moge2_cache\",\n                skip_if_cached=self.skip_depth_computation_if_cached,\n            )\n\n        elif depths_dict is not None:\n            depths = torch.stack([torch.from_numpy(depths_dict[cam]) for cam in cam_names]).unsqueeze(2)\n            depth_confs = depths.new_zeros(depths.shape)\n            depth_confs[depths > 0] = 1000\n\n        else:\n            raise ValueError(\"No depths available/configured\")\n\n        # Sometimes the first frames are noisy, e.g., due to timesync calibration\n        if self.drop_first_n_frames:\n            assert type(self.drop_first_n_frames) == int\n            n_frames -= self.drop_first_n_frames\n            rgbs = rgbs[:, self.drop_first_n_frames:]\n            depths = depths[:, self.drop_first_n_frames:]\n            depth_confs = depth_confs[:, self.drop_first_n_frames:]\n            intrs = intrs[:, self.drop_first_n_frames:]\n            extrs = extrs[:, self.drop_first_n_frames:]\n\n        if self.scene_normalization_mode == \"auto\":\n            scale, translation = compute_auto_scene_normalization(\n                depths, depth_confs, extrs, intrs,\n                conf_thresh=self.scene_normalization_auto_conf_thresh,\n                target_radius=self.scene_normalization_auto_target_radius,\n                rescale_by_camera_radius=self.scene_normalization_auto_rescale_by_camera_radius,\n            )\n            rot = torch.eye(3, dtype=torch.float32, device=depths.device)\n        elif self.scene_normalization_mode == \"manual\":\n            assert self.scene_normalization_manual_scale is not None\n            assert self.scene_normalization_manual_rotation is not None\n            assert self.scene_normalization_manual_translation is not None\n            scale = self.scene_normalization_manual_scale\n            rot = self.scene_normalization_manual_rotation.to(depths.device)\n            translation = self.scene_normalization_manual_translation.to(depths.device)\n        elif self.scene_normalization_mode == \"none\":\n            scale = 1.0\n            rot = torch.eye(3, dtype=torch.float32, device=depths.device)\n            translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device=depths.device)\n        else:\n            raise ValueError(f\"Unknown scene_normalization_mode: {self.scene_normalization_mode}\")\n\n        depths_trans, extrs_trans, _, _, _ = transform_scene(scale, rot, translation, depths, extrs, None, None, None)\n\n        assert rgbs.shape == (n_views, n_frames, 3, H, W)\n        assert depths.shape == (n_views, n_frames, 1, H, W)\n        assert depth_confs.shape == (n_views, n_frames, 1, H, W)\n        assert intrs.shape == (n_views, n_frames, 3, 3)\n        assert extrs.shape == (n_views, n_frames, 3, 4)\n        assert extrs_trans.shape == (n_views, n_frames, 3, 4)\n\n        if self.stream_viz_to_rerun:\n            import rerun as rr\n            from mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun\n            rr.init(f\"3dpt\", recording_id=\"v0.16\")\n            rr.connect_tcp()\n            log_pointclouds_to_rerun(f\"generic-1-before-norm\", idx, rgbs[None], depths[None],\n                                     intrs[None], extrs[None], depth_confs[None], [1.0])\n            log_pointclouds_to_rerun(f\"generic-2-after-norm\", idx, rgbs[None], depths[None],\n                                     intrs[None], extrs_trans[None], depth_confs[None], [1.0])\n\n        datapoint = Datapoint(\n            video=rgbs.float(),\n            videodepth=depths_trans.float(),\n            videodepthconf=depth_confs.float(),\n            feats=None,\n            segmentation=torch.ones((n_views, n_frames, 1, H, W), dtype=torch.float32),\n            trajectory=None,\n            trajectory_3d=None,\n            visibility=None,\n            valid=None,\n            seq_name=seq_name,\n            intrs=intrs.float(),\n            extrs=extrs_trans.float(),\n            query_points=None,\n            query_points_3d=None,\n            trajectory_category=None,\n            track_upscaling_factor=1.0,\n            novel_video=None,\n            novel_intrs=None,\n            novel_extrs=None,\n        )\n\n        return datapoint, True\n\n\ndef compute_auto_scene_normalization(\n        depths,\n        depth_confs,\n        extrs,\n        intrs,\n        conf_thresh=4.8,\n        target_radius=6.3,\n        rescale_by_camera_radius=True,\n):\n    V, T, _, H, W = depths.shape\n    device = depths.device\n\n    extrs_square = torch.eye(4, device=device)[None, None].repeat(V, T, 1, 1)\n    extrs_square[:, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float())\n    intrs_inv = torch.inverse(intrs.float())\n\n    y, x = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device), indexing=\"ij\")\n    homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).reshape(-1, 3).float()\n    homog = homog[None].expand(V, -1, -1)\n\n    pts_all = []\n    for v in range(V):\n        d = depths[v, 0, 0]\n        c = depth_confs[v, 0, 0]\n        mask = (c > conf_thresh) & (d > 0)\n        if mask.sum() < 100:\n            continue\n\n        d_flat = d.flatten()\n        conf_mask = mask.flatten()\n        intr_inv = intrs_inv[v, 0]\n        extr_inv = extrs_inv[v, 0]\n\n        cam_pts = (intr_inv @ homog[v].T).T * d_flat[:, None]\n        cam_pts = cam_pts[conf_mask]\n        cam_pts_h = torch.cat([cam_pts, torch.ones_like(cam_pts[:, :1])], dim=-1)\n        world_pts = (extr_inv @ cam_pts_h.T).T[:, :3]\n\n        pts_all.append(world_pts)\n\n    pts_all = torch.cat(pts_all, dim=0)\n    if pts_all.shape[0] < 100:\n        raise RuntimeError(\"Too few valid points for normalization.\")\n\n    # --- Center scene ---\n    centroid = pts_all.mean(dim=0)\n    pts_centered = pts_all - centroid\n\n    # --- Lift scene so floor is at z=0 ---\n    floor_z = pts_centered[:, 2].quantile(0.12)  # robust floor estimate\n    pts_lifted = pts_centered.clone()\n    pts_lifted[:, 2] -= floor_z\n\n    # --- Compute scale ---\n    if rescale_by_camera_radius:\n        cam_centers = extrs[:, 0, :, 3]  # (V, 3)\n        cam_centers_centered = cam_centers - centroid  # shift\n        cam_centers_centered[:, 2] -= floor_z  # lift\n        cam_dists = cam_centers_centered.norm(dim=1)\n        median_dist = cam_dists.median()\n        scale = target_radius / median_dist\n    else:\n        scene_radius = pts_lifted.norm(dim=1).quantile(0.95)\n        scale = target_radius / scene_radius\n\n    # --- Compute translation (after scaling) ---\n    translate = -scale * centroid\n    translate[2] -= scale * floor_z  # lift to z=0\n\n    return scale, translate\n\n\ndef _ensure_moge2_cache_and_load(rgbs, seq_name, dataset_root, moge2_cache_subdir, skip_if_cached=True):\n    \"\"\"\n    Raw MoGe-2 depth (metric) with per-sequence caching.\n    Returns (depths, confs) shaped [V,T,1,H,W] on CPU.\n    \"\"\"\n    V, T, _, H, W = rgbs.shape\n    cache_root = os.path.join(dataset_root, moge2_cache_subdir, seq_name)\n    os.makedirs(cache_root, exist_ok=True)\n    depths_path = os.path.join(cache_root, \"moge2_depths.npy\")\n    confs_path = os.path.join(cache_root, \"moge2_confs.npy\")\n\n    if skip_if_cached and os.path.isfile(depths_path) and os.path.isfile(confs_path):\n        d = torch.from_numpy(np.load(depths_path)).float()  # [V,T,H,W]\n        c = torch.from_numpy(np.load(confs_path)).float()  # [V,T,H,W]\n        return d.unsqueeze(2), c.unsqueeze(2)\n\n    d = _moge_depths(seq_name, rgbs, cache_root)  # [V,T,H,W], CPU float\n\n    # Simple constant confidence for MoGe-2\n    c = torch.full_like(d, 100.0)\n\n    np.save(depths_path, d.numpy())\n    np.save(confs_path, c.numpy())\n    return d.unsqueeze(2), c.unsqueeze(2)\n\n\ndef _ensure_monofusion_cache_and_load(rgbs, seq_name, dataset_root, monofusion_cache_subdir, skip_if_cached=True):\n    \"\"\"\n    MONOFUSION:\n      - Background mask: patch-change detector over temporal window (static -> BG)\n      - DUSt3R depth: load per frame/view; build static background depth by BG-temporal-average.\n      - MoGe-2 monocular depth per frame/view; align to background by affine (a,b).\n      - Merge BG (DUSt3R static) with FG (aligned MoGe).\n      - Cache final depths & confs.\n    \"\"\"\n    V, T, _, H, W = rgbs.shape\n\n    cache_root = os.path.join(dataset_root, monofusion_cache_subdir, seq_name)\n    os.makedirs(cache_root, exist_ok=True)\n    final_depths_path = os.path.join(cache_root, \"final_depths.npy\")\n    final_confs_path = os.path.join(cache_root, \"final_confs.npy\")\n\n    if skip_if_cached and os.path.isfile(final_depths_path) and os.path.isfile(final_confs_path):\n        fd = torch.from_numpy(np.load(final_depths_path))  # [V,T,H,W]\n        fc = torch.from_numpy(np.load(final_confs_path))  # [V,T,H,W]\n        return fd.unsqueeze(2), fc.unsqueeze(2)\n\n    # ---- DUSt3R depths per frame/view ----\n    depth_root = os.path.join(dataset_root, f\"duster_depths__{seq_name}\")\n    if not os.path.exists(os.path.join(depth_root, f\"3d_model__{T - 1:05d}__scene.npz\")):\n        if \"../duster\" not in sys.path:\n            sys.path.insert(0, \"../duster\")\n        from scripts.egoexo4d_preprocessing import main_estimate_duster_depth\n        pkl_path = os.path.join(dataset_root, f\"{seq_name}.pkl\")\n\n        # Re-enable autograd locally (overrides any surrounding no_grad/inference_mode)\n        with ExitStack() as stack:\n            stack.enter_context(torch.inference_mode(False))\n            stack.enter_context(torch.enable_grad())\n            main_estimate_duster_depth(pkl_path, depth_root, skip_if_cached)\n\n    duster_depths = []\n    for t in range(T):\n        scene_path = os.path.join(depth_root, f\"3d_model__{t:05d}__scene.npz\")\n        scene = np.load(scene_path)\n        d = torch.from_numpy(scene[\"depths\"])  # [V, H', W']\n        d = interpolate(d[:, None], size=(H, W), mode=\"nearest\")[:, 0]  # [V, H, W]\n        duster_depths.append(d)\n    duster_depths = torch.stack(duster_depths, dim=1)  # [V, T, H, W]\n\n    # ---- Background mask (patch-change) ----\n    compute_device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    bg_mask = _static_bg_mask_from_window(rgbs.to(compute_device)).cpu()  # [V,T,H,W] bool\n\n    # ---- Static background depth per camera via temporal average on BG pixels ----\n    V, T, _, _ = duster_depths.shape\n    D_bg = torch.zeros((V, H, W), dtype=torch.float32)\n    for v in range(V):\n        valid = bg_mask[v]  # [T,H,W]\n        num = (duster_depths[v] * valid).sum(dim=0)\n        den = valid.sum(dim=0).clamp_min(1)\n        D_bg[v] = num / den\n\n    # ---- MoGe-2 monocular depths per frame/view ----\n    moge_depths = _moge_depths(seq_name, rgbs, cache_root)  # [V,T,H,W]\n\n    # ---- Align MoGe to background (solve a,b on BG pixels) ----\n    compute_device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    moge_depths = moge_depths.to(compute_device, dtype=torch.float32)  # [V,T,H,W]\n    D_bg_exp = D_bg[:, None].expand_as(moge_depths).to(compute_device)  # [V,T,H,W]\n    bg_mask = bg_mask.to(compute_device)  # [V,T,H,W]\n\n    # Valid BG pixels\n    valid = bg_mask & torch.isfinite(moge_depths) & (moge_depths > 0) \\\n            & torch.isfinite(D_bg_exp) & (D_bg_exp > 0)\n\n    # Flatten over pixels\n    X = moge_depths.view(V, T, -1)  # [V,T,HW]\n    Y = D_bg_exp.view(V, T, -1)  # [V,T,HW]\n    M = valid.view(V, T, -1).float()  # [V,T,HW]\n\n    # Count valid pixels\n    n = M.sum(dim=-1)  # [V,T]\n    min_bg = 200\n    if (n < min_bg).any():\n        bad = torch.nonzero(n < min_bg, as_tuple=False)\n        raise RuntimeError(\n            f\"Too few background pixels in frames: {[(int(v), int(t)) for v, t in bad.tolist()]}\"\n        )\n\n    # Sufficient statistics\n    sx = (X * M).sum(dim=-1)\n    sy = (Y * M).sum(dim=-1)\n    sxx = (X * X * M).sum(dim=-1)\n    sxy = (X * Y * M).sum(dim=-1)\n\n    # Closed-form least squares for a, b\n    eps = 1e-8\n    mx = sx / n\n    my = sy / n\n    varx = sxx / n - mx * mx\n    cov = sxy / n - mx * my\n\n    a = cov / (varx + eps)  # [V,T]\n    b = my - a * mx\n\n    # Apply alignment\n    aligned_moge = (a[..., None] * X + b[..., None]).view(V, T, H, W)\n\n    # Optionally save scale/shift\n    scale = a.float().cpu()\n    shift = b.float().cpu()\n\n    # ---- Merge FG/BG ----\n    final_depths = torch.where(bg_mask, D_bg_exp, aligned_moge)  # [V,T,H,W]\n\n    # ---- Confidence map: high for BG, moderate for FG ----\n    final_confs = torch.zeros_like(final_depths)\n    final_confs[bg_mask] = 1000.0\n    final_confs[~bg_mask] = 10.0\n\n    # ---- Cache results ----\n    np.save(final_depths_path, final_depths.cpu().numpy())\n    np.save(final_confs_path, final_confs.cpu().numpy())\n    np.save(os.path.join(cache_root, \"scale.npy\"), scale.cpu().numpy())\n    np.save(os.path.join(cache_root, \"shift.npy\"), shift.cpu().numpy())\n\n    return final_depths.unsqueeze(2).cpu(), final_confs.unsqueeze(2).cpu()\n\n\ndef _static_bg_mask_from_window(\n        rgbs: torch.Tensor,\n        win: int = -1,\n        r: int = 7,  # spatial patch radius -> (2r+1)x(2r+1)\n        diff_thresh: float = 10.0  # uint8 scale threshold\n):\n    \"\"\"\n    Fast BG detector using 3D max-pooling over frame-to-frame diffs.\n    \"\"\"\n    V, T, C, H, W = rgbs.shape\n    device = rgbs.device\n\n    if T == 1:\n        return torch.ones((V, T, H, W), dtype=torch.bool, device=device)\n\n    if win == -1:\n        win = T\n\n    # 1) Frame-to-frame abs diff (channel-mean): boundaries of length T-1\n    x = rgbs.float()\n    diffs = (x[:, 1:] - x[:, :-1]).abs().mean(dim=2)  # [V, T-1, H, W]\n    diffs = diffs.unsqueeze(1)  # [V, 1, T-1, H, W]  (N,C,D,H,W for 3D pool)\n\n    # 2) 3D max pool over time & space:\n    #    - temporal kernel spans (2*win-1) boundaries\n    #    - spatial kernel spans (2r+1)x(2r+1) patch\n    kt = max(1, 2 * win - 1)\n    kh = kw = 2 * r + 1\n    pt = (kt - 1) // 2\n    ph = pw = r\n    pooled = F.max_pool3d(diffs, kernel_size=(kt, kh, kw), stride=1, padding=(pt, ph, pw))\n    pooled = pooled[:, 0]  # [V, T-1, H, W]\n\n    # 3) Map boundary maxima back to frame centers (symmetric nearest-window approx)\n    change = torch.zeros((V, T, H, W), device=device, dtype=pooled.dtype)\n    change[:, 0] = pooled[:, 0]\n    change[:, 1:-1] = torch.maximum(pooled[:, :-1], pooled[:, 1:])\n    change[:, -1] = pooled[:, -1]\n\n    # 4) Threshold -> background\n    bg_mask = (change < diff_thresh)\n    return bg_mask\n\n\ndef _moge_depths(seq_name, rgbs, cache_root, resize_to=512, batch_size=18):\n    \"\"\"Runs (and caches) MoGe-2; returns [V,T,H,W] float32 at native resolution.\"\"\"\n\n    # pip install git+https://github.com/microsoft/MoGe.git\n    from moge.model.v2 import MoGeModel as MoGe2Model\n\n    depths_path = os.path.join(cache_root, \"moge_depths.npy\")\n    if os.path.isfile(depths_path):\n        logging.info(f\"Loading cached MoGe-2 depths for {seq_name} from {depths_path}\")\n        return torch.from_numpy(np.load(depths_path)).float()\n\n    V, T, C, H, W = rgbs.shape\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = MoGe2Model.from_pretrained(\"Ruicheng/moge-2-vitl-normal\").to(device).eval()\n\n    if resize_to is None:\n        h1, w1 = H, W\n    else:\n        if H >= W:\n            h1, w1 = int(resize_to), max(1, round(resize_to * W / H))\n        else:\n            w1, h1 = int(resize_to), max(1, round(resize_to * H / W))\n\n    imgs = rgbs.view(V * T, C, H, W).float()\n    if (h1, w1) != (H, W):\n        imgs = F.interpolate(imgs, size=(h1, w1), mode=\"bilinear\", align_corners=False)\n    imgs = (imgs / 255.0).to(device, non_blocking=True)  # [N,3,h1,w1]\n\n    out_small = torch.empty((V * T, h1, w1), dtype=torch.float32, device=device)\n\n    with torch.inference_mode(), torch.autocast(device_type=device.type, dtype=torch.bfloat16,\n                                                enabled=(device.type == \"cuda\")):\n        N = imgs.shape[0]\n        for i in range(0, N, batch_size):\n            chunk = imgs[i:i + batch_size]  # [b,3,h1,w1]\n            pred = model.infer(chunk)  # expects batched input\n            assert isinstance(pred, dict) and \"depth\" in pred, \"MoGe-2 infer() must return dict with 'depth'.\"\n            d = torch.as_tensor(pred[\"depth\"], device=device)\n            assert d.ndim == 3 and d.shape[0] == chunk.shape[0] and tuple(d.shape[1:]) == (h1, w1), \\\n                f\"Depth shape {tuple(d.shape)} != ({chunk.shape[0]},{h1},{w1})\"\n            out_small[i:i + chunk.shape[0]] = d\n\n    if (h1, w1) != (H, W):\n        out = F.interpolate(out_small[:, None], size=(H, W), mode=\"bilinear\", align_corners=False)[:, 0]\n    else:\n        out = out_small\n    out = out.clamp_min(0).view(V, T, H, W).cpu()\n\n    np.save(depths_path, out.numpy())\n    return out\n\n\ndef _ensure_vggt_raw_cache_and_load(\n        rgbs: torch.Tensor,  # uint8 [V,T,3,H,W]\n        seq_name: str,\n        dataset_root: str,\n        vggt_cache_subdir: str = \"vggt_cache\",\n        skip_if_cached: bool = True,\n        model_id: str = \"facebook/VGGT-1B\",\n):\n    \"\"\"\n    Run VGGT and cache RAW predictions (no alignment).\n    Returns CPU float32 tensors:\n      depths_raw   [V,T,1,H,W]\n      confs        [V,T,1,H,W]  (constant 100)\n      intrs_raw    [V,T,3,3]\n      extrs_raw    [V,T,3,4]    (world->cam as predicted by VGGT)\n    \"\"\"\n    from mvtracker.models.core.vggt.models.vggt import VGGT\n    from mvtracker.models.core.vggt.utils.pose_enc import pose_encoding_to_extri_intri\n\n    assert rgbs.dtype == torch.uint8 and rgbs.ndim == 5 and rgbs.shape[2] == 3, \"rgbs must be uint8 [V,T,3,H,W]\"\n    V, T, _, H, W = rgbs.shape\n    cache_root = os.path.join(dataset_root, vggt_cache_subdir, seq_name)\n    os.makedirs(cache_root, exist_ok=True)\n\n    f_depths_raw = os.path.join(cache_root, \"vggt_depths_raw.npy\")  # [V,T,H,W]\n    f_confs = os.path.join(cache_root, \"vggt_confs.npy\")  # [V,T,H,W]\n    f_intr_raw = os.path.join(cache_root, \"vggt_intrinsics_raw.npy\")\n    f_extr_raw = os.path.join(cache_root, \"vggt_extrinsics_raw.npy\")\n\n    all_cached = all(os.path.isfile(p) for p in [f_depths_raw, f_confs, f_intr_raw, f_extr_raw])\n    if skip_if_cached and all_cached:\n        depths_raw = torch.from_numpy(np.load(f_depths_raw)).float().unsqueeze(2)\n        confs = torch.from_numpy(np.load(f_confs)).float().unsqueeze(2)\n        intrs_raw = torch.from_numpy(np.load(f_intr_raw)).float()\n        extrs_raw = torch.from_numpy(np.load(f_extr_raw)).float()\n        return depths_raw, confs, intrs_raw, extrs_raw\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = VGGT.from_pretrained(model_id).to(device).eval()\n    amp_dtype = torch.bfloat16 if (\n            device.type == \"cuda\" and torch.cuda.get_device_capability()[0] >= 8) else torch.float16\n\n    def _compute_pad_to_518(H0: int, W0: int, target: int = 518) -> Tuple[int, int, int, int, int, int]:\n        \"\"\"\n        Mirror VGGT's load_and_preprocess_images(mode='pad') padding math so we can undo it.\n        Returns: new_h, new_w, pad_top, pad_bottom, pad_left, pad_right\n        \"\"\"\n        # Make largest dim target, keep aspect, round smaller dim to /14*14, then pad to (target, target)\n        if W0 >= H0:\n            new_w = target\n            new_h = int(round((H0 * (new_w / W0)) / 14.0) * 14)\n            h_pad = max(0, target - new_h)\n            w_pad = 0\n        else:\n            new_h = target\n            new_w = int(round((W0 * (new_h / H0)) / 14.0) * 14)\n            h_pad = 0\n            w_pad = max(0, target - new_w)\n\n        pad_top = h_pad // 2\n        pad_bottom = h_pad - pad_top\n        pad_left = w_pad // 2\n        pad_right = w_pad - pad_left\n        return new_h, new_w, pad_top, pad_bottom, pad_left, pad_right\n\n    depths_raw_arr = torch.empty((V, T, H, W), dtype=torch.float32)\n    confs_arr = torch.full((V, T, H, W), 100.0, dtype=torch.float32)\n    intr_raw_arr = torch.empty((V, T, 3, 3), dtype=torch.float32)\n    extr_raw_arr = torch.empty((V, T, 3, 4), dtype=torch.float32)\n\n    with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device.type == \"cuda\"), dtype=amp_dtype):\n        for t in tqdm(range(T), desc=f\"VGGT RAW {seq_name}\", unit=\"f\"):\n            image_items = [rgbs[v, t].cpu() for v in range(V)]  # each: [3,H,W] uint8\n            images = _vggt_load_and_preprocess_images(image_items, mode=\"pad\").to(device)[None]  # [1,V,3,518,518]\n\n            tokens, ps_idx = model.aggregator(images)\n            pose_enc = model.camera_head(tokens)[-1]\n            extr_pred, intr_pred = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])  # [1,V,3,4],[1,V,3,3]\n            depth_maps, _ = model.depth_head(tokens, images, ps_idx)  # [1,V,518,518]\n\n            # per-view: undo pad, resize back to (H0,W0), adjust intrinsics\n            d_full_list, K_list = [], []\n            for v in range(V):\n                H0, W0 = int(rgbs[v, t].shape[-2]), int(rgbs[v, t].shape[-1])\n                new_h, new_w, pt, pb, pl, pr = _compute_pad_to_518(H0, W0)\n\n                # crop padding region out of the 518x518 depth\n                d_small = depth_maps[0, v:v + 1, pt:518 - pb, pl:518 - pr]  # [1,new_h,new_w]\n                d_full_v = F.interpolate(d_small[:, None, :, :, 0], size=(H0, W0), mode=\"nearest\")[:, 0]  # [1,H0,W0]\n                d_full_list.append(d_full_v.squeeze(0))\n\n                # adjust intrinsics: subtract removed pad, then scale to (H0,W0)\n                K = intr_pred[0, v].detach().cpu().float().clone()\n                K[0, 2] -= float(pl)\n                K[1, 2] -= float(pt)\n                S = torch.tensor([[W0 / float(new_w), 0.0, 0.0],\n                                  [0.0, H0 / float(new_h), 0.0],\n                                  [0.0, 0.0, 1.0]], dtype=torch.float32)\n                K_list.append((S @ K).unsqueeze(0))\n\n            depths_raw_arr[:, t] = torch.stack(d_full_list, dim=0)\n            intr_raw_arr[:, t] = torch.cat(K_list, dim=0)\n            extr_raw_arr[:, t] = extr_pred[0].detach().cpu().float()  # raw VGGT w2c\n\n    # save raw cache\n    np.save(f_depths_raw, depths_raw_arr.numpy())\n    np.save(f_confs, confs_arr.numpy())\n    np.save(f_intr_raw, intr_raw_arr.numpy())\n    np.save(f_extr_raw, extr_raw_arr.numpy())\n\n    return depths_raw_arr.unsqueeze(2), confs_arr.unsqueeze(2), intr_raw_arr, extr_raw_arr\n\n\ndef _vggt_load_and_preprocess_images(image_items, mode=\"crop\"):\n    \"\"\"\n    Same as VGGT loader, but accepts in-memory items as well.\n    \"\"\"\n    if len(image_items) == 0:\n        raise ValueError(\"At least 1 image is required\")\n\n    # Validate mode\n    if mode not in [\"crop\", \"pad\"]:\n        raise ValueError(\"Mode must be either 'crop' or 'pad'\")\n\n    images = []\n    shapes = set()\n    to_tensor = TF.ToTensor()\n    target_size = 518\n\n    def _to_pil(item):\n        # path\n        if isinstance(item, str):\n            img = Image.open(item)\n            return img\n        # numpy HWC\n        if isinstance(item, np.ndarray):\n            if item.ndim == 3 and item.shape[2] in (3, 4):\n                if item.dtype != np.uint8:\n                    item = item.astype(np.uint8)\n                return Image.fromarray(item)\n        # torch CHW\n        if torch.is_tensor(item):\n            x = item\n            if x.ndim == 3 and x.shape[0] in (3, 4):\n                if x.dtype == torch.uint8:\n                    arr = x.permute(1, 2, 0).cpu().numpy()\n                    return Image.fromarray(arr)\n                else:\n                    # assume float [0,1]\n                    arr = (x.clamp(0, 1) * 255.0).byte().permute(1, 2, 0).cpu().numpy()\n                    return Image.fromarray(arr)\n        raise ValueError(\"Unsupported image item type/shape\")\n\n    for item in image_items:\n        img = _to_pil(item)\n\n        # If there's an alpha channel, blend onto white background:\n        if img.mode == \"RGBA\":\n            # Create white background\n            background = Image.new(\"RGBA\", img.size, (255, 255, 255, 255))\n            # Alpha composite onto the white background\n            img = Image.alpha_composite(background, img)\n\n        # Now convert to \"RGB\" (this step assigns white for transparent areas)\n        img = img.convert(\"RGB\")\n\n        width, height = img.size\n\n        if mode == \"pad\":\n            # Make the largest dimension 518px while maintaining aspect ratio\n            if width >= height:\n                new_width = target_size\n                new_height = round(height * (new_width / width) / 14) * 14  # Make divisible by 14\n            else:\n                new_height = target_size\n                new_width = round(width * (new_height / height) / 14) * 14  # Make divisible by 14\n        else:  # mode == \"crop\"\n            # Original behavior: set width to 518px\n            new_width = target_size\n            # Calculate height maintaining aspect ratio, divisible by 14\n            new_height = round(height * (new_width / width) / 14) * 14\n\n        # Resize with new dimensions (width, height)\n        img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)\n        img = to_tensor(img)  # Convert to tensor (0, 1)\n\n        # Center crop height if it's larger than 518 (only in crop mode)\n        if mode == \"crop\" and new_height > target_size:\n            start_y = (new_height - target_size) // 2\n            img = img[:, start_y: start_y + target_size, :]\n\n        # For pad mode, pad to make a square of target_size x target_size\n        if mode == \"pad\":\n            h_padding = target_size - img.shape[1]\n            w_padding = target_size - img.shape[2]\n\n            if h_padding > 0 or w_padding > 0:\n                pad_top = h_padding // 2\n                pad_bottom = h_padding - pad_top\n                pad_left = w_padding // 2\n                pad_right = w_padding - pad_left\n\n                # Pad with white (value=1.0)\n                img = torch.nn.functional.pad(\n                    img, (pad_left, pad_right, pad_top, pad_bottom), mode=\"constant\", value=1.0\n                )\n\n        shapes.add((img.shape[1], img.shape[2]))\n        images.append(img)\n\n    # Check if we have different shapes\n    # In theory our model can also work well with different shapes\n    if len(shapes) > 1:\n        print(f\"Warning: Found images with different shapes: {shapes}\")\n        # Find maximum dimensions\n        max_height = max(shape[0] for shape in shapes)\n        max_width = max(shape[1] for shape in shapes)\n\n        # Pad images if necessary\n        padded_images = []\n        for img in images:\n            h_padding = max_height - img.shape[1]\n            w_padding = max_width - img.shape[2]\n\n            if h_padding > 0 or w_padding > 0:\n                pad_top = h_padding // 2\n                pad_bottom = h_padding - pad_top\n                pad_left = w_padding // 2\n                pad_right = w_padding - pad_left\n\n                img = torch.nn.functional.pad(\n                    img, (pad_left, pad_right, pad_top, pad_bottom), mode=\"constant\", value=1.0\n                )\n            padded_images.append(img)\n        images = padded_images\n\n    images = torch.stack(images)  # concatenate images\n\n    # Ensure correct shape when single image\n    if len(image_items) == 1:\n        # Verify shape is (1, C, H, W)\n        if images.dim() == 3:\n            images = images.unsqueeze(0)\n\n    return images\n\n\ndef _ensure_vggt_aligned_cache_and_load(\n        rgbs: torch.Tensor,  # uint8 [V,T,3,H,W]\n        seq_name: str,\n        dataset_root: str,\n        extrs_gt: torch.Tensor,  # [V,T,3,4] GT world->cam\n        vggt_cache_subdir: str = \"vggt_cache\",\n        skip_if_cached: bool = True,\n        model_id: str = \"facebook/VGGT-1B\",\n):\n    \"\"\"\n    Ensure RAW VGGT cache exists (running VGGT if needed), then align VGGT cameras to GT via\n    Umeyama (pred→gt) per frame. Returns CPU float32:\n\n      depths_aligned  [V,T,1,H,W]   (RAW depths scaled by s)\n      confs           [V,T,1,H,W]   (same constant 100 as RAW)\n      intr_aligned    [V,T,3,3]     (equal to RAW intrinsics; alignment is Sim3 in world)\n      extr_aligned    [V,T,3,4]     (VGGT w2c aligned to GT)\n    \"\"\"\n    # 1) Get RAW results (runs VGGT if needed)\n    depths_raw, confs_raw, intr_raw, extr_raw = _ensure_vggt_raw_cache_and_load(\n        rgbs=rgbs,\n        seq_name=seq_name,\n        dataset_root=dataset_root,\n        vggt_cache_subdir=vggt_cache_subdir,\n        skip_if_cached=skip_if_cached,\n        model_id=model_id,\n    )\n\n    # 2) Aligned cache file paths\n    cache_root = os.path.join(dataset_root, vggt_cache_subdir, seq_name)\n    f_depths_aln = os.path.join(cache_root, \"vggt_depths_aligned.npy\")\n    f_intr_aln = os.path.join(cache_root, \"vggt_intrinsics_aligned.npy\")\n    f_extr_aln = os.path.join(cache_root, \"vggt_extrinsics_aligned.npy\")\n\n    # 3) If aligned already cached, return it\n    if skip_if_cached and all(os.path.isfile(p) for p in [f_depths_aln, f_intr_aln, f_extr_aln]):\n        depths_aln = torch.from_numpy(np.load(f_depths_aln)).float().unsqueeze(2)\n        intr_aln = torch.from_numpy(np.load(f_intr_aln)).float()\n        extr_aln = torch.from_numpy(np.load(f_extr_aln)).float()\n        return depths_aln, confs_raw, intr_aln, extr_aln\n\n    # 4) Compute alignment\n    depths_raw_ = depths_raw.squeeze(2)  # [V,T,H,W]\n    V, T, H, W = depths_raw_.shape\n    assert extrs_gt.shape[:2] == (V, T), \"GT extrinsics must be [V,T,3,4]\"\n\n    depths_aln = depths_raw_.clone()\n    intr_aln = intr_raw.clone()  # intrinsics unchanged by world Sim3\n    extr_aln = extr_raw.clone()\n\n    def _camera_center_from_affine_extr(extr):\n        extr_sq = np.eye(4, dtype=np.float32)[None].repeat(extr.shape[0], 0)\n        extr_sq[:, :3, :4] = extr\n        extr_sq_inv = np.linalg.inv(extr_sq)\n        return extr_sq_inv[:, :3, 3]\n\n    for t in range(T):\n        gt_w2c = extrs_gt[:, t].cpu().numpy()\n        pred_w2c = extr_raw[:, t].cpu().numpy()\n\n        s, R_align, t_align = align_umeyama(\n            _camera_center_from_affine_extr(gt_w2c),\n            _camera_center_from_affine_extr(pred_w2c),\n        )\n        pred_w2c_aligned = apply_sim3_to_extrinsics(pred_w2c, s, R_align, t_align)\n\n        extr_aln[:, t] = torch.from_numpy(np.array(pred_w2c_aligned)).float()\n\n    # 5) Save aligned cache\n    np.save(f_depths_aln, depths_aln.numpy())\n    np.save(f_intr_aln, intr_aln.numpy())\n    np.save(f_extr_aln, extr_aln.numpy())\n\n    return depths_aln.unsqueeze(2), confs_raw, intr_aln, extr_aln\n"
  },
  {
    "path": "mvtracker/datasets/kubric_multiview_dataset.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport logging\nimport os\nimport pathlib\nimport re\nimport time\n\nimport cv2\nimport kornia\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom scipy.spatial.transform import Rotation as R\nfrom torch.utils.data import get_worker_info\nfrom torchvision.transforms import ColorJitter, GaussianBlur\nfrom torchvision.transforms import functional as F_torchvision\n\nfrom mvtracker.datasets.utils import Datapoint, read_json, read_tiff, read_png, transform_scene, add_camera_noise, \\\n    aug_depth\n\n\nclass KubricMultiViewDataset(torch.utils.data.Dataset):\n\n    @staticmethod\n    def from_name(\n            dataset_name: str,\n            dataset_root: str,\n            training_args=None,\n            fabric=None,\n            just_return_kwargs: bool = False,\n            subset: str = \"test\",\n    ):\n        \"\"\"\n        Examples of evaluation datasets supported by this factory method:\n        - kubric-multiview-v3\n        - kubric-multiview-v3-duster0123\n        - kubric-multiview-v3-duster01234567\n        - kubric-multiview-v3-duster01234567cleaned\n        - kubric-multiview-v3-duster01234567cleaned-views012\n        - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7\n        - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training\n        - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-single\n        - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-2dpt-single\n        - kubric-multiview-v3-duster01234567cleaned-views012-novelviews7-overfit-on-training-2dpt-single-cached\n        - kubric-multiview-v3-noise1.23cm\n\n        Example of a training dataset:\n        - kubric-multiview-v3-training\n        \"\"\"\n        # Parse the dataset name, chunk by chunk\n        non_parsed = dataset_name.replace(\"kubric-multiview-v3\", \"\", 1)\n\n        if non_parsed.startswith(\"-noise\"):\n            match = re.match(r\"-noise([\\d.]+)cm\", non_parsed)\n            assert match is not None\n            depth_noise_std = float(match.group(1))\n            depth_noise_std = depth_noise_std / 13  # real-world cm to kubric's metric unit\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            depth_noise_std = 0.0\n\n        if non_parsed.startswith(\"-duster\"):\n            match = re.match(r\"-duster(\\d+)(cleaned)?\", non_parsed)\n            assert match is not None\n            duster_views = list(map(int, match.group(1)))\n            use_duster = True\n            use_duster_cleaned = match.group(2) is not None\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            use_duster = False\n            use_duster_cleaned = False\n            duster_views = None\n\n        if non_parsed.startswith(\"-views\"):\n            match = re.match(r\"-views(\\d+)\", non_parsed)\n            assert match is not None\n            views = list(map(int, match.group(1)))\n            if duster_views is not None:\n                assert all(v in duster_views for v in views)\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            views = duster_views\n\n        if non_parsed.startswith(\"-novelviews\"):\n            match = re.match(r\"-novelviews(\\d+)\", non_parsed)\n            assert match is not None\n            novel_views = list(map(int, match.group(1)))\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            novel_views = None\n\n        if non_parsed.startswith(\"-training\"):\n            training = True\n            non_parsed = non_parsed.replace(\"-training\", \"\", 1)\n            assert training_args is not None\n            assert fabric is not None\n        else:\n            training = False\n\n        if non_parsed.startswith(\"-overfit-on-training\"):\n            overfit_on_train = True\n            non_parsed = non_parsed.replace(\"-overfit-on-training\", \"\", 1)\n            assert not training, \"Either ...-training or ...-overfit-on-training[-single][-2dpt]\"\n            assert training_args is not None\n            expected_training_dset_name = (dataset_name.replace(\"-overfit-on-training\", \"-training\")\n                                           .replace(\"-single\", \"\").replace(\"2dpt\", \"\"))\n            assert training_args.datasets.train.name == expected_training_dset_name, \\\n                f\"{expected_training_dset_name} != {training_args.datasets.train.name}\"\n        else:\n            overfit_on_train = False\n\n        if non_parsed.startswith(\"-single\"):\n            assert not training, \"The single-point evaluation options is not relevant for a training dataset\"\n            single_point = True\n            non_parsed = non_parsed.replace(\"-single\", \"\", 1)\n        else:\n            single_point = False\n\n        if non_parsed.startswith(\"-2dpt\"):\n            eval_2dpt = True\n            non_parsed = non_parsed.replace(\"-2dpt\", \"\", 1)\n        else:\n            eval_2dpt = False\n\n        if non_parsed.startswith(\"-cached\"):\n            use_cached_tracks = True\n            non_parsed = non_parsed.replace(\"-cached\", \"\", 1)\n        else:\n            use_cached_tracks = False\n\n        assert non_parsed == \"\", f\"Unparsed part of the dataset name: {non_parsed}\"\n\n        kubric_kwargs = {\n            \"data_root\": os.path.join(dataset_root, \"kubric-multiview\", subset),\n            \"seq_len\": 24,\n            \"traj_per_sample\": 512,\n            \"seed\": 72,\n            \"sample_vis_1st_frame\": False,\n            \"tune_per_scene\": False,\n            \"max_videos\": 30,\n            \"use_duster_depths\": use_duster,\n            \"duster_views\": duster_views,\n            \"clean_duster_depths\": use_duster_cleaned,\n            \"views_to_return\": views,\n            \"novel_views\": novel_views,\n            \"num_views\": -1 if views is not None else 4,\n            \"depth_noise_std\": depth_noise_std,\n            \"ratio_dynamic\": 0.5,\n            \"ratio_very_dynamic\": 0.25,\n            \"use_cached_tracks\": use_cached_tracks,\n        }\n        if training:\n            kubric_kwargs[\"virtual_dataset_size\"] = fabric.world_size * (training_args.trainer.num_steps + 1000)\n        if training or overfit_on_train:\n            kubric_kwargs[\"data_root\"] = (\n                os.path.join(training_args.datasets.root, \"kubric-multiview\", \"train\")\n                if not training_args.modes.debug else\n                os.path.join(training_args.datasets.root, \"kubric-multiview\", \"validation\")\n            )\n            kubric_kwargs[\"seq_len\"] = training_args.datasets.train.sequence_len\n            kubric_kwargs[\"traj_per_sample\"] = training_args.datasets.train.traj_per_sample\n            kubric_kwargs[\"max_depth\"] = training_args.datasets.train.kubric_max_depth\n            kubric_kwargs[\"tune_per_scene\"] = training_args.modes.tune_per_scene\n            if training:\n                kubric_kwargs[\"max_videos\"] = training_args.datasets.train.max_videos\n            else:\n                kubric_kwargs[\"max_videos\"] = 30\n\n            kubric_kwargs[\"augmentation_probability\"] = training_args.augmentations.probability\n            kubric_kwargs[\"enable_rgb_augs\"] = training_args.augmentations.rgb\n            kubric_kwargs[\"enable_depth_augs\"] = training_args.augmentations.depth\n            kubric_kwargs[\"enable_cropping_augs\"] = training_args.augmentations.cropping\n            kubric_kwargs[\"aug_crop_size\"] = training_args.augmentations.cropping_size\n            kubric_kwargs[\"enable_variable_trajpersample_augs\"] = training_args.augmentations.variable_trajpersample\n            kubric_kwargs[\"enable_scene_transform_augs\"] = training_args.augmentations.scene_transform\n            kubric_kwargs[\"enable_camera_params_noise_augs\"] = training_args.augmentations.camera_params_noise\n            kubric_kwargs[\"enable_variable_depth_type_augs\"] = training_args.augmentations.variable_depth_type\n            kubric_kwargs[\"enable_variable_num_views_augs\"] = training_args.augmentations.variable_num_views\n            kubric_kwargs[\"normalize_scene_following_vggt\"] = training_args.augmentations.normalize_scene_following_vggt\n            kubric_kwargs[\"enable_variable_vggt_crop_size_augs\"] = training_args.augmentations.variable_vggt_crop_size\n            kubric_kwargs[\"keep_principal_point_centered\"] = training_args.augmentations.keep_principal_point_centered\n\n            if training_args.modes.pretrain_only:\n                kubric_kwargs[\"ratio_dynamic\"] = 0.0\n                kubric_kwargs[\"ratio_very_dynamic\"] = 0.0\n\n            if training_args.augmentations.variable_num_views:\n                kubric_kwargs[\"num_views\"] = None\n                kubric_kwargs[\"views_to_return\"] = None\n                kubric_kwargs[\"duster_views\"] = None\n                kubric_kwargs[\"supported_duster_views_sets\"] = [\n                    [0, 1, 2, 3],\n                    [0, 1, 2, 3, 4, 5, 6, 7],\n                ]\n\n        if just_return_kwargs:\n            return kubric_kwargs\n\n        return KubricMultiViewDataset(**kubric_kwargs)\n\n    def __init__(\n            self,\n            data_root,\n            views_to_return=None,\n            novel_views=None,\n            use_duster_depths=False,\n            clean_duster_depths=False,\n            duster_views=None,\n            supported_duster_views_sets=None,\n            seq_len=24,\n            num_views=4,\n            traj_per_sample=768,\n            max_depth=1000,\n            sample_vis_1st_frame=False,\n            ratio_dynamic=0.5,\n            ratio_very_dynamic=0.25,\n            depth_noise_std=0.0,\n\n            augmentation_probability=0.0,\n            enable_rgb_augs=False,\n            enable_depth_augs=False,\n            enable_cropping_augs=False,\n            aug_crop_size=(384, 512),\n            enable_variable_trajpersample_augs=False,\n            enable_scene_transform_augs=False,\n            enable_camera_params_noise_augs=False,\n            enable_variable_depth_type_augs=False,\n            enable_variable_num_views_augs=False,\n\n            normalize_scene_following_vggt=False,\n            enable_variable_vggt_crop_size_augs=False,\n            keep_principal_point_centered=False,\n\n            static_cropping=False,\n            seed=None,\n            tune_per_scene=False,\n            max_videos=None,\n            virtual_dataset_size=None,\n            max_tracks_to_preload=18000,\n            perform_sanity_checks=False,\n            use_cached_tracks=False,\n    ):\n        super(KubricMultiViewDataset, self).__init__()\n\n        self.data_root = data_root\n        self.views_to_return = views_to_return\n        self.novel_views = novel_views\n        self.use_duster_depths = use_duster_depths\n        self.clean_duster_depths = clean_duster_depths\n        self.duster_views = duster_views\n        self.supported_duster_views_sets = supported_duster_views_sets\n        if self.use_duster_depths:\n            assert self.duster_views is not None, \"When using Duster depths, duster_views must be set.\"\n            if self.supported_duster_views_sets is None:\n                self.supported_duster_views_sets = [self.duster_views]\n\n        self.seq_len = seq_len\n        self.num_views = num_views\n        self.traj_per_sample = traj_per_sample\n        self.sample_vis_1st_frame = sample_vis_1st_frame\n        self.ratio_dynamic = ratio_dynamic\n        self.ratio_very_dynamic = ratio_very_dynamic\n\n        self.seed = seed\n        self.add_index_to_seed = not tune_per_scene\n\n        self.perform_sanity_checks = perform_sanity_checks\n        self.use_cached_tracks = use_cached_tracks\n        self.cache_name = self._cache_key()\n        self.max_tracks_to_preload = max_tracks_to_preload\n        if self.traj_per_sample is not None and self.max_tracks_to_preload is not None:\n            assert self.traj_per_sample <= self.max_tracks_to_preload, \"We need to preload more tracks than we sample.\"\n\n        self.depth_noise_std = depth_noise_std\n\n        # Augmentation settings\n        self.augmentation_probability = augmentation_probability\n        if any([enable_rgb_augs, enable_depth_augs, enable_variable_trajpersample_augs,\n                enable_scene_transform_augs, enable_camera_params_noise_augs, enable_variable_num_views_augs,\n                enable_variable_depth_type_augs]):\n            assert self.augmentation_probability > 0, \"Augmentations are enabled, but augmentation probability is 0%.\"\n        if self.augmentation_probability > 0:\n            assert not self.use_cached_tracks, \"caching tracks not supported with augs\"\n\n        self.enable_rgb_augs = enable_rgb_augs\n        self.enable_depth_augs = enable_depth_augs\n        self.enable_cropping_augs = enable_cropping_augs\n        self.enable_variable_trajpersample_augs = enable_variable_trajpersample_augs\n        self.enable_scene_transform_augs = enable_scene_transform_augs\n        self.enable_camera_params_noise_augs = enable_camera_params_noise_augs\n        self.enable_variable_num_views_augs = enable_variable_num_views_augs\n        self.enable_variable_depth_type_augs = enable_variable_depth_type_augs\n        self.enable_variable_depth_type_augs__depth_type_probability = {\n            \"gt\": 0.70, \"duster\": 0.20, \"duster_cleaned\": 0.10,\n        }\n        # TODO: self.enable_seqlen_augs = enable_seqlen_augs\n        if self.enable_variable_depth_type_augs:\n            assert not self.use_duster_depths, \"Cannot force depth type when using variable depth type augs.\"\n            assert not self.clean_duster_depths, \"Cannot force depth type when using variable depth type augs.\"\n        self.enable_variable_num_views_augs__n_views_probability = {\n            # v2\n            1: 0.20,\n            2: 0.10,\n            3: 0.10,\n            4: 0.25,\n            5: 0.10,\n            6: 0.25,\n\n            # # v1\n            # 1: 0.20,\n            # 2: 0.10,\n            # 3: 0.10,\n            # 4: 0.25,\n            # 5: 0.10,\n            # 6: 0.05,\n            # 7: 0.05,\n            # 8: 0.15,\n        }\n        self.enable_variable_num_views_augs__trajpersample_adjustment_factor = {\n            1: 1.00,\n            2: 1.00,\n            3: 1.00,\n            4: 1.00,\n            5: 0.40,\n            6: 0.25,\n        }\n        if self.enable_variable_num_views_augs:\n            assert self.num_views is None, \"Cannot use enable_variable_num_views_augs with num_views != None.\"\n            assert self.views_to_return is None, \"Cannot use enable_variable_num_views_augs with views_to_return.\"\n\n        # photometric augmentation\n        # TODO: \"Override\" ColorJitter and GaussianBlur to take in a random state\n        #       in forward pass so we can assure reproducibility. This affects\n        #       only training as augmentation is disabled during evaluation.\n        self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)\n        self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))\n        self.blur_aug_prob = 0.25\n        self.color_aug_prob = 0.25\n\n        # occlusion augmentation\n        self.eraser_aug_prob = 0.5\n        self.eraser_bounds = [2, 100]\n        self.eraser_max = 10\n\n        # occlusion augmentation\n        self.replace_aug_prob = 0.5\n        self.replace_bounds = [2, 100]\n        self.replace_max = 10\n\n        # spatial augmentations\n        self.crop_size = aug_crop_size\n        self.normalize_scene_following_vggt = normalize_scene_following_vggt\n        self.enable_variable_vggt_crop_size_augs = enable_variable_vggt_crop_size_augs\n        self.keep_principal_point_centered = keep_principal_point_centered\n\n        self.max_depth = max_depth\n\n        self.pad_bounds = [0, 45]\n        self.resize_lim = [0.8, 1.2]\n        self.resize_delta = 0.15\n        self.max_crop_offset = 36\n        if static_cropping or tune_per_scene:\n            self.pad_bounds = [0, 1]\n            self.resize_lim = [1.0, 1.0]\n            self.resize_delta = 0.0\n            self.max_crop_offset = 0\n\n        if self.keep_principal_point_centered:\n            self.pad_bounds = [0, 45]\n            self.resize_lim = [1.02, 1.25]\n            self.resize_delta = None\n            self.max_crop_offset = None\n            if static_cropping or tune_per_scene:\n                self.pad_bounds = [0, 1]\n                self.resize_lim = [1.04, 1.04]\n\n        self.seq_names = [\n            fname\n            for fname in os.listdir(self.data_root)\n            if os.path.isdir(os.path.join(self.data_root, fname))\n               and not fname.startswith(\".\")\n               and not fname.startswith(\"_\")\n        ]\n        self.seq_names = sorted(self.seq_names, key=lambda x: int(x))\n        seq_names_clean = []\n        for seq_name in self.seq_names:\n            scene_path = os.path.join(self.data_root, seq_name)\n            view_folders = [\n                d for d in os.listdir(scene_path)\n                if os.path.isdir(os.path.join(scene_path, d)) and d.startswith('view_')\n            ]\n            if len(view_folders) == 0:\n                logging.warning(f\"Skipping {scene_path} because it has no views.\")\n                continue\n            if self.num_views is not None and len(view_folders) < self.num_views:\n                logging.warning(f\"Skipping {scene_path} because it has {len(view_folders)} views (<{self.num_views}).\")\n                continue\n            seq_names_clean.append(seq_name)\n        self.seq_names = seq_names_clean\n\n        if self.supported_duster_views_sets is not None:\n            supported_duster_views_sets_cleaned = []\n            for s in self.supported_duster_views_sets:\n                duster_views_str = ''.join(str(v) for v in s)\n                if os.path.isdir(os.path.join(self.data_root, self.seq_names[0], f\"duster-views-{duster_views_str}\")):\n                    supported_duster_views_sets_cleaned.append(s)\n                else:\n                    logging.warning(f\"Skipping duster views set {s} because it does not exist.\")\n            self.supported_duster_views_sets = supported_duster_views_sets_cleaned\n\n        if tune_per_scene:\n            self.seq_names = self.seq_names[3:4]\n        if max_videos is not None:\n            self.seq_names = self.seq_names[:max_videos]\n        logging.info(\"Using %d videos from %s\" % (len(self.seq_names), self.data_root))\n\n        self.real_len = len(self.seq_names)\n        if virtual_dataset_size is not None:\n            self.virtual_len = virtual_dataset_size\n        else:\n            self.virtual_len = self.real_len\n        logging.info(f\"Real dataset size: {self.real_len}. Virtual dataset size: {self.virtual_len}.\")\n\n        self.getitem_calls = 0\n\n    def _cache_key(self):\n        name = f\"cachedtracks--seed{self.seed}-dynamic{self.ratio_dynamic}-verydynamic-{self.ratio_very_dynamic}\"\n        if self.views_to_return is not None:\n            name += f\"-views{'_'.join(map(str, self.views_to_return))}\"\n        if self.traj_per_sample is not None:\n            name += f\"-n{self.traj_per_sample}\"\n        if self.num_views is not None:\n            name += f\"-numviews{self.num_views}\"\n        if self.seq_len is not None:\n            name += f\"-t{self.seq_len}\"\n        if self.sample_vis_1st_frame:\n            name += f\"-sample_vis_1st_frame\"\n        return name + \"--v1\"  # bump this if you change the selection policy\n\n    def __len__(self):\n        return self.virtual_len\n\n    def __getitem__(self, index):\n        index = index % self.real_len\n\n        sample, gotit = self._getitem_helper(index)\n\n        if not gotit:\n            logging.warning(\"warning: sampling failed\")\n            # fake sample, so we can still collate\n            num_views = self.num_views if self.num_views is not None else 4\n            h, w = 384, 512\n            traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else 768\n            sample = Datapoint(\n                video=torch.zeros((num_views, self.seq_len, 3, h, w)),\n                videodepth=torch.zeros((num_views, self.seq_len, 1, h, w)),\n                segmentation=torch.zeros((num_views, self.seq_len, 1, h, w)),\n                trajectory=torch.zeros((self.seq_len, traj_per_sample, 2)),\n                visibility=torch.zeros((self.seq_len, traj_per_sample)),\n                valid=torch.zeros((self.seq_len, traj_per_sample)),\n            )\n\n        return sample, gotit\n\n    def _getitem_helper(self, index):\n        start_time_1 = time.time()\n\n        gotit = True\n\n        # Take a new seed from torch or use self.seed if set\n        # The rest of the code will use generators initialized with this seed\n        if self.seed is None:\n            seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()\n        else:\n            seed = self.seed\n            if self.add_index_to_seed:\n                seed += index\n        rnd_torch = torch.Generator().manual_seed(seed)\n        rnd_np = np.random.RandomState(seed=seed)\n\n        # Load the data\n        datapoint = KubricMultiViewDataset.getitem_raw_datapoint(os.path.join(self.data_root, self.seq_names[index]))\n\n        traj3d_world = datapoint[\"tracks_3d\"].numpy()\n        tracks_segmentation_ids = datapoint[\"tracks_segmentation_ids\"].numpy()\n        tracked_objects = datapoint[\"tracked_objects\"]\n        camera_positions = datapoint[\"camera_positions\"].numpy()\n        lookat_positions = datapoint[\"lookat_positions\"].numpy()\n        views = datapoint[\"views\"]\n\n        # Take a random depth type, if enabled\n        if self.enable_variable_depth_type_augs:\n            assert self.use_duster_depths is False, \"Cannot force depth type when using variable depth type augs.\"\n            assert self.clean_duster_depths is False, \"Cannot force depth type when using variable depth type augs.\"\n            depth_type = rnd_np.choice(\n                a=list(self.enable_variable_depth_type_augs__depth_type_probability.keys()),\n                size=1,\n                p=list(self.enable_variable_depth_type_augs__depth_type_probability.values()),\n            )[0]\n            use_duster_depths, clean_duster_depths = {\n                \"gt\": (False, False),\n                \"duster\": (True, False),\n                \"duster_cleaned\": (True, True),\n            }[depth_type]\n        else:\n            use_duster_depths = self.use_duster_depths\n            clean_duster_depths = self.clean_duster_depths\n\n        # Take a random number of views, if enabled\n        all_views = sorted(list(range(len(views))))\n        if self.enable_variable_num_views_augs:\n            assert self.num_views is None, \"Cannot use enable_variable_num_views_augs with num_views != None.\"\n            assert self.views_to_return is None, \"Cannot use enable_variable_num_views_augs with views_to_return.\"\n            num_views = rnd_np.choice(\n                a=list(self.enable_variable_num_views_augs__n_views_probability.keys()),\n                size=1,\n                p=list(self.enable_variable_num_views_augs__n_views_probability.values()),\n            )[0]\n            if use_duster_depths:\n                num_views = min(num_views, max([len(s) for s in self.supported_duster_views_sets]))\n                # Take only those that have the closest number of views that is greater or equal to num_views\n                closest_num_views_in_supported_duster_views_set = min([\n                    len(vs)\n                    for vs in self.supported_duster_views_sets\n                    if len(vs) >= num_views\n                ])\n                supported_duster_views_sets = [\n                    vs\n                    for vs in self.supported_duster_views_sets\n                    if len(vs) == closest_num_views_in_supported_duster_views_set\n                ]\n                duster_views = supported_duster_views_sets[rnd_np.randint(len(supported_duster_views_sets))]\n                views_to_return = rnd_np.choice(duster_views, num_views, replace=False).tolist()\n            else:\n                views_to_return = rnd_np.choice(all_views, num_views, replace=False).tolist()\n                duster_views = views_to_return\n        else:\n            num_views = self.num_views\n            if self.views_to_return is not None:\n                assert num_views == -1, \"Cannot use views_to_return with num_views != -1.\"\n                views_to_return = self.views_to_return\n            elif use_duster_depths:\n                if self.duster_views is not None:\n                    duster_views = self.duster_views\n                else:\n                    # Take only those that have the closest number of views that is greater or equal to num_views\n                    closest_num_views_in_supported_duster_views_set = min([\n                        len(vs)\n                        for vs in self.supported_duster_views_sets\n                        if len(vs) >= num_views\n                    ])\n                    supported_duster_views_sets = [\n                        vs\n                        for vs in self.supported_duster_views_sets\n                        if len(vs) == closest_num_views_in_supported_duster_views_set\n                    ]\n                    duster_views = supported_duster_views_sets[rnd_np.randint(len(supported_duster_views_sets))]\n                views_to_return = duster_views\n            else:\n                if num_views == -1:\n                    # Take all views\n                    views_to_return = all_views\n                elif num_views is None:\n                    # Randomly sample a number of views\n                    n = rnd_np.randint(min(3, len(views)), len(views) + 1)\n                    views_to_return = rnd_np.choice(all_views, n, replace=False).tolist()\n                else:\n                    # Take a fixed number of views\n                    assert num_views > 0, \"Fixed number of views must be positive.\"\n                    assert num_views <= len(views), f\"Not enough views available (idx={index}).\"\n                    views_to_return = rnd_np.choice(all_views, num_views, replace=False).tolist()\n            if self.duster_views is not None:\n                duster_views = self.duster_views\n            else:\n                duster_views = views_to_return\n\n        # Extract only the data we need\n        rgbs = np.stack([views[v][\"rgba\"][..., :3].numpy() for v in views_to_return])\n        depths = np.stack([views[v][\"depth\"].numpy() for v in views_to_return])\n        # segs = np.stack([views[v][\"segmentation\"].numpy() for v in views_to_return])\n        segs = np.ones(((rgbs.shape[0], rgbs.shape[1], rgbs.shape[2], rgbs.shape[3], 1)), dtype=np.float32)\n        intrs = np.stack([views[v][\"intrinsics\"].numpy() for v in views_to_return])\n        intrs = intrs[:, None, :, :].repeat(rgbs.shape[1], axis=1)\n        extrs = np.stack([views[v][\"extrinsics\"].numpy() for v in views_to_return])\n        traj2d = np.stack([views[v][\"tracks_2d\"].numpy() for v in views_to_return])\n        visibility = ~np.stack([views[v][\"occlusion\"].numpy() for v in views_to_return])\n\n        novel_rgbs = None\n        novel_intrs = None\n        novel_extrs = None\n        if self.novel_views is not None:\n            novel_rgbs = np.stack([views[v][\"rgba\"][..., :3].numpy() for v in self.novel_views])\n            novel_intrs = np.stack([views[v][\"intrinsics\"].numpy() for v in self.novel_views])\n            novel_intrs = novel_intrs[:, None, :, :].repeat(rgbs.shape[1], axis=1)\n            novel_extrs = np.stack([views[v][\"extrinsics\"].numpy() for v in self.novel_views])\n\n        # Load Duster's features and estimated depths if they exist\n        duster_views_str = ''.join(str(v) for v in duster_views)\n        duster_root = pathlib.Path(self.data_root) / self.seq_names[index] / f'duster-views-{duster_views_str}'\n\n        num_views, n_frames, h, w, _ = rgbs.shape\n        feats = None\n        feat_dim = None\n        feat_stride = None\n        duster_outputs_exist = duster_root.exists() and (\n                duster_root / f\"3d_model__{n_frames - 1:05d}__scene.npz\").exists()\n        if use_duster_depths:\n            assert duster_outputs_exist, \"use_duster_depths --> duster_output_exist\"\n        if duster_outputs_exist:\n            duster_depths = []\n            duster_feats = []\n            for frame_idx in range(n_frames):\n                scene = np.load(duster_root / f\"3d_model__{frame_idx:05d}__scene.npz\")\n                duster_depth = torch.from_numpy(scene[\"depths\"])\n                duster_conf = torch.from_numpy(scene[\"confs\"])\n                duster_msk = torch.from_numpy(scene[\"cleaned_mask\"])\n                duster_feat = torch.from_numpy(scene[\"feats\"])\n\n                if clean_duster_depths:\n                    ## Filter based on the confidence\n                    # conf_threshold = max(0.00001, min(0.1, torch.quantile(duster_conf.flatten(), 0.3).item()))\n                    # duster_depth = duster_depth * (duster_conf > conf_threshold)\n\n                    # Filter based on the mask\n                    duster_depth = duster_depth * duster_msk\n\n                duster_depth = F.interpolate(duster_depth[:, None], (depths.shape[2], depths.shape[3]), mode='nearest')\n\n                duster_depths.append(duster_depth[:, 0, :, :, None])\n                duster_feats.append(duster_feat)\n\n            duster_depths = torch.stack(duster_depths, dim=1).numpy()\n            feats = torch.stack(duster_feats, dim=1).numpy()\n\n            # Extract the correct views\n            assert duster_depths.shape[0] == feats.shape[0] == len(duster_views)\n            duster_depths = duster_depths[[duster_views.index(v) for v in views_to_return]]\n            feats = feats[[duster_views.index(v) for v in views_to_return]]\n\n            # Reshape the features\n            assert feats.ndim == 4\n            assert feats.shape[0] == num_views\n            assert feats.shape[1] == n_frames\n            feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int)\n            feat_dim = feats.shape[3]\n            feats = feats.reshape(num_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n\n            # Replace the depths with the Duster depths, if configured so\n            if use_duster_depths:\n                depths = duster_depths\n\n        start_time_2 = time.time()\n\n        # Strategically select dynamic points to track\n        visible_at_t_and_t_plus_1 = (visibility[:, :-1] & visibility[:, 1:]).any(0)\n        movement = np.linalg.norm(traj3d_world[1:] - traj3d_world[:-1], axis=-1)\n        movement[~visible_at_t_and_t_plus_1] = 0\n        movement = movement.sum(axis=0)\n        assert np.isfinite(movement).all(), \"Movement contains NaN or Inf values.\"\n\n        static_threshold = 0.01  # < 1 cm\n        dynamic_threshold = 0.1  # > 10 cm\n        very_dynamic_threshold = 2.0  # > 2 m\n\n        static_points = movement < static_threshold  # 1 cm\n        dynamic_points = movement > dynamic_threshold  # 10 cm\n        very_dynamic_points = movement > very_dynamic_threshold  # 2 m\n\n        if self.perform_sanity_checks:\n            logging.info(f\"Movement stats: \"\n                         f\"static: {static_points.sum()} ({static_points.mean() * 100:.2f}), \"\n                         f\"dynamic: {dynamic_points.sum()} ({dynamic_points.mean() * 100:.2f}), \"\n                         f\"very dynamic: {very_dynamic_points.sum()} ({very_dynamic_points.mean() * 100:.2f})\"\n                         f\"other: {(~static_points & ~dynamic_points & ~very_dynamic_points).sum()}\")\n\n        # Sample the points according to the desired ratios if possible\n        max_tracks_to_preload = traj3d_world.shape[1]\n        max_tracks_to_preload = min([\n            max_tracks_to_preload,\n            int(dynamic_points.sum() / self.ratio_dynamic) if self.ratio_dynamic > 0 else max_tracks_to_preload,\n            int(very_dynamic_points.sum() // self.ratio_very_dynamic) if self.ratio_very_dynamic > 0 else max_tracks_to_preload,\n            int(static_points.sum() / (1 - self.ratio_dynamic - self.ratio_very_dynamic)),\n        ])\n        if self.max_tracks_to_preload is not None:\n            max_tracks_to_preload = min(max_tracks_to_preload, self.max_tracks_to_preload)\n        n_dynamic = min(int(max_tracks_to_preload * self.ratio_dynamic), dynamic_points.sum())\n        n_very_dynamic = min(int(max_tracks_to_preload * self.ratio_very_dynamic), very_dynamic_points.sum())\n        n_static = max_tracks_to_preload - n_dynamic - n_very_dynamic\n\n        dynamic_indices = rnd_np.choice(np.where(dynamic_points)[0], n_dynamic, replace=False)\n        very_dynamic_indices = rnd_np.choice(np.where(very_dynamic_points)[0], n_very_dynamic, replace=False)\n        static_indices = rnd_np.choice(np.where(static_points)[0], n_static, replace=False)\n\n        selected_indices = np.concatenate([dynamic_indices, very_dynamic_indices, static_indices])\n        rnd_np.shuffle(selected_indices)\n\n        traj3d_world = traj3d_world[:, selected_indices]\n        traj2d = traj2d[:, :, selected_indices]\n        visibility = visibility[:, :, selected_indices]\n        tracks_segmentation_ids = tracks_segmentation_ids[selected_indices]\n\n        if traj3d_world.shape[1] > max_tracks_to_preload:\n            traj3d_world = traj3d_world[:, :max_tracks_to_preload]\n            traj2d = traj2d[:, :, :max_tracks_to_preload]\n            visibility = visibility[:, :, :max_tracks_to_preload]\n\n        n_tracks = traj3d_world.shape[1]\n        num_views, n_frames, h, w, _ = rgbs.shape\n        assert n_frames >= self.seq_len\n        assert rgbs.shape == (num_views, n_frames, h, w, 3)\n        assert depths.shape == (num_views, n_frames, h, w, 1)\n        assert segs.shape == (num_views, n_frames, h, w, 1)\n        assert feats is None or feats.shape == (num_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n        assert intrs.shape == (num_views, n_frames, 3, 3)\n        assert extrs.shape == (num_views, n_frames, 3, 4)\n        assert traj2d.shape == (num_views, n_frames, n_tracks, 2)\n        assert visibility.shape == (num_views, n_frames, n_tracks)\n        assert traj3d_world.shape == (n_frames, n_tracks, 3)\n\n        if novel_rgbs is not None:\n            assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3)\n            assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3)\n            assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4)\n\n        if ((depths < 0.01) & (depths != 0)).mean() > 0.5:\n            raise ValueError(\"Depth map might be invalid? Values that are too small will be ignored by SpaTracker, \"\n                             \"but found that more than half of non-zero depths are below 0.01 in the loaded depths.\")\n\n        # Make sure our intrinsics and extrinsics work correctly\n        point_3d_world = traj3d_world\n        point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1)\n        point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo)\n        if self.perform_sanity_checks:\n            point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera)\n            point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:]\n            point_2d_pixel_gt = traj2d\n            assert np.allclose(point_2d_pixel[0, :, 0, :], point_2d_pixel_gt[0, :, 0, :], atol=1e-3), f\"Proj. failed\"\n            assert np.allclose(point_2d_pixel, point_2d_pixel_gt, atol=1e-3), f\"Point projection failed\"\n\n        # Now save the z value in traj3d_camera as usual, just if needed\n        traj3d_camera = point_3d_camera\n        assert traj3d_camera.shape == (num_views, n_frames, n_tracks, 3)\n\n        # Also sanity check that pix2cam is working correctly with the intrinsics\n        if self.perform_sanity_checks:\n            from mvtracker.models.core.spatracker.blocks import pix2cam\n            xyz = np.concatenate([traj2d, traj3d_camera[..., 2:]], axis=-1)\n            pix2cam_xyz = torch.from_numpy(xyz).double()\n            pix2cam_intr = torch.from_numpy(intrs).double()\n            traj_3d_repro = pix2cam(pix2cam_xyz, pix2cam_intr).numpy()\n            assert np.allclose(traj3d_camera, traj_3d_repro, atol=0.1)\n\n        # If the video is too long, randomly crop self.seq_len frames\n        if self.seq_len < n_frames:\n            start_ind = rnd_np.choice(n_frames - self.seq_len, 1)[0]\n            rgbs = rgbs[:, start_ind: start_ind + self.seq_len]\n            depths = depths[:, start_ind: start_ind + self.seq_len]\n            segs = segs[:, start_ind: start_ind + self.seq_len]\n            if feats is not None:\n                feats = feats[:, start_ind: start_ind + self.seq_len]\n            intrs = intrs[:, start_ind: start_ind + self.seq_len]\n            extrs = extrs[:, start_ind: start_ind + self.seq_len]\n            traj2d = traj2d[:, start_ind: start_ind + self.seq_len]\n            visibility = visibility[:, start_ind: start_ind + self.seq_len]\n            traj3d_camera = traj3d_camera[:, start_ind: start_ind + self.seq_len]\n            traj3d_world = traj3d_world[start_ind: start_ind + self.seq_len]\n            n_frames = self.seq_len\n\n        # Add the z value to the traj2d\n        traj2d_w_z = np.concatenate((traj2d[..., :], traj3d_camera[..., 2:]), axis=-1)\n\n        start_time_3 = time.time()\n        augment_this_datapoint = False\n        if self.augmentation_probability > 0:\n            augment_this_datapoint = rnd_np.rand() <= self.augmentation_probability\n        if augment_this_datapoint and self.enable_rgb_augs:\n            rgbs, visibility = self._add_photometric_augs(rgbs, traj2d_w_z, visibility, rnd_np)\n\n        crop_size = self.crop_size\n        if augment_this_datapoint and self.enable_variable_vggt_crop_size_augs:\n            sizes = list(range(168, 518 + 14, 14))  # VIT-friendly sizes\n            weights = np.array(sizes) ** 2  # Quadratic bias toward larger sizes\n            probs = weights / weights.sum()\n            shorter_side = rnd_np.choice(a=sizes, size=1, p=probs)[0]\n            longer_side = max(crop_size)\n            crop_size = (shorter_side, longer_side)\n        if self.enable_cropping_augs and not self.keep_principal_point_centered:\n            rgbs, depths, intrs, traj2d_w_z, visibility = self._add_cropping_augs(\n                crop_size=crop_size,\n                rgbs=rgbs,\n                depths=depths,\n                intrs=intrs,\n                trajs=traj2d_w_z,\n                visibles=visibility,\n            )\n            h, w = rgbs.shape[-3:-1]\n        if self.enable_cropping_augs and self.keep_principal_point_centered:\n            rgbs, depths, intrs, traj2d_w_z, visibility = self._add_cropping_augs_with_pp_at_center(\n                crop_size=crop_size,\n                rgbs=rgbs,\n                depths=depths,\n                intrs=intrs,\n                trajs=traj2d_w_z,\n                visibles=visibility,\n            )\n            h, w = rgbs.shape[-3:-1]\n\n        depths[depths > self.max_depth] = 0.0\n        if augment_this_datapoint and self.enable_depth_augs:\n            invalid_depth_mask = depths <= 0.0\n            depths = aug_depth(\n                torch.from_numpy(depths).reshape(num_views * n_frames, 1, h, w),\n                grid=(16, 16),\n                scale=(0.99, 1.01),\n                shift=(-0.001, 0.001),\n                gn_kernel=(5, 5),\n                gn_sigma=(2, 2),\n                generator=rnd_torch,\n            ).reshape(num_views, n_frames, h, w, 1).numpy()\n            depths, visibility = self._rescale_and_erase_depth_patches(depths, traj2d_w_z, visibility, rnd_np)\n            depths[invalid_depth_mask] = 0.0  # Restore invalid depths\n\n        if self.depth_noise_std > 0.0:\n            invalid_depth_mask = depths <= 0.0\n            noise = np.random.normal(loc=0.0, scale=self.depth_noise_std, size=depths.shape)\n            depths = depths + noise.astype(depths.dtype)\n            depths = np.clip(depths, 0.0, self.max_depth)\n            depths[invalid_depth_mask] = 0.0  # Restore invalid depths\n\n        rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float()\n        depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float()\n        segs = torch.from_numpy(segs).permute(0, 1, 4, 2, 3).float()\n        feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None\n        intrs = torch.from_numpy(intrs).float()\n        extrs = torch.from_numpy(extrs).float()\n        visibility = torch.from_numpy(visibility)\n        traj2d = torch.from_numpy(traj2d)\n        traj2d_w_z = torch.from_numpy(traj2d_w_z)\n        traj3d_camera = torch.from_numpy(traj3d_camera)\n        traj3d_world = torch.from_numpy(traj3d_world)\n        if novel_rgbs is not None:\n            novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float()\n            novel_intrs = torch.from_numpy(novel_intrs).float()\n            novel_extrs = torch.from_numpy(novel_extrs).float()\n\n        # Track selection\n        cache_root = os.path.join(self.data_root, self.seq_names[index], \"cache\")\n        os.makedirs(cache_root, exist_ok=True)\n        cache_file = os.path.join(cache_root, f\"{self.cache_name}.npz\")\n\n        # Check if we can use cached tracks\n        use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file)\n        if use_cache:\n            cache = np.load(cache_file)\n            visible_inds_sampled = torch.from_numpy(cache[\"track_indices\"])\n            traj2d_w_z = torch.from_numpy(cache[\"traj2d_w_z\"])\n            traj3d_world = torch.from_numpy(cache[\"traj3d_world\"])\n            visibility = torch.from_numpy(cache[\"visibility\"])\n            valids = torch.from_numpy(cache[\"valids\"])\n            query_points = torch.from_numpy(cache[\"query_points\"])\n\n        # Otherwise, sample the tracks and create query points\n        else:\n            # Sample the points to track\n            visibile_pts_first_frame_inds = (visibility.any(0)[0]).nonzero(as_tuple=False)[:, 0]\n            if self.sample_vis_1st_frame:\n                visibile_pts_inds = visibile_pts_first_frame_inds\n            else:\n                visibile_pts_mid_frame_inds = (visibility.any(0)[self.seq_len // 2]).nonzero(as_tuple=False)[:, 0]\n                visibile_pts_inds = torch.cat((visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0)\n                visibile_pts_inds = torch.unique(visibile_pts_inds)\n            visible_for_at_least_two_frames = (visibility.any(0).sum(0) >= 2).nonzero(as_tuple=False)[:, 0]\n            visibile_pts_inds = visibile_pts_inds[torch.isin(visibile_pts_inds, visible_for_at_least_two_frames)]\n            point_inds = torch.randperm(len(visibile_pts_inds), generator=rnd_torch)\n\n            traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds)\n            if self.enable_variable_num_views_augs:\n                adj_factor = self.enable_variable_num_views_augs__trajpersample_adjustment_factor.get(num_views, 1.0)\n                traj_per_sample = int(traj_per_sample * adj_factor)\n            if len(point_inds) == 0 or len(point_inds) < traj_per_sample // 4:\n                gotit = False\n                return None, gotit\n            if augment_this_datapoint and self.enable_variable_trajpersample_augs:\n                if index % 20 == 0:\n                    traj_per_sample = traj_per_sample // 8\n                elif index % 21 == 0:\n                    pass  # keep the same number of trajectories\n                else:\n                    low = max(1, traj_per_sample // 4)\n                    high = min(len(point_inds), traj_per_sample) + 1\n                    traj_per_sample = torch.randint(low=low, high=high, size=(1,), generator=rnd_torch).item()\n            else:\n                traj_per_sample = min(len(point_inds), traj_per_sample)\n            point_inds = point_inds[:traj_per_sample]\n            logging.info(\n                f\"[i={index:04d};seq={self.seq_names[index]};seed={seed}]\"\n                f\"Selected {len(point_inds)}/{len(visibile_pts_inds)} tracks. \"\n                f\"{num_views=}. \"\n                f\"{point_inds[0]=} max_depth={self.max_depth}.\"\n            )\n\n            visible_inds_sampled = visibile_pts_inds[point_inds]\n\n            n_tracks = len(visible_inds_sampled)\n            traj2d = traj2d[:, :, visible_inds_sampled].float()\n            traj2d_w_z = traj2d_w_z[:, :, visible_inds_sampled].float()\n            traj3d_camera = traj3d_camera[:, :, visible_inds_sampled].float()\n            traj3d_world = traj3d_world[:, visible_inds_sampled].float()\n            visibility = visibility[:, :, visible_inds_sampled]\n            valids = torch.ones((n_frames, n_tracks))\n\n            # Create the query points\n            gt_visibilities_any_view = visibility.any(dim=0)\n            assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), \"All points should be visible in least two frames.\"\n            last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values\n            assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all()\n            gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False\n            assert (gt_visibilities_any_view.sum(dim=0) >= 1).all()\n\n            if self.sample_vis_1st_frame:\n                n_non_first_point_appearance_queries = 0\n                n_first_point_appearance_queries = n_tracks\n            else:\n                n_non_first_point_appearance_queries = n_tracks // 4\n                n_first_point_appearance_queries = n_tracks - n_non_first_point_appearance_queries\n\n            first_point_appearances = torch.argmax(\n                gt_visibilities_any_view[..., -n_first_point_appearance_queries:].float(), dim=0)\n            non_first_point_appearances = first_point_appearances.new_zeros((n_non_first_point_appearance_queries,))\n            for track_idx in range(n_tracks)[:n_non_first_point_appearance_queries]:\n                # Randomly take a timestep where the point is visible\n                non_zero_timesteps = torch.nonzero(gt_visibilities_any_view[:, track_idx] == 1)\n                random_timestep = non_zero_timesteps[rnd_np.randint(len(non_zero_timesteps))].item()\n                non_first_point_appearances[track_idx] = random_timestep\n\n            query_points_t = torch.cat([non_first_point_appearances, first_point_appearances], dim=0)\n            query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)]\n            query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1)\n            assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all()\n\n            # Cache the selected tracks and query points\n            if self.use_cached_tracks:\n                logging.warn(f\"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}\")\n                np.savez_compressed(\n                    cache_file,\n                    track_indices=visible_inds_sampled.numpy(),\n                    traj2d_w_z=traj2d_w_z.numpy(),\n                    traj3d_world=traj3d_world.numpy(),\n                    visibility=visibility.numpy(),\n                    valids=valids.numpy(),\n                    query_points=query_points.numpy(),\n                )\n\n        # Apply a transform to the world space\n        scale = 1.0\n        rot = torch.eye(3, dtype=torch.float32)\n        translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)\n\n        if self.enable_scene_transform_augs:\n            rot_x_angle = rnd_np.uniform(-15, 15)\n            rot_y_angle = rnd_np.uniform(-15, 15)\n            rot_z_angle = 0.0\n            scale = rnd_np.uniform(0.8, 1.5)\n            translate_x = rnd_np.uniform(-2, 2)\n            translate_y = rnd_np.uniform(-2, 2)\n            translate_z = rnd_np.uniform(-2, 2)\n\n            rot_x = R.from_euler('x', rot_x_angle, degrees=True).as_matrix()\n            rot_y = R.from_euler('y', rot_y_angle, degrees=True).as_matrix()\n            rot_z = R.from_euler('z', rot_z_angle, degrees=True).as_matrix()\n            rot = rot_z @ rot_y @ rot_x\n            T_rot = torch.eye(4)\n            T_rot[:3, :3] = torch.from_numpy(rot)\n            T_scale_and_translate = torch.tensor([\n                [scale, 0.0, 0.0, translate_x],\n                [0.0, scale, 0.0, translate_y],\n                [0.0, 0.0, scale, translate_z],\n                [0.0, 0.0, 0.0, 1.0],\n            ], dtype=torch.float32)\n            T = T_scale_and_translate @ T_rot\n\n        if self.normalize_scene_following_vggt:\n            assert not self.enable_scene_transform_augs, \"Cannot normalize scene with scene transform augs enabled.\"\n            extrs_square = torch.eye(4, device=extrs.device)[None, None].repeat(num_views, n_frames, 1, 1)\n            extrs_square[:, :, :3, :] = extrs\n            extrs_inv = torch.inverse(extrs_square)\n            intrs_inv = torch.inverse(intrs)\n\n            y, x = torch.meshgrid(\n                torch.arange(h, device=extrs.device),\n                torch.arange(w, device=extrs.device),\n                indexing=\"ij\",\n            )\n            homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().reshape(-1, 3)\n            homog = homog[None].expand(num_views, -1, -1)\n            cam_points = torch.einsum(\"Vij, VNj->VNi\", intrs_inv[:, 0], homog) * depths[:, 0].reshape(num_views, -1, 1)\n            cam_points_h = torch.cat([cam_points, torch.ones_like(cam_points[..., :1])], dim=-1)\n            world_points_h = torch.einsum(\"Vij, VNj->VNi\", extrs_inv[:, 0], cam_points_h)\n\n            world_points_in_first = torch.einsum(\"ij, VNj->VNi\", extrs[0, 0], world_points_h)\n\n            mask = (depths[:, 0] > 0).reshape(num_views, -1)\n            valid_points = world_points_in_first[mask]\n            avg_dist = valid_points.norm(dim=1).mean()\n            scale = 1.0 / avg_dist\n\n            depths *= scale\n            traj3d_world *= scale\n            traj3d_camera *= scale\n            traj2d_w_z[..., 2] *= scale\n            extrs[:, :, :3, 3] *= scale\n\n            T_first_cam_to_origin = torch.eye(4, device=extrs.device)\n            T_first_cam_to_origin[:3, :4] = extrs[0, 0]\n            T = T_first_cam_to_origin\n\n        (\n            depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n        ) = transform_scene(scale, rot, translation, depths, extrs, query_points, traj3d_world, traj2d_w_z)\n        novel_extrs_trans = transform_scene(scale, rot, translation, None, novel_extrs, None, None, None)[1]\n\n        if self.enable_camera_params_noise_augs:\n            intrs, extrs_trans = add_camera_noise(\n                intrs=intrs.numpy(),\n                extrs=extrs_trans.numpy(),\n                noise_std_intr=0.001,\n                noise_std_extr=0.001,\n                rnd=rnd_np,\n            )\n            intrs = torch.from_numpy(intrs)\n            extrs_trans = torch.from_numpy(extrs_trans)\n\n        # Dump non-normalized tracks to disk\n        if self.augmentation_probability == 0.0 and not self.enable_variable_trajpersample_augs and seed is not None:\n            num_views_str = self.num_views if self.num_views is not None else \"none\"\n            views_str = ''.join(str(v) for v in self.views_to_return) if self.views_to_return is not None else \"none\"\n            duster_views_str = ''.join(str(v) for v in self.duster_views) if self.duster_views is not None else \"none\"\n            sample_identifier_str = (\n                f\"seed-{seed:06d}\"\n                f\"_tracks-{self.traj_per_sample}\"\n                f\"_use-duster-depths-{self.use_duster_depths}\"\n                f\"_clean-duster-depths-{self.clean_duster_depths}\"\n                f\"_num-views-{num_views_str}\"\n                f\"_views-{views_str}\"\n                f\"_duster-views-{duster_views_str}\"\n                f\"_ratio-dynamic-{self.ratio_dynamic}\"\n                f\"_ratio-very-dynamic-{self.ratio_very_dynamic}\"\n                f\"_aug-prob-{self.augmentation_probability}\"\n                f\"_max-tracks-to-preload-{self.max_tracks_to_preload}\"\n            )\n            datapoint_path = os.path.join(self.data_root, self.seq_names[index])\n            dumped_path = os.path.join(datapoint_path, f\"{sample_identifier_str}.npz\")\n            # if not os.path.exists(dumped_path):\n            #     logging.info(f\"Dumping {dumped_path}\")\n            #     np.savez(\n            #         dumped_path,\n            #         trajectories=traj3d_world.numpy(),\n            #         trajectories_pixelspace=traj2d.numpy(),\n            #         per_view_visibilities=visibility.numpy(),\n            #         query_points_3d=query_points.numpy(),\n            #         extrinsics=extrs.numpy(),\n            #         intrinsics=intrs.numpy(),\n            #         transform_that_would_have_been_applied=T,\n            #     )\n\n        datapoint = Datapoint(\n            video=rgbs,\n            videodepth=depths_trans,\n            feats=feats,\n            segmentation=segs,\n            trajectory=traj2d_w_z_trans,\n            trajectory_3d=traj3d_world_trans,\n            visibility=visibility,\n            valid=valids,\n            seq_name=self.seq_names[index],\n            intrs=intrs,\n            extrs=extrs_trans,\n            query_points=None,\n            query_points_3d=query_points_trans,\n            track_upscaling_factor=1 / scale,\n\n            novel_video=novel_rgbs,\n            novel_intrs=novel_intrs,\n            novel_extrs=novel_extrs_trans,\n        )\n\n        # Log timings\n        start_time_4 = time.time()\n        self.getitem_calls += 1\n        top_duration = start_time_2 - start_time_1\n        middle_duration = start_time_3 - start_time_2\n        bottom_duration = start_time_4 - start_time_3\n        total_duration = start_time_4 - start_time_1\n        logging.info(f\"Loading {index:>06d} took {total_duration:>7.3f}s \"\n                     f\"[top:{top_duration:>7.3f}s, middle:{middle_duration:>7.3f}s, bottom:{bottom_duration:>7.3f}s] \"\n                     f\"Getitem calls: {self.getitem_calls:>6d}. \"\n                     f\"n_views={num_views}, {n_tracks=:>4d}, augmented={int(augment_this_datapoint)} {rgbs.shape=}\")\n\n        min_valid_depth_ratio_threshold = 0.1\n        valid_depth_ratio = (depths > 0).float().mean()\n        if valid_depth_ratio < min_valid_depth_ratio_threshold:\n            logging.warning(f\"Skipping datapoint {index} due to too little valid depth values: \"\n                            f\"{valid_depth_ratio * 100:.1f}% (< {min_valid_depth_ratio_threshold * 100:.1f}%)\")\n            return None, False\n\n        return datapoint, gotit\n\n    @staticmethod\n    def getitem_raw_datapoint(scene_path, perform_2d_projection_sanity_check=True):\n        # Load global scene data\n        tracks_3d = torch.from_numpy(\n            np.load(os.path.join(scene_path, 'tracks_3d.npz'))['tracks_3d'],\n        )\n        tracks_segmentation_ids = torch.from_numpy(\n            np.load(os.path.join(scene_path, 'tracks_segmentation_ids.npz'))['tracks_segmentation_ids'],\n        )\n        tracked_objects = read_json(os.path.join(scene_path, 'tracked_objects.json'))\n\n        if os.path.exists(os.path.join(scene_path, 'views.npz')):\n            # V2 (lookat fixed to 0)\n            camera_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'views.npz'))['views'])\n            lookat_positions = 0. * camera_positions\n        elif os.path.exists(os.path.join(scene_path, 'cameras.npz')):\n            # V3 (with randomized lookat)\n            camera_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'cameras.npz'))['camera_positions'])\n            lookat_positions = torch.from_numpy(np.load(os.path.join(scene_path, 'cameras.npz'))['lookat_positions'])\n        else:\n            raise ValueError(\"No camera data found: neither views.npz nor cameras.npz exist.\")\n\n        n_frames = tracks_3d.shape[0]\n        n_tracks = tracks_3d.shape[1]\n        n_views = camera_positions.shape[0]\n        assert tracks_3d.shape == (n_frames, n_tracks, 3)\n        assert tracks_segmentation_ids.shape == (n_tracks,)\n        assert camera_positions.shape == (n_views, 3)\n        assert lookat_positions.shape == (n_views, 3)\n\n        # Initialize views data\n        views_data = []\n        view_folders = [\n            d for d in os.listdir(scene_path)\n            if os.path.isdir(os.path.join(scene_path, d)) and d.startswith('view_')\n        ]\n        view_folders = sorted(view_folders, key=lambda x: int(x.split('_')[-1]))\n\n        for view_folder in view_folders:\n            view_path = os.path.join(scene_path, view_folder)\n\n            # Load per-view data\n            view_data = {\n                'rgba': [],\n                'depth': [],\n                # 'segmentation': [],\n            }\n\n            frame_files = sorted(os.listdir(view_path))\n            for frame_file in frame_files:\n                if frame_file.startswith('rgba_'):\n                    view_data['rgba'].append(read_png(os.path.join(view_path, frame_file)))\n                elif frame_file.startswith('depth_'):\n                    view_data['depth'].append(read_tiff(os.path.join(view_path, frame_file)))\n                # elif frame_file.startswith('segmentation_'):\n                #     view_data['segmentation'].append(read_png(os.path.join(view_path, frame_file)))\n\n            assert len(view_data['rgba']) == n_frames, f\"{len(view_data['rgba'])}!={n_frames}\"\n            assert len(view_data['depth']) == n_frames, f\"{len(view_data['depth'])}!={n_frames}\"\n            # assert len(view_data['segmentation']) == n_frames, f\"{len(view_data['segmentation'])}!={n_frames}\"\n\n            # Convert lists to torch tensors\n            for key in view_data:\n                if view_data[key][0].dtype == np.uint16:\n                    view_data[key] = [a.astype(np.int32) for a in view_data[key]]\n                view_data[key] = torch.stack([torch.from_numpy(np.array(img)) for img in view_data[key]])\n\n            # Load additional per-view data\n            view_data.update({\n                'tracks_2d': torch.from_numpy(np.load(os.path.join(view_path, 'tracks_2d.npz'))['tracks_2d']),\n                'occlusion': torch.from_numpy(np.load(os.path.join(view_path, 'tracks_2d.npz'))['occlusion']),\n                'data_ranges': \"NOT LOADED\",  # read_json(os.path.join(view_path, 'data_ranges.json')),\n                'metadata': read_json(os.path.join(view_path, 'metadata.json')),\n                'events': \"NOT LOADED\",  # read_json(os.path.join(view_path, 'events.json')),\n                'object_id_to_segmentation_id': read_json(os.path.join(view_path, 'object_id_to_segmentation_id.json')),\n            })\n\n            # Extracting the intrinsics\n            view_data['intrinsics'] = torch.tensor(view_data['metadata']['camera']['K'], dtype=torch.float64)\n            assert view_data['intrinsics'].shape == (3, 3)\n\n            # Extracting the extrinsics\n            positions = torch.tensor(view_data['metadata']['camera']['positions'], dtype=torch.float64)\n            quaternions = torch.tensor(view_data['metadata']['camera']['quaternions'], dtype=torch.float64)\n            rotation_matrices = kornia.geometry.quaternion_to_rotation_matrix(quaternions)\n            assert positions.shape == (n_frames, 3)\n            assert quaternions.shape == (n_frames, 4)\n            assert rotation_matrices.shape == (n_frames, 3, 3)\n            extrinsics_inv = torch.zeros((n_frames, 4, 4), dtype=torch.float64)\n            extrinsics_inv[:, :3, :3] = rotation_matrices\n            extrinsics_inv[:, :3, 3] = positions\n            extrinsics_inv[:, 3, 3] = 1\n            view_data['extrinsics'] = extrinsics_inv.inverse()\n            assert torch.allclose(view_data['extrinsics'][:, 3, :3], torch.zeros(n_frames, 3, dtype=torch.float64))\n            assert torch.allclose(view_data['extrinsics'][:, 3, 3], torch.ones(n_frames, dtype=torch.float64))\n            view_data['extrinsics'] = view_data['extrinsics'][:, :3, :]\n\n            # Change the intrinsics to the format\n            w, h = view_data[\"metadata\"][\"metadata\"][\"resolution\"]\n            view_data['intrinsics'] = np.diag([w, h, 1]) @ view_data['intrinsics'].numpy() @ np.diag([1, -1, -1])\n            view_data['extrinsics'] = np.diag([1, -1, -1]) @ view_data['extrinsics'].numpy()\n            view_data['intrinsics'] = torch.from_numpy(view_data['intrinsics'])\n            view_data['extrinsics'] = torch.from_numpy(view_data['extrinsics'])\n\n            # Project one point to the image plane to check if the extrinsics are correct\n            if perform_2d_projection_sanity_check:\n                point_3d_world = tracks_3d[0, 0]\n                point_4d_world_homo = torch.cat([point_3d_world, torch.ones(1)])\n                point_2d_pixel = view_data['intrinsics'] @ view_data['extrinsics'][0] @ point_4d_world_homo\n                point_2d_pixel = point_2d_pixel[:2] / point_2d_pixel[2]\n                point_2d_pixel_gt = view_data[\"tracks_2d\"][0, 0]\n                assert torch.allclose(point_2d_pixel, point_2d_pixel_gt, atol=1e-3), f\"Point projection failed\"\n\n            # The original depth is the euclidean distance from the camera\n            # Compute the depth in z format instead (so the z coordinate in the camera space)\n            view_data['depth'] = KubricMultiViewDataset.depth_from_euclidean_to_z(\n                depth=view_data['depth'],\n                sensor_width=view_data['metadata']['camera']['sensor_width'],\n                focal_length=view_data['metadata']['camera']['focal_length'],\n            )\n\n            # Sometimes the Kubric depths contains very high values of 10e9\n            # We will clip those to 10e3 to avoid problems with inf and nan\n            larger_than_1000 = view_data['depth'] > 1000\n            if larger_than_1000.any():\n                logging.info(f\"Datapoint {scene_path} has depths larger than 1000: \"\n                             f\"{view_data['depth'][larger_than_1000]}. \"\n                             f\"Replacing those by 0 to denote invalid depth and avoid inf and nan values later.\")\n                view_data['depth'][larger_than_1000] = 0\n\n            view_data['view_path'] = view_path\n            views_data.append(view_data)\n\n        datapoint = {\n            \"tracks_3d\": tracks_3d,\n            \"tracks_segmentation_ids\": tracks_segmentation_ids,\n            \"tracked_objects\": tracked_objects,\n            \"camera_positions\": camera_positions,\n            \"lookat_positions\": lookat_positions,\n            \"views\": views_data\n        }\n\n        return datapoint\n\n    @staticmethod\n    def depth_from_euclidean_to_z(depth, sensor_width, focal_length):\n        n_frames, h, w, _ = depth.shape\n        sensor_height = sensor_width / w * h\n        pixel_centers_x = (np.arange(-w / 2, w / 2, dtype=np.float32) + 0.5) / w * sensor_width\n        pixel_centers_y = (np.arange(-h / 2, h / 2, dtype=np.float32) + 0.5) / h * sensor_height\n\n        # Calculate squared distance from the center of the image\n        pixel_centers_x, pixel_centers_y = np.meshgrid(pixel_centers_x, pixel_centers_y, indexing=\"xy\")\n        squared_distance_from_center = np.square(pixel_centers_x) + np.square(pixel_centers_y)\n\n        # Calculate rescaling factor for each pixel\n        z_to_eucl_rescaling = np.sqrt(1 + squared_distance_from_center / focal_length ** 2)\n\n        # Apply the rescaling to each depth value\n        z_to_eucl_rescaling = np.expand_dims(z_to_eucl_rescaling, axis=-1)  # Add a dimension for broadcasting\n        depth_z = depth / z_to_eucl_rescaling\n        return depth_z\n\n    def _add_photometric_augs(\n            self,\n            rgbs,\n            trajs,\n            visibles,\n            rndstate,\n            eraser=True,\n            replace=True,\n    ):\n        V, T, H, W, _ = rgbs.shape\n        _, _, N, _ = trajs.shape\n        assert rgbs.dtype == np.uint8\n        assert rgbs.shape == (V, T, H, W, 3)\n        assert trajs.shape == (V, T, N, 3)\n        assert visibles.shape == (V, T, N)\n\n        rgbs = rgbs.copy()\n        visibles = visibles.copy()\n\n        if eraser:  # eraser the specific region in the image\n            for v in range(V):\n                rgbs_view = rgbs[v]\n                rgbs_view = [rgb.astype(np.float32) for rgb in rgbs_view]\n                ############ eraser transform (per image after the first) ############\n                for i in range(1, T):\n                    if rndstate.rand() < self.eraser_aug_prob:\n                        for _ in range(\n                                rndstate.randint(1, self.eraser_max + 1)\n                        ):  # number of times to occlude\n                            xc = rndstate.randint(0, W)\n                            yc = rndstate.randint(0, H)\n                            dx = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1])\n                            dy = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1])\n                            x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)\n                            x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)\n                            y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)\n                            y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)\n                            mean_color = np.mean(rgbs_view[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)\n                            rgbs_view[i][y0:y1, x0:x1, :] = mean_color\n                            occ_inds = np.logical_and(\n                                np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1),\n                                np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1),\n                            )\n                            visibles[v, i, occ_inds] = 0\n                rgbs_view = [rgb.astype(np.uint8) for rgb in rgbs_view]\n                rgbs[v] = np.stack(rgbs_view)\n\n        if replace:\n            for v in range(V):\n                rgbs_view = rgbs[v]\n                rgbs_view_alt = [\n                    np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)\n                    for rgb in rgbs_view\n                ]\n                rgbs_view_alt = [\n                    np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)\n                    for rgb in rgbs_view_alt\n                ]\n\n                ############ replace transform (per image after the first) ############\n                rgbs_view = [rgb.astype(np.float32) for rgb in rgbs_view]\n                rgbs_view_alt = [rgb.astype(np.float32) for rgb in rgbs_view_alt]\n                for i in range(1, T):\n                    if rndstate.rand() < self.replace_aug_prob:\n                        for _ in range(\n                                rndstate.randint(1, self.replace_max + 1)\n                        ):  # number of times to occlude\n                            xc = rndstate.randint(0, W)\n                            yc = rndstate.randint(0, H)\n                            dx = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1])\n                            dy = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1])\n                            x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)\n                            x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)\n                            y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)\n                            y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)\n\n                            wid = x1 - x0\n                            hei = y1 - y0\n                            y00 = rndstate.randint(0, H - hei)\n                            x00 = rndstate.randint(0, W - wid)\n                            fr = rndstate.randint(0, T)\n                            rep = rgbs_view_alt[fr][y00: y00 + hei, x00: x00 + wid, :]\n                            rgbs_view[i][y0:y1, x0:x1, :] = rep\n\n                            occ_inds = np.logical_and(\n                                np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1),\n                                np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1),\n                            )\n                            visibles[v, i, occ_inds] = 0\n                rgbs_view = [rgb.astype(np.uint8) for rgb in rgbs_view]\n                rgbs[v] = np.stack(rgbs_view)\n\n        ############ photometric augmentation ############\n        if rndstate.rand() < self.color_aug_prob:\n            # random per-frame amount of aug\n            # but shared across all views\n            for i in range(T):\n                fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.photo_aug.get_params(\n                    self.photo_aug.brightness, self.photo_aug.contrast, self.photo_aug.saturation, self.photo_aug.hue\n                )\n                for v in range(V):\n                    rgb = rgbs[v, i]\n                    rgb = Image.fromarray(rgb)\n                    for fn_id in fn_idx:\n                        if fn_id == 0 and brightness_factor is not None:\n                            rgb = F_torchvision.adjust_brightness(rgb, brightness_factor)\n                        elif fn_id == 1 and contrast_factor is not None:\n                            rgb = F_torchvision.adjust_contrast(rgb, contrast_factor)\n                        elif fn_id == 2 and saturation_factor is not None:\n                            rgb = F_torchvision.adjust_saturation(rgb, saturation_factor)\n                        elif fn_id == 3 and hue_factor is not None:\n                            rgb = F_torchvision.adjust_hue(rgb, hue_factor)\n                    rgb = np.array(rgb, dtype=np.uint8)\n                    rgbs[v, i] = rgb\n\n        if rndstate.rand() < self.blur_aug_prob:\n            # random per-frame amount of blur\n            # but shared across all views\n            for i in range(T):\n                sigma = self.blur_aug.get_params(self.blur_aug.sigma[0], self.blur_aug.sigma[1])\n                for v in range(V):\n                    rgb = rgbs[v, i]\n                    rgb = Image.fromarray(rgb)\n                    F_torchvision.gaussian_blur(rgb, self.blur_aug.kernel_size, [sigma, sigma])\n                    rgb = np.array(rgb, dtype=np.uint8)\n                    rgbs[v, i] = rgb\n\n        return rgbs, visibles\n\n    def _add_cropping_augs(self, crop_size, rgbs, depths, intrs, trajs, visibles):\n        V, T, H, W, _ = rgbs.shape\n        _, _, N, _ = trajs.shape\n        assert rgbs.dtype == np.uint8\n        assert depths.dtype == np.float32\n        assert rgbs.shape == (V, T, H, W, 3)\n        assert depths.shape == (V, T, H, W, 1)\n        assert intrs.shape == (V, T, 3, 3)\n        assert trajs.shape == (V, T, N, 3)\n        assert visibles.shape == (V, T, N)\n\n        rgbs = rgbs.copy()\n        depths = depths.copy()\n        intrs = intrs.copy()\n        trajs = trajs.copy()\n        visibles = visibles.copy()\n\n        ############ spatial transform ############\n        rgbs_new = np.zeros((V, T, crop_size[0], crop_size[1], 3), dtype=np.uint8)\n        depths_new = np.zeros((V, T, crop_size[0], crop_size[1], 1), dtype=np.float32)\n        for v in range(V):\n            # padding\n            pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n            pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n            pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n            pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n\n            rgbs_view = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs[v]]\n            depths_view = [np.pad(depth, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for depth in depths[v]]\n            intrs[v, :, 0, 2] += pad_x0\n            intrs[v, :, 1, 2] += pad_y0\n            trajs[v, :, :, 0] += pad_x0\n            trajs[v, :, :, 1] += pad_y0\n            H_padded, W_padded = rgbs_view[0].shape[:2]\n\n            # scaling + stretching\n            scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])\n            scale_x = scale\n            scale_y = scale\n\n            scale_delta_x = 0.0\n            scale_delta_y = 0.0\n\n            for t in range(T):\n                if t == 1:\n                    scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)\n                    scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)\n                elif t > 1:\n                    scale_delta_x = (\n                            scale_delta_x * 0.8\n                            + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2\n                    )\n                    scale_delta_y = (\n                            scale_delta_y * 0.8\n                            + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2\n                    )\n                scale_x = scale_x + scale_delta_x\n                scale_y = scale_y + scale_delta_y\n\n                # bring h/w closer\n                scale_xy = (scale_x + scale_y) * 0.5\n                scale_x = scale_x * 0.5 + scale_xy * 0.5\n                scale_y = scale_y * 0.5 + scale_xy * 0.5\n\n                # don't get too crazy\n                scale_x = np.clip(scale_x, self.resize_lim[0], self.resize_lim[1])\n                scale_y = np.clip(scale_y, self.resize_lim[0], self.resize_lim[1])\n\n                H_new = int(H_padded * scale_y)\n                W_new = int(W_padded * scale_x)\n\n                # make it at least slightly bigger than the crop area,\n                # so that the random cropping can add diversity\n                H_new = np.clip(H_new, crop_size[0] + 10, None)\n                W_new = np.clip(W_new, crop_size[1] + 10, None)\n                # recompute scale in case we clipped\n                scale_x = (W_new - 1) / float(W_padded - 1)\n                scale_y = (H_new - 1) / float(H_padded - 1)\n                rgbs_view[t] = cv2.resize(rgbs_view[t], (W_new, H_new), interpolation=cv2.INTER_LINEAR)\n                depths_view[t] = cv2.resize(depths_view[t], (W_new, H_new), interpolation=cv2.INTER_NEAREST)\n                intrs[v, t, 0, :] *= scale_x\n                intrs[v, t, 1, :] *= scale_y\n                trajs[v, t, :, 0] *= scale_x\n                trajs[v, t, :, 1] *= scale_y\n            ok_inds = visibles[v, 0, :] > 0\n            vis_trajs = trajs[v, :, ok_inds]  # S,?,2\n\n            if vis_trajs.shape[0] > 0:\n                mid_x = np.mean(vis_trajs[:, 0, 0])\n                mid_y = np.mean(vis_trajs[:, 0, 1])\n            else:\n                mid_y = crop_size[0] // 2\n                mid_x = crop_size[1] // 2\n\n            x0 = int(mid_x - crop_size[1] // 2)\n            y0 = int(mid_y - crop_size[0] // 2)\n\n            offset_x = 0\n            offset_y = 0\n\n            for t in range(T):\n                # on each frame, shift a bit more\n                if t == 1:\n                    offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)\n                    offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)\n                elif t > 1:\n                    offset_x = int(\n                        offset_x * 0.8\n                        + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)\n                        * 0.2\n                    )\n                    offset_y = int(\n                        offset_y * 0.8\n                        + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)\n                        * 0.2\n                    )\n                x0 = x0 + offset_x\n                y0 = y0 + offset_y\n\n                H_new, W_new = rgbs_view[t].shape[:2]\n                if H_new == crop_size[0]:\n                    y0 = 0\n                else:\n                    y0 = min(max(0, y0), H_new - crop_size[0] - 1)\n\n                if W_new == crop_size[1]:\n                    x0 = 0\n                else:\n                    x0 = min(max(0, x0), W_new - crop_size[1] - 1)\n\n                rgbs_view[t] = rgbs_view[t][y0: y0 + crop_size[0], x0: x0 + crop_size[1]]\n                depths_view[t] = depths_view[t][y0: y0 + crop_size[0], x0: x0 + crop_size[1]]\n                intrs[v, t, 0, 2] -= x0\n                intrs[v, t, 1, 2] -= y0\n                trajs[v, t, :, 0] -= x0\n                trajs[v, t, :, 1] -= y0\n\n            H_new = crop_size[0]\n            W_new = crop_size[1]\n\n            # # h flip\n            # if self.do_flip and np.random.rand() < self.h_flip_prob:\n            #     rgbs_view = [rgb[:, ::-1] for rgb in rgbs_view]\n            #     depths_view = [depth[:, ::-1] for depth in depths_view]\n            #     intrs[v, :, 0, 2] = W_new - intrs[v, :, 0, 2]\n            #     trajs[v, :, :, 0] = W_new - trajs[v, :, :, 0]\n            #\n            # # v flip\n            # if np.random.rand() < self.v_flip_prob:\n            #     rgbs_view = [rgb[::-1] for rgb in rgbs_view]\n            #     depths_view = [depth[::-1] for depth in depths_view]\n            #     intrs[v, :, 1, 2] = H_new - intrs[v, :, 1, 2]\n            #     trajs[v, :, :, 1] = H_new - trajs[v, :, :, 1]\n\n            rgbs_new[v] = np.stack(rgbs_view)\n            depths_new[v] = np.stack(depths_view)[..., None]\n\n        visibles = (visibles &\n                    (trajs[..., 0] >= 0) &\n                    (trajs[..., 1] >= 0) &\n                    (trajs[..., 0] < crop_size[1]) &\n                    (trajs[..., 1] < crop_size[0]))\n\n        return rgbs_new, depths_new, intrs, trajs, visibles\n\n    def _add_cropping_augs_with_pp_at_center(self, crop_size, rgbs, depths, intrs, trajs, visibles):\n        V, T, H, W, _ = rgbs.shape\n        _, _, N, _ = trajs.shape\n        assert rgbs.dtype == np.uint8\n        assert depths.dtype == np.float32\n        assert rgbs.shape == (V, T, H, W, 3)\n        assert depths.shape == (V, T, H, W, 1)\n        assert intrs.shape == (V, T, 3, 3)\n        assert trajs.shape == (V, T, N, 3)\n        assert visibles.shape == (V, T, N)\n\n        rgbs = rgbs.copy()\n        depths = depths.copy()\n        intrs = intrs.copy()\n        trajs = trajs.copy()\n        visibles = visibles.copy()\n\n        rgbs_new = np.zeros((V, T, crop_size[0], crop_size[1], 3), dtype=np.uint8)\n        depths_new = np.zeros((V, T, crop_size[0], crop_size[1], 1), dtype=np.float32)\n\n        for v in range(V):\n            pad_x0 = pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n            pad_y0 = pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])\n\n            rgbs_view = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs[v]]\n            depths_view = [np.pad(depth, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for depth in depths[v]]\n            intrs[v, :, 0, 2] += pad_x0\n            intrs[v, :, 1, 2] += pad_y0\n            trajs[v, :, :, 0] += pad_x0\n            trajs[v, :, :, 1] += pad_y0\n            H_padded, W_padded = rgbs_view[0].shape[:2]\n\n            scale_x = np.random.uniform(self.resize_lim[0], self.resize_lim[1])\n            scale_y = scale_x + np.random.uniform(-0.01, 0.01)\n            scale_y = max(self.resize_lim[0], min(self.resize_lim[1], scale_y))\n            H_new = max(int(H_padded * scale_y) + int(H_padded * scale_y) % 2, crop_size[0] + 10)\n            W_new = max(int(W_padded * scale_x) + int(W_padded * scale_x) % 2, crop_size[1] + 10)\n            scale_x = W_new / W_padded\n            scale_y = H_new / H_padded\n\n            for t in range(T):\n                rgbs_view[t] = cv2.resize(rgbs_view[t], (W_new, H_new), interpolation=cv2.INTER_LINEAR)\n                depths_view[t] = cv2.resize(depths_view[t], (W_new, H_new), interpolation=cv2.INTER_NEAREST)\n            intrs[v, :, 0, :] *= scale_x\n            intrs[v, :, 1, :] *= scale_y\n            trajs[v, :, :, 0] *= scale_x\n            trajs[v, :, :, 1] *= scale_y\n\n            for t in range(T):\n                cx = intrs[v, t, 0, 2]\n                cy = intrs[v, t, 1, 2]\n                x0 = round(cx - crop_size[1] / 2)\n                y0 = round(cy - crop_size[0] / 2)\n\n                H_new, W_new = rgbs_view[t].shape[:2]\n                assert x0 >= 0\n                assert y0 >= 0\n                assert (H_new - crop_size[0]) >= 0\n                assert (W_new - crop_size[1]) >= 0\n                assert (H_new - crop_size[0]) >= y0\n                assert (W_new - crop_size[1]) >= x0\n\n                rgbs_view[t] = rgbs_view[t][y0:y0 + crop_size[0], x0:x0 + crop_size[1]]\n                depths_view[t] = depths_view[t][y0:y0 + crop_size[0], x0:x0 + crop_size[1]]\n                intrs[v, t, 0, 2] -= x0\n                intrs[v, t, 1, 2] -= y0\n                trajs[v, t, :, 0] -= x0\n                trajs[v, t, :, 1] -= y0\n\n                # Assert principal point is centered\n                assert rgbs_view[t].shape[0] == crop_size[0]\n                assert rgbs_view[t].shape[1] == crop_size[1]\n                assert np.allclose(intrs[v, t, 0, 2], crop_size[1] / 2, atol=0.01)\n                assert np.allclose(intrs[v, t, 1, 2], crop_size[0] / 2, atol=0.01)\n\n            rgbs_new[v] = np.stack(rgbs_view)\n            depths_new[v] = np.stack(depths_view)[..., None]\n\n        visibles = (visibles &\n                    (trajs[..., 0] >= 0) &\n                    (trajs[..., 1] >= 0) &\n                    (trajs[..., 0] < crop_size[1]) &\n                    (trajs[..., 1] < crop_size[0]))\n\n        return rgbs_new, depths_new, intrs, trajs, visibles\n\n    def _rescale_and_erase_depth_patches(self, depths, trajs, visibles, rndstate):\n        V, T, H, W, _ = depths.shape\n        _, _, N, _ = trajs.shape\n        assert depths.dtype == np.float32\n        assert depths.shape == (V, T, H, W, 1)\n        assert trajs.shape == (V, T, N, 3)\n        assert visibles.shape == (V, T, N)\n\n        depths = depths.copy()\n        visibles = visibles.copy()\n\n        ############ eraser transform (per image after the first) ############\n        for v in range(V):\n            for i in range(1, T):\n                if rndstate.rand() < self.eraser_aug_prob:\n                    n = rndstate.randint(1, self.eraser_max + 1)  # number of times to occlude\n                    for _ in range(n):\n                        xc = rndstate.randint(0, W)\n                        yc = rndstate.randint(0, H)\n                        dx = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1])\n                        dy = rndstate.randint(self.eraser_bounds[0], self.eraser_bounds[1])\n                        x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)\n                        x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)\n                        y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)\n                        y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)\n                        eraser_depth = {\n                            0: depths[v, i, y0:y1, x0:x1].mean(),\n                            1: depths[v, i, y0:y1, x0:x1].min(),\n                            2: depths[v, i, y0:y1, x0:x1].max(),\n                            3: 0,\n                        }[rndstate.choice([0, 1, 2, 3], p=[0.2, 0.1, 0.35, 0.35])]\n                        depths[v, i, y0:y1, x0:x1] = eraser_depth\n                        occ_inds = np.logical_and(\n                            np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1),\n                            np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1),\n                        )\n                        visibles[v, i, occ_inds] = 0\n\n        ############ replace transform (per image after the first) ############\n        for v in range(V):\n            for i in range(1, T):\n                if rndstate.rand() < self.replace_aug_prob:\n                    n = rndstate.randint(1, self.replace_max + 1)  # number of times to occlude\n                    for _ in range(n):\n                        xc = rndstate.randint(0, W)\n                        yc = rndstate.randint(0, H)\n                        dx = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1])\n                        dy = rndstate.randint(self.replace_bounds[0], self.replace_bounds[1])\n                        x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)\n                        x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)\n                        y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)\n                        y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)\n                        wid = x1 - x0\n                        hei = y1 - y0\n                        y00 = rndstate.randint(0, H - hei)\n                        x00 = rndstate.randint(0, W - wid)\n                        v_rnd = rndstate.randint(0, V)\n                        i_rnd = rndstate.randint(0, T)\n                        depths[v, i, y0:y1, x0:x1] = depths[v_rnd, i_rnd, y00: y00 + hei, x00: x00 + wid]\n                        occ_inds = np.logical_and(\n                            np.logical_and(trajs[v, i, :, 0] >= x0, trajs[v, i, :, 0] < x1),\n                            np.logical_and(trajs[v, i, :, 1] >= y0, trajs[v, i, :, 1] < y1),\n                        )\n                        visibles[v, i, occ_inds] = 0\n        return depths, visibles\n\n    def _crop(self, rgbs, trajs, crop_size):\n        T, N, _ = trajs.shape\n\n        S = len(rgbs)\n        H, W = rgbs[0].shape[:2]\n        assert S == T\n\n        ############ spatial transform ############\n\n        H_new = H\n        W_new = W\n\n        # simple random crop\n        y0 = 0 if crop_size[0] >= H_new else (H_new - crop_size[0]) // 2\n        # np.random.randint(0,\n        x0 = 0 if crop_size[1] >= W_new else np.random.randint(0, W_new - crop_size[1])\n        rgbs = [rgb[y0: y0 + crop_size[0], x0: x0 + crop_size[1]] for rgb in rgbs]\n\n        trajs[:, :, 0] -= x0\n        trajs[:, :, 1] -= y0\n\n        return np.stack(rgbs), trajs\n"
  },
  {
    "path": "mvtracker/datasets/panoptic_studio_multiview_dataset.py",
    "content": "import logging\nimport os\nimport pathlib\nimport re\nimport time\nimport warnings\n\nimport cv2\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.nn.functional as F\nfrom scipy.spatial.transform import Rotation as R\nfrom torch.utils.data import Dataset\n\nfrom mvtracker.datasets.utils import Datapoint, transform_scene\n\n\nclass PanopticStudioMultiViewDataset(Dataset):\n    @staticmethod\n    def from_name(dataset_name: str, dataset_root: str):\n        \"\"\"\n        Examples of datasets supported by this factory method:\n        - panoptic-multiview\n        - panoptic-multiview-views27_16_14_8\n        - panoptic-multiview-duster27_16_14_8\n        - panoptic-multiview-duster27_16_14_8cleaned\n        - panoptic-multiview-duster27_16_14_8cleaned-views27_16\n        - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4\n        - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single\n        - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single-2dpt\n        - panoptic-multiview-duster27_16_14_8cleaned-views27_16-novelviews1_4-single-2dpt-cached\n        \"\"\"\n        # Parse the dataset name, chunk by chunk\n        non_parsed = dataset_name.replace(\"panoptic-multiview\", \"\", 1)\n\n        if non_parsed.startswith(\"-duster\"):\n            match = re.match(r\"-duster((?:\\d+_?)+)(cleaned)?\", non_parsed)\n            assert match is not None\n            duster_views = list(map(int, match.group(1).split(\"_\")))\n            use_duster = True\n            use_duster_cleaned = match.group(2) is not None\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            use_duster = False\n            use_duster_cleaned = False\n            duster_views = None\n\n        if non_parsed.startswith(\"-views\"):\n            match = re.match(r\"-views((?:\\d+_?)+)\", non_parsed)\n            assert match is not None\n            views = list(map(int, match.group(1).split(\"_\")))\n            if duster_views is not None:\n                assert all(v in duster_views for v in views)\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            views = duster_views\n\n        if non_parsed.startswith(\"-novelviews\"):\n            match = re.match(r\"-novelviews((?:\\d+_?)+)\", non_parsed)\n            assert match is not None\n            novel_views = list(map(int, match.group(1).split(\"_\")))\n            non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n        else:\n            novel_views = None\n\n        if non_parsed.startswith(\"-single\"):\n            single_point = True\n            non_parsed = non_parsed.replace(\"-single\", \"\", 1)\n        else:\n            single_point = False\n\n        if non_parsed.startswith(\"-2dpt\"):\n            eval_2dpt = True\n            non_parsed = non_parsed.replace(\"-2dpt\", \"\", 1)\n        else:\n            eval_2dpt = False\n\n        if non_parsed.startswith(\"-cached\"):\n            use_cached_tracks = True\n            non_parsed = non_parsed.replace(\"-cached\", \"\", 1)\n        else:\n            use_cached_tracks = False\n\n        assert non_parsed == \"\", f\"Unparsed part of the dataset name: {non_parsed}\"\n\n        return PanopticStudioMultiViewDataset(\n            data_root=os.path.join(dataset_root, \"panoptic-multiview\"),\n            views_to_return=views,\n            novel_views=novel_views,\n            use_duster_depths=use_duster,\n            clean_duster_depths=use_duster_cleaned,\n            traj_per_sample=384,\n            seed=72,\n            max_videos=6,\n            perform_sanity_checks=False,\n            use_cached_tracks=use_cached_tracks,\n        )\n\n    def __init__(\n            self,\n            data_root,\n            views_to_return=None,\n            novel_views=None,\n            use_duster_depths=False,\n            clean_duster_depths=False,\n            traj_per_sample=512,\n            seed=None,\n            max_videos=None,\n            perform_sanity_checks=False,\n            use_cached_tracks=False,\n    ):\n        super().__init__()\n        self.data_root = data_root\n        self.views_to_return = views_to_return\n        self.novel_views = novel_views\n        self.use_duster_depths = use_duster_depths\n        self.clean_duster_depths = clean_duster_depths\n        self.traj_per_sample = traj_per_sample\n        self.seed = seed\n        self.perform_sanity_checks = perform_sanity_checks\n        self.use_cached_tracks = use_cached_tracks\n        self.cache_name = self._cache_key()\n        self.seq_names = self._get_sequence_names(max_videos)\n        self.getitem_calls = 0\n\n    def _get_sequence_names(self, max_videos):\n        \"\"\"\n        Fetch all valid sequence names from the dataset root.\n\n        Args:\n            max_videos (int): Limit the number of sequences to load.\n\n        Returns:\n            List[str]: Sorted list of valid sequence names.\n        \"\"\"\n        seq_names = [\n            fname\n            for fname in os.listdir(self.data_root)\n            if os.path.isdir(os.path.join(self.data_root, fname))\n               and not fname.startswith(\".\")\n               and not fname.startswith(\"_\")\n        ]\n        seq_names = sorted(seq_names)\n        valid_seqs = []\n\n        for seq_name in seq_names:\n            scene_path = os.path.join(self.data_root, seq_name)\n            if not os.path.exists(os.path.join(scene_path, \"tapvid3d_annotations.npz\")):\n                warnings.warn(f\"Skipping {scene_path} because it has no tapvid3d_annotations.npz labels file.\")\n                continue\n\n            valid_seqs.append(seq_name)\n\n        if max_videos is not None:\n            valid_seqs = valid_seqs[:max_videos]\n\n        print(f\"Using {len(valid_seqs)} videos from {self.data_root}\")\n        return valid_seqs\n\n    def _cache_key(self):\n        name = f\"cachedtracks--seed{self.seed}\"\n        if self.views_to_return is not None:\n            name += f\"-views{'_'.join(map(str, self.views_to_return))}\"\n        if self.traj_per_sample is not None:\n            name += f\"-n{self.traj_per_sample}\"\n        return name + \"--v1\"  # bump this if you change the selection policy\n\n    def __len__(self):\n        return len(self.seq_names)\n\n    def __getitem__(self, index):\n        start_time = time.time()\n        sample = self._getitem_helper(index)\n\n        self.getitem_calls += 1\n        if self.getitem_calls < 10:\n            print(f\"Loading {index:>06d} took  {time.time() - start_time:.3f} sec. Getitem calls: {self.getitem_calls}\")\n\n        return sample, True\n\n    def _getitem_helper(self, index):\n        \"\"\"\n        Helper function to load a single sample.\n\n        Args:\n            index (int): Index of the sample to load.\n\n        Returns:\n            CoTrackerData, bool: Sample data and success flag.\n        \"\"\"\n        if self.seed is None:\n            seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()\n        else:\n            seed = self.seed\n        rnd_torch = torch.Generator().manual_seed(seed)\n        rnd_np = np.random.RandomState(seed=seed)\n\n        datapoint_path = os.path.join(self.data_root, self.seq_names[index])\n        ims_path = os.path.join(datapoint_path, \"ims\")\n        depths_path = os.path.join(datapoint_path, \"dynamic3dgs_depth\")\n\n        tapvid3d_merged_annotations = np.load(os.path.join(datapoint_path, \"tapvid3d_annotations.npz\"))\n        traj3d_world = tapvid3d_merged_annotations[\"trajectories\"]\n        traj2d = tapvid3d_merged_annotations[\"trajectories_pixelspace\"]\n        visibility = tapvid3d_merged_annotations[\"per_view_visibilities\"]\n        query_points_3d = tapvid3d_merged_annotations[\"query_points_3d\"]\n        extrs = tapvid3d_merged_annotations[\"extrinsics\"]\n        intrs = tapvid3d_merged_annotations[\"intrinsics\"]\n\n        views = {}\n        view_folders = sorted([f for f in os.listdir(ims_path)], key=lambda x: int(x))\n        if self.views_to_return is not None:\n            views_to_return = self.views_to_return\n        else:\n            views_to_return = sorted(list(range(len(view_folders))))\n        views_to_load = views_to_return.copy()\n        if self.novel_views is not None:\n            views_to_load = list(set(views_to_load + self.novel_views))\n        for v in views_to_load:\n            rgb_folder = os.path.join(ims_path, str(v))\n            rgb_files = sorted(os.listdir(rgb_folder))\n            rgb_images = [cv2.imread(os.path.join(rgb_folder, f))[:, :, ::-1] for f in rgb_files]\n            depth = np.load(os.path.join(depths_path, f\"depths_{v:02d}.npy\"))\n            views[v] = {\n                \"rgb\": np.stack(rgb_images),\n                \"depth\": depth,\n            }\n\n        rgbs = np.stack([views[v][\"rgb\"] for v in views_to_return])\n        n_views, n_frames, h, w, _ = rgbs.shape\n        depths = np.stack([views[v][\"depth\"] for v in views_to_return])[..., None].astype(np.float32)\n        intrs = np.stack([intrs[v] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1)\n        extrs = np.stack([extrs[v][:3, :] for v in views_to_return])[:, None, :, :].repeat(n_frames, axis=1)\n        visibility = visibility[views_to_return]\n        traj2d = traj2d[views_to_return]\n\n        # Load novel views if they exist\n        novel_rgbs = None\n        novel_intrs = None\n        novel_extrs = None\n        if self.novel_views is not None:\n            novel_rgbs = np.stack([views[v][\"rgb\"]\n                                   for v in self.novel_views])\n            novel_intrs = np.stack([tapvid3d_merged_annotations[\"intrinsics\"][v]\n                                    for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1)\n            novel_extrs = np.stack([tapvid3d_merged_annotations[\"extrinsics\"][v][:3, :]\n                                    for v in self.novel_views])[:, None, :, :].repeat(n_frames, axis=1)\n\n        # Load Duster's features and estimated depths if they exist\n        views_selection_str = '-'.join(str(v) for v in self.views_to_return)\n        duster_root = pathlib.Path(datapoint_path) / f'duster-views-{views_selection_str}'\n        if self.use_duster_depths:\n            assert duster_root.exists(), f\"Duster root {duster_root} does not exist.\"\n            last_frame_scene_file = duster_root / f\"3d_model__{n_frames - 1:05d}__scene.npz\"\n            assert last_frame_scene_file.exists(), f\"Duster scene file {last_frame_scene_file} does not exist.\"\n\n        feats = None\n        feat_dim = None\n        feat_stride = None\n        if duster_root.exists() and (duster_root / f\"3d_model__{n_frames - 1:05d}__scene.npz\").exists():\n            duster_depths = []\n            duster_feats = []\n            for frame_idx in range(n_frames):\n                scene = np.load(duster_root / f\"3d_model__{frame_idx:05d}__scene.npz\")\n                duster_depth = torch.from_numpy(scene[\"depths\"])\n                duster_conf = torch.from_numpy(scene[\"confs\"])\n                duster_msk = torch.from_numpy(scene[\"cleaned_mask\"])\n                duster_feat = torch.from_numpy(scene[\"feats\"])\n\n                if self.clean_duster_depths:\n                    duster_depth = duster_depth * duster_msk\n\n                duster_depth = F.interpolate(duster_depth[:, None], (h, w), mode='nearest')\n                duster_depths.append(duster_depth[:, 0, :, :, None])\n                duster_feats.append(duster_feat)\n\n            feats = torch.stack(duster_feats, dim=1).numpy()\n            assert feats.ndim == 4\n            assert feats.shape[0] == n_views\n            assert feats.shape[1] == n_frames\n            feat_stride = np.round(np.sqrt(h * w / feats.shape[2])).astype(int)\n            feat_dim = feats.shape[3]\n            feats = feats.reshape(n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n\n            # Replace the depths with the Duster depths, if configured so\n            if self.use_duster_depths:\n                depths = torch.stack(duster_depths, dim=1).numpy()\n\n        n_tracks = traj3d_world.shape[1]\n        assert rgbs.shape == (n_views, n_frames, h, w, 3)\n        assert depths.shape == (n_views, n_frames, h, w, 1)\n        assert feats is None or feats.shape == (n_views, n_frames, h // feat_stride, w // feat_stride, feat_dim)\n        assert intrs.shape == (n_views, n_frames, 3, 3)\n        assert extrs.shape == (n_views, n_frames, 3, 4)\n        assert traj2d.shape == (n_views, n_frames, n_tracks, 2)\n        assert visibility.shape == (n_views, n_frames, n_tracks)\n        assert traj3d_world.shape == (n_frames, n_tracks, 3)\n\n        if novel_rgbs is not None:\n            assert novel_rgbs.shape == (len(self.novel_views), n_frames, h, w, 3)\n            assert novel_intrs.shape == (len(self.novel_views), n_frames, 3, 3)\n            assert novel_extrs.shape == (len(self.novel_views), n_frames, 3, 4)\n\n        # Make sure our intrinsics and extrinsics work correctly\n        point_3d_world = traj3d_world\n        point_4d_world_homo = np.concatenate([point_3d_world, np.ones_like(point_3d_world[..., :1])], axis=-1)\n        point_3d_camera = np.einsum('ABij,BCj->ABCi', extrs, point_4d_world_homo)\n        if self.perform_sanity_checks:\n            point_2d_pixel_homo = np.einsum('ABij,ABCj->ABCi', intrs, point_3d_camera)\n            point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:]\n            point_2d_pixel_gt = traj2d\n\n            point_2d_pixel_no_nan = np.nan_to_num(point_2d_pixel, nan=0)\n            point_2d_pixel_gt_no_nan = np.nan_to_num(point_2d_pixel_gt, nan=0)\n\n            assert np.allclose(point_2d_pixel_no_nan[0, :, 0, :], point_2d_pixel_no_nan[0, :, 0, :], atol=.01)\n            assert np.allclose(point_2d_pixel_gt_no_nan, point_2d_pixel_gt_no_nan, atol=.01), f\"Point projection failed\"\n        traj2d_w_z = np.concatenate([traj2d, point_3d_camera[..., 2:]], axis=-1)\n\n        rgbs = torch.from_numpy(rgbs).permute(0, 1, 4, 2, 3).float()\n        depths = torch.from_numpy(depths).permute(0, 1, 4, 2, 3).float()\n        feats = torch.from_numpy(feats).permute(0, 1, 4, 2, 3).float() if feats is not None else None\n        intrs = torch.from_numpy(intrs).float()\n        extrs = torch.from_numpy(extrs).float()\n        traj2d = torch.from_numpy(traj2d)\n        traj2d_w_z = torch.from_numpy(traj2d_w_z)\n        traj3d_world = torch.from_numpy(traj3d_world)\n        visibility = torch.from_numpy(visibility)\n        if novel_rgbs is not None:\n            novel_rgbs = torch.from_numpy(novel_rgbs).permute(0, 1, 4, 2, 3).float()\n            novel_intrs = torch.from_numpy(novel_intrs).float()\n            novel_extrs = torch.from_numpy(novel_extrs).float()\n\n        # Track selection\n        cache_root = os.path.join(self.data_root, self.seq_names[index], \"cache\")\n        os.makedirs(cache_root, exist_ok=True)\n        cache_file = os.path.join(cache_root, f\"{self.cache_name}.npz\")\n\n        # Check if we can use cached tracks\n        use_cache = bool(self.use_cached_tracks) and os.path.isfile(cache_file)\n        if use_cache:\n            cache = np.load(cache_file)\n            inds_sampled = torch.from_numpy(cache[\"track_indices\"])\n            traj2d_w_z = torch.from_numpy(cache[\"traj2d_w_z\"])\n            traj3d_world = torch.from_numpy(cache[\"traj3d_world\"])\n            visibility = torch.from_numpy(cache[\"visibility\"])\n            valids = torch.from_numpy(cache[\"valids\"])\n            query_points = torch.from_numpy(cache[\"query_points\"])\n\n        # Otherwise, sample the tracks and create query points\n        else:\n            # Prefer TAPVid-3D's merged query points when selecting the query points\n            # First, denote the points in time before the query points appeared as non-visible\n            # Second, choose the query points as the first appearance of the points in the selected views (which might be\n            # later than in the TAPVid-3D annotations because the query might not be visible in the selected views)\n            tapvid3d_merged_query_point_timestep = query_points_3d[:, 0].round().astype(int)\n            visibility *= (np.arange(n_frames)[None, :, None] >= tapvid3d_merged_query_point_timestep[None, None, :])\n\n            # Sample the points to track\n            visible_for_at_least_two_frames = visibility.any(0).sum(0) >= 2\n            valid_tracks = visible_for_at_least_two_frames\n            valid_tracks = valid_tracks.nonzero(as_tuple=False)[:, 0]\n\n            point_inds = torch.randperm(len(valid_tracks), generator=rnd_torch)\n            traj_per_sample = self.traj_per_sample if self.traj_per_sample is not None else len(point_inds)\n            assert len(point_inds) >= traj_per_sample\n            point_inds = point_inds[:traj_per_sample]\n            inds_sampled = valid_tracks[point_inds]\n\n            n_tracks = len(inds_sampled)\n            traj2d = traj2d[:, :, inds_sampled].float()\n            traj2d_w_z = traj2d_w_z[:, :, inds_sampled].float()\n            traj3d_world = traj3d_world[:, inds_sampled].float()\n            visibility = visibility[:, :, inds_sampled]\n\n            valids = ~torch.isnan(traj2d).any(dim=-1).any(dim=0)\n\n            # Create the query points\n            gt_visibilities_any_view = visibility.any(dim=0)\n            assert (gt_visibilities_any_view.sum(dim=0) >= 2).all(), \"All points should be visible in least two frames.\"\n            last_visible_index = (torch.arange(n_frames).unsqueeze(-1) * gt_visibilities_any_view).max(0).values\n            assert gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)].all()\n            gt_visibilities_any_view[last_visible_index[None, :], torch.arange(n_tracks)] = False\n            assert (gt_visibilities_any_view.sum(dim=0) >= 1).all()\n\n            query_points_t = torch.argmax(gt_visibilities_any_view.float(), dim=0)\n            query_points_xyz_worldspace = traj3d_world[query_points_t, torch.arange(n_tracks)]\n            query_points = torch.cat([query_points_t[:, None], query_points_xyz_worldspace], dim=1)\n            assert gt_visibilities_any_view[query_points_t, torch.arange(n_tracks)].all()\n\n            # Replace nans with zeros\n            traj2d[torch.isnan(traj2d)] = 0\n            traj2d_w_z[torch.isnan(traj2d_w_z)] = 0\n            traj3d_world[torch.isnan(traj3d_world)] = 0\n            assert torch.isnan(visibility).sum() == 0\n\n            # Cache the selected tracks and query points\n            if self.use_cached_tracks:\n                logging.warn(f\"Caching tracks for {self.seq_names[index]} at {os.path.abspath(cache_file)}\")\n                np.savez_compressed(\n                    cache_file,\n                    track_indices=inds_sampled.numpy(),\n                    traj2d_w_z=traj2d_w_z.numpy(),\n                    traj3d_world=traj3d_world.numpy(),\n                    visibility=visibility.numpy(),\n                    valids=valids.numpy(),\n                    query_points=query_points.numpy(),\n                )\n\n        # Normalize the scene to be similar to Kubric's scene\n        scale = 2.5\n        rot_x = R.from_euler('x', -90, degrees=True).as_matrix()\n        rot_y = R.from_euler('y', 0, degrees=True).as_matrix()\n        rot_z = R.from_euler('z', 0, degrees=True).as_matrix()\n        rot = torch.from_numpy(rot_z @ rot_y @ rot_x)\n        translate = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)\n        (\n            depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n        ) = transform_scene(scale, rot, translate, depths, extrs, query_points, traj3d_world, traj2d_w_z)\n        novel_extrs_trans = transform_scene(scale, rot, translate, None, novel_extrs, None, None, None)[1]\n\n        # # Use the auto scene normalization of generic scenes\n        # from mvtracker.datasets.generic_scene_dataset import compute_auto_scene_normalization\n        # scale, rot, translation = compute_auto_scene_normalization(depths, torch.ones_like(depths) * 100, extrs_trans, intrs)\n        # scale = scale * T[0, 0].item()\n        # print(f\"{scale=}\")\n        # (depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n        # ) = transform_scene(scale, rot, translation, depths_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans)\n        # _, novel_extrs_trans, _, _, _ = transform_scene(scale, rot, translation, None, novel_extrs_trans, None, None, None)\n        # 85.7 94.5 92.3 --> 86.0 94.8 92.2\n\n        # from mvtracker.datasets.dexycb_multiview_dataset import rerun_viz_scene\n        # rerun_viz_scene(\"nane/pc__no_transform/\", rgbs[:, ::20], depths[:, ::20], intrs[:, ::20], extrs[:, ::20], traj3d_world[:, ::20], 0.1)\n        # rerun_viz_scene(\"nane/pc_transformed/\", rgbs[:, ::20], depths[:, ::20], intrs[:, ::20], extrs_trans[:, ::20], traj3d_world_trans[:, ::20], 1)\n\n        segs = torch.ones((n_frames, 1, h, w))  # Dummy segmentation masks\n        datapoint = Datapoint(\n            video=rgbs,\n            videodepth=depths_trans,\n            feats=feats,\n            segmentation=segs,\n            trajectory=traj2d_w_z_trans,\n            trajectory_3d=traj3d_world_trans,\n            trajectory_category=None,\n            visibility=visibility,\n            valid=valids,\n            seq_name=self.seq_names[index],\n            intrs=intrs,\n            extrs=extrs_trans,\n            query_points=None,\n            query_points_3d=query_points_trans,\n            track_upscaling_factor=1 / scale,\n\n            novel_video=novel_rgbs,\n            novel_intrs=novel_intrs,\n            novel_extrs=novel_extrs_trans,\n        )\n        return datapoint\n"
  },
  {
    "path": "mvtracker/datasets/tap_vid_datasets.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport glob\nimport io\nimport logging\nimport os\nimport pickle\nimport re\nimport sys\nfrom pathlib import Path\nfrom typing import *\n\nimport matplotlib\nimport mediapy as media\nimport numpy as np\nimport rerun as rr\nimport torch\nfrom PIL import Image\nfrom scipy.spatial.transform import Rotation as R\n\nfrom mvtracker.datasets.utils import Datapoint, transform_scene\n\nDatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]\n\n\ndef resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:\n    \"\"\"Resize a video to output_size.\"\"\"\n    # If you have a GPU, consider replacing this with a GPU-enabled resize op,\n    # such as a jitted jax.image.resize.  It will make things faster.\n    return media.resize_video(video, output_size)\n\n\ndef sample_queries_first(\n        target_occluded: np.ndarray,\n        target_points: np.ndarray,\n        frames: np.ndarray,\n) -> Mapping[str, np.ndarray]:\n    \"\"\"Package a set of frames and tracks for use in TAPNet evaluations.\n    Given a set of frames and tracks with no query points, use the first\n    visible point in each track as the query.\n    Args:\n      target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],\n        where True indicates occluded.\n      target_points: Position, of shape [n_tracks, n_frames, 2], where each point\n        is [x,y] scaled between 0 and 1.\n      frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between\n        -1 and 1.\n    Returns:\n      A dict with the keys:\n        video: Video tensor of shape [1, n_frames, height, width, 3]\n        query_points: Query points of shape [1, n_queries, 3] where\n          each point is [t, y, x] scaled to the range [-1, 1]\n        target_points: Target points of shape [1, n_queries, n_frames, 2] where\n          each point is [x, y] scaled to the range [-1, 1]\n    \"\"\"\n    valid = np.sum(~target_occluded, axis=1) > 0\n    target_points = target_points[valid, :]\n    target_occluded = target_occluded[valid, :]\n\n    query_points = []\n    for i in range(target_points.shape[0]):\n        index = np.where(target_occluded[i] == 0)[0][0]\n        x, y = target_points[i, index, 0], target_points[i, index, 1]\n        query_points.append(np.array([index, x, y]))  # [t, x, y]\n    query_points = np.stack(query_points, axis=0)\n\n    return {\n        \"video\": frames[np.newaxis, ...],\n        \"query_points\": query_points[np.newaxis, ...],\n        \"target_points\": target_points[np.newaxis, ...],\n        \"occluded\": target_occluded[np.newaxis, ...],\n    }\n\n\ndef sample_queries_strided(\n        target_occluded: np.ndarray,\n        target_points: np.ndarray,\n        frames: np.ndarray,\n        query_stride: int = 5,\n) -> Mapping[str, np.ndarray]:\n    \"\"\"Package a set of frames and tracks for use in TAPNet evaluations.\n\n    Given a set of frames and tracks with no query points, sample queries\n    strided every query_stride frames, ignoring points that are not visible\n    at the selected frames.\n\n    Args:\n      target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],\n        where True indicates occluded.\n      target_points: Position, of shape [n_tracks, n_frames, 2], where each point\n        is [x,y] scaled between 0 and 1.\n      frames: Video tensor, of shape [n_frames, height, width, 3].  Scaled between\n        -1 and 1.\n      query_stride: When sampling query points, search for un-occluded points\n        every query_stride frames and convert each one into a query.\n\n    Returns:\n      A dict with the keys:\n        video: Video tensor of shape [1, n_frames, height, width, 3].  The video\n          has floats scaled to the range [-1, 1].\n        query_points: Query points of shape [1, n_queries, 3] where\n          each point is [t, y, x] scaled to the range [-1, 1].\n        target_points: Target points of shape [1, n_queries, n_frames, 2] where\n          each point is [x, y] scaled to the range [-1, 1].\n        trackgroup: Index of the original track that each query point was\n          sampled from.  This is useful for visualization.\n    \"\"\"\n    tracks = []\n    occs = []\n    queries = []\n    trackgroups = []\n    total = 0\n    trackgroup = np.arange(target_occluded.shape[0])\n    for i in range(0, target_occluded.shape[1], query_stride):\n        mask = target_occluded[:, i] == 0\n        query = np.stack(\n            [\n                i * np.ones(target_occluded.shape[0:1]),\n                target_points[:, i, 1],\n                target_points[:, i, 0],\n            ],\n            axis=-1,\n        )\n        queries.append(query[mask])\n        tracks.append(target_points[mask])\n        occs.append(target_occluded[mask])\n        trackgroups.append(trackgroup[mask])\n        total += np.array(np.sum(target_occluded[:, i] == 0))\n\n    return {\n        \"video\": frames[np.newaxis, ...],\n        \"query_points\": np.concatenate(queries, axis=0)[np.newaxis, ...],\n        \"target_points\": np.concatenate(tracks, axis=0)[np.newaxis, ...],\n        \"occluded\": np.concatenate(occs, axis=0)[np.newaxis, ...],\n        \"trackgroup\": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],\n    }\n\n\nclass TapVidDataset(torch.utils.data.Dataset):\n\n    @staticmethod\n    def from_name(dataset_name: str, dataset_root: str):\n        \"\"\"\n        Examples of datasets supported by this factory method:\n        - tapvid2d-davis-nodepth\n        - tapvid2d-davis-moge\n        - tapvid2d-davis-zoedepth\n        - tapvid2d-davis-videodepthanything\n        - tapvid2d-davis-megasam\n        - tapvid2d-davis-mogewithextrinsics\n        - tapvid2d-davis-mogewithextrinsics-256x256\n        - tapvid2d-davis-mogewithextrinsics-256x256-single\n        \"\"\"\n        if dataset_name.startswith(\"tapvid2d-davis-\"):\n            # Parse the dataset name, chunk by chunk\n            non_parsed = dataset_name.replace(\"tapvid2d-davis-\", \"\", 1)\n\n            # Extract depth estimator (until first possible resolution or single flag)\n            match = re.match(r\"([^-]+)\", non_parsed)\n            assert match is not None\n            depth_estimator_name = match.group(1)\n            non_parsed = non_parsed.replace(depth_estimator_name, \"\", 1)\n\n            # Extract resolution\n            resize_to = None\n            match = re.search(r\"-([0-9]+x[0-9]+)\", non_parsed)\n            if match:\n                width, height = map(int, match.group(1).split(\"x\"))\n                resize_to = (width, height)\n                non_parsed = non_parsed.replace(match.group(0), \"\", 1)\n\n            # Check for single point flag\n            single_point = \"-single\" in non_parsed\n            non_parsed = non_parsed.replace(\"-single\", \"\", 1) if single_point else non_parsed\n\n            # Ensure no unparsed parts left\n            assert non_parsed == \"\", f\"Unparsed part of the dataset name: {non_parsed}\"\n\n            data_root = os.path.join(dataset_root, \"tapvid_davis/tapvid_davis.pkl\")\n            return TapVidDataset(\n                dataset_type=\"davis\",\n                data_root=data_root,\n                resize_to=resize_to,\n                queried_first=True,\n                depth_estimator_name=depth_estimator_name,\n                depth_estimator_batch_size=2,\n                depth_estimator_device=\"cuda\",\n                stream_rerun_depth_viz=False,\n                save_rerun_depth_viz=False,\n            )\n\n    def __init__(\n            self,\n            data_root,\n            dataset_type=\"davis\",\n            resize_to=(256, 256),\n            queried_first=True,\n            depth_estimator_name=\"moge-with-extrinsics\",\n            depth_estimator_batch_size=2,\n            depth_estimator_device=\"cuda\",\n            stream_rerun_depth_viz=False,\n            save_rerun_depth_viz=False,\n    ):\n        self.dataset_type = dataset_type\n        self.resize_to = resize_to\n        self.queried_first = queried_first\n        if self.dataset_type == \"kinetics\":\n            self.depth_cache_root = os.path.join(data_root, \"depth_cache\")\n        else:\n            self.depth_cache_root = os.path.join(os.path.dirname(data_root), \"depth_cache\")\n        os.makedirs(self.depth_cache_root, exist_ok=True)\n        if self.dataset_type == \"kinetics\":\n            all_paths = glob.glob(os.path.join(data_root, \"*_of_0010.pkl\"))\n            points_dataset = []\n            for pickle_path in all_paths:\n                with open(pickle_path, \"rb\") as f:\n                    data = pickle.load(f)\n                    points_dataset = points_dataset + data\n            self.points_dataset = points_dataset\n        else:\n            with open(data_root, \"rb\") as f:\n                self.points_dataset = pickle.load(f)\n            if self.dataset_type == \"davis\":\n                self.video_names = list(self.points_dataset.keys())\n        logging.info(\"found %d unique videos in %s\" % (len(self.points_dataset), data_root))\n\n        self.depth_estimator_name = depth_estimator_name\n        self.depth_estimator_batch_size = depth_estimator_batch_size\n        self.depth_estimator_device = depth_estimator_device\n\n        self.stream_rerun_depth_viz = stream_rerun_depth_viz\n        self.save_rerun_depth_viz = save_rerun_depth_viz\n\n        # # Dummy call all items to generate rerun visualizations\n        # self.stream_rerun_depth_viz = False\n        # self.save_rerun_depth_viz = True\n        # for i in tqdm(range(len(self.points_dataset))):\n        #     try:\n        #         self[i]\n        #     except Exception as e:\n        #         logging.error(f\"Error processing video {i}: {e}\")\n        #         logging.info(f\"But we continue anyway\")\n        #         continue\n        # exit()\n\n    def __getitem__(self, index):\n        if self.dataset_type == \"davis\":\n            video_name = self.video_names[index]\n        else:\n            video_name = index\n        frames = self.points_dataset[video_name][\"video\"].copy()\n\n        if isinstance(frames[0], bytes):\n            # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.\n            def decode(frame):\n                byteio = io.BytesIO(frame)\n                img = Image.open(byteio)\n                return np.array(img)\n\n            frames = np.array([decode(frame) for frame in frames])\n\n        target_points = self.points_dataset[video_name][\"points\"].copy()\n        if self.resize_to is not None:\n            frames = resize_video(frames, self.resize_to)\n            target_points *= np.array([self.resize_to[1] - 1, self.resize_to[0] - 1])\n        else:\n            target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])\n        assert target_points[:, :, 0].min() >= 0\n        assert target_points[:, :, 0].max() <= frames.shape[2] - 1\n        assert target_points[:, :, 1].min() >= 0\n        assert target_points[:, :, 1].max() <= frames.shape[1] - 1\n\n        T, H, W, C = frames.shape\n        N, T, D = target_points.shape\n\n        target_occ = self.points_dataset[video_name][\"occluded\"].copy()\n        if self.queried_first:\n            converted = sample_queries_first(target_occ, target_points, frames)\n        else:\n            converted = sample_queries_strided(target_occ, target_points, frames)\n        assert converted[\"target_points\"].shape[1] == converted[\"query_points\"].shape[1]\n\n        trajs = (torch.from_numpy(converted[\"target_points\"])[0].permute(1, 0, 2).float())  # T, N, D\n\n        rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()\n        visibles = torch.logical_not(torch.from_numpy(converted[\"occluded\"]))[0].permute(1, 0)  # T, N\n        query_points_2d = torch.from_numpy(converted[\"query_points\"])[0]  # T, N\n\n        # Let's estimate depths RIGHT HERE\n        res = f\"{H}x{W}\"\n        cached_file_zoedepth_nk = os.path.join(self.depth_cache_root, f\"zoedepth_nk__{video_name}__{res}.npz\")\n        cached_file_moge = os.path.join(self.depth_cache_root, f\"moge__{video_name}__{res}.npz\")\n        cached_file_megasam = os.path.join(self.depth_cache_root, f\"megasam__{video_name}__{res}-v1.npz\")\n        if self.depth_estimator_name == \"nodepth\":\n            depth = np.ones((T, H, W))\n            intrs = np.eye(3) * max(H, W)\n            extrs = np.eye(4)[None].repeat(T, axis=0)\n        elif self.depth_estimator_name == \"zoedepth\":\n            depth = zoedepth_nk(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n                                cached_file_zoedepth_nk)\n            _, intrs, _, _, _ = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n                                     cached_file_moge)\n            extrs = np.eye(4)[None].repeat(T, axis=0)\n        elif self.depth_estimator_name == \"moge\":\n            depth, intrs, _, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n                                            cached_file_moge)\n            depth[~mask] = 0\n            extrs = np.eye(4)[None].repeat(T, axis=0)\n        elif self.depth_estimator_name == \"mogewithextrinsics\":\n            depth, intrs, extrs, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n                                                cached_file_moge)\n            depth[~mask] = 0\n        elif self.depth_estimator_name == \"videodepthanything\":\n            raise NotImplementedError(\"videodepthanything is not implemented yet\")\n        elif self.depth_estimator_name == \"megasam\":\n            try:\n                depth, intrs, extrs = megasam(\n                    rgbs=rgbs,\n                    batch_size=self.depth_estimator_batch_size,\n                    device=self.depth_estimator_device,\n                    cached_file=cached_file_megasam,\n                )\n            except Exception as e:\n                logging.error(f\"MegaSAM error for {video_name} ({rgbs.shape=}) (we will use moge depth instead): {e}\")\n                depth, intrs, extrs, _, mask = moge(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n                                                    cached_file_moge)\n                depth[~mask] = 0\n        else:\n            raise NotImplementedError\n\n        depth = torch.from_numpy(depth).float()\n        if intrs.ndim == 2:\n            intrs = intrs[None].repeat(T, axis=0)\n        intrs = torch.from_numpy(intrs).float()\n        extrs_square = torch.from_numpy(extrs).float()\n        extrs = extrs_square[:, :3, :]\n\n        intrs_inv = torch.inverse(intrs)\n        extrs_inv = torch.inverse(extrs_square)\n\n        # Project trajectories to 3D\n        trajs_depth = trajs.new_ones((T, N, 1)) * np.inf\n        for t in range(T):\n            # # V1: Not good enough, depths are jumping to the background near edges because of interpolation\n            # trajs_depth[t] = bilinear_sample2d(\n            #     im=depth[t][None, None],\n            #     x=trajs[t, :, 0][None],\n            #     y=trajs[t, :, 1][None],\n            # )[0].permute(1, 0).type(trajs_depth.dtype)\n\n            # V2: Still not good, taking the closest pixel only (without interpolating) still has jumps at edges\n            x_nearest = trajs[t, :, 0].round().long()\n            y_nearest = trajs[t, :, 1].round().long()\n            depth_nearest = depth[t].view(-1)[(y_nearest * W + x_nearest).view(-1)]\n            depth_nearest = depth_nearest.view(1, -1).type(trajs_depth.dtype).permute(1, 0)\n            trajs_depth[t] = depth_nearest\n\n            # # V3: Taking the minimum depth value of the neighbors also fails when there are other things in front.\n            # depth_pad = F.pad(depth[t][None, None], (1, 1, 1, 1), mode=\"replicate\")  # Pad to handle edges\n            # depth_min = -F.max_pool2d(-depth_pad, kernel_size=9, stride=1)  # Min pooling using negation\n            # depth_min_sampled = depth_min[0, 0, trajs[t, :, 1].long(), trajs[t, :, 0].long()].type(trajs_depth.dtype)\n            # trajs_depth[t] = depth_min_sampled[:, None]\n        assert torch.all(torch.isfinite(trajs_depth)).item()\n        trajs_camera = torch.einsum(\"Tij,TNj->TNi\", intrs_inv, to_homogenous_torch(trajs)) * trajs_depth\n        trajs_world = torch.einsum(\"Tij,TNj->TNi\", extrs_inv, to_homogenous_torch(trajs_camera))[..., :3]\n        trajs_3d = trajs_world\n\n        trajs_w_z = torch.cat([trajs, trajs_depth], dim=2)\n\n        # Project query points to 3D\n        qp_t = query_points_2d[:, 0].float()\n        qp_xyz_pixel = query_points_2d[:, 1:].float()\n        qp_depth = qp_xyz_pixel.new_ones((N, 1)) * np.inf\n        qp_xyz_world = qp_xyz_pixel.new_ones((N, 3)) * np.inf\n        for t in range(T):\n            qp_mask = qp_t == t\n            if qp_mask.sum() == 0:\n                continue\n\n            # V2 depth interpolation\n            x_nearest = qp_xyz_pixel[qp_mask, 0].round().long()\n            y_nearest = qp_xyz_pixel[qp_mask, 1].round().long()\n            depth_nearest = depth[t].view(-1)[(y_nearest * W + x_nearest).view(-1)]\n            depth_nearest = depth_nearest.view(1, -1).type(trajs_depth.dtype).permute(1, 0)\n            qp_depth[qp_mask] = depth_nearest\n\n            qp_xyz_pixel_t = to_homogenous_torch(qp_xyz_pixel[qp_mask])\n            qp_xyz_camera_t = torch.einsum(\"ij,Nj->Ni\", intrs_inv[t], qp_xyz_pixel_t) * qp_depth[qp_mask]\n            qp_xyz_world_t = torch.einsum(\"ij,Nj->Ni\", extrs_inv[t], to_homogenous_torch(qp_xyz_camera_t))[..., :3]\n            qp_xyz_world[qp_mask] = qp_xyz_world_t\n        assert torch.all(torch.isfinite(qp_depth))\n        assert torch.all(torch.isfinite(qp_xyz_world))\n        query_points_3d = torch.cat([qp_t[:, None], qp_xyz_world], dim=1)\n\n        # Visualize the depth estimation in Rerun\n        radii_scale = 0.1\n        streams = []\n        if self.stream_rerun_depth_viz: streams += [True]\n        if self.save_rerun_depth_viz: streams += [False]\n        for stream in streams:\n            # depth_zoedepth = zoedepth_nk(rgbs, self.depth_estimator_batch_size, self.depth_estimator_device,\n            #                              cached_file_zoedepth_nk)\n            depth_moge, intrinsics_moge, w2c_moge, _, mask_moge = moge(\n                rgbs=rgbs,\n                batch_size=self.depth_estimator_batch_size,\n                device=self.depth_estimator_device,\n                cached_file=cached_file_moge,\n            )\n\n            # TODO: But what intrinsics did Zoe really assume or use, if any?\n            K = intrinsics_moge\n            K_inv = np.linalg.inv(K)\n\n            rr.init(\"TAPVid-2D Estimated Depths\", recording_id=\"v0.1\")\n            if stream:\n                rr.connect_tcp()\n            rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n\n            rr.set_time_seconds(\"frame\", 0)\n            rr.log(\n                \"world/xyz\",\n                rr.Arrows3D(\n                    vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                    colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n                ),\n            )\n            for t in range(T):\n                rr.set_time_seconds(\"frame\", t / 12)\n                rgb = rgbs[t].permute(1, 2, 0).numpy()\n\n                # Log the depth used for 3D tracking\n                rr.log(f\"{video_name}/image/depth_for_tracking\", rr.Pinhole(\n                    image_from_camera=intrs[t].numpy(),\n                    width=W,\n                    height=H,\n                ))\n                rr.log(f\"{video_name}/image/depth_for_tracking\", rr.Transform3D(\n                    translation=np.linalg.inv(extrs_square[t].numpy())[:3, 3],\n                    mat3x3=np.linalg.inv(extrs_square[t].numpy())[:3, :3],\n                ))\n                rr.log(f\"{video_name}/image/depth_for_tracking/depth\", rr.DepthImage(\n                    image=depth[t].numpy(),\n                    point_fill_ratio=0.2,\n                ))\n                rr.log(f\"{video_name}/image/depth_for_tracking/rgb\", rr.Image(rgb))\n\n                # Log all other depth maps\n                # d_zoe = depth_zoedepth[t, 0]\n                d_moge = depth_moge[t]\n                c2w_moge = np.linalg.inv(w2c_moge[t])\n                # for name, archetype in [\n                #     (\"depth-zoe\", rr.DepthImage(d_zoe, point_fill_ratio=0.2)),\n                #     (\"depth-moge\", rr.DepthImage(d_moge, point_fill_ratio=0.2)),\n                #     (\"depth-moge-with-extrinsics\", rr.DepthImage(d_moge, point_fill_ratio=0.2)),\n                # ]:\n                #     rr.log(f\"{video_name}/image/{name}\", rr.Pinhole(image_from_camera=K, width=W, height=H))\n                #     rr.log(f\"{video_name}/image/{name}/{name}\", archetype)\n                #     if name == \"depth-moge-with-extrinsics\":\n                #         transform = rr.Transform3D(translation=c2w_moge[:3, 3], mat3x3=c2w_moge[:3, :3])\n                #         rr.log(f\"{video_name}/image/{name}\", transform)\n\n                # Convert depth map to 3D point cloud\n                y, x = np.indices((H, W))\n                homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n\n                for _name, _depth, _w2c in [\n                    (\"used_for_tracking\", depth[t].numpy(), extrs_square[t]),\n                    # (\"zoe\", d_zoe, None),\n                    (\"moge\", d_moge, w2c_moge[t]),\n                    (\"moge-with-extrinsics\", d_moge, w2c_moge[t]),\n                ]:\n                    depth_values = _depth.ravel()\n                    cam_coords = (K_inv @ homo_pixel_coords) * depth_values\n                    if _w2c is None:\n                        world_coords = cam_coords.T\n                    else:\n                        world_coords = from_homogeneous(\n                            np.einsum(\"ij,Nj->Ni\", np.linalg.inv(_w2c), to_homogeneous(cam_coords.T)))\n                    valid_mask = depth_values > 0\n                    world_coords = world_coords[valid_mask]\n                    rgb_colors = rgb.reshape(-1, 3)[valid_mask].astype(np.uint8)\n                    rr.log(f\"{video_name}/pointcloud/{_name}\",\n                           rr.Points3D(world_coords, colors=rgb_colors, radii=0.001))\n\n            def log_tracks(\n                    tracks: np.ndarray,\n                    visibles: np.ndarray,\n                    query_timestep: np.ndarray,\n                    colors: np.ndarray,\n                    track_names=None,\n\n                    entity_format_str=\"{}\",\n\n                    log_points=True,\n                    points_radii=0.03 * radii_scale,\n                    invisible_color=[0., 0., 0.],\n\n                    log_line_strips=True,\n                    max_strip_length_past=6,\n                    max_strip_length_future=1,\n                    hide_invisible_strips=True,\n                    strips_radii=0.0027 * radii_scale,\n\n                    log_error_lines=False,\n                    error_lines_radii=0.0042 * radii_scale,\n                    error_lines_color=[1., 0., 0.],\n                    gt_for_error_lines=None,\n            ) -> None:\n                \"\"\"\n                Log tracks to Rerun.\n\n                Parameters:\n                    tracks: Shape (T, N, 3), the 3D trajectories of points.\n                    visibles: Shape (T, N), boolean visibility mask for each point at each timestep.\n                    query_timestep: Shape (T, N), the frame index after which the tracks start.\n                    colors: Shape (N, 4), RGBA colors for each point.\n                    entity_prefix: String prefix for entity hierarchy in Rerun.\n                    entity_suffix: String suffix for entity hierarchy in Rerun.\n                \"\"\"\n\n                T, N, _ = tracks.shape\n                assert tracks.shape == (T, N, 3)\n                assert visibles.shape == (T, N)\n                assert query_timestep.shape == (N,)\n                assert query_timestep.min() >= 0\n                assert query_timestep.max() < T\n                assert colors.shape == (N, 4)\n\n                for n in range(N):\n                    track_name = track_names[n] if track_names is not None else f\"track-{n}\"\n                    rr.log(entity_format_str.format(track_name, rr.Clear(recursive=True)))\n                    for t in range(query_timestep[n], T):\n                        rr.set_time_seconds(\"frame\", t / 12)\n\n                        # Log the point (special handling for invisible points)\n                        if log_points:\n                            rr.log(\n                                entity_format_str.format(f\"{track_name}/point\"),\n                                rr.Points3D(\n                                    positions=[tracks[t, n]],\n                                    colors=[colors[n, :3]] if visibles[t, n] else [invisible_color],\n                                    radii=points_radii,\n                                ),\n                            )\n\n                        # Log line segments for visible tracks\n                        if log_line_strips and t > query_timestep[n]:\n                            strip_t_start = max(t - max_strip_length_past, query_timestep[n].item())\n                            strip_t_end = min(t + max_strip_length_future, T - 1)\n\n                            if not hide_invisible_strips:\n                                strips = np.stack([\n                                    tracks[strip_t_start:strip_t_end, n],\n                                    tracks[strip_t_start + 1:strip_t_end + 1, n],\n                                ], axis=-2)\n                                strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n]\n                                strips_colors = np.where(\n                                    strips_visibility[:, None],\n                                    colors[None, n, :3],\n                                    [invisible_color],\n                                )\n                            else:\n                                point_sequence = tracks[strip_t_start:strip_t_end + 1, n]\n                                point_sequence_visible = point_sequence[visibles[strip_t_start:strip_t_end + 1, n]]\n                                strips = np.stack([point_sequence_visible[:-1], point_sequence_visible[1:]], axis=-2)\n                                strips_colors = colors[None, n, :3]\n\n                            rr.log(\n                                entity_format_str.format(f\"{track_name}/line\"),\n                                rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii),\n                            )\n\n                        if log_error_lines:\n                            assert gt_for_error_lines is not None\n                            strips = np.stack([\n                                tracks[t, n],\n                                gt_for_error_lines[t, n],\n                            ], axis=-2)\n                            rr.log(\n                                entity_format_str.format(f\"{track_name}/error\"),\n                                rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii),\n                            )\n\n            # Log the tracks\n            trajs_3d_np = trajs_3d.cpu().numpy()\n            visibles_np = visibles.cpu().numpy()\n            query_timestep_np = query_points_3d[:, 0].cpu().numpy().round().astype(int)\n            cmap = matplotlib.colormaps[\"gist_rainbow\"]\n            norm = matplotlib.colors.Normalize(vmin=trajs_3d_np[..., 0].min(), vmax=trajs_3d_np[..., 0].max())\n            track_color = cmap(norm(trajs_3d_np[-1, :, 0]))\n            # track_color = track_color * 0 + 1  # Just make all tracks white\n\n            log_tracks(\n                tracks=trajs_3d_np,\n                visibles=visibles_np,\n                query_timestep=query_timestep_np,\n                colors=track_color,\n                entity_format_str=f\"{video_name}/tracks/{{}}\",\n                max_strip_length_future=0,\n            )\n\n            if not stream:\n                rr_rrd_path = os.path.join(self.depth_cache_root, f\"rerun_viz__{video_name}.rrd\")\n                rr.save(rr_rrd_path)\n                logging.info(f\"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}\")\n\n        V = 1\n        rgbs = rgbs[None]\n        trajs = trajs[None]\n        trajs_w_z = trajs_w_z[None]\n        trajs_3d = trajs_3d\n        query_points_3d = query_points_3d\n        visibles = visibles[None]\n        depth = depth[None, :, None]\n        feats = None\n        intrs = intrs[None]\n        extrs = extrs[None]\n\n        assert rgbs.shape == (V, T, 3, H, W)\n        assert depth.shape == (V, T, 1, H, W)\n        assert feats is None\n        assert intrs.shape == (V, T, 3, 3)\n        assert extrs.shape == (V, T, 3, 4)\n        assert trajs.shape == (V, T, N, 2)\n        assert trajs_w_z.shape == (V, T, N, 3)\n        assert visibles.shape == (V, T, N)\n        assert trajs_3d.shape == (T, N, 3)\n        assert query_points_3d.shape == (N, 4)\n\n        # Normalize the scene to be similar to training scenes\n        rot_x = R.from_euler('x', -90, degrees=True).as_matrix()\n        rot_y = R.from_euler('y', 0, degrees=True).as_matrix()\n        rot_z = R.from_euler('z', 0, degrees=True).as_matrix()\n        rot = rot_z @ rot_y @ rot_x\n        T_rot = torch.eye(4)\n        T_rot[:3, :3] = torch.from_numpy(rot)\n\n        ## V1: GT track-agnostic transformation\n        # scale = 10\n        # translate_x = 0\n        # translate_y = -15\n        # translate_z = 2\n        #\n        # T_scale_and_translate = torch.tensor([\n        #     [scale, 0.0, 0.0, translate_x],\n        #     [0.0, scale, 0.0, translate_y],\n        #     [0.0, 0.0, scale, translate_z],\n        #     [0.0, 0.0, 0.0, 1.0],\n        # ], dtype=torch.float32)\n\n        ## V2: GT track-aware transformation\n        # Rotate the 3D GT tracks first\n        trajs_3d_homo = torch.cat([trajs_3d, torch.ones_like(trajs_3d[..., :1])], dim=-1)\n        trajs_3d_rotated = torch.einsum('ij,TNj->TNi', T_rot, trajs_3d_homo)[..., :3]\n\n        # Mask out non-visible points\n        visible_mask = visibles[0]  # (T, N)\n        trajs_3d_visible = trajs_3d_rotated[visible_mask]  # (V, 3)\n\n        # Compute bbox over only visible points\n        bbox_min = trajs_3d_visible.amin(dim=0)\n        bbox_max = trajs_3d_visible.amax(dim=0)\n        bbox_center = (bbox_min + bbox_max) / 2\n        bbox_size = bbox_max - bbox_min\n\n        # Target bounds (half-extent of desired cube)\n        target_bounds = torch.tensor([10.0, 10.0, 6.0])\n        scale = (target_bounds / bbox_size).min().item()\n        translation = -bbox_center * scale\n        rot = torch.from_numpy(rot)\n\n        # Optional: clamp depth map if needed (max Z-depth defined in scaled space)\n        logging.info(f\"[datapoint_idx={index}] Scale={scale:.2f}, Translate={translation.tolist()}\")\n        # depth[depth > 50 / scale] = 50 / scale\n        depth[depth > 20] = 20\n\n        # Apply to scene\n        (\n            depth_trans, extrs_trans, query_points_3d_trans, trajs_3d_trans, trajs_w_z_trans\n        ) = transform_scene(scale, rot, translation, depth, extrs, query_points_3d, trajs_3d, trajs_w_z)\n        assert torch.allclose(trajs_w_z[..., :2], trajs_w_z_trans[..., :2])\n\n        gotit = True\n        return Datapoint(\n            video=rgbs,\n            videodepth=depth_trans,\n            feats=None,\n            segmentation=torch.ones(T, 1, H, W).float(),\n            trajectory=trajs_w_z_trans,\n            trajectory_3d=trajs_3d_trans,\n            visibility=visibles,\n            valid=torch.ones((T, N)),\n            seq_name=str(video_name),\n            intrs=intrs,\n            extrs=extrs_trans,\n            query_points=query_points_2d,\n            query_points_3d=query_points_3d_trans,\n        ), gotit\n\n    def __len__(self):\n        return len(self.points_dataset)\n\n\n@torch.no_grad()\ndef zoedepth_nk(rgbs, batch_size=2, device=\"cuda\", cached_file=None):\n    if cached_file is not None and os.path.exists(cached_file):\n        return np.load(cached_file)[\"depth\"]\n\n    # needs timm==0.6.7, but megasam needs timm==1.0.15\n    model = torch.hub.load(\"isl-org/ZoeDepth\", \"ZoeD_NK\", pretrained=True).to(device)\n    model.eval()\n\n    T, _, H, W = rgbs.shape\n    depth = []\n    for i in range(0, T, batch_size):\n        rgbs_i = rgbs[i:i + batch_size].to(device) / 255.\n        depth_i = model.infer(rgbs_i).clamp(0.01, 65.0).cpu()\n        depth.append(depth_i)\n    depth = torch.cat(depth, dim=0).numpy()[:, 0]\n\n    if cached_file is not None:\n        np.savez(cached_file, depth=depth)\n\n    del model\n    torch.cuda.empty_cache()\n\n    return depth\n\n\ndef rigid_registration(\n        p: np.ndarray,\n        q: np.ndarray,\n        w: np.ndarray = None,\n        eps: float = 1e-12\n) -> Tuple[float, np.ndarray, np.ndarray]:\n    from moge.utils.geometry_numpy import weighted_mean_numpy\n\n    if w is None:\n        w = np.ones(p.shape[0])\n    centroid_p = weighted_mean_numpy(p, w[:, None], axis=0)\n    centroid_q = weighted_mean_numpy(q, w[:, None], axis=0)\n\n    p_centered = p - centroid_p\n    q_centered = q - centroid_q\n    w = w / (np.sum(w) + eps)\n\n    cov = (w[:, None] * p_centered).T @ q_centered\n    U, S, Vh = np.linalg.svd(cov)\n    R = Vh.T @ U.T\n    if np.linalg.det(R) < 0:\n        Vh[2, :] *= -1\n        R = Vh.T @ U.T\n    scale = np.sum(S) / np.trace((w[:, None] * p_centered).T @ p_centered)\n    t = centroid_q - scale * (centroid_p @ R.T)\n    return scale, R, t\n\n\ndef rigid_registration_ransac(\n        p: np.ndarray,\n        q: np.ndarray,\n        w: np.ndarray = None,\n        max_iters: int = 20,\n        hypothetical_size: int = 10,\n        inlier_thresh: float = 0.02\n) -> Tuple[Tuple[float, np.ndarray, np.ndarray], np.ndarray]:\n    n = p.shape[0]\n    if w is None:\n        w = np.ones(p.shape[0])\n\n    best_score, best_inlines = 0., np.zeros(n, dtype=bool)\n    best_solution = (np.array(1.), np.eye(3), np.zeros(3))\n\n    for _ in range(max_iters):\n        maybe_inliers = np.random.choice(n, size=hypothetical_size, replace=False)\n        try:\n            s, R, t = rigid_registration(p[maybe_inliers], q[maybe_inliers], w[maybe_inliers])\n        except np.linalg.LinAlgError:\n            continue\n        transformed_p = s * p @ R.T + t\n        errors = w * np.linalg.norm(transformed_p - q, axis=1)\n        inliers = errors < inlier_thresh\n\n        score = inlier_thresh * n - np.clip(errors, None, inlier_thresh).sum()\n        if score > best_score:\n            best_score, best_inlines = score, inliers\n            best_solution = rigid_registration(p[inliers], q[inliers], w[inliers])\n\n    return best_solution, best_inlines\n\n\ndef to_homogeneous(x):\n    return np.concatenate([x, np.ones_like(x[..., :1])], axis=-1)\n\n\ndef from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.001):\n    if assert_homogeneous_part_is_equal_to_1:\n        assert np.allclose(x[..., -1:], 1, atol=eps), f\"Expected homogeneous part to be 1, got {x[..., -1:]}\"\n    return x[..., :-1] / x[..., -1:]\n\n\ndef to_homogenous_torch(x):\n    return torch.cat([x, torch.ones_like(x[..., :1])], axis=-1)\n\n\n@torch.no_grad()\ndef moge(rgbs, batch_size=10, device=\"cuda\", cached_file=None, intrinsics=None):\n    if cached_file is not None and os.path.exists(cached_file):\n        cached_data = np.load(cached_file)\n        depths_with_normalized_scale = cached_data[\"depth\"]\n        points_in_world_space = cached_data[\"points\"]\n        w2c = cached_data[\"w2c\"]\n        intrinsics = cached_data[\"intrinsics\"]\n        mask = cached_data[\"mask\"]\n        return depths_with_normalized_scale, intrinsics, w2c, points_in_world_space, mask\n\n    # git clone https://github.com/microsoft/MoGe.git ../moge\n    # cd ../moge\n    # git checkout dd158c0\n    sys.path.append(\"../moge\")  # TODO: Find a clean way to do this so that it is not hardcoded\n    from moge.model import MoGeModel\n    import utils3d\n    model = MoGeModel.from_pretrained(\"Ruicheng/moge-vitl\").to(device)\n\n    T, _, H, W = rgbs.shape\n    assert rgbs.shape == (T, 3, H, W)\n\n    points = []\n    depth = []\n    mask = []\n    for rgb in rgbs:\n        rgb = rgb.to(device)\n        output = model.infer(\n            image=rgb / 255,\n            resolution_level=9,\n            force_projection=True,\n            apply_mask=True,\n            fov_x=np.rad2deg(utils3d.intrinsics_to_fov(intrinsics)[0]) if intrinsics is not None else None,\n        )\n        points.append(output[\"points\"].cpu().numpy())\n        depth.append(output[\"depth\"].cpu().numpy())\n        mask.append(output[\"mask\"].cpu().numpy())\n        if intrinsics is None:\n            intrinsics = output[\"intrinsics\"].cpu().numpy()\n        assert np.allclose(intrinsics, output[\"intrinsics\"].cpu().numpy(), atol=0.01), \"Intr. changed between frames\"\n    points = np.stack(points)\n    depth = np.stack(depth)\n    mask = np.stack(mask)\n    intrinsics = np.diag([W, H, 1]) @ intrinsics\n\n    # Assert we can reproduce the points from the depth maps already (should be enforced with force_projection=True)\n    pixel_xy = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1)\n    pixel_xy_homo = to_homogeneous(pixel_xy)\n    depthmap_camera_xyz = np.einsum('ij,HWj->HWi', np.linalg.inv(intrinsics), pixel_xy_homo)\n    depthmap_camera_xyz = depthmap_camera_xyz[None, :, :, :] * depth[:, :, :, None]\n    valid = mask & (depth > 0)\n    assert np.allclose(points[valid], depthmap_camera_xyz[valid], atol=1, rtol=0.1)\n\n    depths_with_normalized_scale = depth.copy()\n    points_in_world_space = points.copy()\n    w2c = np.eye(4)[None].repeat(T, axis=0)\n    for t in range(1, T):\n        valid_p = mask[t] & (depth[t] > 0)  # & (depth[t] <= 4.20)  # TODO: magic number here!\n        valid_q = mask[t - 1] & (depth[t] > 0)  # & (depth[t] <= 4.20)  # TODO: magic number here!\n        valid = valid_p & valid_q\n        (scale, rotation, translation), inliers = rigid_registration_ransac(\n            p=points[t][valid].reshape(-1, 3),\n            q=points_in_world_space[t - 1][valid].reshape(-1, 3),\n            w=(1 / depths_with_normalized_scale[t - 1][valid]).reshape(-1),\n            max_iters=20,\n            hypothetical_size=10,\n            inlier_thresh=0.02\n        )\n        depths_with_normalized_scale[t] = scale * depths_with_normalized_scale[t]\n\n        # Transforming points[t] -> points_in_world_space[t - 1] already tells us how to transform to the\n        # world space since points_in_world_space[t - 1] had already been transformed to the world space\n        points_in_world_space[t] = scale * points_in_world_space[t] @ rotation.T + translation\n\n        # I prefer to use column vectors: Q = q.T, P = p.T\n        # q = p @ R.T + t -> Q = R @ P + t.T\n        # p = q @ R - t @ rotation -> P = R.T @ Q - R.T @ t.T\n        w2c[t, :3, :3] = rotation.T\n        w2c[t, :3, 3] = -rotation.T @ translation.T\n\n    # Assert no nans\n    assert not np.isnan(depths_with_normalized_scale).any()\n    assert not np.isnan(w2c).any()\n    assert np.allclose(w2c[:, 3, 3], 1)\n\n    # Now let's make sure we can go from scale-normalized depth maps to the points in world space\n    # Pixel --> Camera --> World\n    pixel_xy = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1)\n    pixel_xy_homo = to_homogeneous(pixel_xy)\n    depthmap_camera_xyz = np.einsum('ij,HWj->HWi', np.linalg.inv(intrinsics), pixel_xy_homo)\n    depthmap_camera_xyz = depthmap_camera_xyz[None, :, :, :] * depths_with_normalized_scale[:, :, :, None]\n    depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz)\n    depthmap_world_xyz_homo = np.einsum('Tij,THWj->THWi', np.linalg.inv(w2c), depthmap_camera_xyz_homo)\n    depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo)\n    points_in_world_space_reproduced = depthmap_world_xyz\n    valid = mask & (depths_with_normalized_scale > 0)\n    assert np.allclose(points_in_world_space[valid], points_in_world_space_reproduced[valid], atol=0.1, rtol=0.1)\n\n    if cached_file is not None:\n        np.savez(\n            cached_file,\n            depth=depths_with_normalized_scale,\n            points=points_in_world_space,\n            w2c=w2c,\n            intrinsics=intrinsics,\n            mask=mask,\n        )\n\n    return depths_with_normalized_scale, intrinsics, w2c, points_in_world_space, mask\n\n\ndef megasam(rgbs: torch.Tensor, batch_size: int = 10, device: str = \"cuda\", cached_file: Optional[str] = None):\n    if cached_file is not None and os.path.exists(cached_file):\n        cached_data = np.load(cached_file)\n        return (\n            cached_data[\"depths\"].astype(np.float32),\n            cached_data[\"intrinsics\"].astype(np.float32),\n            cached_data[\"extrinsics\"].astype(np.float32),\n        )\n    # else:\n    #     raise NotImplementedError(\"TMP ERR\")\n\n    T, C, H, W = rgbs.shape\n    assert C == 3, \"Expected shape (T, 3, H, W)\"\n\n    # Convert to NumPy format for MegaSAM (T, H, W, 3), uint8 [0, 255]\n    rgbs_np = (rgbs.permute(0, 2, 3, 1).cpu().numpy()).astype(np.uint8)\n\n    # git clone https://github.com/zbw001/TAPIP3D.git ../tapip3d\n    # cd ../tapip3d\n    # git checkout 8871375\n    sys.path.append(\"../tapip3d\")\n    from annotation.megasam import MegaSAMAnnotator\n    megasam = MegaSAMAnnotator(\n        script_path=Path(\"../tapip3d\") / \"third_party\" / \"megasam\" / \"inference.py\",\n        depth_model=\"moge\",\n        resolution=H * W\n    )\n    megasam.to(device)\n    depths, intrinsics, extrinsics = megasam.process_video(\n        rgbs=rgbs_np,\n        gt_intrinsics=None,\n        return_raw_depths=False,\n    )\n\n    if cached_file is not None:\n        np.savez(cached_file, depths=depths, intrinsics=intrinsics, extrinsics=extrinsics)\n    return depths, intrinsics, extrinsics\n"
  },
  {
    "path": "mvtracker/datasets/utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport dataclasses\nimport json\nimport pathlib\nfrom dataclasses import dataclass\nfrom typing import Any, Optional, List\n\nimport numpy as np\nimport png\nimport torch\nfrom torch.nn import functional as F\nfrom torchvision.transforms import functional as TF\n\nfrom mvtracker.utils.basic import to_homogeneous, from_homogeneous\n\n\n@dataclass(eq=False)\nclass Datapoint:\n    \"\"\"\n    Dataclass for storing video tracks data.\n    \"\"\"\n\n    video: torch.Tensor  # B, S, C, H, W\n    segmentation: torch.Tensor  # B, S, 1, H, W\n\n    # optional data\n    videodepth: Optional[torch.Tensor] = None  # B, S, 1, H, W\n    videodepthconf: Optional[torch.Tensor] = None  # B, S, 1, H, W\n    feats: Optional[torch.Tensor] = None  # B, S, C, H_strided, W_strided\n    valid: Optional[torch.Tensor] = None  # B, S, N\n    seq_name: Optional[List[str]] = None  # B\n    intrs: Optional[torch.Tensor] = torch.eye(3).unsqueeze(0)  # B, 3, 3\n\n    query_points: Optional[torch.Tensor] = None  # TapVID evaluation format\n    query_points_3d: Optional[torch.Tensor] = None  # TapVID evaluation format\n\n    trajectory: Optional[torch.Tensor] = None  # B, S, N, 2\n    visibility: Optional[torch.Tensor] = None  # B, S, N\n    trajectory_3d: Optional[torch.Tensor] = None  # B, S, 4, 4\n    trajectory_category: Optional[torch.Tensor] = None  # B, S, 1\n    extrs: Optional[torch.Tensor] = None  # B, S, 4, 4\n\n    track_upscaling_factor: Optional[float] = 1.0\n\n    novel_video: Optional[torch.Tensor] = None  # B, S, C, H, W\n    novel_intrs: Optional[torch.Tensor] = torch.eye(3).unsqueeze(0)  # B, 3, 3\n    novel_extrs: Optional[torch.Tensor] = None  # B, S, 4, 4\n\n\ndef collate_fn(batch):\n    gotit = [gotit for _, gotit in batch]\n    video = torch.stack([b.video for b, _ in batch], dim=0)\n    videodepth = torch.stack([b.videodepth for b, _ in batch], dim=0)\n    segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)\n    seq_name = [b.seq_name for b, _ in batch]\n    intrs = torch.stack([b.intrs for b, _ in batch], dim=0)\n\n    videodepthconf = (\n        torch.stack([b.videodepthconf for b, _ in batch], dim=0)\n        if batch[0][0].videodepthconf is not None\n        else None\n    )\n    feats = (\n        torch.stack([b.feats for b, _ in batch], dim=0)\n        if batch[0][0].feats is not None\n        else None\n    )\n    trajectory = (\n        torch.stack([b.trajectory for b, _ in batch], dim=0)\n        if batch[0][0].trajectory is not None\n        else None\n    )\n    valid = (\n        torch.stack([b.valid for b, _ in batch], dim=0)\n        if batch[0][0].valid is not None\n        else None\n    )\n    visibility = (\n        torch.stack([b.visibility for b, _ in batch], dim=0)\n        if batch[0][0].visibility is not None\n        else None\n    )\n    trajectory_3d = (\n        torch.stack([b.trajectory_3d for b, _ in batch], dim=0)\n        if batch[0][0].trajectory_3d is not None\n        else None\n    )\n    extrs = (\n        torch.stack([b.extrs for b, _ in batch], dim=0)\n        if batch[0][0].extrs is not None\n        else None\n    )\n    query_points = (\n        torch.stack([b.query_points for b, _ in batch], dim=0)\n        if batch[0][0].query_points is not None\n        else None\n    )\n    query_points_3d = (\n        torch.stack([b.query_points_3d for b, _ in batch], dim=0)\n        if batch[0][0].query_points_3d is not None\n        else None\n    )\n\n    track_upscaling_factor = batch[0][0].track_upscaling_factor\n\n    novel_video = None\n    novel_intrs = None\n    novel_extrs = None\n    if batch[0][0].novel_video is not None:\n        novel_video = torch.stack([b.novel_video for b, _ in batch], dim=0)\n        novel_intrs = torch.stack([b.novel_intrs for b, _ in batch], dim=0)\n        novel_extrs = torch.stack([b.novel_extrs for b, _ in batch], dim=0)\n\n    return (\n        Datapoint(\n            video=video,\n            videodepth=videodepth,\n            videodepthconf=videodepthconf,\n            feats=feats,\n            segmentation=segmentation,\n            trajectory=trajectory,\n            trajectory_3d=trajectory_3d,\n            visibility=visibility,\n            valid=valid,\n            seq_name=seq_name,\n            intrs=intrs,\n            extrs=extrs,\n            query_points=query_points,\n            query_points_3d=query_points_3d,\n            track_upscaling_factor=track_upscaling_factor,\n            novel_video=novel_video,\n            novel_intrs=novel_intrs,\n            novel_extrs=novel_extrs\n        ),\n        gotit,\n    )\n\n\ndef try_to_cuda(t: Any) -> Any:\n    \"\"\"\n    Try to move the input variable `t` to a cuda device.\n\n    Args:\n        t: Input.\n\n    Returns:\n        t_cuda: `t` moved to a cuda device, if supported.\n    \"\"\"\n    try:\n        t = t.float().cuda()\n    except AttributeError:\n        pass\n    return t\n\n\ndef dataclass_to_cuda_(obj):\n    \"\"\"\n    Move all contents of a dataclass to cuda inplace if supported.\n\n    Args:\n        batch: Input dataclass.\n\n    Returns:\n        batch_cuda: `batch` moved to a cuda device, if supported.\n    \"\"\"\n    for f in dataclasses.fields(obj):\n        setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))\n    return obj\n\n\ndef read_json(filename: str) -> Any:\n    with open(filename, \"r\") as fp:\n        return json.load(fp)\n\n\ndef read_tiff(filename: str) -> np.ndarray:\n    import imageio\n    img = imageio.v2.imread(pathlib.Path(filename).read_bytes(), format=\"tiff\")\n    if img.ndim == 2:\n        img = img[:, :, None]\n    return img\n\n\ndef read_png(filename: str, rescale_range=None) -> np.ndarray:\n    png_reader = png.Reader(bytes=pathlib.Path(filename).read_bytes())\n    width, height, pngdata, info = png_reader.read()\n    del png_reader\n\n    bitdepth = info[\"bitdepth\"]\n    if bitdepth == 8:\n        dtype = np.uint8\n    elif bitdepth == 16:\n        dtype = np.uint16\n    else:\n        raise NotImplementedError(f\"Unsupported bitdepth: {bitdepth}\")\n\n    plane_count = info[\"planes\"]\n    pngdata = np.vstack(list(map(dtype, pngdata)))\n    if rescale_range is not None:\n        minv, maxv = rescale_range\n        pngdata = pngdata / 2 ** bitdepth * (maxv - minv) + minv\n\n    return pngdata.reshape((height, width, plane_count))\n\ndef transform_scene(\n        transformation_scale: float = 1.0,\n        transformation_rotation: torch.Tensor = torch.eye(3, dtype=torch.float32),\n        transformation_translation: torch.Tensor = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32),\n\n        depth: torch.Tensor = None,  # [V,T,1,H,W]\n        extrs: torch.Tensor = None,  # [V,T,3,4] world->cam\n\n        query_points: torch.Tensor = None,  # [N,4] (t, x, y, z) in world\n        traj3d_world: torch.Tensor = None,  # [T,N,3]\n        traj2d_w_z: torch.Tensor = None,  # [V,T,N,3] (x_px, y_px, z_cam)\n):\n    \"\"\"\n    Make the world space `transformation_scale` larger, then rotate it by `transformation_rotation`,\n    then translate it by `transformation_translation`. In other words, apply the following transformation:\n    X_world' = transformation_translation + transformation_rotation @ (transformation_scale * X_world).\n\n    Implemented as:\n      - depth (z_cam) *= scale\n      - extrinsics: scale translation by 'scale', then right-multiply by rigid inverse\n      - query/world trajectories: scale then rigid\n      - traj2d_w_z: only z scaled; (x,y) unchanged\n    \"\"\"\n    is_rot_orthonormal = torch.allclose(\n        transformation_rotation @ transformation_rotation.T,\n        torch.eye(3, dtype=transformation_rotation.dtype, device=transformation_rotation.device),\n        atol=1e-3,\n    )\n    assert is_rot_orthonormal, \"The rotation matrix should be orthonormal.\"\n\n    Rt = torch.eye(4, dtype=transformation_rotation.dtype, device=transformation_rotation.device)\n    Rt[:3, :3] = transformation_rotation\n    Rt[:3, 3] = transformation_translation\n\n    # Transform depth\n    if depth is not None:\n        depth_trans = depth * transformation_scale\n    else:\n        depth_trans = None\n\n    # Transform extrinsics\n    if extrs is not None:\n        n_views, n_frames, _, _ = extrs.shape\n        assert extrs.shape == (n_views, n_frames, 3, 4)\n        src_dtype = extrs.dtype\n        extrs = extrs.type(Rt.dtype)\n        extrs_trans_square = torch.eye(4, dtype=extrs.dtype, device=extrs.device).repeat(n_views, n_frames, 1, 1)\n        extrs_trans_square[:, :, :3, :3] = extrs[:, :, :3, :3]\n        extrs_trans_square[:, :, :3, 3] = extrs[:, :, :3, 3] * transformation_scale\n        extrs_trans_square = torch.einsum('ABki,ij->ABkj', extrs_trans_square, torch.inverse(Rt))\n        extrs_trans = extrs_trans_square[..., :3, :]\n        extrs_trans = extrs_trans.type(src_dtype)\n    else:\n        extrs_trans = None\n\n    # Transform query points\n    if query_points is not None:\n        n_tracks = query_points.shape[0]\n        assert query_points.shape == (n_tracks, 4)\n        src_dtype = query_points.dtype\n        query_points = query_points.type(Rt.dtype)\n        query_points_xyz_scaled_homo = to_homogeneous(query_points[..., 1:4] * transformation_scale)\n        query_points_xyz_trans_homo = torch.einsum('ij,Nj->Ni', Rt, query_points_xyz_scaled_homo)\n        query_points_xyz_trans = from_homogeneous(query_points_xyz_trans_homo)\n        query_points_trans = torch.cat([query_points[..., :1], query_points_xyz_trans], dim=-1)\n        query_points_trans = query_points_trans.type(src_dtype)\n    else:\n        query_points_trans = None\n\n    # Transform 3D trajectories\n    if traj3d_world is not None:\n        n_frames, n_tracks, _ = traj3d_world.shape\n        assert traj3d_world.shape == (n_frames, n_tracks, 3)\n        src_dtype = traj3d_world.dtype\n        traj3d_world = traj3d_world.type(Rt.dtype)\n        traj3d_world_scaled_homo = to_homogeneous(traj3d_world * transformation_scale)\n        traj3d_world_trans_homo = torch.einsum('ij,TNj->TNi', Rt, traj3d_world_scaled_homo)\n        traj3d_world_trans = from_homogeneous(traj3d_world_trans_homo)\n        traj3d_world_trans = traj3d_world_trans.type(src_dtype)\n    else:\n        traj3d_world_trans = None\n\n    # Transform 2D+depth trajectories\n    if traj2d_w_z is not None:\n        n_views, n_frames, n_tracks, _ = traj2d_w_z.shape\n        assert traj2d_w_z.shape == (n_views, n_frames, n_tracks, 3)\n        traj2d_w_z_trans = traj2d_w_z.clone()\n        traj2d_w_z_trans[:, :, :, 2] *= transformation_scale\n    else:\n        traj2d_w_z_trans = None\n\n    return depth_trans, extrs_trans, query_points_trans, traj3d_world_trans, traj2d_w_z_trans\n\n\ndef add_camera_noise(intrs, extrs, noise_std_intr=0.01, noise_std_extr=0.001, rnd=np.random):\n    \"\"\"\n    Add small Gaussian noise to intrinsic and extrinsic camera parameters.\n\n    Args:\n        intrs (np.ndarray): (V, T, 3, 3) intrinsic matrices.\n        extrs (np.ndarray): (V, T, 3, 4) extrinsic matrices.\n        noise_std_intr (float): Standard deviation of intrinsic matrix noise.\n        noise_std_extr (float): Standard deviation of extrinsic matrix noise.\n        rnd (module): Random number generator (e.g., np.random or torch).\n\n    Returns:\n        intrs (same type as input): Noisy intrinsic matrices.\n        extrs (same type as input): Noisy extrinsic matrices.\n    \"\"\"\n    V, T, _, _ = intrs.shape\n    assert isinstance(intrs, np.ndarray)\n    assert intrs.shape == (V, T, 3, 3)\n    assert extrs.shape == (V, T, 3, 4)\n\n    intrs, extrs = intrs.copy(), extrs.copy()\n\n    intrs += rnd.normal(0, noise_std_intr, size=intrs.shape)\n    extrs += rnd.normal(0, noise_std_extr, size=extrs.shape)\n\n    return intrs, extrs\n\n\ndef aug_depth(depth, grid=(8, 8), scale=(0.7, 1.3), shift=(-0.1, 0.1),\n              gn_kernel=(7, 7), gn_sigma=(2.0, 2.0), generator=None):\n    \"\"\"\n    Augment depth for training.\n    \"\"\"\n    B, T, H, W = depth.shape\n    msk = (depth != 0)\n\n    # fallback to global generator if none is provided\n    gen = generator if generator is not None else torch.default_generator\n\n    # generate scale and shift maps\n    H_s, W_s = grid\n    scale_map = (torch.rand(B, T, H_s, W_s, device=depth.device, generator=gen) * (scale[1] - scale[0]) + scale[0])\n    shift_map = (torch.rand(B, T, H_s, W_s, device=depth.device, generator=gen) * (shift[1] - shift[0]) + shift[0])\n\n    # scale and shift the depth map\n    scale_map = F.interpolate(scale_map, (H, W), mode='bilinear', align_corners=True)\n    shift_map = F.interpolate(shift_map, (H, W), mode='bilinear', align_corners=True)\n\n    # local scale and shift the depth\n    depth[msk] = (depth[msk] * scale_map[msk]) + shift_map[msk] * (depth[msk].mean())\n\n    # gaussian blur\n    depth = TF.gaussian_blur(depth, kernel_size=gn_kernel, sigma=gn_sigma)\n    depth[~msk] = 0\n\n    return depth\n\n\ndef align_umeyama(model, data, known_scale=False, yaw_only=False):\n    mu_M = model.mean(0)\n    mu_D = data.mean(0)\n    model_zerocentered = model - mu_M\n    data_zerocentered = data - mu_D\n    n = np.shape(model)[0]\n\n    # correlation\n    C = 1.0 / n * np.dot(model_zerocentered.transpose(), data_zerocentered)\n    sigma2 = 1.0 / n * np.multiply(data_zerocentered, data_zerocentered).sum()\n    U_svd, D_svd, V_svd = np.linalg.linalg.svd(C)\n    D_svd = np.diag(D_svd)\n    V_svd = np.transpose(V_svd)\n\n    S = np.eye(3)\n    if np.linalg.det(U_svd) * np.linalg.det(V_svd) < 0:\n        S[2, 2] = -1\n\n    if yaw_only:\n        rot_C = np.dot(data_zerocentered.transpose(), model_zerocentered)\n        theta = get_best_yaw(rot_C)\n        R = rot_z(theta)\n    else:\n        R = np.dot(U_svd, np.dot(S, np.transpose(V_svd)))\n\n    if known_scale:\n        s = 1\n    else:\n        s = 1.0 / sigma2 * np.trace(np.dot(D_svd, S))\n\n    t = mu_M - s * np.dot(R, mu_D)\n\n    return s, R, t\n\n\ndef get_camera_center(extr):\n    R = extr[:, :3]\n    t = extr[:, 3]\n    return -R.T @ t\n\n\ndef apply_sim3_to_extrinsics(vggt_extrinsics, s, R_align, t_align):\n    aligned_extrinsics = []\n    R_inv = R_align.T\n    t_inv = -R_inv @ t_align / s\n    for extr in vggt_extrinsics:\n        extr_h = np.eye(4)\n        extr_h[:3, :4] = extr\n        sim3_inv = np.eye(4)\n        sim3_inv[:3, :3] = R_inv / s\n        sim3_inv[:3, 3] = t_inv\n        aligned = extr_h @ sim3_inv\n        aligned_extrinsics.append(aligned[:3, :])\n    return aligned_extrinsics\n\n\ndef get_best_yaw(C):\n    \"\"\"\n    maximize trace(Rz(theta) * C)\n    \"\"\"\n    assert C.shape == (3, 3)\n\n    A = C[0, 1] - C[1, 0]\n    B = C[0, 0] + C[1, 1]\n    theta = np.pi / 2 - np.arctan2(B, A)\n\n    return theta\n\n\ndef rot_z(theta):\n    R = rotation_matrix(theta, [0, 0, 1])\n    R = R[0:3, 0:3]\n\n    return R\n"
  },
  {
    "path": "mvtracker/evaluation/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/evaluation/evaluator_3dpt.py",
    "content": "import json\nimport logging\nimport os\nimport re\nimport time\nimport warnings\nfrom collections import namedtuple\nfrom typing import Iterable\nfrom typing import Optional\n\nimport imageio\nimport matplotlib.cm as cm\nimport numpy as np\nimport rerun as rr\nimport torch\nfrom sklearn.cluster import KMeans\nfrom threadpoolctl import threadpool_limits\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nfrom mvtracker.datasets.utils import dataclass_to_cuda_\nfrom mvtracker.evaluation.metrics import compute_tapvid_metrics_original, evaluate_predictions\nfrom mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z, \\\n    pixel_xy_and_camera_z_to_world_space, init_pointcloud_from_rgbd\nfrom mvtracker.utils.visualizer_mp4 import log_mp4_track_viz\nfrom mvtracker.utils.visualizer_rerun import log_pointclouds_to_rerun, log_tracks_to_rerun\n\n\nclass NumpyEncoder(json.JSONEncoder):\n    def default(self, obj):\n        if isinstance(obj, np.ndarray):\n            if obj.size == 1:\n                return obj.item()\n            return obj.tolist()\n        if isinstance(obj, np.integer):\n            return int(obj)\n        if isinstance(obj, np.floating):\n            return float(obj)\n        return json.JSONEncoder.default(self, obj)\n\n\ndef kmeans_sample(pts, count):\n    \"\"\"\n    Given (N, 3) torch tensor of 3D points, return (count, 3) tensor of kmeans centers.\n    \"\"\"\n    if len(pts) <= count:\n        return pts\n\n    logging.info(f\"Computing k-means (k={count}, N={len(pts)})...\")\n\n    start = time.time()\n    with threadpool_limits(limits=1):\n        pts_np = pts.detach().cpu().numpy()\n        kmeans = KMeans(n_clusters=count, n_init='auto', random_state=0).fit(pts_np)\n    duration = time.time() - start\n\n    logging.info(f\"K-means clustering completed in {duration:.2f} seconds.\")\n    centers = torch.tensor(kmeans.cluster_centers_, dtype=pts.dtype, device=pts.device)\n    return centers\n\n\ndef evaluate_3dpt(\n        gt_tracks,\n        gt_visibilities,\n        pred_tracks,\n        pred_visibilities,\n        evaluation_setting,\n        track_upscaling_factor,\n        query_points=None,\n        prefix=\"3dpt\",\n        verbose=True,\n        add_per_track_results=True,\n):\n    n_frames, n_tracks, n_point_dim = gt_tracks.shape\n    assert gt_tracks.shape == pred_tracks.shape\n    assert gt_visibilities.shape == (n_frames, n_tracks)\n    assert pred_visibilities.shape == (n_frames, n_tracks)\n\n    if query_points is None:\n        query_points_frame_id = gt_visibilities.argmax(axis=0)\n        query_points_xyz = gt_tracks[query_points_frame_id, np.arange(gt_tracks.shape[1]), :]\n        query_points = np.concatenate([query_points_frame_id[:, None], query_points_xyz], axis=-1)\n    else:\n        query_points_frame_id = query_points[:, 0].astype(int)\n        query_points_xyz = query_points[:, 1:]\n\n    if evaluation_setting == \"kubric-multiview\":\n        assert n_point_dim == 3\n        distance_thresholds = [0.05, 0.1, 0.2, 0.4, 0.8]  # The scale is non-metric\n        survival_distance_threshold = 0.5  # 50 cm\n        static_threshold = 0.01  # < 1 cm\n        dynamic_threshold = 0.1  # > 10 cm\n        very_dynamic_threshold = 2.0  # > 2 m\n    elif evaluation_setting == \"dexycb-multiview\":\n        assert n_point_dim == 3\n        distance_thresholds = [0.01, 0.02, 0.05, 0.1, 0.2]  # 1 cm, 2 cm, 5 cm, 10 cm, 20 cm\n        survival_distance_threshold = 0.1  # 10 cm\n        static_threshold = 0.01  # < 1 cm\n        dynamic_threshold = 0.1  # > 10 cm\n        very_dynamic_threshold = 0.5  # > 50 cm\n    elif evaluation_setting == \"panoptic-multiview\":\n        assert n_point_dim == 3\n        distance_thresholds = [0.05, 0.10, 0.20, 0.40]  # from 5 cm to 80 cm\n        survival_distance_threshold = 1.0  # 1 m\n        static_threshold = None\n        dynamic_threshold = None\n        very_dynamic_threshold = None\n    elif evaluation_setting == \"tapvid2d\":\n        assert n_point_dim == 2\n        distance_thresholds = [1, 2, 4, 8, 16]  # pixels\n        survival_distance_threshold = 50\n        static_threshold = None\n        dynamic_threshold = None\n        very_dynamic_threshold = None\n    elif evaluation_setting == \"2dpt_ablation\":\n        assert n_point_dim == 2\n        distance_thresholds = [1, 2, 4, 8, 16]  # pixels\n        survival_distance_threshold = 50\n        static_threshold = 1\n        dynamic_threshold = 1\n        very_dynamic_threshold = 50\n    else:\n        raise NotImplementedError\n\n    if verbose:\n        logging.info(f\"n_frames: {n_frames}, n_tracks: {n_tracks}\")\n        logging.info(f\"GT TRACKS (min, max): {gt_tracks.min()}, {gt_tracks.max()}\")\n        logging.info(f\"query_poits_xyz (min, max): {query_points_xyz.min()}, {query_points_xyz.max()}\")\n\n    df_model, df_model_per_track = evaluate_predictions(\n        gt_tracks * track_upscaling_factor,\n        gt_visibilities,\n        pred_tracks * track_upscaling_factor,\n        ~pred_visibilities,\n        np.concatenate([query_points[:, 0:1], query_points[:, 1:] * track_upscaling_factor], axis=-1),\n        distance_thresholds=distance_thresholds,\n        survival_distance_threshold=survival_distance_threshold,\n        static_threshold=static_threshold,\n        dynamic_threshold=dynamic_threshold,\n        very_dynamic_threshold=very_dynamic_threshold,\n    )\n\n    if verbose:\n        logging.info(f\"DF Model:\\n{df_model}\")\n        logging.info(f\"DF Model:\\n{df_model.loc[['average_pts_within_thresh', 'survival']]}\")\n\n    # Save to results_dict\n    results_dict = {}\n\n    # For dynamic points, report all metrics\n    for point_type in [\"dynamic-static-mean\", \"dynamic\", \"very_dynamic\", \"static\", \"any\"]:\n        if f'all_{point_type}' not in df_model.columns:\n            continue\n        for metric in sorted(df_model.index):\n            results_dict[f'{prefix}/model__{metric}__{point_type}'] = df_model.loc[metric, f'all_{point_type}']\n\n    # For other point types, report only selected metrics\n    for point_type in []:\n        if f'all_{point_type}' not in df_model.columns:\n            continue\n        for metric in [\"average_pts_within_thresh\", \"survival\", \"occlusion_accuracy\", \"average_jaccard\"]:\n            results_dict[f'{prefix}/model__{metric}__{point_type}'] = df_model.loc[metric, f'all_{point_type}']\n\n    for k in results_dict:\n        results_dict[k] = results_dict[k].item()\n\n    if verbose:\n        logging.info(f\"3DPT results:\\n{results_dict}\")\n\n    if add_per_track_results:\n        results_dict[f'{prefix}/model__per_track_results'] = df_model_per_track\n\n    return results_dict\n\n\nclass Evaluator:\n    def __init__(\n            self,\n            rerun_viz_indices: Optional[Iterable[int]] = None,\n            forward_pass_log_indices: Optional[Iterable[int]] = None,\n            mp4_track_viz_indices: Optional[Iterable[int]] = (0, 3, 4, 5),\n    ) -> None:\n        \"\"\"\n        Initializes the Evaluator.\n\n        Parameters\n        ----------\n        rerun_viz_indices : Optional[Iterable[int]]\n            Indices of datapoints for which rerun 3D visualizations should be saved.\n            If None, no rerun visualizations will be logged.\n\n        forward_pass_log_indices : Optional[Iterable[int]]\n            Indices of datapoints for which debug logs from the model's forward pass should be saved.\n            If None, no forward pass debug logs will be generated.\n\n        mp4_track_viz_indices : Optional[Iterable[int]]\n            Indices of datapoints for which 2D trajectory visualizations (MP4 videos) should be saved.\n            If None, MP4 visualizations will not be generated.\n        \"\"\"\n        self.rerun_viz_indices = rerun_viz_indices\n        self.forward_pass_log_indices = forward_pass_log_indices\n        self.mp4_track_viz_indices = mp4_track_viz_indices\n\n        if self.rerun_viz_indices is None:\n            self.rerun_viz_indices = []\n        if self.forward_pass_log_indices is None:\n            self.forward_pass_log_indices = []\n        if self.mp4_track_viz_indices is None:\n            self.mp4_track_viz_indices = []\n\n    @torch.no_grad()\n    def evaluate_sequence(\n            self,\n            model,\n            test_dataloader,\n            dataset_name,\n            log_dir,\n            writer: Optional[SummaryWriter] = None,\n            step: Optional[int] = 0,\n    ):\n        metrics = {}\n        assert len(test_dataloader) > 0\n        total_fps = 0.0\n        count = 0\n        for datapoint_idx, datapoint in enumerate(tqdm(test_dataloader)):\n            should_save_mp4_viz = datapoint_idx in self.mp4_track_viz_indices\n            should_save_forward_pass_logs = datapoint_idx in self.forward_pass_log_indices\n            should_save_rerun_viz = datapoint_idx in self.rerun_viz_indices\n\n            # Hotfix for debugging: Load an edge-case datapoint directly from disk\n            if False:\n                # Batch 10060\n                datapoint = torch.load(\"logs/debug/ablation-E07/mvtracker-ptv3-512/crash_batch_step_010060.pt\",\n                                       map_location=\"cuda:0\")\n                # (datapoint.videodepth > 0).float().mean() --> 0\n\n                # Batch 8145\n                datapoint = torch.load(\"logs/ablation-E07/mvtracker-ptv3-512-2/crash_batch_step_008145.pt\",\n                                       map_location=\"cuda:0\")\n                datapoint.videodepth = datapoint.videodepth.clip(0.0, 1000.0)\n\n                should_save_mp4_viz = True\n                should_save_rerun_viz = True\n                should_save_forward_pass_logs = False\n                model.model.use_ptv3 = False\n\n            if isinstance(datapoint, tuple) or isinstance(datapoint, list) and len(datapoint) == 2:\n                datapoint, gotit = datapoint\n                if not all(gotit):\n                    logging.warning(\"batch is None\")\n                    continue\n            if torch.cuda.is_available():\n                dataclass_to_cuda_(datapoint)\n                device = torch.device(\"cuda\")\n            else:\n                device = torch.device(\"cpu\")\n\n            # Per view data\n            rgbs = datapoint.video\n            depths = datapoint.videodepth\n            depths_conf = datapoint.videodepthconf\n            image_features = datapoint.feats\n            intrs = datapoint.intrs\n            extrs = datapoint.extrs\n            gt_trajectories_2d_pixelspace_w_z_cameraspace = datapoint.trajectory\n            gt_visibilities_per_view = datapoint.visibility\n\n            query_points_2d = (datapoint.query_points.clone().float().to(device)\n                               if datapoint.query_points is not None else None)\n            query_points_3d = (datapoint.query_points_3d.clone().float().to(device)\n                               if datapoint.query_points_3d is not None else None)\n\n            # Non-per-view data\n            gt_trajectories_3d_worldspace = datapoint.trajectory_3d\n            valid_tracks_per_frame = datapoint.valid\n            track_upscaling_factor = datapoint.track_upscaling_factor\n            seq_name = datapoint.seq_name[0]\n\n            # Novel view data\n            novel_rgbs = datapoint.novel_video\n            novel_intrs = datapoint.novel_intrs\n            novel_extrs = datapoint.novel_extrs\n\n            batch_size, num_views, num_frames, _, height, width = rgbs.shape\n\n            # For generic datasets without labels, we will try sampling queries from depthmap points and around origin\n            no_tracking_labels = False\n            if query_points_2d is None and query_points_3d is None:\n                no_tracking_labels = True\n                assert batch_size == 1\n                assert gt_trajectories_2d_pixelspace_w_z_cameraspace is None\n                assert gt_visibilities_per_view is None\n                assert gt_trajectories_3d_worldspace is None\n                assert valid_tracks_per_frame is None\n\n                assert depths is not None\n                assert depths_conf is not None\n\n                # Config: (frame_idx, z_min, z_max, count)\n                if \"selfcap\" in dataset_name:\n                    sampling_spec = [\n                        (0, -0.1, 0.2, 1.8, 100, \"\"),\n                        (0, 0.2, 2.1, 1.8, 200, \"\"),\n                        # (0, 0.2, 2.1, 1.8, 200, \"kmeans\"),\n                        (36, 0.2, 2.1, 1.8, 200, \"\"),\n                        (120, 0.2, 2.1, 1.8, 200, \"\"),\n                    ]\n                    x0, y0, zmin, zmax, radius = 0.25, 0.7, -0.15, 3.6, 1.8\n                    xyz, _ = init_pointcloud_from_rgbd(\n                        fmaps=depths_conf,\n                        depths=depths,\n                        intrs=intrs,\n                        extrs=extrs,\n                        stride=1,\n                        level=0,\n                        depth_interp_mode=\"N/A\",\n                    )\n                    x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2]\n                    x -= x0\n                    y -= y0\n                    mask = (x ** 2 + y ** 2 < radius ** 2) & (z >= zmin) & (z <= zmax)\n                    mask = mask.reshape(batch_size, num_frames, num_views, height, width).permute(0, 2, 1, 3, 4)\n                    # depths[~mask[:, :, :, None, :, :]] = 0.0\n                    # depths[depths_conf < 5] = 0.0\n                    depths_conf[~mask[:, :, :, None, :, :]] = 2.0\n\n                elif \"4d-dress\" in dataset_name:\n                    sampling_spec = [\n                        # (0, -10, +10, 10, 1500, \"\"),\n                        # (0, -10, +10, 10, 500, \"\"),\n                        (0, -10, +10, 10, 300, \"kmeans\"),\n                        # (72, -10, +10, 10, 500, \"kmeans\"),\n                    ]\n\n                elif \"hi4d\" in dataset_name:\n                    sampling_spec = [\n                        (0, -np.inf, +np.inf, np.inf, 1000, \"\"),\n                    ]\n\n                else:\n                    sampling_spec = [\n                        (0, -0.1, +4.2, 2.1, 1000, \"kmeans\"),\n                    ]\n\n                depth_conf_threshold = 0.9\n                query_list = []\n\n                for t, zmin, zmax, radius, count, method in sampling_spec:\n                    if t >= num_frames:\n                        continue  # skip invalid timestep\n\n                    dmap = depths[:, :, t:t + 1]\n                    conf = depths_conf[:, :, t:t + 1]\n                    xyz, c = init_pointcloud_from_rgbd(\n                        fmaps=conf,\n                        depths=dmap,\n                        intrs=intrs[:, :, t:t + 1],\n                        extrs=extrs[:, :, t:t + 1],\n                        stride=1,\n                        level=0,\n                        depth_interp_mode=\"N/A\",\n                    )\n                    xyz = xyz[0]  # (N, 3)\n                    conf = c[0, :, 0]  # (N,)\n                    valid = conf > depth_conf_threshold\n                    pts = xyz[valid]\n                    if pts.numel() == 0:\n                        continue\n\n                    x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]\n                    mask = (x ** 2 + y ** 2 < radius ** 2) & (z >= zmin) & (z <= zmax)\n                    pts = pts[mask]\n                    if pts.numel() == 0:\n                        continue\n\n                    if len(pts) >= count:\n                        if method == \"\":\n                            pts = pts[torch.randperm(len(pts))[:count]]\n                        elif method == \"kmeans\":\n                            pts = kmeans_sample(pts, count)\n                        else:\n                            raise NotImplementedError\n\n                    t_col = torch.full((len(pts), 1), float(t), device=pts.device)\n                    query_list.append(torch.cat([t_col, pts], dim=1))\n\n                # Finalize query points\n                query_points_3d = torch.cat(query_list, dim=0)[None]  # (1, N, 4)\n\n                # Dummy GT trajectory\n                num_points = query_points_3d.shape[1]\n                gt_trajectories_3d_worldspace = query_points_3d[:, None, :, 1:].repeat(1, num_frames, 1, 1)\n                gt_trajectories_2d_pixelspace_w_z_cameraspace = torch.stack([\n                    torch.cat(world_space_to_pixel_xy_and_camera_z(\n                        world_xyz=gt_trajectories_3d_worldspace[0],\n                        intrs=intrs[0, view_idx],\n                        extrs=extrs[0, view_idx],\n                    ), dim=-1)\n                    for view_idx in range(num_views)\n                ], dim=0).unsqueeze(0)\n                d = query_points_3d.device\n                gt_visibilities_per_view = torch.ones((batch_size, num_views, num_frames, num_points), dtype=bool).to(d)\n                valid_tracks_per_frame = torch.ones((batch_size, num_frames, num_points), dtype=bool).to(d)\n\n            if no_tracking_labels and not any([should_save_mp4_viz,\n                                               should_save_rerun_viz,\n                                               should_save_forward_pass_logs]):\n                continue\n\n            # Assert shapes of per-view data\n            num_points = gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[3]\n            assert depths is not None, \"Depth is required for evaluation.\"\n            assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n            assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n            assert depths_conf is None or depths_conf.shape == (batch_size, num_views, num_frames, 1, height, width)\n            assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n            assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n            assert gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == (\n                batch_size, num_views, num_frames, num_points, 3)\n            assert gt_visibilities_per_view.shape == (batch_size, num_views, num_frames, num_points)\n\n            # Assert shapes of non-per-view data\n            assert query_points_3d.shape == (batch_size, num_points, 4)\n            assert gt_trajectories_3d_worldspace.shape == (batch_size, num_frames, num_points, 3)\n            assert valid_tracks_per_frame.shape == (batch_size, num_frames, num_points)\n\n            # Dump the RGBs and depths to disk\n            if should_save_rerun_viz:\n                for v in range(num_views):\n                    rgb_path = os.path.join(log_dir, f\"rgbs__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4\")\n                    depth_path = os.path.join(log_dir, f\"depths__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4\")\n                    conf_path = os.path.join(log_dir, f\"depth_confs__{dataset_name}--seq-{datapoint_idx}__view-{v}.mp4\")\n\n                    # Precompute global min/max\n                    d_all = depths[0, v, :, 0].reshape(-1, height, width).cpu().numpy()\n                    d_min, d_max = d_all.min(), d_all.max()\n                    if depths_conf is not None:\n                        c_all = depths_conf[0, v, :, 0].reshape(-1, height, width).cpu().numpy()\n                        c_min, c_max = c_all.min(), c_all.max()\n\n                    # Colormaps\n                    depth_cmap = cm.get_cmap(\"turbo\")\n                    conf_cmap = cm.get_cmap(\"inferno\")\n\n                    rgb_video, depth_video, conf_video = [], [], []\n                    for t in range(num_frames):\n                        rgb = (rgbs[0, v, t].permute(1, 2, 0).cpu().numpy()).astype(np.uint8)\n                        rgb_video.append(rgb)\n\n                        d = depths[0, v, t, 0].cpu().numpy()\n                        d_norm = (d - d_min) / (d_max - d_min + 1e-5)\n                        depth_color = (depth_cmap(d_norm)[..., :3] * 255).astype(np.uint8)\n                        depth_video.append(depth_color)\n\n                        if depths_conf is not None:\n                            c = depths_conf[0, v, t, 0].cpu().numpy()\n                            c_norm = (c - c_min) / (c_max - c_min + 1e-5)\n                            conf_color = (conf_cmap(c_norm)[..., :3] * 255).astype(np.uint8)\n                            conf_video.append(conf_color)\n\n                    if \"selfcap-v1\" in dataset_name:\n                        fps = 12\n                    elif \"4d-dress\" in dataset_name or \"egoexo4d\" in dataset_name:\n                        fps = 30\n                    else:\n                        fps = 12\n                    imageio.mimsave(rgb_path, rgb_video, fps=fps)\n                    imageio.mimsave(depth_path, depth_video, fps=fps)\n                    if depths_conf is not None:\n                        imageio.mimsave(conf_path, conf_video, fps=fps)\n\n            # Run the model\n            fwd_kwargs = {\n                \"rgbs\": rgbs,\n                \"depths\": depths,\n                \"image_features\": image_features,\n                \"query_points_3d\": query_points_3d,\n                \"intrs\": intrs,\n                \"extrs\": extrs,\n                \"save_debug_logs\": should_save_forward_pass_logs,\n                \"debug_logs_path\": os.path.join(\n                    log_dir, f\"forward_pass__eval_{dataset_name}_step-{step}_seq-{datapoint_idx}\",\n                ),\n                \"save_rerun_logs\": should_save_rerun_viz,\n                \"save_rerun_logs_output_rrd_path\": os.path.join(\n                    log_dir, f\"rerun__{dataset_name}--seq-{datapoint_idx}--name-{seq_name}--fwd.rrd\"\n                ),\n\n            }\n            if \"2dpt\" in dataset_name:\n                assert batch_size == 1\n                query_timestep = query_points_3d[0, :, 0].cpu().numpy().astype(int)\n                query_points_view = gt_visibilities_per_view.argmax(dim=1)[0, query_timestep, torch.arange(num_points)]\n                fwd_kwargs[\"query_points_view\"] = query_points_view[None]\n\n            start_time = time.time()\n            if \"shape_of_motion\" in log_dir or \"dynamic_3dgs\" in log_dir:\n                if \"dynamic_3dgs\" in log_dir:\n                    cached_output_path = os.path.join(log_dir, f\"step-0_seq-{seq_name}_tracks.npz\")\n                else:\n                    cached_output_path = os.path.join(log_dir, f\"step-{step}_seq-{seq_name}_tracks.npz\")\n                cached_output_path = re.sub(r\"-novelviews\\d+(_\\d+)*\", \"\", cached_output_path)\n                assert os.path.exists(cached_output_path), cached_output_path\n                cached_data = np.load(cached_output_path)\n                if \"dynamic_3dgs\" in log_dir:\n                    results = {\n                        \"traj_e\": torch.from_numpy(cached_data[\"pred_trajectories_3d\"]).to(device)[None],\n                        \"vis_e\": torch.from_numpy(cached_data[\"pred_visibilities_any_view\"]).to(device).any(1),\n                    }\n                else:\n                    results = {\n                        \"traj_e\": torch.from_numpy(cached_data[\"pred_trajectories_3d\"]).to(device),\n                        \"vis_e\": torch.from_numpy(cached_data[\"pred_visibilities_any_view\"]).to(device),\n                    }\n            else:\n                results = model(**fwd_kwargs)\n\n            end_time = time.time()\n            frames_processed = batch_size * num_frames\n            elapsed = end_time - start_time\n            fps = frames_processed / elapsed\n            logging.info(f\"[Datapoint {datapoint_idx}] FPS: {fps:.1f}\")\n            total_fps += fps\n            count += 1\n\n            pred_trajectories = results[\"traj_e\"]\n            pred_visibilities = results[\"vis_e\"]\n            pred_trajectories_2d = results[\"traj2d_e\"] if \"traj2d_e\" in results else None\n            assert \"strided\" not in dataset_name, \"Strided evaluation is not supported yet.\"\n\n            # Determine the evaluation setting\n            if \"kubric\" in dataset_name:\n                evaluation_setting = \"kubric-multiview\"\n            elif \"panoptic-multiview\" in dataset_name:\n                evaluation_setting = \"panoptic-multiview\"\n            elif \"dex-ycb\" in dataset_name:\n                evaluation_setting = \"dexycb-multiview\"\n            elif \"tapvid2d\" in dataset_name:\n                evaluation_setting = \"tapvid2d\"\n            elif no_tracking_labels:\n                evaluation_setting = \"no-tracking-labels\"\n            else:\n                raise NotImplementedError\n\n            # Invert the intrinsics and extrinsics matrices\n            intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n            extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n            extrs_square[:, :, :, :3, :] = extrs\n            extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n            assert intrs_inv.shape == (batch_size, num_views, num_frames, 3, 3)\n            assert extrs_inv.shape == (batch_size, num_views, num_frames, 4, 4)\n\n            # Project the predictions to pixel space for visualization\n            pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([\n                torch.cat(world_space_to_pixel_xy_and_camera_z(\n                    world_xyz=pred_trajectories[0],\n                    intrs=intrs[0, view_idx],\n                    extrs=extrs[0, view_idx],\n                ), dim=-1)\n                for view_idx in range(num_views)\n            ], dim=0)\n            for view_idx in range(num_views):\n                pred_trajectories_reproduced = pixel_xy_and_camera_z_to_world_space(\n                    pixel_xy=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, :2],\n                    camera_z=pred_trajectories_pixel_xy_camera_z_per_view[view_idx, :, :, 2:],\n                    intrs_inv=intrs_inv[0, view_idx],\n                    extrs_inv=extrs_inv[0, view_idx],\n                )\n                if not torch.allclose(pred_trajectories_reproduced, pred_trajectories, atol=1):\n                    warnings.warn(f\"Reprojection of the predicted trajectories failed: \"\n                                  f\"view_idx={view_idx}, \"\n                                  f\"max_diff={torch.max(torch.abs(pred_trajectories_reproduced - pred_trajectories))}\")\n            pred_trajectories_pixel_xy_camera_z_per_view = pred_trajectories_pixel_xy_camera_z_per_view[None]\n\n            # Compute 3D metrics\n            gt_visibilities_any_view = gt_visibilities_per_view.any(dim=1)\n            assert gt_visibilities_any_view.any(dim=1).all(), \"All points should be visible in at least one view.\"\n            per_track_results = None\n            if evaluation_setting in [\"kubric-multiview\", \"panoptic-multiview\", \"dexycb-multiview\"]:\n                eval_3dpt_results_dict = evaluate_3dpt(\n                    gt_tracks=gt_trajectories_3d_worldspace[0].cpu().numpy(),\n                    gt_visibilities=gt_visibilities_any_view[0].cpu().numpy(),\n                    query_points=query_points_3d[0].cpu().numpy(),\n                    pred_tracks=pred_trajectories[0].cpu().numpy(),\n                    pred_visibilities=pred_visibilities[0].cpu().numpy(),\n                    evaluation_setting=evaluation_setting,\n                    track_upscaling_factor=track_upscaling_factor,\n                    prefix=f\"eval_{dataset_name}\",\n                    add_per_track_results=should_save_rerun_viz,\n                    verbose=False,\n                )\n                if should_save_rerun_viz:\n                    per_track_results = eval_3dpt_results_dict[f'eval_{dataset_name}/model__per_track_results']\n                    del eval_3dpt_results_dict[f'eval_{dataset_name}/model__per_track_results']\n                metrics[datapoint_idx] = eval_3dpt_results_dict\n\n                if \"2dpt\" in dataset_name:\n                    assert batch_size == 1\n                    if pred_trajectories_2d is None:\n                        pred_trajectories_2d = pred_trajectories_pixel_xy_camera_z_per_view[:, :, :, :, :2]\n                    _rescale_to_256x256 = np.array([256, 256]) / np.array([width, height])\n                    _metrics = {}\n                    for view_idx in range(num_views):\n                        track_mask = (query_points_view == view_idx).cpu().numpy()\n                        if track_mask.sum() == 0:\n                            continue\n                        _n_tracks = track_mask.sum()\n                        _gt_tracks = gt_trajectories_2d_pixelspace_w_z_cameraspace[0, view_idx, :, track_mask, :2]\n                        _gt_tracks = _gt_tracks.cpu().numpy()\n                        _gt_visibilities = gt_visibilities_per_view[0, view_idx, :, track_mask].cpu().bool().numpy()\n                        _query_t = query_timestep[track_mask]\n                        _query_xy = _gt_tracks[_query_t, np.arange(_n_tracks)]\n                        _query = np.concatenate([_query_t[:, None], _query_xy], axis=-1)\n                        _pred_tracks = pred_trajectories_2d[0, view_idx, :, track_mask].cpu().numpy()\n                        _pred_visibilities = np.zeros_like(_gt_visibilities)\n                        assert _gt_visibilities[_query_t, np.arange(_n_tracks)].all()\n                        eval_2dpt_results_dict = evaluate_3dpt(\n                            gt_tracks=_gt_tracks,\n                            gt_visibilities=_gt_visibilities,\n                            query_points=_query,\n                            pred_tracks=_pred_tracks,\n                            pred_visibilities=_pred_visibilities,\n                            evaluation_setting=\"2dpt_ablation\",\n                            track_upscaling_factor=_rescale_to_256x256,\n                            prefix=f\"eval_{dataset_name}\",\n                            add_per_track_results=False,\n                            verbose=False,\n                        )\n                        tapvid2d_original_metrics = compute_tapvid_metrics_original(\n                            query_points=np.concatenate([_query_t[:, None], _query_xy * _rescale_to_256x256], axis=-1),\n                            gt_occluded=~_gt_visibilities[None].transpose(0, 2, 1),\n                            gt_tracks=_gt_tracks[None].transpose(0, 2, 1, 3) * _rescale_to_256x256,\n                            pred_occluded=~_pred_visibilities[None].transpose(0, 2, 1),\n                            pred_tracks=_pred_tracks[None].transpose(0, 2, 1, 3) * _rescale_to_256x256,\n                            query_mode=\"first\",\n                        )\n                        tapvid2d_original_metrics = {\n                            f\"eval_{dataset_name}/model__tapvid2d_{k}\":\n                                (tapvid2d_original_metrics[k] * 100).round(2).item()\n                            for k in sorted(tapvid2d_original_metrics)\n                        }\n                        _metrics[view_idx] = {}\n                        _metrics[view_idx].update(eval_2dpt_results_dict)\n                        _metrics[view_idx].update(tapvid2d_original_metrics)\n                        _metrics[view_idx] = {\n                            k.replace(\"model__\", \"model__2dpt__\"): v\n                            for k, v in _metrics[view_idx].items()\n                            if \"jaccard\" not in k and \"occlusion\" not in k\n                        }\n                    _metrics_avg = {}\n                    for k in _metrics[next(iter(_metrics.keys()))]:\n                        _metrics_avg[k] = np.mean([\n                            _metrics[view_idx][k]\n                            for view_idx in _metrics\n                            if k in _metrics[view_idx]\n                        ]).round(2)\n\n                    metrics[datapoint_idx].update(_metrics_avg)\n                    for view_idx in _metrics:\n                        metrics[datapoint_idx].update({\n                            f\"{k}__view-{view_idx}\": v\n                            for k, v in _metrics[view_idx].items()\n                        })\n\n            # Compute 2D metrics\n            elif evaluation_setting in [\"tapvid2d\"]:\n                assert num_views == 1\n                if pred_trajectories_2d is None:\n                    pred_trajectories_2d = pred_trajectories_pixel_xy_camera_z_per_view[:, :, :, :, :2]\n                eval_2dpt_results_dict = evaluate_3dpt(\n                    gt_tracks=gt_trajectories_2d_pixelspace_w_z_cameraspace[0, 0, :, :, :2].cpu().numpy(),\n                    gt_visibilities=gt_visibilities_per_view[0, 0].cpu().bool().numpy(),\n                    query_points=query_points_2d[0].cpu().numpy(),\n                    pred_tracks=pred_trajectories_2d[0, 0].cpu().numpy(),\n                    pred_visibilities=pred_visibilities[0].cpu().numpy(),\n                    evaluation_setting=evaluation_setting,\n                    track_upscaling_factor=track_upscaling_factor,\n                    prefix=f\"eval_{dataset_name}\",\n                    add_per_track_results=should_save_rerun_viz,\n                    verbose=False,\n                )\n                if should_save_rerun_viz:\n                    per_track_results = eval_2dpt_results_dict[f'eval_{dataset_name}/model__per_track_results']\n                    del eval_2dpt_results_dict[f'eval_{dataset_name}/model__per_track_results']\n                metrics[datapoint_idx] = eval_2dpt_results_dict\n\n                tapvid2d_original_metrics = compute_tapvid_metrics_original(\n                    query_points_2d[0].cpu().numpy(),\n                    torch.logical_not(gt_visibilities_per_view[:, 0].clone().permute(0, 2, 1)).cpu().numpy(),\n                    gt_trajectories_2d_pixelspace_w_z_cameraspace[:, 0, :, :, :2].clone().permute(0, 2, 1,\n                                                                                                  3).cpu().numpy(),\n                    torch.logical_not(pred_visibilities.clone().permute(0, 2, 1)).cpu().numpy(),\n                    pred_trajectories_2d[:, 0].permute(0, 2, 1, 3).cpu().numpy(),\n                    query_mode=\"first\",\n                )\n                tapvid2d_original_metrics = {\n                    f\"eval_{dataset_name}/model__tapvid2d_{k}\": (tapvid2d_original_metrics[k] * 100).round(2).item()\n                    for k in sorted(tapvid2d_original_metrics)\n                }\n                metrics[datapoint_idx].update(tapvid2d_original_metrics)\n\n            elif evaluation_setting in [\"no-tracking-labels\"]:\n                metrics[datapoint_idx] = {}\n\n            np.savez(\n                os.path.join(log_dir, f\"step-{step}_seq-{seq_name}_tracks.npz\"),\n                gt_trajectories_2d=gt_trajectories_2d_pixelspace_w_z_cameraspace.cpu().numpy(),\n                gt_trajectories_3d=gt_trajectories_3d_worldspace.cpu().numpy(),\n                gt_visibilities_per_view=gt_visibilities_per_view.cpu().numpy(),\n                gt_visibilities_any_view=gt_visibilities_any_view.cpu().numpy(),\n                pred_trajectories_2d=pred_trajectories_pixel_xy_camera_z_per_view.cpu().numpy(),\n                pred_trajectories_3d=pred_trajectories.cpu().numpy(),\n                pred_visibilities_any_view=pred_visibilities.cpu().numpy(),\n                query_points_2d=query_points_2d.cpu().numpy() if query_points_2d is not None else None,\n                query_points_3d=query_points_3d.cpu().numpy(),\n                track_upscaling_factor=track_upscaling_factor,\n            )\n\n            # Visualize the results with rerun.io\n            viz_fps = 30\n            if \"panoptic\" in dataset_name:\n                viz_fps = 30\n            elif \"dex\" in dataset_name:\n                viz_fps = 10\n            elif \"kubric\" in dataset_name:\n                viz_fps = 12\n\n            if should_save_rerun_viz:\n                # Log the visualizations to rerun\n                if \"mvtracker\" in log_dir:\n                    method_id = 0\n                    method_name = \"MVTracker\"\n\n                elif \"spatracker_mono\" in log_dir:\n                    method_id = 1\n                    method_name = \"SpatialTrackerV1\"\n\n                elif \"tapip3d\" in log_dir:\n                    method_id = 2\n                    method_name = \"TAPIP3D\"\n\n                elif \"spatracker_multi\" in log_dir:\n                    method_id = 3\n                    method_name = \"Triplane\"\n\n                else:\n                    method_id = None\n                    method_name = \"x\"\n\n                if \"panoptic\" in dataset_name:\n                    sphere_radius = 12\n                else:\n                    sphere_radius = 6.0\n\n                max_tracks = None\n                if \"dress\" in dataset_name:\n                    max_tracks = 300\n                elif \"panoptic\" in dataset_name:\n                    max_tracks = 100\n                elif \"kubric\" in dataset_name or \"dex-ycb\" in dataset_name:\n                    max_tracks = 36\n\n                LogConfig = namedtuple(\"LogConfig\", [\n                    \"suffix\", \"method_id\", \"max_tracks\", \"track_batch_size\", \"sphere_radius\",\n                    \"conf_thrs\", \"log_only_confident_pc\", \"memory_lightweight_logging\"\n                ])\n                log_configs = [\n                    LogConfig(\n                        suffix=\"\",\n                        method_id=None,\n                        max_tracks=None,\n                        track_batch_size=50,\n                        sphere_radius=None,\n                        conf_thrs=[1.0, 5.0],\n                        log_only_confident_pc=False,\n                        memory_lightweight_logging=False,\n                    ),\n                    LogConfig(\n                        suffix=\".comparisons\",\n                        method_id=method_id,\n                        max_tracks=100,\n                        track_batch_size=50,\n                        sphere_radius=None,\n                        conf_thrs=[1.0, 5.0],\n                        log_only_confident_pc=False,\n                        memory_lightweight_logging=True,\n                    ),\n                    LogConfig(\n                        suffix=\".lightweight\",\n                        method_id=None,\n                        max_tracks=max_tracks,\n                        track_batch_size=50,\n                        sphere_radius=sphere_radius,\n                        conf_thrs=[5.0],\n                        log_only_confident_pc=True,\n                        memory_lightweight_logging=True,\n                    ),\n                    LogConfig(\n                        suffix=\".lightweight.comparisons\",\n                        method_id=method_id,\n                        max_tracks=50,\n                        track_batch_size=50,\n                        sphere_radius=sphere_radius,\n                        conf_thrs=[5.0],\n                        log_only_confident_pc=True,\n                        memory_lightweight_logging=True,\n                    ),\n                ]\n\n                for cfg in log_configs:\n                    logfile_name = f\"rerun__{dataset_name}--seq-{datapoint_idx}--name-{seq_name}--eval{cfg.suffix}.rrd\"\n                    rr.init(\"3dpt\", recording_id=\"v0.16\")\n\n                    if cfg.method_id is None or cfg.method_id == 0:\n                        log_pointclouds_to_rerun(\n                            dataset_name=dataset_name,\n                            datapoint_idx=datapoint_idx,\n                            rgbs=rgbs,\n                            depths=depths,\n                            intrs=intrs,\n                            extrs=extrs,\n                            depths_conf=depths_conf,\n                            conf_thrs=cfg.conf_thrs,\n                            log_only_confident_pc=cfg.log_only_confident_pc,\n                            radii=-2.45,\n                            fps=viz_fps,\n                            bbox_crop=None,\n                            sphere_radius_crop=cfg.sphere_radius,\n                            sphere_center_crop=np.array([0, 0, 0]),\n                            log_rgb_image=not cfg.memory_lightweight_logging,\n                            log_depthmap_as_image_v1=False,\n                            log_depthmap_as_image_v2=False,\n                            log_camera_frustrum=True,\n                            log_rgb_pointcloud=True,\n                        )\n\n                    log_tracks_to_rerun(\n                        dataset_name=dataset_name,\n                        datapoint_idx=datapoint_idx,\n                        predictor_name=method_name,\n                        gt_trajectories_3d_worldspace=None if no_tracking_labels else gt_trajectories_3d_worldspace,\n                        gt_visibilities_any_view=None if no_tracking_labels else gt_visibilities_any_view,\n                        query_points_3d=query_points_3d,\n                        pred_trajectories=pred_trajectories,\n                        pred_visibilities=pred_visibilities,\n                        per_track_results=per_track_results,\n                        radii_scale=1.0,\n                        fps=viz_fps,\n                        sphere_radius_crop=cfg.sphere_radius,\n                        sphere_center_crop=np.array([0, 0, 0]),\n                        log_per_interval_results=False,\n                        max_tracks_to_log=cfg.max_tracks,\n                        track_batch_size=cfg.track_batch_size,\n                        method_id=cfg.method_id,\n                        memory_lightweight_logging=cfg.memory_lightweight_logging,\n                    )\n\n                    rr_rrd_path = os.path.join(log_dir, logfile_name)\n                    rr.save(rr_rrd_path)\n                    logging.info(f\"Saved Rerun recording to: {rr_rrd_path}\")\n\n            # Visualize the results as mp4\n            if should_save_mp4_viz:\n                log_mp4_track_viz(\n                    log_dir=log_dir,\n                    dataset_name=dataset_name,\n                    datapoint_idx=datapoint_idx,\n                    rgbs=rgbs,\n                    intrs=intrs,\n                    extrs=extrs,\n                    gt_trajectories=gt_trajectories_3d_worldspace,\n                    gt_visibilities=gt_visibilities_any_view,\n                    pred_trajectories=pred_trajectories,\n                    pred_visibilities=pred_visibilities,\n                    query_points_3d=query_points_3d,\n                    step=step,\n                    prefix=\"comparison__v4a-train__\",\n                    max_tracks_to_visualize=36,\n                    max_individual_tracks_to_visualize=6,\n                )\n                if novel_rgbs is not None:\n                    log_mp4_track_viz(\n                        log_dir=log_dir,\n                        dataset_name=dataset_name,\n                        datapoint_idx=datapoint_idx,\n                        rgbs=novel_rgbs,\n                        intrs=novel_intrs,\n                        extrs=novel_extrs,\n                        gt_trajectories=gt_trajectories_3d_worldspace,\n                        gt_visibilities=gt_visibilities_any_view,\n                        pred_trajectories=pred_trajectories,\n                        pred_visibilities=pred_visibilities,\n                        query_points_3d=query_points_3d,\n                        step=step,\n                        prefix=\"comparison__v4b-novel__\",\n                        max_tracks_to_visualize=36,\n                        max_individual_tracks_to_visualize=0,\n                    )\n\n            metrics[datapoint_idx][\"fps\"] = fps\n\n            try:\n                params_total = sum(p.numel() for p in model.parameters())\n                params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n                params_non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)\n                metrics[datapoint_idx][\"params_total\"] = params_total\n                metrics[datapoint_idx][\"params_trainable\"] = params_trainable\n                metrics[datapoint_idx][\"params_non_trainable\"] = params_non_trainable\n            except Exception as e:\n                logging.info(f\"Error calculating model parameters: {e}\")\n\n        # Compute average\n        if count > 0:\n            avg_fps = total_fps / count\n            logging.info(f\"\\nAverage FPS across {count} datapoints: {avg_fps:.1f}\")\n        else:\n            logging.warning(\"No datapoints were processed.\")\n\n        return metrics\n"
  },
  {
    "path": "mvtracker/evaluation/metrics.py",
    "content": "import logging\nimport warnings\nfrom typing import Mapping\n\nimport numpy as np\nimport pandas as pd\nimport torch\n\n\ndef compute_metrics(\n        query_points,\n        gt_occluded,\n        gt_tracks,\n        pred_occluded,\n        pred_tracks,\n        distance_thresholds=[1, 2, 4, 8, 16],\n        survival_distance_threshold=50,\n        query_mode=\"first\",\n):\n    n_batches, n_frames, n_points, n_point_dim = gt_tracks.shape\n\n    # First, we compute the original TAP-Vid metrics\n    tapvid_metrics = compute_tapvid_metrics(query_points, gt_occluded, gt_tracks, pred_occluded,\n                                            pred_tracks, distance_thresholds, query_mode)\n\n    # Compute distances only for visible points\n    visible_mask = ~gt_occluded\n    distances = torch.norm(pred_tracks - gt_tracks, dim=-1)\n    distances[~visible_mask] = float('nan')\n    distances[torch.arange(n_frames)[None, :, None] < query_points[:, :, 0].long()[:, None, :]] = float('nan')\n\n    # Compute Median Trajectory Error (MTE) and Average Trajectory Error (ATE) for visible points\n    mte_per_track = torch.nanmedian(distances, dim=1).values\n    ate_per_track = torch.nanmean(distances, dim=1)\n    assert torch.isnan(mte_per_track).sum() == 0\n    assert torch.isnan(ate_per_track).sum() == 0\n\n    # Compute Final Trajectory Error (FDE) for the last visible frame\n    last_visible_idx = torch.argmax(visible_mask * np.arange(n_frames)[None, :, None], dim=1)\n    fde_per_track = distances[torch.arange(n_batches)[:, None], last_visible_idx, torch.arange(n_points)]\n\n    # Compute \"Survival\" rate for visible points\n    tracking_failed = (distances > survival_distance_threshold) * visible_mask\n    failure_index = tracking_failed.float().argmax(dim=1)\n    failure_index[(~tracking_failed).all(dim=1)] = n_frames  # If all points survived, survival is 1.0\n    survival_per_track = (failure_index - query_points[:, :, 0].long()) / (n_frames - query_points[:, :, 0].long())\n\n    assert mte_per_track.shape == ate_per_track.shape == survival_per_track.shape == fde_per_track.shape\n\n    metrics = {\n        'mte_visible_per_track': mte_per_track,\n        'ate_visible_per_track': ate_per_track,\n        'fde_visible_per_track': fde_per_track,\n        'survival_per_track': survival_per_track,\n        **tapvid_metrics,\n    }\n\n    return metrics\n\n\ndef compute_tapvid_metrics(\n        query_points,\n        gt_occluded,\n        gt_tracks,\n        pred_occluded,\n        pred_tracks,\n        distance_thresholds,\n        query_mode=\"first\",\n):\n    \"\"\"\n    Computes metrics from TAP-Vid (https://arxiv.org/abs/2211.03726) based on given ground truth and predictions.\n    The computations are performed separately for each video in the batch.\n\n    Parameters\n    ----------\n    query_points : torch.Tensor\n        Tensor of shape (n_batches, n_points, 3) representing the query points.\n    gt_occluded : torch.Tensor\n        Boolean tensor of shape (n_batches, n_frames, n_points) indicating if a point is occluded in the ground truth.\n    gt_tracks : torch.Tensor\n        Tensor of shape (n_batches, n_frames, n_points, n_point_dim) representing the ground truth tracks.\n    pred_occluded : torch.Tensor\n        Boolean tensor of shape (n_batches, n_frames, n_points) indicating if a point is occluded in the predictions.\n    pred_tracks : torch.Tensor\n        Tensor of shape (n_batches, n_frames, n_points, n_point_dim) representing the predicted tracks.\n    query_mode : str, optional\n        Either \"first\" or \"strided\", default is \"first\". Indicates how the query points are sampled.\n\n    Returns\n    -------\n    dict\n        A dictionary containing:\n        - 'occlusion_accuracy_per_track': Accuracy at predicting occlusion, per track.\n        - 'pts_within_{x}_per_track' for x in [1, 2, 4, 8, 16]: Fraction of points predicted to be within the given pixel threshold, ignoring occlusion prediction, per track.\n        - 'jaccard_{x}_per_track' for x in [1, 2, 4, 8, 16]: Jaccard metric for the given pixel threshold. Combines occlusion and point prediction accuracy, per track.\n        - 'average_jaccard_per_track': Average Jaccard metric across thresholds, per track.\n        - 'average_pts_within_thresh_per_track': Average fraction of points within threshold across thresholds, per track.\n    \"\"\"\n\n    metrics = {}\n\n    # Check shapes.\n    n_batches, n_frames, n_points, n_point_dim = gt_tracks.shape\n    assert n_point_dim in [2, 3]\n    assert query_points.shape == (n_batches, n_points, n_point_dim + 1)\n    assert gt_occluded.shape == (n_batches, n_frames, n_points)\n    assert gt_tracks.shape == (n_batches, n_frames, n_points, n_point_dim)\n    assert pred_occluded.shape == (n_batches, n_frames, n_points)\n    assert pred_tracks.shape == (n_batches, n_frames, n_points, n_point_dim)\n    assert query_mode in [\"first\", \"strided\"]\n    assert query_points.dtype == torch.float32\n    assert gt_occluded.dtype == torch.bool\n    assert gt_tracks.dtype == torch.float32\n    assert pred_occluded.dtype == torch.bool\n    assert pred_tracks.dtype == torch.float32\n\n    # Don't evaluate the query point.\n    evaluation_points = torch.ones_like(gt_occluded, dtype=torch.bool)\n    for batch_idx in range(n_batches):\n        t = query_points[batch_idx, :, 0].long()\n        evaluation_points[batch_idx, t, torch.arange(n_points)] = False\n\n    # In first query mode, don't evaluate points before the query point.\n    if query_mode == \"first\":\n        t = query_points[:, :, 0].long()\n        mask = torch.arange(n_frames).unsqueeze(-1) < t.unsqueeze(1)\n        evaluation_points[mask] = False\n\n    # Compute occlusion accuracy per track.\n    occ_acc = ((pred_occluded == gt_occluded) & evaluation_points).float().sum(dim=1) / evaluation_points.sum(dim=1)\n    metrics[\"occlusion_accuracy_per_track\"] = occ_acc\n\n    # Let's report the numbers separately for gt=0 and gt=1\n    numer0 = ((pred_occluded == gt_occluded) & (gt_occluded == 1) & evaluation_points).float().sum(dim=1)\n    numer1 = ((pred_occluded == gt_occluded) & (gt_occluded == 0) & evaluation_points).float().sum(dim=1)\n    denom0 = ((gt_occluded == 1) & evaluation_points).float().sum(dim=1)\n    denom1 = ((gt_occluded == 0) & evaluation_points).float().sum(dim=1)\n    occ_acc_for_vis0 = numer0 / denom0\n    occ_acc_for_vis1 = numer1 / denom1\n    metrics[\"occlusion_accuracy_for_vis0_per_track\"] = occ_acc_for_vis0\n    metrics[\"occlusion_accuracy_for_vis1_per_track\"] = occ_acc_for_vis1\n\n    # Compute position metrics per track.\n    distances = torch.norm(pred_tracks - gt_tracks, dim=-1)\n    thresholds = torch.tensor(distance_thresholds, device=distances.device)\n    for thresh in thresholds:\n        within_threshold = distances < thresh\n        correct_positions = (within_threshold & ~gt_occluded & evaluation_points).float().sum(dim=1)\n        visible_points = (~gt_occluded & evaluation_points).float().sum(dim=1)\n        assert visible_points.min() > 0, \"No visible points to evaluate. Make sure at least two timesteps were visible.\"\n        metrics[f\"pts_within_{thresh:.2f}_per_track\"] = correct_positions / visible_points\n\n        true_positives = (within_threshold & ~pred_occluded & ~gt_occluded & evaluation_points).float().sum(dim=1)\n        gt_positives = (~gt_occluded & evaluation_points).float().sum(dim=1)\n        false_positives = (~within_threshold & ~pred_occluded) | (~pred_occluded & gt_occluded)\n        false_positives = (false_positives & evaluation_points).float().sum(dim=1)\n        jaccard = true_positives / (gt_positives + false_positives)\n        metrics[f\"jaccard_{thresh:.2f}_per_track\"] = jaccard\n\n    metrics[\"average_jaccard_per_track\"] = torch.stack([metrics[f\"jaccard_{thresh:.2f}_per_track\"]\n                                                        for thresh in thresholds], dim=-1).mean(dim=-1)\n    metrics[\"average_pts_within_thresh_per_track\"] = torch.stack([metrics[f\"pts_within_{thresh:.2f}_per_track\"]\n                                                                  for thresh in thresholds], dim=-1).mean(dim=-1)\n\n    # Assert no nans\n    for k, v in metrics.items():\n        if k in [\"occlusion_accuracy_for_vis0_per_track\", \"occlusion_accuracy_for_vis1_per_track\"]:\n            continue  # They can have nans and will be handled later\n        assert not torch.isnan(v).any(), f\"NaN found in {k}\"\n\n    return metrics\n\n\ndef compute_tapvid_metrics_original(\n        query_points: np.ndarray,\n        gt_occluded: np.ndarray,\n        gt_tracks: np.ndarray,\n        pred_occluded: np.ndarray,\n        pred_tracks: np.ndarray,\n        query_mode: str,\n) -> Mapping[str, np.ndarray]:\n    \"\"\"Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)\n    See the TAP-Vid paper for details on the metric computation.  All inputs are\n    given in raster coordinates.  The first three arguments should be the direct\n    outputs of the reader: the 'query_points', 'occluded', and 'target_points'.\n    The paper metrics assume these are scaled relative to 256x256 images.\n    pred_occluded and pred_tracks are your algorithm's predictions.\n    This function takes a batch of inputs, and computes metrics separately for\n    each video.  The metrics for the full benchmark are a simple mean of the\n    metrics across the full set of videos.  These numbers are between 0 and 1,\n    but the paper multiplies them by 100 to ease reading.\n    Args:\n       query_points: The query points, an in the format [t, y, x].  Its size is\n         [b, n, 3], where b is the batch size and n is the number of queries\n       gt_occluded: A boolean array of shape [b, n, t], where t is the number\n         of frames.  True indicates that the point is occluded.\n       gt_tracks: The target points, of shape [b, n, t, 2].  Each point is\n         in the format [x, y]\n       pred_occluded: A boolean array of predicted occlusions, in the same\n         format as gt_occluded.\n       pred_tracks: An array of track predictions from your algorithm, in the\n         same format as gt_tracks.\n       query_mode: Either 'first' or 'strided', depending on how queries are\n         sampled.  If 'first', we assume the prior knowledge that all points\n         before the query point are occluded, and these are removed from the\n         evaluation.\n    Returns:\n        A dict with the following keys:\n        occlusion_accuracy: Accuracy at predicting occlusion.\n        pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points\n          predicted to be within the given pixel threshold, ignoring occlusion\n          prediction.\n        jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given\n          threshold\n        average_pts_within_thresh: average across pts_within_{x}\n        average_jaccard: average across jaccard_{x}\n    \"\"\"\n\n    metrics = {}\n    # Fixed bug is described in:\n    # https://github.com/facebookresearch/co-tracker/issues/20\n    eye = np.eye(gt_tracks.shape[2], dtype=np.int32)\n\n    if query_mode == \"first\":\n        # evaluate frames after the query frame\n        query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye\n    elif query_mode == \"strided\":\n        # evaluate all frames except the query frame\n        query_frame_to_eval_frames = 1 - eye\n    else:\n        raise ValueError(\"Unknown query mode \" + query_mode)\n\n    query_frame = query_points[..., 0]\n    query_frame = np.round(query_frame).astype(np.int32)\n    evaluation_points = query_frame_to_eval_frames[query_frame] > 0\n\n    # Occlusion accuracy is simply how often the predicted occlusion equals the\n    # ground truth.\n    occ_acc = np.sum(\n        np.equal(pred_occluded, gt_occluded) & evaluation_points,\n        axis=(1, 2),\n    ) / np.sum(evaluation_points)\n    metrics[\"occlusion_accuracy\"] = occ_acc\n\n    # Next, convert the predictions and ground truth positions into pixel\n    # coordinates.\n    visible = np.logical_not(gt_occluded)\n    pred_visible = np.logical_not(pred_occluded)\n    all_frac_within = []\n    all_jaccard = []\n    for thresh in [1, 2, 4, 8, 16]:\n        # True positives are points that are within the threshold and where both\n        # the prediction and the ground truth are listed as visible.\n        within_dist = np.sum(\n            np.square(pred_tracks - gt_tracks),\n            axis=-1,\n        ) < np.square(thresh)\n        is_correct = np.logical_and(within_dist, visible)\n\n        # Compute the frac_within_threshold, which is the fraction of points\n        # within the threshold among points that are visible in the ground truth,\n        # ignoring whether they're predicted to be visible.\n        count_correct = np.sum(\n            is_correct & evaluation_points,\n            axis=(1, 2),\n        )\n        count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))\n        frac_correct = count_correct / count_visible_points\n        metrics[\"pts_within_\" + str(thresh)] = frac_correct\n        all_frac_within.append(frac_correct)\n\n        true_positives = np.sum(\n            is_correct & pred_visible & evaluation_points, axis=(1, 2)\n        )\n\n        # The denominator of the jaccard metric is the true positives plus\n        # false positives plus false negatives.  However, note that true positives\n        # plus false negatives is simply the number of points in the ground truth\n        # which is easier to compute than trying to compute all three quantities.\n        # Thus we just add the number of points in the ground truth to the number\n        # of false positives.\n        #\n        # False positives are simply points that are predicted to be visible,\n        # but the ground truth is not visible or too far from the prediction.\n        gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))\n        false_positives = (~visible) & pred_visible\n        false_positives = false_positives | ((~within_dist) & pred_visible)\n        false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))\n        jaccard = true_positives / (gt_positives + false_positives)\n        metrics[\"jaccard_\" + str(thresh)] = jaccard\n        all_jaccard.append(jaccard)\n    metrics[\"average_jaccard\"] = np.mean(\n        np.stack(all_jaccard, axis=1),\n        axis=1,\n    )\n    metrics[\"average_pts_within_thresh\"] = np.mean(\n        np.stack(all_frac_within, axis=1),\n        axis=1,\n    )\n    return metrics\n\n\ndef evaluate_predictions(\n        gt_tracks,\n        gt_visibilities,\n        pred_tracks,\n        pred_occluded,\n        query_points=None,\n        distance_thresholds=[0.01, 0.02, 0.04, 0.08, 0.16],  # 1 cm, 2 cm, 4 cm, 8 cm, 16 cm\n        survival_distance_threshold=0.5,  # 50 cm\n        static_threshold=0.01,  # < 0.01 cm\n        dynamic_threshold=0.1,  # > 10 cm\n        very_dynamic_threshold=2.0,  # > 2 m\n):\n    n_frames, n_points, n_point_dim = gt_tracks.shape\n\n    if query_points is None:\n        warnings.warn(\"Query points are not provided. Using the first visible frame as query points.\")\n        query_points_t = np.argmax(gt_visibilities, axis=0)\n        query_points_xyz = gt_tracks[query_points_t, np.arange(n_points)]\n        query_points = np.concatenate([query_points_t[:, None], query_points_xyz], axis=-1)\n\n    at_query_timestep_or_later = (np.arange(n_frames)[:, None] >= query_points[:, 0][None, :])\n    gt_visibilities = gt_visibilities.copy() * at_query_timestep_or_later\n\n    movement = np.zeros(n_points)\n    for point_idx in range(n_points):\n        point_track = gt_tracks[gt_visibilities[:, point_idx], point_idx, :]\n        movement[point_idx] = np.linalg.norm(point_track[1:] - point_track[:-1], axis=-1).sum()\n\n    point_types = [\"any\"]\n    static_points = None\n    dynamic_points = None\n    very_dynamic_points = None\n    if static_threshold is not None:\n        point_types += [\"static\"]\n        static_points = movement < static_threshold\n    if dynamic_threshold is not None:\n        point_types += [\"dynamic\"]\n        dynamic_points = movement > dynamic_threshold\n    if very_dynamic_threshold is not None:\n        point_types += [\"very_dynamic\"]\n        very_dynamic_points = movement > very_dynamic_threshold\n\n    mask_1 = gt_visibilities.sum(axis=0) >= 2  # At least two visible, the first one is a query\n\n    results = {}\n    results_per_track = {}\n    for short_name, mask_a in [\n        (\"all\", mask_1),\n    ]:\n        for point_type in point_types:\n            if point_type == \"any\":\n                mask_b = np.ones_like(mask_a)\n            elif point_type == \"static\":\n                mask_b = static_points\n            elif point_type == \"dynamic\":\n                mask_b = dynamic_points\n            elif point_type == \"very_dynamic\":\n                mask_b = very_dynamic_points\n            else:\n                raise ValueError\n            mask_ab = mask_a & mask_b\n            short_name_ = f\"{short_name}_{point_type}\"\n\n            if mask_ab.sum() == 0:\n                logging.info(f\"No points for {short_name_} (empty mask).\")\n                continue\n\n            pred_tracks_ = pred_tracks[:, mask_ab, :][None]\n            out_metrics_3d = compute_metrics(\n                torch.from_numpy(query_points[mask_ab, :][None]).float(),\n                torch.from_numpy(~gt_visibilities[:, mask_ab][None]),\n                torch.from_numpy(gt_tracks[:, mask_ab, :][None]).float(),\n                torch.from_numpy(pred_occluded[:, mask_ab][None]),\n                torch.from_numpy(pred_tracks_).float(),\n                distance_thresholds=distance_thresholds,\n                survival_distance_threshold=survival_distance_threshold,\n                query_mode=\"first\",\n            )\n            results[short_name_] = {}\n            for k, v in out_metrics_3d.items():\n                assert \"_per_track\" in k\n                results[short_name_][k.replace(\"_per_track\", \"\")] = v.nanmean().item() * 100\n            results[short_name_][\"n\"] = mask_ab.sum() / n_points * 100\n            results[short_name_][\"v\"] = (gt_visibilities[:, mask_ab].sum() / mask_ab.sum() / n_frames) * 100\n\n            results_per_track[short_name_] = {}\n            for k, v in out_metrics_3d.items():\n                assert v.ndim == 2 and v.shape[0] == 1\n                v = v[0]\n                results_per_track[short_name_][k] = v.cpu().numpy() * 100\n            results_per_track[short_name_][\"indices\"] = np.where(mask_ab)[0]\n\n    if \"all_static\" in results.keys() and \"all_dynamic\" in results.keys():\n        results[\"all_dynamic-static-mean\"] = {}\n        for k in results[\"all_static\"].keys():\n            results[\"all_dynamic-static-mean\"][k] = (results[\"all_dynamic\"][k] + results[\"all_static\"][k]) / 2\n\n    df = pd.DataFrame(results)\n    df = df.round(2)\n\n    df_per_track = pd.DataFrame(results_per_track)\n    df_per_track = df_per_track.round(2)\n\n    return df, df_per_track\n"
  },
  {
    "path": "mvtracker/models/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/models/core/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/models/core/copycat.py",
    "content": "import torch\nfrom torch import nn as nn\n\n\nclass CopyCat(nn.Module):\n    \"\"\"\n    Dummy, no-movement baseline that always outputs the query points as the predicted points.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.dummy_learnable_param = nn.Parameter(torch.zeros(1))\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points,\n            intrs,\n            extrs,\n            **kwargs,\n    ):\n        batch_size, num_views, num_frames, _, height, width = rgbs.shape\n        _, num_points, _ = query_points.shape\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n        assert query_points.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        traj_e = query_points[:, None, :, 1:].repeat(1, num_frames, 1, 1)\n        vis_e = query_points.new_ones((batch_size, num_frames, num_points))\n\n        results = {\n            \"traj_e\": traj_e,\n            \"feat_init\": None,\n            \"vis_e\": vis_e,\n        }\n        return results\n"
  },
  {
    "path": "mvtracker/models/core/cotracker2/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/cotracker2/blocks.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport collections\nfrom itertools import repeat\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\nto_2tuple = _ntuple(2)\n\n\nclass Mlp(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n            self,\n            in_features,\n            hidden_features=None,\n            out_features=None,\n            act_layer=nn.GELU,\n            bias=True,\n            drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn=\"group\", stride=1):\n        super(ResidualBlock, self).__init__()\n\n        self.conv1 = nn.Conv2d(\n            in_planes,\n            planes,\n            kernel_size=3,\n            padding=1,\n            stride=stride,\n            padding_mode=\"zeros\",\n        )\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode=\"zeros\")\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n\n        elif norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n\n        elif norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n\n        else:\n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3\n            )\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x + y)\n\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, input_dim=3, output_dim=128, stride=4):\n        super(BasicEncoder, self).__init__()\n        self.stride = stride\n        self.norm_fn = \"instance\"\n        self.in_planes = output_dim // 2\n\n        self.norm1 = nn.InstanceNorm2d(self.in_planes)\n        self.norm2 = nn.InstanceNorm2d(output_dim * 2)\n\n        self.conv1 = nn.Conv2d(\n            input_dim,\n            self.in_planes,\n            kernel_size=7,\n            stride=2,\n            padding=3,\n            padding_mode=\"zeros\",\n        )\n        self.relu1 = nn.ReLU(inplace=True)\n        self.layer1 = self._make_layer(output_dim // 2, stride=1)\n        self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)\n        self.layer3 = self._make_layer(output_dim, stride=2)\n        self.layer4 = self._make_layer(output_dim, stride=2)\n\n        self.conv2 = nn.Conv2d(\n            output_dim * 3 + output_dim // 4,\n            output_dim * 2,\n            kernel_size=3,\n            padding=1,\n            padding_mode=\"zeros\",\n        )\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n            elif isinstance(m, (nn.InstanceNorm2d)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n\n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        _, _, H, W = x.shape\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        a = self.layer1(x)\n        b = self.layer2(a)\n        c = self.layer3(b)\n        d = self.layer4(c)\n\n        def _bilinear_intepolate(x):\n            return F.interpolate(\n                x,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n\n        a = _bilinear_intepolate(a)\n        b = _bilinear_intepolate(b)\n        c = _bilinear_intepolate(c)\n        d = _bilinear_intepolate(d)\n\n        x = self.conv2(torch.cat([a, b, c, d], dim=1))\n        x = self.norm2(x)\n        x = self.relu2(x)\n        x = self.conv3(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):\n        super().__init__()\n        inner_dim = dim_head * num_heads\n        context_dim = default(context_dim, query_dim)\n        self.scale = dim_head ** -0.5\n        self.heads = num_heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)\n        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)\n        self.to_out = nn.Linear(inner_dim, query_dim)\n\n    def forward(self, x, context=None, attn_mask=None):\n        B, N1, _ = x.shape\n        h = self.heads\n\n        q = self.to_q(x).reshape(B, N1, h, -1).permute(0, 2, 1, 3)\n        context = default(context, x)\n        k, v = self.to_kv(context).chunk(2, dim=-1)\n\n        N2 = context.shape[1]\n        k = k.reshape(B, N2, h, -1).permute(0, 2, 1, 3)\n        v = v.reshape(B, N2, h, -1).permute(0, 2, 1, 3)\n\n        sim = (q @ k.transpose(-2, -1)) * self.scale\n\n        if attn_mask is not None:\n            sim = sim.masked_fill(~attn_mask, float('-inf'))\n        attn = sim.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N1, -1)\n        return self.to_out(x)\n\n\nclass FlashAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):\n        super().__init__()\n        inner_dim = dim_head * num_heads\n        context_dim = default(context_dim, query_dim)\n        self.num_heads = num_heads\n        self.dim_head = dim_head\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)\n        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)\n        self.to_out = nn.Linear(inner_dim, query_dim)\n\n    def forward(self, x, context=None, attn_mask=None):\n        B, N1, _ = x.shape\n        h = self.num_heads\n\n        q = self.to_q(x).reshape(B, N1, h, self.dim_head).transpose(1, 2)\n        context = default(context, x)\n        k, v = self.to_kv(context).chunk(2, dim=-1)\n        N2 = context.shape[1]\n        k = k.reshape(B, N2, h, self.dim_head).transpose(1, 2)\n        v = v.reshape(B, N2, h, self.dim_head).transpose(1, 2)\n\n        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)\n        x = x.transpose(1, 2).reshape(B, N1, -1)\n        return self.to_out(x)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(\n            self,\n            hidden_size,\n            num_heads,\n            mlp_ratio=4.0,\n            attn_class: Callable[..., nn.Module] = Attention,\n            **block_kwargs,\n    ):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)\n\n        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        approx_gelu = lambda: nn.GELU(approximate=\"tanh\")\n        self.mlp = Mlp(\n            in_features=hidden_size,\n            hidden_features=mlp_hidden_dim,\n            act_layer=approx_gelu,\n            drop=0,\n        )\n\n    def forward(self, x, attn_mask=None):\n        x = x + self.attn(self.norm1(x), attn_mask=attn_mask)\n        x = x + self.mlp(self.norm2(x))\n        return x\n\n\nclass CrossAttnBlock(nn.Module):\n    def __init__(\n            self,\n            hidden_size,\n            context_dim,\n            num_heads,\n            mlp_ratio=4.0,\n            attn_class: Callable[..., nn.Module] = Attention,\n            **block_kwargs,\n    ):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.norm_context = nn.LayerNorm(hidden_size)\n        self.cross_attn = attn_class(\n            query_dim=hidden_size,\n            context_dim=context_dim,\n            num_heads=num_heads,\n            qkv_bias=True,\n            **block_kwargs,\n        )\n\n        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        approx_gelu = lambda: nn.GELU(approximate=\"tanh\")\n        self.mlp = Mlp(\n            in_features=hidden_size,\n            hidden_features=mlp_hidden_dim,\n            act_layer=approx_gelu,\n            drop=0,\n        )\n\n    def forward(self, x, context, attn_mask=None):\n        x = x + self.cross_attn(self.norm1(x), context=self.norm_context(context), attn_mask=attn_mask)\n        x = x + self.mlp(self.norm2(x))\n        return x\n\n\nclass EfficientUpdateFormer(nn.Module):\n    \"\"\"\n    Transformer model that updates track estimates.\n    \"\"\"\n\n    def __init__(\n            self,\n            space_depth=6,\n            time_depth=6,\n            input_dim=320,\n            hidden_size=384,\n            num_heads=8,\n            output_dim=130,\n            mlp_ratio=4.0,\n            add_space_attn=True,\n            num_virtual_tracks=64,\n            attn_class: Callable[..., nn.Module] = Attention,\n            linear_layer_for_vis_conf=False,\n    ):\n        super().__init__()\n        self.out_channels = 2\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.add_space_attn = add_space_attn\n        self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)\n        self.linear_layer_for_vis_conf = linear_layer_for_vis_conf\n        if self.linear_layer_for_vis_conf:\n            self.flow_head = nn.Sequential(\n                nn.Linear(hidden_size, output_dim, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Linear(output_dim, output_dim, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Linear(output_dim, output_dim - 2, bias=True)\n            )\n            self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)\n        else:\n            self.flow_head = nn.Sequential(\n                nn.Linear(hidden_size, output_dim, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Linear(output_dim, output_dim, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Linear(output_dim, output_dim, bias=True)\n            )\n        self.num_virtual_tracks = num_virtual_tracks\n        self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))\n        self.time_blocks = nn.ModuleList(\n            [\n                AttnBlock(\n                    hidden_size,\n                    num_heads,\n                    mlp_ratio=mlp_ratio,\n                    attn_class=attn_class,\n                )\n                for _ in range(time_depth)\n            ]\n        )\n\n        if add_space_attn:\n            self.space_virtual_blocks = nn.ModuleList(\n                [\n                    AttnBlock(\n                        hidden_size,\n                        num_heads,\n                        mlp_ratio=mlp_ratio,\n                        attn_class=attn_class,\n                    )\n                    for _ in range(space_depth)\n                ]\n            )\n            self.space_point2virtual_blocks = nn.ModuleList(\n                [\n                    CrossAttnBlock(\n                        hidden_size,\n                        hidden_size,\n                        num_heads,\n                        mlp_ratio=mlp_ratio,\n                        attn_class=attn_class,\n                    )\n                    for _ in range(space_depth)\n                ]\n            )\n            self.space_virtual2point_blocks = nn.ModuleList(\n                [\n                    CrossAttnBlock(\n                        hidden_size,\n                        hidden_size,\n                        num_heads,\n                        mlp_ratio=mlp_ratio,\n                        attn_class=attn_class,\n                    )\n                    for _ in range(space_depth)\n                ]\n            )\n            assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        def xavier_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n\n        def trunc_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.trunc_normal_(module.weight, std=0.001)\n\n        # Apply xavier to all except flow_head\n        self.apply(xavier_init)\n\n        # Then override flow_head with trunc_normal\n        self.flow_head.apply(trunc_init)\n        if self.linear_layer_for_vis_conf:\n            self.vis_conf_head.apply(trunc_init)\n\n    def forward(self, input_tensor, mask=None):\n        tokens = self.input_transform(input_tensor)\n        B, _, T, _ = tokens.shape\n        virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)\n        tokens = torch.cat([tokens, virtual_tokens], dim=1)\n        _, N, _, _ = tokens.shape\n\n        j = 0\n        for i in range(len(self.time_blocks)):\n            time_tokens = tokens.contiguous().view(B * N, T, -1)  # B N T C -> (B N) T C\n            time_tokens = self.time_blocks[i](time_tokens)\n\n            tokens = time_tokens.view(B, N, T, -1)  # (B N) T C -> B N T C\n            if self.add_space_attn and (\n                    i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0\n            ):\n                space_tokens = (\n                    tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)\n                )  # B N T C -> (B T) N C\n                point_tokens = space_tokens[:, : N - self.num_virtual_tracks]\n                virtual_tokens = space_tokens[:, N - self.num_virtual_tracks:]\n\n                virtual_tokens = self.space_virtual2point_blocks[j](\n                    virtual_tokens, point_tokens, attn_mask=mask\n                )\n                virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)\n                point_tokens = self.space_point2virtual_blocks[j](\n                    point_tokens, virtual_tokens, attn_mask=mask\n                )\n                space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)\n                tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3)  # (B T) N C -> B N T C\n                j += 1\n        tokens = tokens[:, : N - self.num_virtual_tracks]\n\n        flow = self.flow_head(tokens)\n        if self.linear_layer_for_vis_conf:\n            vis_conf = self.vis_conf_head(tokens)\n            flow = torch.cat([flow, vis_conf], dim=-1)\n\n        return flow\n"
  },
  {
    "path": "mvtracker/models/core/dpt/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/dpt/base_model.py",
    "content": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n\n        Args:\n            path (str): file path\n        \"\"\"\n        parameters = torch.load(path, map_location=torch.device(\"cpu\"))\n\n        if \"optimizer\" in parameters:\n            parameters = parameters[\"model\"]\n\n        self.load_state_dict(parameters)\n"
  },
  {
    "path": "mvtracker/models/core/dpt/blocks.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom mvtracker.models.core.dpt.vit import (\n    _make_pretrained_vitb_rn50_384,\n    _make_pretrained_vitl16_384,\n    _make_pretrained_vitb16_384,\n    _make_pretrained_vit_tiny\n)\n\n\ndef _make_encoder(\n        backbone,\n        features,\n        use_pretrained,\n        groups=1,\n        expand=False,\n        exportable=True,\n        hooks=None,\n        use_vit_only=False,\n        use_readout=\"ignore\",\n        enable_attention_hooks=False,\n):\n    if backbone == \"vitl16_384\":\n        pretrained = _make_pretrained_vitl16_384(\n            use_pretrained,\n            hooks=hooks,\n            use_readout=use_readout,\n            enable_attention_hooks=enable_attention_hooks,\n        )\n        scratch = _make_scratch(\n            [256, 512, 1024, 1024], features, groups=groups, expand=expand\n        )  # ViT-L/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb_rn50_384\":\n        pretrained = _make_pretrained_vitb_rn50_384(\n            use_pretrained,\n            hooks=hooks,\n            use_vit_only=use_vit_only,\n            use_readout=use_readout,\n            enable_attention_hooks=enable_attention_hooks,\n        )\n        scratch = _make_scratch(\n            [256, 512, 768, 768], features, groups=groups, expand=expand\n        )  # ViT-H/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb16_384\":\n        pretrained = _make_pretrained_vitb16_384(\n            use_pretrained,\n            hooks=hooks,\n            use_readout=use_readout,\n            enable_attention_hooks=enable_attention_hooks,\n        )\n        scratch = _make_scratch(\n            [96, 192, 384, 768], features, groups=groups, expand=expand\n        )  # ViT-B/16 - 84.6% Top1 (backbone)\n    elif backbone == \"resnext101_wsl\":\n        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)\n        scratch = _make_scratch(\n            [256, 512, 1024, 2048], features, groups=groups, expand=expand\n        )  # efficientnet_lite3\n    elif backbone == \"vit_tiny_r_s16_p8_384\":\n        pretrained = _make_pretrained_vit_tiny(\n            use_pretrained,\n            hooks=hooks,\n            use_readout=use_readout,\n            enable_attention_hooks=enable_attention_hooks,\n        )\n        scratch = _make_scratch(\n            [96, 192, 384, 768], features, groups=groups, expand=expand\n        )\n    else:\n        print(f\"Backbone '{backbone}' not implemented\")\n        assert False\n\n    return pretrained, scratch\n\n\ndef _make_scratch(in_shape, out_shape, groups=1, expand=False):\n    scratch = nn.Module()\n\n    out_shape1 = out_shape\n    out_shape2 = out_shape\n    out_shape3 = out_shape\n    out_shape4 = out_shape\n    if expand == True:\n        out_shape1 = out_shape\n        out_shape2 = out_shape * 2\n        out_shape3 = out_shape * 4\n        out_shape4 = out_shape * 8\n\n    scratch.layer1_rn = nn.Conv2d(\n        in_shape[0],\n        out_shape1,\n        kernel_size=3,\n        stride=1,\n        padding=1,\n        bias=False,\n        groups=groups,\n    )\n    scratch.layer2_rn = nn.Conv2d(\n        in_shape[1],\n        out_shape2,\n        kernel_size=3,\n        stride=1,\n        padding=1,\n        bias=False,\n        groups=groups,\n    )\n    scratch.layer3_rn = nn.Conv2d(\n        in_shape[2],\n        out_shape3,\n        kernel_size=3,\n        stride=1,\n        padding=1,\n        bias=False,\n        groups=groups,\n    )\n    scratch.layer4_rn = nn.Conv2d(\n        in_shape[3],\n        out_shape4,\n        kernel_size=3,\n        stride=1,\n        padding=1,\n        bias=False,\n        groups=groups,\n    )\n\n    return scratch\n\n\ndef _make_resnet_backbone(resnet):\n    pretrained = nn.Module()\n    pretrained.layer1 = nn.Sequential(\n        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1\n    )\n\n    pretrained.layer2 = resnet.layer2\n    pretrained.layer3 = resnet.layer3\n    pretrained.layer4 = resnet.layer4\n\n    return pretrained\n\n\ndef _make_pretrained_resnext101_wsl(use_pretrained):\n    resnet = torch.hub.load(\"facebookresearch/WSL-Images\", \"resnext101_32x8d_wsl\")\n    return _make_resnet_backbone(resnet)\n\n\nclass Interpolate(nn.Module):\n    \"\"\"Interpolation module.\"\"\"\n\n    def __init__(self, scale_factor, mode, align_corners=False):\n        \"\"\"Init.\n\n        Args:\n            scale_factor (float): scaling\n            mode (str): interpolation mode\n        \"\"\"\n        super(Interpolate, self).__init__()\n\n        self.interp = nn.functional.interpolate\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: interpolated data\n        \"\"\"\n\n        x = self.interp(\n            x,\n            scale_factor=self.scale_factor,\n            mode=self.mode,\n            align_corners=self.align_corners,\n        )\n\n        return x\n\n\nclass ResidualConvUnit(nn.Module):\n    \"\"\"Residual convolution module.\"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n        out = self.relu(x)\n        out = self.conv1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        return out + x\n\n\nclass FeatureFusionBlock(nn.Module):\n    \"\"\"Feature fusion block.\"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock, self).__init__()\n\n        self.resConfUnit1 = ResidualConvUnit(features)\n        self.resConfUnit2 = ResidualConvUnit(features)\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            output += self.resConfUnit1(xs[1])\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=True\n        )\n\n        return output\n\n\nclass ResidualConvUnit_custom(nn.Module):\n    \"\"\"Residual convolution module.\"\"\"\n\n    def __init__(self, features, activation, bn):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.bn = bn\n\n        self.groups = 1\n\n        self.conv1 = nn.Conv2d(\n            features,\n            features,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            bias=not self.bn,\n            groups=self.groups,\n        )\n\n        self.conv2 = nn.Conv2d(\n            features,\n            features,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            bias=not self.bn,\n            groups=self.groups,\n        )\n\n        if self.bn == True:\n            self.bn1 = nn.BatchNorm2d(features)\n            self.bn2 = nn.BatchNorm2d(features)\n\n        self.activation = activation\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n\n        out = self.activation(x)\n        out = self.conv1(out)\n        if self.bn == True:\n            out = self.bn1(out)\n\n        out = self.activation(out)\n        out = self.conv2(out)\n        if self.bn == True:\n            out = self.bn2(out)\n\n        if self.groups > 1:\n            out = self.conv_merge(out)\n\n        return self.skip_add.add(out, x)\n\n        # return out + x\n\n\nclass FeatureFusionBlock_custom(nn.Module):\n    \"\"\"Feature fusion block.\"\"\"\n\n    def __init__(\n            self,\n            features,\n            activation,\n            deconv=False,\n            bn=False,\n            expand=False,\n            align_corners=True,\n    ):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock_custom, self).__init__()\n\n        self.deconv = deconv\n        self.align_corners = align_corners\n\n        self.groups = 1\n\n        self.expand = expand\n        out_features = features\n        if self.expand == True:\n            out_features = features // 2\n\n        self.out_conv = nn.Conv2d(\n            features,\n            out_features,\n            kernel_size=1,\n            stride=1,\n            padding=0,\n            bias=True,\n            groups=1,\n        )\n\n        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)\n        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            res = self.resConfUnit1(xs[1])\n            output = self.skip_add.add(output, res)\n            # output += res\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=self.align_corners\n        )\n\n        output = self.out_conv(output)\n\n        return output\n"
  },
  {
    "path": "mvtracker/models/core/dpt/midas_net.py",
    "content": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is adapted from\nhttps://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py\n\"\"\"\nimport torch\nimport torch.nn as nn\n\nfrom mvtracker.models.core.dpt.base_model import BaseModel\nfrom mvtracker.models.core.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder\n\n\nclass MidasNet_large(BaseModel):\n    \"\"\"Network for monocular depth estimation.\"\"\"\n\n    def __init__(self, path=None, features=256, non_negative=True):\n        \"\"\"Init.\n\n        Args:\n            path (str, optional): Path to saved model. Defaults to None.\n            features (int, optional): Number of features. Defaults to 256.\n            backbone (str, optional): Backbone network for encoder. Defaults to resnet50\n        \"\"\"\n        print(\"Loading weights: \", path)\n\n        super(MidasNet_large, self).__init__()\n\n        use_pretrained = False if path is None else True\n\n        self.pretrained, self.scratch = _make_encoder(\n            backbone=\"resnext101_wsl\", features=features, use_pretrained=use_pretrained\n        )\n\n        self.scratch.refinenet4 = FeatureFusionBlock(features)\n        self.scratch.refinenet3 = FeatureFusionBlock(features)\n        self.scratch.refinenet2 = FeatureFusionBlock(features)\n        self.scratch.refinenet1 = FeatureFusionBlock(features)\n\n        self.scratch.output_conv = nn.Sequential(\n            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),\n            Interpolate(scale_factor=2, mode=\"bilinear\"),\n            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),\n            nn.ReLU(True),\n            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(True) if non_negative else nn.Identity(),\n        )\n\n        if path:\n            self.load(path)\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input data (image)\n\n        Returns:\n            tensor: depth\n        \"\"\"\n\n        layer_1 = self.pretrained.layer1(x)\n        layer_2 = self.pretrained.layer2(layer_1)\n        layer_3 = self.pretrained.layer3(layer_2)\n        layer_4 = self.pretrained.layer4(layer_3)\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        path_4 = self.scratch.refinenet4(layer_4_rn)\n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n\n        out = self.scratch.output_conv(path_1)\n\n        return torch.squeeze(out, dim=1)\n"
  },
  {
    "path": "mvtracker/models/core/dpt/models.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mvtracker.models.core.dpt.base_model import BaseModel\nfrom mvtracker.models.core.dpt.blocks import (\n    FeatureFusionBlock_custom,\n    Interpolate,\n    _make_encoder,\n)\nfrom mvtracker.models.core.dpt.vit import forward_vit\n\n\ndef _make_fusion_block(features, use_bn):\n    return FeatureFusionBlock_custom(\n        features,\n        nn.ReLU(False),\n        deconv=False,\n        bn=use_bn,\n        expand=False,\n        align_corners=True,\n    )\n\n\nclass DPT(BaseModel):\n    def __init__(\n            self,\n            head,\n            features=256,\n            backbone=\"vitb_rn50_384\",\n            readout=\"project\",\n            channels_last=False,\n            use_bn=True,\n            enable_attention_hooks=False,\n    ):\n\n        super(DPT, self).__init__()\n\n        self.channels_last = channels_last\n\n        hooks = {\n            \"vitb_rn50_384\": [0, 1, 8, 11],\n            \"vitb16_384\": [2, 5, 8, 11],\n            \"vitl16_384\": [5, 11, 17, 23],\n            \"vit_tiny_r_s16_p8_384\": [0, 1, 2, 3],\n        }\n\n        # Instantiate backbone and reassemble blocks\n        self.pretrained, self.scratch = _make_encoder(\n            backbone,\n            features,\n            False,  # Set to true of you want to train from scratch, uses ImageNet weights\n            groups=1,\n            expand=False,\n            exportable=False,\n            hooks=hooks[backbone],\n            use_readout=readout,\n            enable_attention_hooks=enable_attention_hooks,\n        )\n\n        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)\n        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)\n\n        self.scratch.output_conv = head\n\n        self.proj_out = nn.Sequential(\n            nn.Conv2d(\n                256 + 512 + 384 + 384,\n                256,\n                kernel_size=3,\n                padding=1,\n                padding_mode=\"zeros\",\n            ),\n            nn.BatchNorm2d(128 * 2),\n            nn.ReLU(True),\n            nn.Conv2d(\n                128 * 2,\n                128,\n                kernel_size=3,\n                padding=1,\n                padding_mode=\"zeros\",\n            )\n        )\n\n    def forward(self, x, only_enc=False):\n        if self.channels_last == True:\n            x.contiguous(memory_format=torch.channels_last)\n        if only_enc:\n            layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)\n            a = (layer_1)\n            b = (\n                F.interpolate(\n                    layer_2,\n                    scale_factor=2,\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n            c = (\n                F.interpolate(\n                    layer_3,\n                    scale_factor=8,\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n            d = (\n                F.interpolate(\n                    layer_4,\n                    scale_factor=16,\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n            x = self.proj_out(torch.cat([a, b, c, d], dim=1))\n            return x\n        else:\n            layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        path_4 = self.scratch.refinenet4(layer_4_rn)\n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n\n        _, _, H_out, W_out = path_1.size()\n        path_2_up = F.interpolate(path_2, size=(H_out, W_out), mode=\"bilinear\", align_corners=True)\n        path_3_up = F.interpolate(path_3, size=(H_out, W_out), mode=\"bilinear\", align_corners=True)\n        path_4_up = F.interpolate(path_4, size=(H_out, W_out), mode=\"bilinear\", align_corners=True)\n\n        out = self.scratch.output_conv(path_1 + path_2_up + path_3_up + path_4_up)\n\n        return out\n\n\nclass DPTDepthModel(DPT):\n    def __init__(\n            self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs\n    ):\n        features = kwargs[\"features\"] if \"features\" in kwargs else 256\n\n        self.scale = scale\n        self.shift = shift\n        self.invert = invert\n\n        head = nn.Sequential(\n            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),\n            Interpolate(scale_factor=2, mode=\"bilinear\", align_corners=True),\n            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),\n            nn.ReLU(True),\n            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(True) if non_negative else nn.Identity(),\n            nn.Identity(),\n        )\n\n        super().__init__(head, **kwargs)\n\n        if path is not None:\n            self.load(path)\n\n    def forward(self, x):\n        inv_depth = super().forward(x).squeeze(dim=1)\n\n        if self.invert:\n            depth = self.scale * inv_depth + self.shift\n            depth[depth < 1e-8] = 1e-8\n            depth = 1.0 / depth\n            return depth\n        else:\n            return inv_depth\n\n\nclass DPTEncoder(DPT):\n    def __init__(\n            self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs\n    ):\n        features = kwargs[\"features\"] if \"features\" in kwargs else 256\n\n        self.scale = scale\n        self.shift = shift\n\n        head = nn.Sequential(\n            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),\n        )\n\n        super().__init__(head, **kwargs)\n\n        if path is not None:\n            self.load(path)\n\n    def forward(self, x):\n        features = super().forward(x, only_enc=True).squeeze(dim=1)\n\n        return features\n\n\nclass DPTSegmentationModel(DPT):\n    def __init__(self, num_classes, path=None, **kwargs):\n        features = kwargs[\"features\"] if \"features\" in kwargs else 256\n\n        kwargs[\"use_bn\"] = True\n\n        head = nn.Sequential(\n            nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),\n            nn.BatchNorm2d(features),\n            nn.ReLU(True),\n            nn.Dropout(0.1, False),\n            nn.Conv2d(features, num_classes, kernel_size=1),\n            Interpolate(scale_factor=2, mode=\"bilinear\", align_corners=True),\n        )\n\n        super().__init__(head, **kwargs)\n\n        self.auxlayer = nn.Sequential(\n            nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),\n            nn.BatchNorm2d(features),\n            nn.ReLU(True),\n            nn.Dropout(0.1, False),\n            nn.Conv2d(features, num_classes, kernel_size=1),\n        )\n\n        if path is not None:\n            self.load(path)\n"
  },
  {
    "path": "mvtracker/models/core/dpt/transforms.py",
    "content": "import cv2\nimport math\nimport numpy as np\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):\n    \"\"\"Rezise the sample to ensure the given size. Keeps aspect ratio.\n\n    Args:\n        sample (dict): sample\n        size (tuple): image size\n\n    Returns:\n        tuple: new size\n    \"\"\"\n    shape = list(sample[\"disparity\"].shape)\n\n    if shape[0] >= size[0] and shape[1] >= size[1]:\n        return sample\n\n    scale = [0, 0]\n    scale[0] = size[0] / shape[0]\n    scale[1] = size[1] / shape[1]\n\n    scale = max(scale)\n\n    shape[0] = math.ceil(scale * shape[0])\n    shape[1] = math.ceil(scale * shape[1])\n\n    # resize\n    sample[\"image\"] = cv2.resize(\n        sample[\"image\"], tuple(shape[::-1]), interpolation=image_interpolation_method\n    )\n\n    sample[\"disparity\"] = cv2.resize(\n        sample[\"disparity\"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST\n    )\n    sample[\"mask\"] = cv2.resize(\n        sample[\"mask\"].astype(np.float32),\n        tuple(shape[::-1]),\n        interpolation=cv2.INTER_NEAREST,\n    )\n    sample[\"mask\"] = sample[\"mask\"].astype(bool)\n\n    return tuple(shape)\n\n\nclass Resize(object):\n    \"\"\"Resize sample to given size (width, height).\"\"\"\n\n    def __init__(\n            self,\n            width,\n            height,\n            resize_target=True,\n            keep_aspect_ratio=False,\n            ensure_multiple_of=1,\n            resize_method=\"lower_bound\",\n            image_interpolation_method=cv2.INTER_AREA,\n    ):\n        \"\"\"Init.\n\n        Args:\n            width (int): desired output width\n            height (int): desired output height\n            resize_target (bool, optional):\n                True: Resize the full sample (image, mask, target).\n                False: Resize image only.\n                Defaults to True.\n            keep_aspect_ratio (bool, optional):\n                True: Keep the aspect ratio of the input sample.\n                Output sample might not have the given width and height, and\n                resize behaviour depends on the parameter 'resize_method'.\n                Defaults to False.\n            ensure_multiple_of (int, optional):\n                Output width and height is constrained to be multiple of this parameter.\n                Defaults to 1.\n            resize_method (str, optional):\n                \"lower_bound\": Output will be at least as large as the given size.\n                \"upper_bound\": Output will be at max as large as the given size. (Output size might be smaller than given size.)\n                \"minimal\": Scale as least as possible.  (Output size might be smaller than given size.)\n                Defaults to \"lower_bound\".\n        \"\"\"\n        self.__width = width\n        self.__height = height\n\n        self.__resize_target = resize_target\n        self.__keep_aspect_ratio = keep_aspect_ratio\n        self.__multiple_of = ensure_multiple_of\n        self.__resize_method = resize_method\n        self.__image_interpolation_method = image_interpolation_method\n\n    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):\n        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if max_val is not None and y > max_val:\n            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if y < min_val:\n            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        return y\n\n    def get_size(self, width, height):\n        # determine new height and width\n        scale_height = self.__height / height\n        scale_width = self.__width / width\n\n        if self.__keep_aspect_ratio:\n            if self.__resize_method == \"lower_bound\":\n                # scale such that output size is lower bound\n                if scale_width > scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"upper_bound\":\n                # scale such that output size is upper bound\n                if scale_width < scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"minimal\":\n                # scale as least as possbile\n                if abs(1 - scale_width) < abs(1 - scale_height):\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            else:\n                raise ValueError(\n                    f\"resize_method {self.__resize_method} not implemented\"\n                )\n\n        if self.__resize_method == \"lower_bound\":\n            new_height = self.constrain_to_multiple_of(\n                scale_height * height, min_val=self.__height\n            )\n            new_width = self.constrain_to_multiple_of(\n                scale_width * width, min_val=self.__width\n            )\n        elif self.__resize_method == \"upper_bound\":\n            new_height = self.constrain_to_multiple_of(\n                scale_height * height, max_val=self.__height\n            )\n            new_width = self.constrain_to_multiple_of(\n                scale_width * width, max_val=self.__width\n            )\n        elif self.__resize_method == \"minimal\":\n            new_height = self.constrain_to_multiple_of(scale_height * height)\n            new_width = self.constrain_to_multiple_of(scale_width * width)\n        else:\n            raise ValueError(f\"resize_method {self.__resize_method} not implemented\")\n\n        return (new_width, new_height)\n\n    def __call__(self, sample):\n        width, height = self.get_size(\n            sample[\"image\"].shape[1], sample[\"image\"].shape[0]\n        )\n\n        # resize sample\n        sample[\"image\"] = cv2.resize(\n            sample[\"image\"],\n            (width, height),\n            interpolation=self.__image_interpolation_method,\n        )\n\n        if self.__resize_target:\n            if \"disparity\" in sample:\n                sample[\"disparity\"] = cv2.resize(\n                    sample[\"disparity\"],\n                    (width, height),\n                    interpolation=cv2.INTER_NEAREST,\n                )\n\n            if \"depth\" in sample:\n                sample[\"depth\"] = cv2.resize(\n                    sample[\"depth\"], (width, height), interpolation=cv2.INTER_NEAREST\n                )\n\n            sample[\"mask\"] = cv2.resize(\n                sample[\"mask\"].astype(np.float32),\n                (width, height),\n                interpolation=cv2.INTER_NEAREST,\n            )\n            sample[\"mask\"] = sample[\"mask\"].astype(bool)\n\n        return sample\n\n\nclass NormalizeImage(object):\n    \"\"\"Normlize image by given mean and std.\"\"\"\n\n    def __init__(self, mean, std):\n        self.__mean = mean\n        self.__std = std\n\n    def __call__(self, sample):\n        sample[\"image\"] = (sample[\"image\"] - self.__mean) / self.__std\n\n        return sample\n\n\nclass PrepareForNet(object):\n    \"\"\"Prepare sample for usage as network input.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, sample):\n        image = np.transpose(sample[\"image\"], (2, 0, 1))\n        sample[\"image\"] = np.ascontiguousarray(image).astype(np.float32)\n\n        if \"mask\" in sample:\n            sample[\"mask\"] = sample[\"mask\"].astype(np.float32)\n            sample[\"mask\"] = np.ascontiguousarray(sample[\"mask\"])\n\n        if \"disparity\" in sample:\n            disparity = sample[\"disparity\"].astype(np.float32)\n            sample[\"disparity\"] = np.ascontiguousarray(disparity)\n\n        if \"depth\" in sample:\n            depth = sample[\"depth\"].astype(np.float32)\n            sample[\"depth\"] = np.ascontiguousarray(depth)\n\n        return sample\n"
  },
  {
    "path": "mvtracker/models/core/dpt/vit.py",
    "content": "import types\n\nimport math\nimport timm\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nactivations = {}\n\n\ndef get_activation(name):\n    def hook(model, input, output):\n        activations[name] = output\n\n    return hook\n\n\nattention = {}\n\n\ndef get_attention(name):\n    def hook(module, input, output):\n        x = input[0]\n        B, N, C = x.shape\n        qkv = (\n            module.qkv(x)\n            .reshape(B, N, 3, module.num_heads, C // module.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        q, k, v = (\n            qkv[0],\n            qkv[1],\n            qkv[2],\n        )  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * module.scale\n\n        attn = attn.softmax(dim=-1)  # [:,:,1,1:]\n        attention[name] = attn\n\n    return hook\n\n\ndef get_mean_attention_map(attn, token, shape):\n    attn = attn[:, :, token, 1:]\n    attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()\n    attn = torch.nn.functional.interpolate(\n        attn, size=shape[2:], mode=\"bicubic\", align_corners=False\n    ).squeeze(0)\n\n    all_attn = torch.mean(attn, 0)\n\n    return all_attn\n\n\nclass Slice(nn.Module):\n    def __init__(self, start_index=1):\n        super(Slice, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        return x[:, self.start_index:]\n\n\nclass AddReadout(nn.Module):\n    def __init__(self, start_index=1):\n        super(AddReadout, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        if self.start_index == 2:\n            readout = (x[:, 0] + x[:, 1]) / 2\n        else:\n            readout = x[:, 0]\n        return x[:, self.start_index:] + readout.unsqueeze(1)\n\n\nclass ProjectReadout(nn.Module):\n    def __init__(self, in_features, start_index=1):\n        super(ProjectReadout, self).__init__()\n        self.start_index = start_index\n\n        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())\n\n    def forward(self, x):\n        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])\n        features = torch.cat((x[:, self.start_index:], readout), -1)\n\n        return self.project(features)\n\n\nclass Transpose(nn.Module):\n    def __init__(self, dim0, dim1):\n        super(Transpose, self).__init__()\n        self.dim0 = dim0\n        self.dim1 = dim1\n\n    def forward(self, x):\n        x = x.transpose(self.dim0, self.dim1)\n        return x\n\n\ndef forward_vit(pretrained, x):\n    b, c, h, w = x.shape\n\n    glob = pretrained.model.forward_flex(x)\n\n    layer_1 = pretrained.activations[\"1\"]\n    layer_2 = pretrained.activations[\"2\"]\n    layer_3 = pretrained.activations[\"3\"]\n    layer_4 = pretrained.activations[\"4\"]\n\n    layer_1 = pretrained.act_postprocess1[0:2](layer_1)\n    layer_2 = pretrained.act_postprocess2[0:2](layer_2)\n    layer_3 = pretrained.act_postprocess3[0:2](layer_3)\n    layer_4 = pretrained.act_postprocess4[0:2](layer_4)\n\n    unflatten = nn.Sequential(\n        nn.Unflatten(\n            2,\n            torch.Size(\n                [\n                    h // pretrained.model.patch_size[1],\n                    w // pretrained.model.patch_size[0],\n                ]\n            ),\n        )\n    )\n\n    if layer_1.ndim == 3:\n        layer_1 = unflatten(layer_1)\n    if layer_2.ndim == 3:\n        layer_2 = unflatten(layer_2)\n    if layer_3.ndim == 3:\n        layer_3 = unflatten(layer_3)\n    if layer_4.ndim == 3:\n        layer_4 = unflatten(layer_4)\n\n    layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)\n    layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)\n    layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)\n    layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)\n\n    return layer_1, layer_2, layer_3, layer_4\n\n\ndef _resize_pos_embed(self, posemb, gs_h, gs_w):\n    posemb_tok, posemb_grid = (\n        posemb[:, : self.start_index],\n        posemb[0, self.start_index:],\n    )\n\n    gs_old = int(math.sqrt(len(posemb_grid)))\n\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode=\"bilinear\")\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)\n\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n\n    return posemb\n\n\ndef forward_flex(self, x):\n    b, c, h, w = x.shape\n\n    pos_embed = self._resize_pos_embed(\n        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]\n    )\n\n    B = x.shape[0]\n\n    if hasattr(self.patch_embed, \"backbone\"):\n        x = self.patch_embed.backbone(x)\n        if isinstance(x, (list, tuple)):\n            x = x[-1]  # last feature if backbone outputs list/tuple of features\n    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)\n\n    if getattr(self, \"dist_token\", None) is not None:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        dist_token = self.dist_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, dist_token, x), dim=1)\n    else:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n\n    x = x + pos_embed\n    x = self.pos_drop(x)\n\n    for blk in self.blocks:\n        x = blk(x)\n\n    x = self.norm(x)\n\n    return x\n\n\ndef get_readout_oper(vit_features, features, use_readout, start_index=1):\n    if use_readout == \"ignore\":\n        readout_oper = [Slice(start_index)] * len(features)\n    elif use_readout == \"add\":\n        readout_oper = [AddReadout(start_index)] * len(features)\n    elif use_readout == \"project\":\n        readout_oper = [\n            ProjectReadout(vit_features, start_index) for out_feat in features\n        ]\n    else:\n        assert (\n            False\n        ), \"wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'\"\n\n    return readout_oper\n\n\ndef _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        size=[384, 384],\n        hooks=[2, 5, 8, 11],\n        vit_features=768,\n        use_readout=\"ignore\",\n        start_index=1,\n        enable_attention_hooks=False,\n):\n    pretrained = nn.Module()\n\n    pretrained.model = model\n    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    pretrained.activations = activations\n\n    if enable_attention_hooks:\n        pretrained.model.blocks[hooks[0]].attn.register_forward_hook(\n            get_attention(\"attn_1\")\n        )\n        pretrained.model.blocks[hooks[1]].attn.register_forward_hook(\n            get_attention(\"attn_2\")\n        )\n        pretrained.model.blocks[hooks[2]].attn.register_forward_hook(\n            get_attention(\"attn_3\")\n        )\n        pretrained.model.blocks[hooks[3]].attn.register_forward_hook(\n            get_attention(\"attn_4\")\n        )\n        pretrained.attention = attention\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    # 32, 48, 136, 384\n    pretrained.act_postprocess1 = nn.Sequential(\n        readout_oper[0],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[0],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[0],\n            out_channels=features[0],\n            kernel_size=4,\n            stride=4,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess2 = nn.Sequential(\n        readout_oper[1],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[1],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[1],\n            out_channels=features[1],\n            kernel_size=2,\n            stride=2,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [16, 16]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_vit_b_rn50_backbone(\n        model,\n        features=[256, 512, 768, 768],\n        size=[384, 384],\n        hooks=[0, 1, 8, 11],\n        vit_features=384,\n        use_vit_only=False,\n        use_readout=\"ignore\",\n        start_index=1,\n        enable_attention_hooks=False,\n):\n    pretrained = nn.Module()\n    pretrained.model = model\n    pretrained.model.patch_size = [32, 32]\n    ps = pretrained.model.patch_size[0]\n    if use_vit_only == True:\n        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    else:\n        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(\n            get_activation(\"1\")\n        )\n        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(\n            get_activation(\"2\")\n        )\n\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    if enable_attention_hooks:\n        pretrained.model.blocks[2].attn.register_forward_hook(get_attention(\"attn_1\"))\n        pretrained.model.blocks[5].attn.register_forward_hook(get_attention(\"attn_2\"))\n        pretrained.model.blocks[8].attn.register_forward_hook(get_attention(\"attn_3\"))\n        pretrained.model.blocks[11].attn.register_forward_hook(get_attention(\"attn_4\"))\n        pretrained.attention = attention\n\n    pretrained.activations = activations\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    if use_vit_only == True:\n        pretrained.act_postprocess1 = nn.Sequential(\n            readout_oper[0],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[0],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[0],\n                out_channels=features[0],\n                kernel_size=4,\n                stride=4,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n\n        pretrained.act_postprocess2 = nn.Sequential(\n            readout_oper[1],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[1],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[1],\n                out_channels=features[1],\n                kernel_size=2,\n                stride=2,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n    else:\n        pretrained.act_postprocess1 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n        pretrained.act_postprocess2 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // ps, size[1] // ps])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [32, 32]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_pretrained_vitb_rn50_384(\n        pretrained,\n        use_readout=\"ignore\",\n        hooks=None,\n        use_vit_only=False,\n        enable_attention_hooks=False,\n):\n    # model = timm.create_model(\"vit_base_resnet50_384\", pretrained=pretrained)\n    # model = timm.create_model(\"vit_tiny_r_s16_p8_384\", pretrained=pretrained)\n    model = timm.create_model(\"vit_small_r26_s32_384\", pretrained=pretrained)\n    hooks = [0, 1, 8, 11] if hooks == None else hooks\n    return _make_vit_b_rn50_backbone(\n        model,\n        features=[128, 256, 384, 384],\n        size=[384, 384],\n        hooks=hooks,\n        use_vit_only=use_vit_only,\n        use_readout=use_readout,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n\n\ndef _make_pretrained_vit_tiny(\n        pretrained,\n        use_readout=\"ignore\",\n        hooks=None,\n        use_vit_only=False,\n        enable_attention_hooks=False,\n):\n    # model = timm.create_model(\"vit_base_resnet50_384\", pretrained=pretrained)\n    model = timm.create_model(\"vit_tiny_r_s16_p8_384\", pretrained=pretrained)\n    import ipdb;\n    ipdb.set_trace()\n    hooks = [0, 1, 8, 11] if hooks == None else hooks\n    return _make_vit_tiny_backbone(\n        model,\n        features=[256, 512, 768, 768],\n        size=[384, 384],\n        hooks=hooks,\n        use_vit_only=use_vit_only,\n        use_readout=use_readout,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n\n\ndef _make_pretrained_vitl16_384(\n        pretrained, use_readout=\"ignore\", hooks=None, enable_attention_hooks=False\n):\n    model = timm.create_model(\"vit_large_patch16_384\", pretrained=pretrained)\n\n    hooks = [5, 11, 17, 23] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[256, 512, 1024, 1024],\n        hooks=hooks,\n        vit_features=1024,\n        use_readout=use_readout,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n\n\ndef _make_pretrained_vitb16_384(\n        pretrained, use_readout=\"ignore\", hooks=None, enable_attention_hooks=False\n):\n    model = timm.create_model(\"vit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        hooks=hooks,\n        use_readout=use_readout,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n\n\ndef _make_pretrained_deitb16_384(\n        pretrained, use_readout=\"ignore\", hooks=None, enable_attention_hooks=False\n):\n    model = timm.create_model(\"vit_deit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        hooks=hooks,\n        use_readout=use_readout,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n\n\ndef _make_pretrained_deitb16_distil_384(\n        pretrained, use_readout=\"ignore\", hooks=None, enable_attention_hooks=False\n):\n    model = timm.create_model(\n        \"vit_deit_base_distilled_patch16_384\", pretrained=pretrained\n    )\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        hooks=hooks,\n        use_readout=use_readout,\n        start_index=2,\n        enable_attention_hooks=enable_attention_hooks,\n    )\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/LICENSE.md",
    "content": "## Notes on license:\n\nThe code in this repository (except in external.py) is licensed under the MIT licence.\n\nHowever, for this code to run it uses the cuda rasterizer code\nfrom [here](https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth),\nas well as various code in [external.py](./external.py) which has been taken or adapted\nfrom [here](https://github.com/graphdeco-inria/gaussian-splatting).\nThese are required for this project, and for these a much more restrictive license from Inria applies which can be\nfound [here](https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md).\nThis requires express permission (licensing agreements) from Inria for use in any commercial application, but is\notherwise freely freely distributed for research and experimentation.\n\nMIT License for the code in this repository where it applies (see above) is below:\n\n## License:\n\nCopyright (c) 2023 Jonathon Luiten\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated\ndocumentation files (the “Software”), to deal in the Software without restriction, including without limitation the\nrights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit\npersons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the\nSoftware.\n\nTHE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE\nWARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR\nCOPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR\nOTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/colormap.py",
    "content": "import numpy as np\n\ncolormap = np.array([\n    # 0     ,         0,         0,\n    0.5020, 0, 0,\n    0, 0.5020, 0,\n    0.5020, 0.5020, 0,\n    0, 0, 0.5020,\n    0.5020, 0, 0.5020,\n    0, 0.5020, 0.5020,\n    # 0.5020,    0.5020,    0.5020,\n    0.2510, 0, 0,\n    0.7529, 0, 0,\n    0.2510, 0.5020, 0,\n    0.7529, 0.5020, 0,\n    0.2510, 0, 0.5020,\n    0.7529, 0, 0.5020,\n    0.2510, 0.5020, 0.5020,\n    0.7529, 0.5020, 0.5020,\n    0, 0.2510, 0,\n    0.5020, 0.2510, 0,\n    0, 0.7529, 0,\n    0.5020, 0.7529, 0,\n    0, 0.2510, 0.5020,\n    0.5020, 0.2510, 0.5020,\n    0, 0.7529, 0.5020,\n    0.5020, 0.7529, 0.5020,\n    0.2510, 0.2510, 0,\n    0.7529, 0.2510, 0,\n    0.2510, 0.7529, 0,\n    0.7529, 0.7529, 0,\n    0.2510, 0.2510, 0.5020,\n    0.7529, 0.2510, 0.5020,\n    0.2510, 0.7529, 0.5020,\n    0.7529, 0.7529, 0.5020,\n    0, 0, 0.2510,\n    0.5020, 0, 0.2510,\n    0, 0.5020, 0.2510,\n    0.5020, 0.5020, 0.2510,\n    0, 0, 0.7529,\n    0.5020, 0, 0.7529,\n    0, 0.5020, 0.7529,\n    0.5020, 0.5020, 0.7529,\n    0.2510, 0, 0.2510,\n    0.7529, 0, 0.2510,\n    0.2510, 0.5020, 0.2510,\n    0.7529, 0.5020, 0.2510,\n    0.2510, 0, 0.7529,\n    0.7529, 0, 0.7529,\n    0.2510, 0.5020, 0.7529,\n    0.7529, 0.5020, 0.7529,\n    0, 0.2510, 0.2510,\n    0.5020, 0.2510, 0.2510,\n    0, 0.7529, 0.2510,\n    0.5020, 0.7529, 0.2510,\n    0, 0.2510, 0.7529,\n    0.5020, 0.2510, 0.7529,\n    0, 0.7529, 0.7529,\n    0.5020, 0.7529, 0.7529,\n    # 0.2510,    0.2510,    0.2510,\n    0.7529, 0.2510, 0.2510,\n    0.2510, 0.7529, 0.2510,\n    0.7529, 0.7529, 0.2510,\n    0.2510, 0.2510, 0.7529,\n    0.7529, 0.2510, 0.7529,\n    0.2510, 0.7529, 0.7529,\n    # 0.7529,    0.7529,    0.7529,\n    0.1255, 0, 0,\n    0.6275, 0, 0,\n    0.1255, 0.5020, 0,\n    0.6275, 0.5020, 0,\n    0.1255, 0, 0.5020,\n    0.6275, 0, 0.5020,\n    0.1255, 0.5020, 0.5020,\n    0.6275, 0.5020, 0.5020,\n    0.3765, 0, 0,\n    0.8784, 0, 0,\n    0.3765, 0.5020, 0,\n    0.8784, 0.5020, 0,\n    0.3765, 0, 0.5020,\n    0.8784, 0, 0.5020,\n    0.3765, 0.5020, 0.5020,\n    0.8784, 0.5020, 0.5020,\n    0.1255, 0.2510, 0,\n    0.6275, 0.2510, 0,\n    0.1255, 0.7529, 0,\n    0.6275, 0.7529, 0,\n    0.1255, 0.2510, 0.5020,\n    0.6275, 0.2510, 0.5020,\n    0.1255, 0.7529, 0.5020,\n    0.6275, 0.7529, 0.5020,\n    0.3765, 0.2510, 0,\n    0.8784, 0.2510, 0,\n    0.3765, 0.7529, 0,\n    0.8784, 0.7529, 0,\n    0.3765, 0.2510, 0.5020,\n    0.8784, 0.2510, 0.5020,\n    0.3765, 0.7529, 0.5020,\n    0.8784, 0.7529, 0.5020,\n    0.1255, 0, 0.2510,\n    0.6275, 0, 0.2510,\n    0.1255, 0.5020, 0.2510,\n    0.6275, 0.5020, 0.2510,\n    0.1255, 0, 0.7529,\n    0.6275, 0, 0.7529,\n    0.1255, 0.5020, 0.7529,\n    0.6275, 0.5020, 0.7529,\n    0.3765, 0, 0.2510,\n    0.8784, 0, 0.2510,\n    0.3765, 0.5020, 0.2510,\n    0.8784, 0.5020, 0.2510,\n    0.3765, 0, 0.7529,\n    0.8784, 0, 0.7529,\n    0.3765, 0.5020, 0.7529,\n    0.8784, 0.5020, 0.7529,\n    0.1255, 0.2510, 0.2510,\n    0.6275, 0.2510, 0.2510,\n    0.1255, 0.7529, 0.2510,\n    0.6275, 0.7529, 0.2510,\n    0.1255, 0.2510, 0.7529,\n    0.6275, 0.2510, 0.7529,\n    0.1255, 0.7529, 0.7529,\n    0.6275, 0.7529, 0.7529,\n    0.3765, 0.2510, 0.2510,\n    0.8784, 0.2510, 0.2510,\n    0.3765, 0.7529, 0.2510,\n    0.8784, 0.7529, 0.2510,\n    0.3765, 0.2510, 0.7529,\n    0.8784, 0.2510, 0.7529,\n    0.3765, 0.7529, 0.7529,\n    0.8784, 0.7529, 0.7529,\n    0, 0.1255, 0,\n    0.5020, 0.1255, 0,\n    0, 0.6275, 0,\n    0.5020, 0.6275, 0,\n    0, 0.1255, 0.5020,\n    0.5020, 0.1255, 0.5020,\n    0, 0.6275, 0.5020,\n    0.5020, 0.6275, 0.5020,\n    0.2510, 0.1255, 0,\n    0.7529, 0.1255, 0,\n    0.2510, 0.6275, 0,\n    0.7529, 0.6275, 0,\n    0.2510, 0.1255, 0.5020,\n    0.7529, 0.1255, 0.5020,\n    0.2510, 0.6275, 0.5020,\n    0.7529, 0.6275, 0.5020,\n    0, 0.3765, 0,\n    0.5020, 0.3765, 0,\n    0, 0.8784, 0,\n    0.5020, 0.8784, 0,\n    0, 0.3765, 0.5020,\n    0.5020, 0.3765, 0.5020,\n    0, 0.8784, 0.5020,\n    0.5020, 0.8784, 0.5020,\n    0.2510, 0.3765, 0,\n    0.7529, 0.3765, 0,\n    0.2510, 0.8784, 0,\n    0.7529, 0.8784, 0,\n    0.2510, 0.3765, 0.5020,\n    0.7529, 0.3765, 0.5020,\n    0.2510, 0.8784, 0.5020,\n    0.7529, 0.8784, 0.5020,\n    0, 0.1255, 0.2510,\n    0.5020, 0.1255, 0.2510,\n    0, 0.6275, 0.2510,\n    0.5020, 0.6275, 0.2510,\n    0, 0.1255, 0.7529,\n    0.5020, 0.1255, 0.7529,\n    0, 0.6275, 0.7529,\n    0.5020, 0.6275, 0.7529,\n    0.2510, 0.1255, 0.2510,\n    0.7529, 0.1255, 0.2510,\n    0.2510, 0.6275, 0.2510,\n    0.7529, 0.6275, 0.2510,\n    0.2510, 0.1255, 0.7529,\n    0.7529, 0.1255, 0.7529,\n    0.2510, 0.6275, 0.7529,\n    0.7529, 0.6275, 0.7529,\n    0, 0.3765, 0.2510,\n    0.5020, 0.3765, 0.2510,\n    0, 0.8784, 0.2510,\n    0.5020, 0.8784, 0.2510,\n    0, 0.3765, 0.7529,\n    0.5020, 0.3765, 0.7529,\n    0, 0.8784, 0.7529,\n    0.5020, 0.8784, 0.7529,\n    0.2510, 0.3765, 0.2510,\n    0.7529, 0.3765, 0.2510,\n    0.2510, 0.8784, 0.2510,\n    0.7529, 0.8784, 0.2510,\n    0.2510, 0.3765, 0.7529,\n    0.7529, 0.3765, 0.7529,\n    0.2510, 0.8784, 0.7529,\n    0.7529, 0.8784, 0.7529,\n    0.1255, 0.1255, 0,\n    0.6275, 0.1255, 0,\n    0.1255, 0.6275, 0,\n    0.6275, 0.6275, 0,\n    0.1255, 0.1255, 0.5020,\n    0.6275, 0.1255, 0.5020,\n    0.1255, 0.6275, 0.5020,\n    0.6275, 0.6275, 0.5020,\n    0.3765, 0.1255, 0,\n    0.8784, 0.1255, 0,\n    0.3765, 0.6275, 0,\n    0.8784, 0.6275, 0,\n    0.3765, 0.1255, 0.5020,\n    0.8784, 0.1255, 0.5020,\n    0.3765, 0.6275, 0.5020,\n    0.8784, 0.6275, 0.5020,\n    0.1255, 0.3765, 0,\n    0.6275, 0.3765, 0,\n    0.1255, 0.8784, 0,\n    0.6275, 0.8784, 0,\n    0.1255, 0.3765, 0.5020,\n    0.6275, 0.3765, 0.5020,\n    0.1255, 0.8784, 0.5020,\n    0.6275, 0.8784, 0.5020,\n    0.3765, 0.3765, 0,\n    0.8784, 0.3765, 0,\n    0.3765, 0.8784, 0,\n    0.8784, 0.8784, 0,\n    0.3765, 0.3765, 0.5020,\n    0.8784, 0.3765, 0.5020,\n    0.3765, 0.8784, 0.5020,\n    0.8784, 0.8784, 0.5020,\n    0.1255, 0.1255, 0.2510,\n    0.6275, 0.1255, 0.2510,\n    0.1255, 0.6275, 0.2510,\n    0.6275, 0.6275, 0.2510,\n    0.1255, 0.1255, 0.7529,\n    0.6275, 0.1255, 0.7529,\n    0.1255, 0.6275, 0.7529,\n    0.6275, 0.6275, 0.7529,\n    0.3765, 0.1255, 0.2510,\n    0.8784, 0.1255, 0.2510,\n    0.3765, 0.6275, 0.2510,\n    0.8784, 0.6275, 0.2510,\n    0.3765, 0.1255, 0.7529,\n    0.8784, 0.1255, 0.7529,\n    0.3765, 0.6275, 0.7529,\n    0.8784, 0.6275, 0.7529,\n    0.1255, 0.3765, 0.2510,\n    0.6275, 0.3765, 0.2510,\n    0.1255, 0.8784, 0.2510,\n    0.6275, 0.8784, 0.2510,\n    0.1255, 0.3765, 0.7529,\n    0.6275, 0.3765, 0.7529,\n    0.1255, 0.8784, 0.7529,\n    0.6275, 0.8784, 0.7529,\n    0.3765, 0.3765, 0.2510,\n    0.8784, 0.3765, 0.2510,\n    0.3765, 0.8784, 0.2510,\n    0.8784, 0.8784, 0.2510,\n    0.3765, 0.3765, 0.7529,\n    0.8784, 0.3765, 0.7529,\n    0.3765, 0.8784, 0.7529,\n    0.8784, 0.8784, 0.7529,\n    # 1.0,       1.0,       1.0,\n]).reshape(-1, 3)\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/export_depths_from_pretrained_checkpoint.py",
    "content": "import json\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\nfrom tqdm import tqdm\n\nfrom .helpers import setup_camera\n\n\ndef load_scene_data(params_path, seg_as_col=False):\n    \"\"\"Load 3D scene data from file.\"\"\"\n    params = dict(np.load(params_path, allow_pickle=True))\n    params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}\n    is_fg = params['seg_colors'][:, 0] > 0.5\n    scene_data = []\n    for t in range(len(params['means3D'])):\n        rendervar = {\n            'means3D': params['means3D'][t],\n            'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'],\n            'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]),\n            'opacities': torch.sigmoid(params['logit_opacities']),\n            'scales': torch.exp(params['log_scales']),\n            'means2D': torch.zeros_like(params['means3D'][0], device=\"cuda\")\n        }\n        scene_data.append(rendervar)\n    return scene_data, is_fg\n\n\ndef render(w, h, k, w2c, timestep_data, near=0.01, far=100.0):\n    \"\"\"Render scene using Gaussian Rasterization.\"\"\"\n    with torch.no_grad():\n        cam = setup_camera(w, h, k, w2c, near, far)\n        im, _, depth = Renderer(raster_settings=cam)(**timestep_data)\n        return im, depth\n\n\ndef export_depth(scene_root, output_root, checkpoint_path):\n    scene_data, is_fg = load_scene_data(os.path.join(checkpoint_path, \"params.npz\"))\n    md_train = json.load(open(os.path.join(scene_root, \"train_meta.json\"), \"r\"))\n    md_test = json.load(open(os.path.join(scene_root, \"test_meta.json\"), \"r\"))\n\n    views = sorted(list(set(md_train[\"cam_id\"][0]) | set(md_test[\"cam_id\"][0])))\n    assert list(range(31)) == views, \"We expect exactly 31 views: from 0 to 30.\"\n    n_frames = len(md_train['fn'])\n    n_views = len(views)\n\n    # Check that the selected views are in the training set\n    view_paths = []\n    for view_idx in views:\n        view_path = scene_root / \"ims\" / f\"{view_idx}\"\n        assert view_path.exists()\n        view_paths.append(view_path)\n    frame_paths = [sorted(view_path.glob(\"*.jpg\")) for view_path in view_paths]\n    assert all(len(frame_paths[v]) == n_frames for v in range(n_views))\n    assert len(scene_data) == n_frames\n\n    # Load the camera parameters\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_idx in views:\n        fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], []\n        for t in range(n_frames):\n            if view_idx in md_train['cam_id'][t]:\n                md = md_train\n            elif view_idx in md_test['cam_id'][t]:\n                md = md_test\n            else:\n                raise ValueError(f\"Camera {view_idx} not found in any of the meta files\")\n\n            view_idx_in_array = md['cam_id'][t].index(view_idx)\n            k = md['k'][t][view_idx_in_array]\n            w2c = np.array(md['w2c'][t][view_idx_in_array])\n\n            fx_current.append(k[0][0])\n            fy_current.append(k[1][1])\n            cx_current.append(k[0][2])\n            cy_current.append(k[1][2])\n            extrinsics_current.append(w2c)\n\n        assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames))\n\n        fx.append(fx_current[0])\n        fy.append(fy_current[0])\n        cx.append(cx_current[0])\n        cy.append(cy_current[0])\n        extrinsics.append(extrinsics_current[0])\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    k = torch.eye(3).float()[None].repeat(n_views, 1, 1)\n    k[:, 0, 0] = fx\n    k[:, 1, 1] = fy\n    k[:, 0, 2] = cx\n    k[:, 1, 2] = cy\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n\n    # Render and save the depths\n    os.makedirs(output_root, exist_ok=True)\n    rgbs = np.stack([\n        np.stack([\n            np.array(Image.open(frame_paths[v][t]))\n            for t in range(n_frames)\n        ])\n        for v in range(n_views)\n    ])\n    h, w = rgbs.shape[2], rgbs.shape[3]\n    for v, view_idx in enumerate(views):\n        depths = []\n        for t in range(n_frames):\n            im, depth = render(w, h, k[v].numpy(), extrinsics[v].numpy(), scene_data[t])\n            depths.append(depth.cpu().numpy()[0])\n        depths = np.stack(depths)\n        np.save(output_root / f\"depths_{view_idx:02d}.npy\", depths)\n\n\nif __name__ == \"__main__\":\n    print(\"Exporting depths from pretrained checkpoints\")\n    for sequence_name in tqdm([\"basketball\", \"boxes\", \"football\", \"juggle\", \"softball\", \"tennis\"]):\n        scene_root = Path(f\"./datasets/panoptic_d3dgs/{sequence_name}\")\n        output_path = Path(f\"./datasets/panoptic_d3dgs/{sequence_name}/dynamic3dgs_depth\")\n        checkpoint_path = Path(f\"./dynamic3dgs/output/pretrained/{sequence_name}\")\n        export_depth(scene_root, output_path, checkpoint_path)\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/external.py",
    "content": "\"\"\"\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file found here:\n# https://github.com/graphdeco-inria/gaussian-splatting/blob/main/LICENSE.md\n#\n# For inquiries contact  george.drettakis@inria.fr\n\n#######################################################################################################################\n##### NOTE: CODE IN THIS FILE IS NOT INCLUDED IN THE OVERALL PROJECT'S MIT LICENSE #####\n##### USE OF THIS CODE FOLLOWS THE COPYRIGHT NOTICE ABOVE #####\n#######################################################################################################################\n\"\"\"\n\nimport torch\nimport torch.nn.functional as func\nfrom math import exp\nfrom torch.autograd import Variable\n\n\ndef build_rotation(q):\n    norm = torch.sqrt(q[:, 0] * q[:, 0] + q[:, 1] * q[:, 1] + q[:, 2] * q[:, 2] + q[:, 3] * q[:, 3])\n    q = q / norm[:, None]\n    rot = torch.zeros((q.size(0), 3, 3), device='cuda')\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n    rot[:, 0, 0] = 1 - 2 * (y * y + z * z)\n    rot[:, 0, 1] = 2 * (x * y - r * z)\n    rot[:, 0, 2] = 2 * (x * z + r * y)\n    rot[:, 1, 0] = 2 * (x * y + r * z)\n    rot[:, 1, 1] = 1 - 2 * (x * x + z * z)\n    rot[:, 1, 2] = 2 * (y * z - r * x)\n    rot[:, 2, 0] = 2 * (x * z - r * y)\n    rot[:, 2, 1] = 2 * (y * z + r * x)\n    rot[:, 2, 2] = 1 - 2 * (x * x + y * y)\n    return rot\n\n\ndef calc_mse(img1, img2):\n    return ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n\n\ndef calc_psnr(img1, img2):\n    mse = ((img1 - img2) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\n\ndef calc_ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = func.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = func.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = func.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = func.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = func.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    c1 = 0.01 ** 2\n    c2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n\ndef accumulate_mean2d_gradient(variables):\n    variables['means2D_gradient_accum'][variables['seen']] += torch.norm(\n        variables['means2D'].grad[variables['seen'], :2], dim=-1)\n    variables['denom'][variables['seen']] += 1\n    return variables\n\n\ndef update_params_and_optimizer(new_params, params, optimizer):\n    for k, v in new_params.items():\n        group = [x for x in optimizer.param_groups if x[\"name\"] == k][0]\n        stored_state = optimizer.state.get(group['params'][0], None)\n\n        stored_state[\"exp_avg\"] = torch.zeros_like(v)\n        stored_state[\"exp_avg_sq\"] = torch.zeros_like(v)\n        del optimizer.state[group['params'][0]]\n\n        group[\"params\"][0] = torch.nn.Parameter(v.requires_grad_(True))\n        optimizer.state[group['params'][0]] = stored_state\n        params[k] = group[\"params\"][0]\n    return params\n\n\ndef cat_params_to_optimizer(new_params, params, optimizer):\n    for k, v in new_params.items():\n        group = [g for g in optimizer.param_groups if g['name'] == k][0]\n        stored_state = optimizer.state.get(group['params'][0], None)\n        if stored_state is not None:\n            stored_state[\"exp_avg\"] = torch.cat((stored_state[\"exp_avg\"], torch.zeros_like(v)), dim=0)\n            stored_state[\"exp_avg_sq\"] = torch.cat((stored_state[\"exp_avg_sq\"], torch.zeros_like(v)), dim=0)\n            del optimizer.state[group['params'][0]]\n            group[\"params\"][0] = torch.nn.Parameter(torch.cat((group[\"params\"][0], v), dim=0).requires_grad_(True))\n            optimizer.state[group['params'][0]] = stored_state\n            params[k] = group[\"params\"][0]\n        else:\n            group[\"params\"][0] = torch.nn.Parameter(torch.cat((group[\"params\"][0], v), dim=0).requires_grad_(True))\n            params[k] = group[\"params\"][0]\n    return params\n\n\ndef remove_points(to_remove, params, variables, optimizer):\n    to_keep = ~to_remove\n    keys = [k for k in params.keys() if k not in ['cam_m', 'cam_c']]\n    for k in keys:\n        group = [g for g in optimizer.param_groups if g['name'] == k][0]\n        stored_state = optimizer.state.get(group['params'][0], None)\n        if stored_state is not None:\n            stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][to_keep]\n            stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][to_keep]\n            del optimizer.state[group['params'][0]]\n            group[\"params\"][0] = torch.nn.Parameter((group[\"params\"][0][to_keep].requires_grad_(True)))\n            optimizer.state[group['params'][0]] = stored_state\n            params[k] = group[\"params\"][0]\n        else:\n            group[\"params\"][0] = torch.nn.Parameter(group[\"params\"][0][to_keep].requires_grad_(True))\n            params[k] = group[\"params\"][0]\n    variables['means2D_gradient_accum'] = variables['means2D_gradient_accum'][to_keep]\n    variables['denom'] = variables['denom'][to_keep]\n    variables['max_2D_radius'] = variables['max_2D_radius'][to_keep]\n    return params, variables\n\n\ndef inverse_sigmoid(x):\n    return torch.log(x / (1 - x))\n\n\ndef densify(params, variables, optimizer, i):\n    if i <= 5000:\n        variables = accumulate_mean2d_gradient(variables)\n        grad_thresh = 0.0002\n        if (i >= 500) and (i % 100 == 0):\n            grads = variables['means2D_gradient_accum'] / variables['denom']\n            grads[grads.isnan()] = 0.0\n            to_clone = torch.logical_and(grads >= grad_thresh, (\n                    torch.max(torch.exp(params['log_scales']), dim=1).values <= 0.01 * variables['scene_radius']))\n            new_params = {k: v[to_clone] for k, v in params.items() if k not in ['cam_m', 'cam_c']}\n            params = cat_params_to_optimizer(new_params, params, optimizer)\n            num_pts = params['means3D'].shape[0]\n\n            padded_grad = torch.zeros(num_pts, device=\"cuda\")\n            padded_grad[:grads.shape[0]] = grads\n            to_split = torch.logical_and(padded_grad >= grad_thresh,\n                                         torch.max(torch.exp(params['log_scales']), dim=1).values > 0.01 * variables[\n                                             'scene_radius'])\n            n = 2  # number to split into\n            new_params = {k: v[to_split].repeat(n, 1) for k, v in params.items() if k not in ['cam_m', 'cam_c']}\n            stds = torch.exp(params['log_scales'])[to_split].repeat(n, 1)\n            means = torch.zeros((stds.size(0), 3), device=\"cuda\")\n            samples = torch.normal(mean=means, std=stds)\n            rots = build_rotation(params['unnorm_rotations'][to_split]).repeat(n, 1, 1)\n            new_params['means3D'] += torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1)\n            new_params['log_scales'] = torch.log(torch.exp(new_params['log_scales']) / (0.8 * n))\n            params = cat_params_to_optimizer(new_params, params, optimizer)\n            num_pts = params['means3D'].shape[0]\n\n            variables['means2D_gradient_accum'] = torch.zeros(num_pts, device=\"cuda\")\n            variables['denom'] = torch.zeros(num_pts, device=\"cuda\")\n            variables['max_2D_radius'] = torch.zeros(num_pts, device=\"cuda\")\n            to_remove = torch.cat((to_split, torch.zeros(n * to_split.sum(), dtype=torch.bool, device=\"cuda\")))\n            params, variables = remove_points(to_remove, params, variables, optimizer)\n\n            remove_threshold = 0.25 if i == 5000 else 0.005\n            to_remove = (torch.sigmoid(params['logit_opacities']) < remove_threshold).squeeze()\n            if i >= 3000:\n                big_points_ws = torch.exp(params['log_scales']).max(dim=1).values > 0.1 * variables['scene_radius']\n                to_remove = torch.logical_or(to_remove, big_points_ws)\n            params, variables = remove_points(to_remove, params, variables, optimizer)\n\n            torch.cuda.empty_cache()\n\n        if i > 0 and i % 3000 == 0:\n            new_params = {'logit_opacities': inverse_sigmoid(torch.ones_like(params['logit_opacities']) * 0.01)}\n            params = update_params_and_optimizer(new_params, params, optimizer)\n\n    return params, variables\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/helpers.py",
    "content": "import os\n\nimport numpy as np\nimport open3d as o3d\nimport torch\nfrom diff_gaussian_rasterization import GaussianRasterizationSettings as Camera\n\n\ndef setup_camera(w, h, k, w2c, near=0.01, far=100):\n    fx, fy, cx, cy = k[0][0], k[1][1], k[0][2], k[1][2]\n    w2c = torch.tensor(w2c).cuda().float()\n    cam_center = torch.inverse(w2c)[:3, 3]\n    w2c = w2c.unsqueeze(0).transpose(1, 2)\n    opengl_proj = torch.tensor([[2 * fx / w, 0.0, -(w - 2 * cx) / w, 0.0],\n                                [0.0, 2 * fy / h, -(h - 2 * cy) / h, 0.0],\n                                [0.0, 0.0, far / (far - near), -(far * near) / (far - near)],\n                                [0.0, 0.0, 1.0, 0.0]]).cuda().float().unsqueeze(0).transpose(1, 2)\n    full_proj = w2c.bmm(opengl_proj)\n    cam = Camera(\n        image_height=h,\n        image_width=w,\n        tanfovx=w / (2 * fx),\n        tanfovy=h / (2 * fy),\n        bg=torch.tensor([0, 0, 0], dtype=torch.float32, device=\"cuda\"),\n        scale_modifier=1.0,\n        viewmatrix=w2c,\n        projmatrix=full_proj,\n        sh_degree=0,\n        campos=cam_center,\n        prefiltered=False\n    )\n    return cam\n\n\ndef params2rendervar(params):\n    rendervar = {\n        'means3D': params['means3D'],\n        'colors_precomp': params['rgb_colors'],\n        'rotations': torch.nn.functional.normalize(params['unnorm_rotations']),\n        'opacities': torch.sigmoid(params['logit_opacities']),\n        'scales': torch.exp(params['log_scales']),\n        'means2D': torch.zeros_like(params['means3D'], requires_grad=True, device=\"cuda\") + 0\n    }\n    return rendervar\n\n\ndef l1_loss_v1(x, y):\n    return torch.abs((x - y)).mean()\n\n\ndef l1_loss_v2(x, y):\n    return (torch.abs(x - y).sum(-1)).mean()\n\n\ndef weighted_l2_loss_v1(x, y, w):\n    return torch.sqrt(((x - y) ** 2) * w + 1e-20).mean()\n\n\ndef weighted_l2_loss_v2(x, y, w):\n    return torch.sqrt(((x - y) ** 2).sum(-1) * w + 1e-20).mean()\n\n\ndef quat_mult(q1, q2):\n    w1, x1, y1, z1 = q1.T\n    w2, x2, y2, z2 = q2.T\n    w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2\n    x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2\n    y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2\n    z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2\n    return torch.stack([w, x, y, z]).T\n\n\ndef o3d_knn(pts, num_knn):\n    indices = []\n    sq_dists = []\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(np.ascontiguousarray(pts, np.float64))\n    pcd_tree = o3d.geometry.KDTreeFlann(pcd)\n    for p in pcd.points:\n        [_, i, d] = pcd_tree.search_knn_vector_3d(p, num_knn + 1)\n        indices.append(i[1:])\n        sq_dists.append(d[1:])\n    return np.array(sq_dists), np.array(indices)\n\n\ndef params2cpu(params, is_initial_timestep):\n    if is_initial_timestep:\n        res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items()}\n    else:\n        res = {k: v.detach().cpu().contiguous().numpy() for k, v in params.items() if\n               k in ['means3D', 'rgb_colors', 'unnorm_rotations']}\n    return res\n\n\ndef save_params(output_params, seq, exp):\n    to_save = {}\n    for k in output_params[0].keys():\n        if k in output_params[1].keys():\n            to_save[k] = np.stack([params[k] for params in output_params])\n        else:\n            to_save[k] = output_params[0][k]\n    os.makedirs(f\"./output/{exp}/{seq}\", exist_ok=True)\n    np.savez(f\"./output/{exp}/{seq}/params\", **to_save)\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/merge_tapvid3d_per_camera_annotations.py",
    "content": "import json\nimport os\nimport warnings\nfrom pathlib import Path\n\nimport matplotlib\nimport numpy as np\nimport rerun as rr\nimport torch\nfrom PIL import Image\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\nfrom tqdm import tqdm\n\nfrom .helpers import setup_camera\nfrom .visualize import log_tracks_to_rerun\n\n\ndef to_homogeneous(x):\n    return np.concatenate([x, np.ones_like(x[..., :1])], axis=-1)\n\n\ndef from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.001):\n    if assert_homogeneous_part_is_equal_to_1:\n        assert np.allclose(x[..., -1:], 1, atol=eps), f\"Expected homogeneous part to be 1, got {x[..., -1:]}\"\n    return x[..., :-1] / x[..., -1:]\n\n\ndef load_scene_data(params_path, seg_as_col=False):\n    \"\"\"Load 3D scene data from file.\"\"\"\n    params = dict(np.load(params_path, allow_pickle=True))\n    params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}\n    is_fg = params['seg_colors'][:, 0] > 0.5\n    scene_data = []\n    for t in range(len(params['means3D'])):\n        rendervar = {\n            'means3D': params['means3D'][t],\n            'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'],\n            'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]),\n            'opacities': torch.sigmoid(params['logit_opacities']),\n            'scales': torch.exp(params['log_scales']),\n            'means2D': torch.zeros_like(params['means3D'][0], device=\"cuda\")\n        }\n        scene_data.append(rendervar)\n    return scene_data, is_fg\n\n\ndef render(h, w, k, w2c, timestep_data, near=0.01, far=100.0):\n    \"\"\"Render scene using Gaussian Rasterization.\"\"\"\n    with torch.no_grad():\n        cam = setup_camera(w, h, k, w2c, near, far)\n        im, _, depth = Renderer(raster_settings=cam)(**timestep_data)\n        return im, depth\n\n\ndef merge_annotations(\n        scene_root,\n        checkpoint_path,\n        tapvid3d_annotation_paths,\n        nearest_neighbor_distance_threshold_for_visibility=0.015,\n        skip_if_output_already_exists=False,\n\n        assert_query_points_project_to_trajectories_in_tapvid3d_annotation=False,\n\n        rerun_logging=False,\n        rerun_stream_only=False,\n        rerun_views_to_viz=(27, 16, 1),\n\n        rerun_log_rgb=True,\n        rerun_log_d3dgs_rgb=False,\n        rerun_log_d3dgs_depth=False,\n        rerun_log_d3dgs_point_cloud=True,\n        rerun_log_tracks=True,\n        rerun_log_n_skip_t=1,\n):\n    output_annotation_path = scene_root / \"tapvid3d_annotations.npz\"\n    if skip_if_output_already_exists and output_annotation_path.exists():\n        print(f\"Output file {output_annotation_path} already exists, skipping.\")\n        return\n\n    scene_data, is_fg = load_scene_data(os.path.join(checkpoint_path, \"params.npz\"))\n    md_train = json.load(open(os.path.join(scene_root, \"train_meta.json\"), \"r\"))\n    md_test = json.load(open(os.path.join(scene_root, \"test_meta.json\"), \"r\"))\n\n    views = sorted(list(set(md_train[\"cam_id\"][0]) | set(md_test[\"cam_id\"][0])))\n    assert list(range(31)) == views, \"We expect exactly 31 views: from 0 to 30.\"\n    n_frames = len(md_train['fn'])\n    n_views = len(views)\n\n    # Check that the selected views are in the training set\n    view_paths = []\n    for view_idx in views:\n        view_path = scene_root / \"ims\" / f\"{view_idx}\"\n        assert view_path.exists()\n        view_paths.append(view_path)\n    frame_paths = [sorted(view_path.glob(\"*.jpg\")) for view_path in view_paths]\n    assert all(len(frame_paths[v]) == n_frames for v in range(n_views))\n    assert len(scene_data) == n_frames\n\n    # Load the camera parameters\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_idx in views:\n        fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], []\n        for t in range(n_frames):\n            if view_idx in md_train['cam_id'][t]:\n                md = md_train\n            elif view_idx in md_test['cam_id'][t]:\n                md = md_test\n            else:\n                raise ValueError(f\"Camera {view_idx} not found in any of the meta files\")\n\n            view_idx_in_array = md['cam_id'][t].index(view_idx)\n            k = md['k'][t][view_idx_in_array]\n            w2c = np.array(md['w2c'][t][view_idx_in_array])\n\n            fx_current.append(k[0][0])\n            fy_current.append(k[1][1])\n            cx_current.append(k[0][2])\n            cy_current.append(k[1][2])\n            extrinsics_current.append(w2c)\n\n        assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames))\n\n        fx.append(fx_current[0])\n        fy.append(fy_current[0])\n        cx.append(cx_current[0])\n        cy.append(cy_current[0])\n        extrinsics.append(extrinsics_current[0])\n\n    k = np.eye(3).astype(np.float64)[None].repeat(n_views, 0)\n    k[:, 0, 0] = fx\n    k[:, 1, 1] = fy\n    k[:, 0, 2] = cx\n    k[:, 1, 2] = cy\n    extrinsics = np.stack(extrinsics).astype(np.float64)\n    k_inv = np.linalg.inv(k)\n    extrinsics_inv = np.linalg.inv(extrinsics)\n\n    # Render imgs and depths\n    rgbs = np.stack([\n        np.stack([\n            np.array(Image.open(frame_paths[v][t]))\n            for t in range(n_frames)\n        ])\n        for v in range(n_views)\n    ])\n\n    h, w = rgbs.shape[2], rgbs.shape[3]\n    d3dgs_rgbs = []\n    d3dgs_depths = []\n    for v, view_idx in enumerate(views):\n        for t in range(n_frames):\n            im, depth = render(h, w, k[v], extrinsics[v], scene_data[t])\n            d3dgs_rgbs.append(im.cpu().numpy().transpose(1, 2, 0))\n            d3dgs_depths.append(depth.cpu().numpy()[0])\n    d3dgs_rgbs = np.stack(d3dgs_rgbs).reshape(n_views, n_frames, h, w, 3)\n    d3dgs_depths = np.stack(d3dgs_depths).reshape(n_views, n_frames, h, w)\n\n    assert rgbs.shape == (n_views, n_frames, h, w, 3)\n    assert d3dgs_rgbs.shape == (n_views, n_frames, h, w, 3)\n    assert d3dgs_depths.shape == (n_views, n_frames, h, w)\n\n    # Merge TAP-Vid3D annotations\n    merged_trajectories = []\n    merged_trajectories_pixelspace = []\n    merged_per_view_visibilities = []\n    merged_query_points_3d = []\n    for tapvid3d_annotation_path in tqdm(tapvid3d_annotation_paths):\n        annotation = np.load(tapvid3d_annotation_path)\n        queries_xyt = annotation[\"queries_xyt\"]\n        tracks_XYZ = annotation[\"tracks_XYZ\"]\n        visibility = annotation[\"visibility\"]\n        fx_fy_cx_cy = annotation[\"fx_fy_cx_cy\"]\n        images_jpeg_bytes = annotation[\"images_jpeg_bytes\"]\n\n        _, cam_id = os.path.basename(tapvid3d_annotation_path)[:-4].split(\"_\")\n        cam_id = int(cam_id)\n        assert cam_id == views.index(cam_id)\n\n        n_tracks, _ = queries_xyt.shape\n        assert cam_id in views\n        assert queries_xyt.shape == (n_tracks, 3)\n        assert fx_fy_cx_cy.shape == (4,)\n        assert images_jpeg_bytes.shape == (n_frames,)\n        assert tracks_XYZ.shape == (n_frames, n_tracks, 3)\n        assert visibility.shape == (n_frames, n_tracks)\n        assert np.allclose(fx_fy_cx_cy, [fx[cam_id], fy[cam_id], cx[cam_id], cy[cam_id]])\n\n        # Project the tracks to the world space\n        cam_coords_homo = to_homogeneous(tracks_XYZ)\n        world_coords_homo = np.einsum(\"ij,SNj->SNi\", extrinsics_inv[cam_id], cam_coords_homo)\n        world_coords = from_homogeneous(world_coords_homo, assert_homogeneous_part_is_equal_to_1=True)\n\n        # Project query points to 3D to verify we can reproduce the camera space points\n        qp_t = queries_xyt[:, 2].astype(np.int32)\n        qp_xy_pixel = queries_xyt[:, :2].astype(np.float32)\n        qp_depth = np.ones((n_tracks, 1), dtype=np.float32) * np.inf\n        qp_xyz_camera = np.ones((n_tracks, 3), dtype=np.float32) * np.inf\n        qp_xyz_world = np.ones((n_tracks, 3), dtype=np.float32) * np.inf\n        for t in range(n_frames):\n            qp_mask = qp_t == t\n            if qp_mask.sum() == 0:\n                continue\n\n            # V2 depth interpolation\n            x_nearest = qp_xy_pixel[qp_mask, 0].round().astype(np.int32).clip(0, w - 1)\n            y_nearest = qp_xy_pixel[qp_mask, 1].round().astype(np.int32).clip(0, h - 1)\n            depth_nearest = d3dgs_depths[cam_id, t].reshape(-1)[\n                (y_nearest * w + x_nearest).reshape(-1)]\n            depth_nearest = depth_nearest.reshape(-1, 1)\n            qp_depth[qp_mask] = depth_nearest\n\n            qp_xyz_pixel_t = np.concatenate([qp_xy_pixel[qp_mask], np.ones_like(qp_xy_pixel[qp_mask][..., :1])],\n                                            axis=1)\n            qp_xyz_camera_t = np.einsum(\"ij,Nj->Ni\", k_inv[cam_id], qp_xyz_pixel_t) * qp_depth[qp_mask]\n            qp_xyz_world_t = np.einsum(\"ij,Nj->Ni\", extrinsics_inv[cam_id],\n                                       np.concatenate([qp_xyz_camera_t, np.ones_like(qp_xyz_camera_t[..., :1])],\n                                                      axis=1))[:, :3]\n\n            qp_xyz_camera[qp_mask] = qp_xyz_camera_t\n            qp_xyz_world[qp_mask] = qp_xyz_world_t\n\n        assert np.all(np.isfinite(qp_depth))\n        assert np.all(np.isfinite(qp_xyz_camera))\n        assert np.all(np.isfinite(qp_xyz_world))\n\n        # Verify that the query points are close to the tracks in the world space\n        qp_projection_diff = np.linalg.norm(\n            qp_xyz_camera - tracks_XYZ[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)], axis=1)\n        repro1 = np.percentile(qp_projection_diff, 80) < 1\n        repro2 = qp_projection_diff.mean() < 0.1\n        if not repro1 or not repro2:\n            warnings.warn(f\"Projecting query points to match tracks in camera space failed. \"\n                          f\"Differences: max={qp_projection_diff.max():0.3f}, \"\n                          f\"mean={qp_projection_diff.mean():0.3f}, \"\n                          f\"median={np.percentile(qp_projection_diff, 50):0.3f}, \"\n                          f\"p80={np.percentile(qp_projection_diff, 80):0.3f}\")\n        if assert_query_points_project_to_trajectories_in_tapvid3d_annotation:\n            assert repro1\n            assert repro2\n\n        # Verify that the projected tracks are close to the query points in pixel space\n        cam_coords_per_view = from_homogeneous(np.einsum(\"Vij,SNj->VSNi\", extrinsics, world_coords_homo), True)\n        pixel_coords_per_view = from_homogeneous(np.einsum(\"Vij,VSNj->VSNi\", k, cam_coords_per_view))\n        diff = np.linalg.norm(qp_xy_pixel - pixel_coords_per_view[cam_id][qp_t, np.arange(n_tracks)], axis=-1)\n        repro3 = np.percentile(diff, 80) < 0.1\n        # The xy pixel query from queries_xyz in the raw labels sometimes doesn't match the tracks_XYZ in camera space.\n        # In the merged labels, we will not use the queries_xyz, but just directly work with the tracks_XYZ and their\n        # projections (where pixel-space projections are needed).\n        if not repro3:\n            warnings.warn(f\"Projecting tracks to pixel space to match query points failed. \"\n                          f\"Max diff: {diff.max()}. Mean diff: {diff.mean()}. Median diff: {np.percentile(diff, 50)}. \"\n                          f\"Percentile 80: {np.percentile(diff, 80)}.\")\n        if assert_query_points_project_to_trajectories_in_tapvid3d_annotation:\n            assert repro3\n\n        # import matplotlib.pyplot as plt\n        # plt.imshow(rgbs[v, qp_t[0]])\n        # plt.scatter(qp_xy_pixel[0, 0], qp_xy_pixel[0, 1], color=\"red\")\n        # plt.scatter(pixel_coords_per_view[cam_id][qp_t[0], 0, 0], pixel_coords_per_view[cam_id][qp_t[0], 0, 1], color=\"green\")\n        # plt.show()\n\n        # Compute the distance from the trajectories to their nearest depthmap neighbors\n        depthmap_nearest_neighbor_distance = np.ones((n_views, n_frames, n_tracks), dtype=np.float32) * np.inf\n        k_inv_torch = torch.from_numpy(k_inv).cuda()\n        extrinsics_inv_torch = torch.from_numpy(extrinsics_inv).cuda()\n        pixel_coords_per_view_round_torch = torch.from_numpy(pixel_coords_per_view.round().astype(int)).cuda()\n        world_coords_torch = torch.from_numpy(world_coords).cuda()\n        for v, view_idx in enumerate(views):\n            for t in range(n_frames):\n                # Project depths to world space\n                # Pixel --> Camera --> World\n                pixel_xy = torch.stack(torch.meshgrid(torch.arange(w), torch.arange(h), indexing=\"xy\"), dim=-1).cuda()\n                pixel_xy = pixel_xy.type(k_inv_torch.dtype)\n                pixel_xy_homo = torch.cat([pixel_xy, torch.ones_like(pixel_xy[..., :1])], dim=-1)\n                depthmap_camera_xyz = torch.einsum(\"ij,hwj->hwi\", k_inv_torch[v], pixel_xy_homo)\n                depthmap_camera_xyz *= torch.tensor(d3dgs_depths[v, t], device=\"cuda\", dtype=torch.float32)[..., None]\n                depthmap_camera_xyz_homo = torch.cat(\n                    [depthmap_camera_xyz, torch.ones_like(depthmap_camera_xyz[..., :1])], dim=-1)\n                depthmap_world_xyz_homo = torch.einsum(\"ij,hwj->hwi\", extrinsics_inv_torch[v], depthmap_camera_xyz_homo)\n                depthmap_world_xyz = depthmap_world_xyz_homo[..., :-1] / depthmap_world_xyz_homo[..., -1:]\n\n                radius = 3\n                xmin = (pixel_coords_per_view_round_torch[v, t, :, 0] - radius).clip(min=0, max=w - 1 - 2 * radius)\n                ymin = (pixel_coords_per_view_round_torch[v, t, :, 1] - radius).clip(min=0, max=h - 1 - 2 * radius)\n                offsets = torch.arange(0, 2 * radius + 1, device=\"cuda\")\n                x_offsets, y_offsets = torch.meshgrid(offsets, offsets, indexing=\"ij\")\n                x_offsets = x_offsets.reshape(-1)\n                y_offsets = y_offsets.reshape(-1)\n                x_indices = (xmin[:, None] + x_offsets[None, :]).long()\n                y_indices = (ymin[:, None] + y_offsets[None, :]).long()\n                neighbors = depthmap_world_xyz[y_indices, x_indices]\n                nearest_dist = torch.linalg.norm(neighbors - world_coords_torch[t][:, None, :], dim=-1).min(dim=-1)[0]\n                depthmap_nearest_neighbor_distance[v, t, :] = nearest_dist.cpu().numpy()\n        assert not np.isinf(depthmap_nearest_neighbor_distance).any()\n\n        # Compute whether the projected trajectory is within the HxW frame of a view\n        within_frame = ((pixel_coords_per_view[..., 0] >= 0) & (pixel_coords_per_view[..., 0] < w)\n                        & (pixel_coords_per_view[..., 1] >= 0) & (pixel_coords_per_view[..., 1] < h))\n\n        # If nearest neighbor in depth is less than X cm away, consider the point as visible in that view\n        # Furthermore if the projected pixel space location is out of the frame, the point is not visible\n        per_view_visibility = depthmap_nearest_neighbor_distance <= nearest_neighbor_distance_threshold_for_visibility\n        per_view_visibility = per_view_visibility & within_frame\n\n        valid_tracks_mask = (per_view_visibility[cam_id] == visibility).mean(0) > 0.7\n        valid_tracks_indices = np.where(valid_tracks_mask)[0]\n        assert (per_view_visibility[cam_id] == visibility)[:, valid_tracks_mask].mean() > 0.8\n\n        query_points_3d_t = np.max(np.stack([qp_t, per_view_visibility[cam_id].argmax(0)], axis=1), axis=1)\n        query_points_3d_xyz = world_coords[query_points_3d_t, np.arange(n_tracks)]\n        query_points_3d = np.concatenate([query_points_3d_t[:, None], query_points_3d_xyz[:, :]], axis=1)\n\n        merged_trajectories.append(world_coords[:, valid_tracks_indices, :])\n        merged_trajectories_pixelspace.append(pixel_coords_per_view[:, :, valid_tracks_indices, :])\n        merged_per_view_visibilities.append(per_view_visibility[:, :, valid_tracks_indices])\n        merged_query_points_3d.append(query_points_3d[valid_tracks_indices])\n\n        # print(f\"VERBOSE LOGS: varying the distance threshold for cam_id={cam_id}\")\n        # for d in [0.001, 0.005, 0.01, 0.011, 0.012, 0.013, 0.014, 0.015, 0.016, 0.017, 0.019,\n        #           0.020, 0.021, 0.022, 0.023, 0.024, 0.025, 0.026, 0.027, 0.028, 0.029, 0.030, 0.035, 0.04, 0.05]:\n        #     per_view_visibility = (depthmap_nearest_neighbor_distance <= d) & within_frame\n        #     print(f\" --> dist={d:0.3f} \"\n        #           f\"v1={per_view_visibility[cam_id].mean() * 100:.1f} \"\n        #           f\"v2={visibility.mean() * 100:.1f} \"\n        #           f\"acc={(per_view_visibility[cam_id] == visibility).mean() * 100:.1f}\")\n        # per_view_visibility = depthmap_nearest_neighbor_distance <= nearest_neighbor_distance_threshold_for_visibility\n        # per_view_visibility = per_view_visibility & within_frame\n        # print(f\"dist={nearest_neighbor_distance_threshold_for_visibility:0.3f} \"\n        #       f\"v1={per_view_visibility[cam_id].mean() * 100:.1f} \"\n        #       f\"v2={visibility.mean() * 100:.1f} \"\n        #       f\"acc={(per_view_visibility[cam_id] == visibility).mean() * 100:.1f}\")\n        #\n        # if cam_id != 16:\n        #     continue\n        #\n        # rr.init(\"reconstruction\", recording_id=\"v0.1\")\n        # rr.connect_tcp()\n        # rr.log(\"/\", rr.ViewCoordinates.LEFT_HAND_Y_DOWN, static=True)\n        # rr.set_time_seconds(\"frame\", 0)\n        # rr.log(\"world/xyz\", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n        #                                 colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]]))\n        #\n        # rr.log(f\"debug/qp_xyz_camera\",\n        #        rr.Points3D(world_coords[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)],\n        #                    colors=np.ones_like(qp_xyz_camera) * [0, 1, 0], radii=0.01))\n        # rr.log(f\"debug/qp_xyz_camera_reproj\",\n        #        rr.Points3D(qp_xyz_world, colors=np.ones_like(qp_xyz_camera) * [0, 0, 1], radii=0.01))\n        # strips = np.stack([world_coords[queries_xyt[:, 2].astype(np.int32), np.arange(n_tracks)], qp_xyz_world], axis=1)\n        # rr.log(\"debug/qp_xyz_error_line\", rr.LineStrips3D(strips=strips, colors=np.array([1., 0, 0]), radii=0.003))\n        #\n        # seq = os.path.basename(scene_root)\n        # for t in range(0, n_frames, rerun_log_n_skip_t):\n        #     for v in rerun_views_to_viz:\n        #         rr.set_time_seconds(\"frame\", t / 30)\n        #         depth_values = d3dgs_depths[v, t].ravel()\n        #         valid_mask = depth_values > 0\n        #         y, x = np.indices((h, w))\n        #         homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n        #         cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values\n        #         cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n        #         world_coords_ = (extrinsics_inv[v] @ cam_coords)[:3].T\n        #         world_coords_ = world_coords_[valid_mask]\n        #         rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8)\n        #         rr.log(f\"{seq}/dyn-3dgs-point-cloud/view-{v}\",\n        #                rr.Points3D(world_coords_, colors=rgb_colors, radii=0.004))\n        # cmap = matplotlib.colormaps[\"gist_rainbow\"]\n        # norm = matplotlib.colors.Normalize(vmin=world_coords[..., 0].min(), vmax=world_coords[..., 0].max())\n        # track_colors = cmap(norm(world_coords[-1, :, 0]))\n        # log_tracks_to_rerun(\n        #     tracks=world_coords,\n        #     visibles=visibility,\n        #     query_timestep=np.zeros(n_tracks, dtype=np.int32),\n        #     colors=track_colors,\n        #     track_names=[f\"track-{i:02d}\" for i in range(n_tracks)],\n        #     entity_format_str=f\"debug/tapvid3d-tracks-visGT/{{}}\",\n        #     invisible_color=[0.3, 0.3, 0.3],\n        # )\n        # log_tracks_to_rerun(\n        #     tracks=world_coords,\n        #     visibles=per_view_visibility[views.index(16)],\n        #     query_timestep=np.zeros(n_tracks, dtype=np.int32),\n        #     colors=track_colors,\n        #     track_names=[f\"track-{i:02d}\" for i in range(n_tracks)],\n        #     entity_format_str=f\"debug/tapvid3d-tracks-vis16-v2/{{}}\",\n        #     invisible_color=[0.3, 0.3, 0.3],\n        # )\n        # log_tracks_to_rerun(\n        #     tracks=world_coords,\n        #     visibles=per_view_visibility[views.index(27)],\n        #     query_timestep=np.zeros(n_tracks, dtype=np.int32),\n        #     colors=track_colors,\n        #     track_names=[f\"track-{i:02d}\" for i in range(n_tracks)],\n        #     entity_format_str=f\"debug/tapvid3d-tracks-vis27/{{}}\",\n        #     invisible_color=[0.3, 0.3, 0.3],\n        # )\n        # exit()\n    merged_trajectories = np.concatenate(merged_trajectories, axis=1)\n    merged_trajectories_pixelspace = np.concatenate(merged_trajectories_pixelspace, axis=2)\n    merged_per_view_visibilities = np.concatenate(merged_per_view_visibilities, axis=2)\n    merged_query_points_3d = np.concatenate(merged_query_points_3d, axis=0)\n\n    # Remove duplicates from the merged trajectories\n    from sklearn.cluster import DBSCAN\n    flat_trajectories = merged_trajectories.transpose(1, 0, 2).reshape(-1, n_frames * 3)\n    dbscan = DBSCAN(eps=0.01, min_samples=1, metric='euclidean')\n    labels = dbscan.fit_predict(flat_trajectories)\n    _, unique_indices = np.unique(labels, return_index=True)\n    unique_indices = np.sort(unique_indices)\n    merged_trajectories = merged_trajectories[:, unique_indices, :]\n    merged_trajectories_pixelspace = merged_trajectories_pixelspace[:, :, unique_indices, :]\n    merged_per_view_visibilities = merged_per_view_visibilities[:, :, unique_indices]\n    merged_query_points_3d = merged_query_points_3d[unique_indices, :]\n\n    n_tracks = merged_trajectories.shape[1]\n    assert merged_trajectories.shape == (n_frames, n_tracks, 3)\n    assert merged_trajectories_pixelspace.shape == (n_views, n_frames, n_tracks, 2)\n    assert merged_per_view_visibilities.shape == (n_views, n_frames, n_tracks)\n    assert merged_query_points_3d.shape == (n_tracks, 4)\n\n    # Shuffle the tracks\n    np.random.seed(72)\n    track_perm = np.random.permutation(n_tracks)\n    shuffled_trajectories = merged_trajectories[:, track_perm, :]\n    shuffled_trajectories_pixelspace = merged_trajectories_pixelspace[:, :, track_perm, :]\n    shuffled_per_view_visibilities = merged_per_view_visibilities[:, :, track_perm]\n    shuffled_query_points_3d = merged_query_points_3d[track_perm, :]\n\n    # Save the merged annotations\n    np.savez(\n        output_annotation_path,\n        trajectories=shuffled_trajectories,\n        trajectories_pixelspace=shuffled_trajectories_pixelspace,\n        per_view_visibilities=shuffled_per_view_visibilities,\n        query_points_3d=shuffled_query_points_3d,\n        intrinsics=k,\n        extrinsics=extrinsics,\n    )\n    print(f\"Saved merged annotations to {output_annotation_path}\")\n\n    if rerun_logging:\n        rr.init(\"reconstruction\", recording_id=\"v0.1\")\n        if rerun_stream_only:\n            rr.connect_tcp()\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\"/\", rr.ViewCoordinates.LEFT_HAND_Y_DOWN, static=True)\n        rr.log(\"world/xyz\", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                                        colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]]))\n\n        seq = os.path.basename(scene_root)\n        for t in range(0, n_frames, rerun_log_n_skip_t):\n            for v in rerun_views_to_viz:\n                rr.set_time_seconds(\"frame\", t / 30)\n                if rerun_log_rgb:\n                    rr.log(f\"{seq}/rgb/view-{views[v]}/rgb\", rr.Image(rgbs[v, t]))\n                    rr.log(f\"{seq}/rgb/view-{views[v]}\", rr.Pinhole(image_from_camera=k[v], width=w, height=h))\n                    rr.log(f\"{seq}/rgb/view-{views[v]}\", rr.Transform3D(translation=extrinsics_inv[v, :3, 3],\n                                                                        mat3x3=extrinsics_inv[v, :3, :3]))\n                if rerun_log_d3dgs_rgb:\n                    rr.log(f\"{seq}/dyn-3dgs-rgb/view-{views[v]}/rgb\", rr.Image(d3dgs_rgbs[v, t]))\n                    rr.log(f\"{seq}/dyn-3dgs-rgb/view-{views[v]}\", rr.Pinhole(image_from_camera=k[v], width=w, height=h))\n                    rr.log(f\"{seq}/dyn-3dgs-rgb/view-{views[v]}\", rr.Transform3D(translation=extrinsics_inv[v, :3, 3],\n                                                                                 mat3x3=extrinsics_inv[v, :3, :3]))\n                if rerun_log_d3dgs_depth:\n                    rr.log(f\"{seq}/dyn-3dgs-depth/view-{views[v]}/depth\",\n                           rr.DepthImage(d3dgs_depths[v, t], point_fill_ratio=0.2))\n                    rr.log(f\"{seq}/dyn-3dgs-depth/view-{views[v]}\",\n                           rr.Pinhole(image_from_camera=k[v], width=w, height=h))\n                    rr.log(f\"{seq}/dyn-3dgs-depth/view-{views[v]}\",\n                           rr.Transform3D(translation=extrinsics_inv[v, :3, 3], mat3x3=extrinsics_inv[v, :3, :3]))\n\n                if rerun_log_d3dgs_point_cloud:\n                    depth_values = d3dgs_depths[v, t].ravel()\n                    valid_mask = depth_values > 0\n                    y, x = np.indices((h, w))\n                    homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                    cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values\n                    cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                    world_coords = (extrinsics_inv[v] @ cam_coords)[:3].T\n                    world_coords = world_coords[valid_mask]\n                    rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8)\n                    rr.log(f\"{seq}/dyn-3dgs-point-cloud/view-{v}\",\n                           rr.Points3D(world_coords, colors=rgb_colors, radii=0.004))\n\n        if rerun_log_tracks:\n            raw_tracks = np.stack([data['means3D'][is_fg][::200].contiguous().cpu().numpy() for data in scene_data])\n            n_tracks_raw = raw_tracks.shape[1]\n            cmap = matplotlib.colormaps[\"gist_rainbow\"]\n            norm = matplotlib.colors.Normalize(vmin=raw_tracks[..., 0].min(), vmax=raw_tracks[..., 0].max())\n            track_colors = cmap(norm(raw_tracks[-1, :, 0]))\n            log_tracks_to_rerun(\n                tracks=raw_tracks,\n                visibles=np.ones((n_frames, n_tracks_raw), dtype=bool),\n                query_timestep=np.zeros(n_tracks_raw, dtype=np.int32),\n                colors=track_colors,\n                track_names=[f\"track-{i:02d}\" for i in range(n_tracks_raw)],\n                entity_format_str=f\"{seq}/dyn-3dgs-raw-tracks/{{}}\",\n                invisible_color=[0.3, 0.3, 0.3],\n            )\n\n            cmap = matplotlib.colormaps[\"gist_rainbow\"]\n            norm = matplotlib.colors.Normalize(vmin=shuffled_trajectories[..., 0].min(),\n                                               vmax=shuffled_trajectories[..., 0].max())\n            track_colors = cmap(norm(shuffled_trajectories[-1, :, 0]))\n            batch_size = 50\n            max_tracks = 500\n            for v in rerun_views_to_viz:\n                for tracks_batch_start in range(0, max_tracks, batch_size):\n                    tracks_batch_end = min(tracks_batch_start + batch_size, n_tracks)\n                    log_tracks_to_rerun(\n                        tracks=shuffled_trajectories[:, tracks_batch_start:tracks_batch_end],\n                        visibles=shuffled_per_view_visibilities[v, :, tracks_batch_start:tracks_batch_end],\n                        query_timestep=shuffled_query_points_3d[:, 0][tracks_batch_start:tracks_batch_end].astype(int),\n                        colors=track_colors[tracks_batch_start:tracks_batch_end],\n                        track_names=[f\"track-{i:02d}\" for i in range(tracks_batch_start, tracks_batch_end)],\n                        entity_format_str=f\"{seq}/tapvid3d-tracks/view-{v}-visiblity/{tracks_batch_start}-{tracks_batch_end}/{{}}\",\n                        invisible_color=[0.3, 0.3, 0.3],\n                    )\n\n        if not rerun_stream_only:\n            rr_rrd_path = scene_root / \"rerun_tapvid3d_labels.rrd\"\n            rr.save(rr_rrd_path)\n            print(f\"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}\")\n\n\nif __name__ == \"__main__\":\n    print(\"Merging TAP-Vid3D per-camera annotations.\")\n    for sequence_name in tqdm([\"basketball\", \"boxes\", \"football\", \"juggle\", \"softball\", \"tennis\"]):\n        scene_root = Path(f\"./datasets/panoptic_d3dgs/{sequence_name}\")\n        checkpoint_path = Path(f\"./dynamic3dgs/output/pretrained/{sequence_name}\")\n        tapvid3d_annotation_paths = list(Path(f\"./datasets/tapvid3d_dataset/pstudio\").glob(f\"{sequence_name}_*.npz\"))\n        merge_annotations(\n            scene_root,\n            checkpoint_path,\n            tapvid3d_annotation_paths,\n            skip_if_output_already_exists=True,\n            rerun_logging=True\n        )\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/metadata_dexycb.py",
    "content": "import json\nimport os\nfrom collections import defaultdict\n\nimport numpy as np\n\n# Configurable parameters\nBASE_PATH = \".\"\nIMAGE_WIDTH = 640\nIMAGE_HEIGHT = 480\nSELECTED_CAMS = [0, 1, 2, 3]\nOUTPUT_NAME = \"0123_metadata\"\n\n# Filter sequences\nsequences = [f for f in os.listdir(BASE_PATH) if f.startswith(\"2020\")]\nprint(sequences)\n\nfor sequence in sequences:\n    sequence_path = os.path.join(BASE_PATH, sequence)\n    view_folders = [\n        f\n        for f in os.listdir(sequence_path)\n        if f.startswith(\"view_\") and f[-2:].isdigit()\n    ]\n\n    if not view_folders:\n        continue\n\n    example_view_path = os.path.join(sequence_path, view_folders[0])\n    frame_files = [\n        fname\n        for fname in os.listdir(example_view_path)\n        if fname.endswith(\".png\") and fname[:-4].isdigit()\n    ]\n    num_timesteps = len(frame_files)\n    print(f\"{sequence}: Found {num_timesteps} frames in {view_folders[0]}\")\n\n    combined_data = defaultdict(\n        lambda: defaultdict(\n            lambda: {\"cam_id\": 0, \"w\": 0, \"h\": 0, \"k\": [], \"w2c\": [], \"fn\": []}\n        )\n    )\n\n    for time_step in range(num_timesteps):\n        for view_folder in view_folders:\n            view_folder_path = os.path.join(sequence_path, view_folder)\n            if not os.path.exists(view_folder_path):\n                print(f\"Skipping {view_folder_path}\")\n                continue\n\n            cam_id = int(view_folder[-2:])\n\n            if SELECTED_CAMS != [] and cam_id not in SELECTED_CAMS:\n                continue\n\n            data_path = os.path.join(view_folder_path, \"intrinsics_extrinsics.npz\")\n\n            if not os.path.exists(data_path):\n                print(f\"Missing intrinsics_extrinsics.npz in {view_folder_path}\")\n                continue\n\n            data = np.load(data_path)\n            k = data[\"intrinsics\"][:3, :3]\n            w2c = data[\"extrinsics\"][:3, :]\n            w2c = np.vstack([w2c, np.array([0, 0, 0, 1])])\n\n            frame_name = f\"{cam_id}/{str(time_step).zfill(5)}.png\"\n\n            cam_info = combined_data[time_step][str(cam_id)]\n            cam_info[\"cam_id\"] = cam_id\n            cam_info[\"w\"] = IMAGE_WIDTH\n            cam_info[\"h\"] = IMAGE_HEIGHT\n            cam_info[\"k\"] = k.tolist()\n            cam_info[\"w2c\"] = w2c.tolist()\n            cam_info[\"fn\"] = frame_name\n\n    output_path = os.path.join(sequence_path, \"metadata.json\")\n    with open(output_path, \"w\") as f:\n        json.dump(dict(combined_data), f, indent=4)\n\n    print(f\"Saved metadata for {sequence}\")\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/metadata_kubric.py",
    "content": "import json\nimport os\nfrom collections import defaultdict\n\nimport kornia\nimport numpy as np\nimport torch\n\nBASE_PATH = \".\"\nIMAGE_WIDTH = 512\nIMAGE_HEIGHT = 512\nNUM_TIMESTEPS = 24\nSELECTED_CAMS = [0, 1, 2, 3]\nOUTPUT_NAME = \"0123_metadata\"\n# Filter valid sequences\nsequences = [f for f in os.listdir(BASE_PATH)]\n\nfor sequence in sequences:\n    sequence_path = os.path.join(BASE_PATH, sequence)\n    view_folders = [\n        f\n        for f in os.listdir(sequence_path)\n        if f.startswith(\"view_\") and f[-1:].isdigit()\n    ]\n\n    combined_data = defaultdict(\n        lambda: defaultdict(\n            lambda: {\n                \"cam_id\": 0,\n                \"w\": 0,\n                \"h\": 0,\n                \"k\": [],\n                \"w2c\": [],\n                \"fn\": [],\n                \"sensor_width\": 0,\n                \"focal_length\": 0,\n            }\n        )\n    )\n\n    if not view_folders:\n        continue\n\n    first_valid_view = None\n    for vf in view_folders:\n        cam_id = int(vf[-1:])\n        if not SELECTED_CAMS or cam_id in SELECTED_CAMS:\n            first_valid_view = vf\n            break\n\n    if first_valid_view is None:\n        continue\n\n    example_path = os.path.join(sequence_path, first_valid_view)\n    all_frames = [\n        f\n        for f in os.listdir(example_path)\n        if f.endswith(\".png\") and f[:-4].isdigit() and f.startswith(\"rgba\")\n    ]\n    num_timesteps = len(all_frames)\n\n    for time_step in range(NUM_TIMESTEPS):\n        for view_folder in view_folders:\n            print(f\"Processing {sequence}/{view_folder}, time step {time_step}\")\n\n            view_folder_path = os.path.join(sequence_path, view_folder)\n            if not os.path.exists(view_folder_path):\n                continue\n\n            cam_id = int(view_folder[-1:])\n            if SELECTED_CAMS != [] and cam_id not in SELECTED_CAMS:\n                continue\n\n            with open(os.path.join(view_folder_path, \"metadata.json\"), \"r\") as f:\n                data = json.load(f)\n\n            cam_data = data[\"camera\"]\n            k = cam_data[\"K\"]\n\n            quaternions = torch.tensor(cam_data[\"quaternions\"])\n            positions = torch.tensor(cam_data[\"positions\"])\n\n            rot_matrices = kornia.geometry.quaternion_to_rotation_matrix(quaternions)\n\n            ext_inv = torch.eye(4).repeat(NUM_TIMESTEPS, 1, 1)\n            ext_inv[:, :3, :3] = rot_matrices\n            ext_inv[:, :3, 3] = positions\n\n            ext = ext_inv.inverse()[:, :3, :]\n            ext = np.diag([1, -1, -1]) @ ext.numpy()\n\n            w2c = ext[0].tolist()\n            w2c.append([0, 0, 0, 1])\n\n            intrinsics = (\n                    np.diag([IMAGE_WIDTH, IMAGE_HEIGHT, 1])\n                    @ np.array(k)\n                    @ np.diag([1, -1, -1])\n            )\n            frame_name = f\"{cam_id}/{str(time_step).zfill(5)}.png\"\n\n            cam_info = combined_data[time_step][str(cam_id)]\n            cam_info[\"cam_id\"] = cam_id\n            cam_info[\"w\"] = IMAGE_WIDTH\n            cam_info[\"h\"] = IMAGE_HEIGHT\n            cam_info[\"k\"] = intrinsics.tolist()\n            cam_info[\"w2c\"] = w2c\n            cam_info[\"fn\"] = frame_name\n            cam_info[\"sensor_width\"] = cam_data[\"sensor_width\"]\n            cam_info[\"focal_length\"] = cam_data[\"focal_length\"]\n\n    output_path = os.path.join(sequence_path, f\"{OUTPUT_NAME}.json\")\n    with open(output_path, \"w\") as f:\n        json.dump(dict(combined_data), f, indent=4)\n\n    print(f\"Saved metadata for {sequence}\")\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/reorganize_dexycb.py",
    "content": "import os\n\nsource_roots = [f for f in os.listdir(\".\") if f.startswith(\"2020\")]\nimport os\nimport shutil\n\nsource_roots = [f for f in os.listdir(\".\") if f.startswith(\"2020\")]\nfor source_root in source_roots:\n    target_root = source_root\n    ims_target = os.path.join(target_root, \"ims\")\n    seg_target = os.path.join(target_root, \"seg\")\n    depths_target = os.path.join(target_root, \"depths\")\n\n    for target in [ims_target, seg_target, depths_target]:\n        os.makedirs(target, exist_ok=True)\n\n    for i in range(8):  # view_00 to view_07\n        view_folder = os.path.join(source_root, f\"view_{i:02d}\")\n\n        ims_source = os.path.join(view_folder, \"rgb\")\n        ims_dest = os.path.join(ims_target, str(i))\n        if os.path.exists(ims_source):\n            shutil.copytree(ims_source, ims_dest, dirs_exist_ok=True)\n\n        mask_source = os.path.join(view_folder, \"mask\")\n        seg_dest = os.path.join(seg_target, str(i))\n        if os.path.exists(mask_source):\n            shutil.copytree(mask_source, seg_dest, dirs_exist_ok=True)\n\n        depth_source = os.path.join(view_folder, \"depth\")\n        depth_dest = os.path.join(depths_target, str(i))\n        if os.path.exists(depth_source):\n            shutil.copytree(depth_source, depth_dest, dirs_exist_ok=True)\n\nprint(\"Copying complete!\")\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/test.py",
    "content": "import json\nimport os\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom PIL import Image\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\nfrom tqdm import tqdm\n\nfrom external import calc_psnr, calc_ssim\nfrom helpers import setup_camera\n\nTEST_CAMS = [0, 10, 15, 30]\n\n\ndef load_saved_params(seq, exp):\n    \"\"\"Load saved parameters for testing.\"\"\"\n    params_path = f\"./output/{exp}/{seq}/params.npz\"\n    params = np.load(params_path)\n    params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}\n    return params\n\n\ndef prepare_test_dataset(t, md, seq, exclude_cam_ids):\n    \"\"\"Prepare dataset for the given timestep, excluding specific camera IDs.\"\"\"\n    dataset = []\n    used_cam_ids = []\n    for c in range(len(md[\"fn\"][t])):\n        cam_id = md[\"cam_id\"][t][c]\n        # if cam_id in exclude_cam_ids:\n        #     continue\n        # ONLY USE THE SPECIFIC CAMS\n        if cam_id not in TEST_CAMS:\n            continue\n        w, h, k, w2c = md[\"w\"], md[\"h\"], md[\"k\"][t][c], md[\"w2c\"][t][c]\n        cam = setup_camera(w, h, k, w2c, near=1.0, far=100)\n        fn = md[\"fn\"][t][c]\n        im_path = f\"./data/{seq}/ims/{fn}\"\n        im = np.array(Image.open(im_path)) / 255.0\n        im = torch.tensor(im).float().cuda().permute(2, 0, 1)\n        dataset.append({\"cam\": cam, \"im\": im, \"id\": cam_id})\n        used_cam_ids.append(cam_id)\n    return dataset, used_cam_ids\n\n\ndef render_image(cam, rendervar):\n    \"\"\"Render an image using the given camera and render variables.\"\"\"\n    with torch.no_grad():\n        im, _, _ = Renderer(raster_settings=cam)(**rendervar)\n    return im\n\n\ndef test(seq, exp, exclude_cam_ids=[]):\n    \"\"\"Test saved parameters on a dataset and report metrics.\"\"\"\n    print(f\"Testing sequence: {seq}, experiment: {exp}\")\n\n    # Load metadata and saved parameters\n    md = json.load(open(f\"./data/{seq}/test_meta.json\", \"r\"))  # metadata\n    params = load_saved_params(seq, exp)\n\n    # Prepare output paths\n    render_path = f\"./output/{exp}/{seq}/renders\"\n    results_path = f\"./output/{exp}_metrics_test.csv\"\n    os.makedirs(render_path, exist_ok=True)\n\n    if not os.path.exists(results_path):\n        with open(results_path, \"w\") as f:\n            f.write(\"Sequence,Experiment,Timestep,Camera ID,PSNR,SSIM\\n\")\n\n    num_timesteps = len(md[\"fn\"])\n    psnrs, ssims = [], []\n    used_cameras = []\n\n    for t in tqdm(range(num_timesteps), desc=\"Testing timesteps\"):\n        dataset, used_cam_ids = prepare_test_dataset(t, md, seq, exclude_cam_ids)\n        used_cameras.extend(used_cam_ids)\n        rendervar = {\n            \"means3D\": params[\"means3D\"][t],\n            \"colors_precomp\": params[\"rgb_colors\"][t],\n            \"rotations\": torch.nn.functional.normalize(params[\"unnorm_rotations\"][t]),\n            \"opacities\": torch.sigmoid(params[\"logit_opacities\"]),\n            \"scales\": torch.exp(params[\"log_scales\"]),\n            \"means2D\": torch.zeros_like(params[\"means3D\"][t], device=\"cuda\"),\n        }\n\n        for camera in dataset:\n            im_rendered = render_image(camera[\"cam\"], rendervar)\n            gt = camera[\"im\"]\n\n            # Save rendered and ground truth images\n            idx = camera[\"id\"]\n            torchvision.utils.save_image(\n                im_rendered, f\"{render_path}/t{t:03d}_c{idx:02d}_rendered.png\"\n            )\n            torchvision.utils.save_image(\n                gt, f\"{render_path}/t{t:03d}_c{idx:02d}_gt.png\"\n            )\n\n            # Compute metrics\n            psnr_val = calc_psnr(im_rendered, gt).mean().item()\n            ssim_val = calc_ssim(im_rendered, gt).mean().item()\n            psnrs.append(psnr_val)\n            ssims.append(ssim_val)\n\n            # Save metrics\n            with open(results_path, \"a\") as f:\n                f.write(f\"{seq},{exp},{t},{idx},{psnr_val:.4f},{ssim_val:.4f}\\n\")\n\n    print(f\"Used cameras: {sorted(set(used_cameras))}\")\n    print(f\"Average PSNR: {np.mean(psnrs):.4f}, Average SSIM: {np.mean(ssims):.4f}\")\n\n\nif __name__ == \"__main__\":\n    exp_name = \"testing_init_pt\"\n    training_cam_ids = [1, 4, 7, 11, 17, 20, 23, 26, 29]  # Cameras used during training\n    # for sequence in [\"basketball\", \"boxes\", \"football\"]:\n    for sequence in [\"basketball\"]:\n        test(sequence, exp_name, exclude_cam_ids=training_cam_ids)\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/track_2d.py",
    "content": "import json\nimport os\n\nimport numpy as np\nimport torch\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\nfrom tqdm import tqdm\n\nfrom external import build_rotation\nfrom helpers import setup_camera\n\nREMOVE_BACKGROUND = False\n\nw, h = 640, 360\nnear, far = 0.01, 100.0\n\n\ndef gaussian_influence(point, gaussians):\n    \"\"\"\n    Computes the most influential Gaussian for a given 3D point.\n\n    Args:\n        point (torch.Tensor): 3D point (shape: [3]).\n        gaussians (dict): Dictionary containing:\n            - \"means3D\": [N, 3] Gaussian means.\n            - \"scales\": [N, 3] Gaussian scales.\n            - \"opacities\": [N, 1] Gaussian opacities.\n            - \"rotations\": [N, 4] Gaussian quaternion rotations.\n\n    Returns:\n        int: Index of the most influential Gaussian.\n    \"\"\"\n    # print(f\"Query point: {point}\")\n\n    means = gaussians[\"means3D\"]  # [N, 3]\n    scales = gaussians[\"scales\"]  # [N, 3]\n    opacities = gaussians[\"opacities\"]  # [N, 1]\n    rotations = gaussians[\"rotations\"]  # [N, 4]\n\n    sigmoid_opacities = opacities.squeeze()\n\n    diff = point - means  # [N, 3]\n\n    R = build_rotation(rotations)  # [N, 3, 3]\n\n    S = torch.diag_embed(scales)  # [N, 3, 3]\n    cov = R @ S @ S.transpose(-1, -2) @ R.transpose(-1, -2)  # [N, 3, 3]\n\n    try:\n        cov_inv = torch.inverse(cov)  # [N, 3, 3]\n        diff = diff.unsqueeze(1)  # [N, 1, 3]\n        # -1/2 * (x - mu)^T * cov^-1 * (x - mu)\n        mahalanobis = (\n                -0.5\n                * torch.matmul(\n            diff, torch.matmul(cov_inv, diff.transpose(-1, -2))\n        ).squeeze()\n        )  # [N]\n\n        # Gaussian influences\n        influences = sigmoid_opacities * torch.exp(mahalanobis)  # [N]\n\n        most_influential_idx = torch.argmax(influences).item()\n\n        return most_influential_idx\n\n    except RuntimeError as e:\n        print(f\"Error in  computation: {e}\")\n        return -1\n\n\ndef render_depth(timestep_data, w2c, k):\n    \"\"\"\n    Renders a depth map using the Gaussian parameters.\n\n    Args:\n        timestep_data (dict): Scene data for the specific timestep.\n\n    Returns:\n        torch.Tensor: Depth map.\n    \"\"\"\n    with torch.no_grad():\n        cam = setup_camera(w, h, k, w2c, near, far)\n        (\n            im,\n            _,\n            depth,\n        ) = Renderer(raster_settings=cam)(**timestep_data)\n\n        if depth.dim() == 3 and depth.size(0) == 1:  # Shape (1, H, W)\n            depth = depth.squeeze(0)\n\n        return depth\n\n\ndef load_scene_data(seq, exp, seg_as_col=False):\n    params = dict(np.load(f\"./output/{exp}/{seq}/params.npz\"))\n    params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}\n    is_fg = params[\"seg_colors\"][:, 0] > 0.5\n    scene_data = []\n    for t in range(len(params[\"means3D\"])):\n        rendervar = {\n            \"means3D\": params[\"means3D\"][t],\n            \"colors_precomp\": params[\"rgb_colors\"][t]\n            if not seg_as_col\n            else params[\"seg_colors\"],\n            \"rotations\": torch.nn.functional.normalize(params[\"unnorm_rotations\"][t]),\n            \"opacities\": torch.sigmoid(params[\"logit_opacities\"]),\n            \"scales\": torch.exp(params[\"log_scales\"]),\n            \"means2D\": torch.zeros_like(params[\"means3D\"][0], device=\"cuda\"),\n        }\n\n        if REMOVE_BACKGROUND:\n            rendervar = {k: v[is_fg] for k, v in rendervar.items()}\n        scene_data.append(rendervar)\n    if REMOVE_BACKGROUND:\n        is_fg = is_fg[is_fg]\n    return (\n        scene_data,\n        is_fg,\n    )\n\n\ndef unproject_2d_to_3d(query_pt, depth_map, intrinsics):\n    \"\"\"\n    Unproject a 2D point to 3D.\n    \"\"\"\n    x, y = query_pt\n    z = depth_map[y, x]\n    fx, fy = intrinsics[0, 0], intrinsics[1, 1]\n    cx, cy = intrinsics[0, 2], intrinsics[1, 2]\n    X = (x - cx) * z / fx\n    Y = (y - cy) * z / fy\n    Z = z\n\n    return torch.tensor([X, Y, Z], dtype=torch.float32).cuda()\n\n\ndef load_camera_params(dataset_path, seq, cam_id_g):\n    cam_params = f\"{dataset_path}/{seq}/merged_by_timestamp.json\"\n    with open(cam_params, \"r\") as f:\n        cam_params = json.load(f)\n    for timestamp, cameras in cam_params.items():\n        for cam_id, cam_data in cameras.items():\n            if int(cam_id) == int(cam_id_g):\n                return np.array(cam_data[\"w2c\"]), np.array(cam_data[\"k\"])\n    return None, None\n\n\ndef c2w_convert(point_3d, w2c):\n    point_3d_h = np.append(point_3d.cpu().numpy(), 1).reshape(4, 1)\n    c2w = np.linalg.inv(w2c)\n    point_cam = c2w @ point_3d_h\n    return torch.tensor(point_cam[:3].flatten(), dtype=torch.float32).cuda()\n\n\ndef w2c_convert(point_3d_h, w2c):\n    point_3d = np.append(point_3d_h.cpu().numpy(), 1).reshape(4, 1)\n    point_cam = w2c @ point_3d\n    return torch.tensor(point_cam[:3].flatten(), dtype=torch.float32).cuda()\n\n\ndef track_query_point(scene_data, query_point, depth_map, w2c, k, t_given=0):\n    \"\"\"\n    Tracks the 3D trajectory of a 2D query point across all frames.\n\n    Args:\n        scene_data (list): Scene data for all frames.\n        query_point (tuple): Initial 2D query point (x, y).\n        intrinsics (torch.Tensor): Camera intrinsics.\n        t_start (int): Starting frame index.\n\n    Returns:\n        list: A list of 3D points (numpy arrays) across all timestamps.\n    \"\"\"\n    trajectory = []\n    opacities = []\n    point_3d = unproject_2d_to_3d(query_point, depth_map, k)\n\n    point_3d_gaussian = c2w_convert(point_3d, w2c)\n    gaussians = scene_data[t_given]\n    gaussian_idx = gaussian_influence(point_3d_gaussian, gaussians)\n    for t in range(0, len(scene_data)):\n        gaussians = scene_data[t]\n        gaussian = {k: v[gaussian_idx] for k, v in gaussians.items()}\n        point_3d_gaussian = gaussian[\"means3D\"]\n        point_3d = w2c_convert(point_3d_gaussian, w2c)\n        trajectory.append(point_3d)\n        opacities.append(gaussian[\"opacities\"])\n    return trajectory\n\n\nif __name__ == \"__main__\":\n    exp = \"exp_init_1-7-14-20\"\n    exp = \"exp_merged_cleaned_pt_1-7-14-20\"\n    tapvid3d_dir = \"./datasets/tapvid3d_dataset/pstudio\"\n    dataset_path = \"./datasets/panoptic_d3dgs\"\n    # read the .npz files under directory\n    npz_files = [\n        f\n        for f in os.listdir(tapvid3d_dir)\n        if f.endswith(\".npz\") and \"basketball\" in f and \"_1.\" in f\n    ]\n\n    file_avg_distances = {}\n    # for each .npz file, it has following naming: {seq}_{cam_id}.npz\n    for npz_file in tqdm(npz_files):\n        seq, cam_id = npz_file.split(\".\")[0].split(\"_\")\n\n        # load tapvid3d\n        gt_file = f\"{tapvid3d_dir}/{npz_file}\"\n        print(f\"Loading {gt_file}\")\n        data = np.load(gt_file)\n        print(data.files)\n        queries_xyt = data[\"queries_xyt\"]\n        print(\"quries_xyt:\", queries_xyt)\n        gt_trajectories = data[\"tracks_XYZ\"]\n        trajectories = []\n        for query in tqdm(queries_xyt):\n            # round to nearest integer\n            q_x = round(query[0])\n            q_y = round(query[1])\n            query_point = (q_x, q_y)\n            t_given = int(query[2]) - 1\n\n            # Load the scene data\n            scene_data, _ = load_scene_data(seq, exp)\n            w2c, k = load_camera_params(dataset_path, seq, cam_id)\n\n            depth_map = render_depth(scene_data[t_given], w2c, k)\n\n            # Track the query point across all timestamps\n            trajectory = track_query_point(\n                scene_data, query_point, depth_map, w2c, k, t_given=t_given\n            )\n\n            trajectories.append(torch.stack(trajectory).cpu().numpy())\n\n        # save the trajectories\n        # np.savez(\n        #     \"{exp}_{seq}_{cam_id}_trajectories.npz\",\n        #     trajectories=trajectories.cpu().numpy(),\n        # )\n        # print(f\"Trajectories for {seq}_{cam_id} saved.\")\n        distances = []\n        for i, query in enumerate(queries_xyt):\n            t_given = int(query[2])\n            gt_traj = gt_trajectories[\n                      :, i\n                      ]  # Extract ground truth trajectory for this query\n            exp_traj = trajectories[i]  # Our computed trajectory\n\n            # Compute Euclidean distances for each timestamp\n            per_frame_distances = np.linalg.norm(gt_traj - exp_traj, axis=1)\n            avg_distance = np.mean(per_frame_distances)\n            sum_distance = np.sum(per_frame_distances)\n            distances.append(avg_distance)\n        print(f\"avg distance for {npz_file}: {np.mean(distances)}\")\n        file_avg_distances[npz_file] = np.mean(distances)\n\n    print(\"Average distances per file:\")\n    print(file_avg_distances)\n    print(\"Overall average distance:\", np.mean(list(file_avg_distances.values())))\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/track_3d.py",
    "content": "import os\n\nimport cv2\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom external import build_rotation\n\nREMOVE_BACKGROUND = False\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n\nw, h = 512, 512\nnear, far = 0.01, 100.0\n\nfrom mvtracker.evaluation.evaluator_3dpt import evaluate_3dpt\n\n\ndef load_scene_data(seq, exp, seg_as_col=False):\n    params = dict(np.load(f\"./output/{exp}/{seq}/params.npz\"))\n    params = {k: torch.tensor(v, device=device).float() for k, v in params.items()}\n\n    is_fg = params[\"seg_colors\"][:, 0] > 0.5\n    scene_data = []\n    for t in range(len(params[\"means3D\"])):\n        rendervar = {\n            \"means3D\": params[\"means3D\"][t],\n            \"colors_precomp\": params[\"rgb_colors\"][t]\n            if not seg_as_col\n            else params[\"seg_colors\"],\n            \"rotations\": params[\"unnorm_rotations\"][t],\n            \"opacities\": torch.sigmoid(params[\"logit_opacities\"]),\n            \"scales\": torch.exp(params[\"log_scales\"]),\n            \"means2D\": torch.zeros_like(params[\"means3D\"][0], device=device),\n        }\n\n        if REMOVE_BACKGROUND:\n            rendervar = {k: v[is_fg] for k, v in rendervar.items()}\n        scene_data.append(rendervar)\n    if REMOVE_BACKGROUND:\n        is_fg = is_fg[is_fg]\n    return scene_data, is_fg\n\n\ndef load_depth_maps(dataset_path, seq, cam_ids):\n    depth_maps = {}\n    for cam_id in cam_ids:\n        depth_dir = f\"{dataset_path}/{seq}/depths/{cam_id}/\"\n        depth_maps[cam_id] = []\n        for frame_idx in sorted(os.listdir(depth_dir)):\n            depth_path = os.path.join(depth_dir, frame_idx)\n            depth_map = (\n                    cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 1000.0\n            )\n            depth_maps[cam_id].append(torch.tensor(depth_map, device=device))\n        depth_maps[cam_id] = torch.stack(depth_maps[cam_id])\n    return depth_maps\n\n\ndef preload_camera_data(dataset_path, seq, cam_ids):\n    cam_params_path = f\"{dataset_path}/{seq}/metadata.json\"\n    with open(cam_params_path, \"r\") as f:\n        cam_params = json.load(f)\n\n    preloaded_cameras = {}\n    for cam_id in cam_ids:\n        for timestamp, cameras in cam_params.items():\n            if str(cam_id) in cameras:\n                preloaded_cameras[cam_id] = (\n                    torch.tensor(\n                        cameras[str(cam_id)][\"w2c\"], dtype=torch.float32\n                    ).cuda(),\n                    torch.tensor(cameras[str(cam_id)][\"k\"], dtype=torch.float32).cuda(),\n                )\n                break  # We only need one instance per camera\n    return preloaded_cameras\n\n\ndef gaussian_influence(point, gaussians):\n    \"\"\"\n    Computes the most influential Gaussian for a given 3D point.\n\n    Args:\n        point (torch.Tensor): 3D point (shape: [3]).\n        gaussians (dict): Dictionary containing:\n            - \"means3D\": [N, 3] Gaussian means.\n            - \"scales\": [N, 3] Gaussian scales.\n            - \"opacities\": [N, 1] Gaussian opacities.\n            - \"rotations\": [N, 4] Gaussian quaternion rotations.\n\n    Returns:\n        int: Index of the most influential Gaussian.\n    \"\"\"\n    # print(f\"Query point: {point}\")\n\n    means = gaussians[\"means3D\"]  # [N, 3]\n    scales = gaussians[\"scales\"]  # [N, 3]\n    opacities = gaussians[\"opacities\"]  # [N, 1]\n    rotations = gaussians[\"rotations\"]  # [N, 4]\n\n    sigmoid_opacities = opacities.squeeze()\n\n    diff = point - means  # [N, 3]\n\n    R = build_rotation(rotations)  # [N, 3, 3]\n\n    S = torch.diag_embed(scales)  # [N, 3, 3]\n    cov = R @ S @ S.transpose(-1, -2) @ R.transpose(-1, -2)  # [N, 3, 3]\n\n    try:\n        cov_inv = torch.inverse(cov)  # [N, 3, 3]\n        diff = diff.unsqueeze(1)  # [N, 1, 3]\n        # -1/2 * (x - mu)^T * cov^-1 * (x - mu)\n        mahalanobis = (\n                -0.5\n                * torch.matmul(\n            diff, torch.matmul(cov_inv, diff.transpose(-1, -2))\n        ).squeeze()\n        )  # [N]\n\n        # Gaussian influences\n        influences = sigmoid_opacities * torch.exp(mahalanobis)  # [N]\n\n        most_influential_idx = torch.argmax(influences).item()\n\n        print(\"Most influnce:\", influences[most_influential_idx])\n\n        return most_influential_idx, influences[most_influential_idx]\n\n    except RuntimeError as e:\n        print(f\"Error in  computation: {e}\")\n        return -1\n\n\ndef get_visibilities(\n        point_3d,\n        cam_ids,\n        t,\n        depth_maps,\n        preloaded_cameras,\n        th=0.02,\n):\n    visibilities = []\n    for cam_id in cam_ids:\n        if cam_id not in preloaded_cameras:\n            continue\n\n        w2c, intrinsics = preloaded_cameras[cam_id]\n        point_cam = torch.matmul(\n            w2c, torch.cat([point_3d, torch.tensor([1.0], device=point_3d.device)])\n        )[:3]\n        X, Y, Z = point_cam\n        if Z <= 0:\n            continue\n\n        x = int((X * intrinsics[0, 0]) / Z + intrinsics[0, 2])\n        y = int((Y * intrinsics[1, 1]) / Z + intrinsics[1, 2])\n        if not (\n                0 <= x < depth_maps[cam_id].shape[2]\n                and 0 <= y < depth_maps[cam_id].shape[1]\n        ):\n            continue\n\n        depth_at_pixel = depth_maps[cam_id][t, y, x]\n        depth_diff = Z - depth_at_pixel\n        visibilities.append(0 <= depth_diff <= th)\n    return visibilities\n\n\ndef track_query_point(\n        scene_data,\n        query_point,\n        cam_ids,\n        t_given,\n        depth_maps,\n        preloaded_cameras,\n        threshold=0.02,\n):\n    \"\"\"\n    Tracks the 3D trajectory of a 3D query point across all frames.\n\n    Args:\n        scene_data (list): Scene data for all frames.\n        query_point (tuple): Initial 2D query point (x, y).\n        intrinsics (torch.Tensor): Camera intrinsics.\n        t_start (int): Starting frame index.\n\n    Returns:\n        list: A list of 3D points (numpy arrays) across all timestamps.\n    \"\"\"\n    trajectory = []\n    visibilities = []\n\n    gaussians = scene_data[t_given]\n    gaussian_idx, influence = gaussian_influence(query_point, gaussians)\n\n    for t in range(0, len(scene_data)):\n        gaussians = scene_data[t]\n        gaussian = {k: v[gaussian_idx] for k, v in gaussians.items()}\n        point_3d_gaussian = gaussian[\"means3D\"]\n        trajectory.append(point_3d_gaussian)\n        visibility = get_visibilities(\n            point_3d_gaussian, cam_ids, t, depth_maps, preloaded_cameras, threshold\n        )\n        visibilities.append(torch.tensor(visibility))\n\n    # print ratio of visibilities for each cam: visibity has shape n_frames * cam\n    # print(\"Visibility ratio for each camera:\")\n    # print(np.array(visibility).sum(axis=0) / len(visibility))\n    return trajectory, visibilities\n\n\nif __name__ == \"__main__\":\n    exp = \"exp_use_duster_views_0123\"\n    sequences = [\n        \"20200709-subject-01__20200709_141754\",\n        \"20200813-subject-02__20200813_145653\",\n        \"20200903-subject-04__20200903_104428\",\n        \"20200820-subject-03__20200820_135841\",\n        \"20200908-subject-05__20200908_144409\",\n        \"20200918-subject-06__20200918_114117\",\n        \"20200928-subject-07__20200928_144906\",\n        \"20201002-subject-08__20201002_110227\",\n        \"20201015-subject-09__20201015_144721\",\n        \"20201022-subject-10__20201022_112651\",\n    ]\n    dataset_path = \"./datasets/dex_formatted/neus_nsubsample-3\"\n    remove_hand = False\n    use_duster = True\n    cleaned_duster = False\n    views = \"0123\"\n    tracks_path = f\"seed-000072_remove-hand-{remove_hand}_tracks-384_use-duster-depths-{use_duster}_clean-duster-depths-{cleaned_duster}_views-{views}_duster-views-{views}.npz\"\n    # sequences = [\"basketball\"]\n    # cam_ids = [27, 16, 14, 8]\n    # cam_ids = [0, 1, 2, 3, 4, 5, 6, 7]\n    cam_ids = [0, 1, 2, 3]\n    for seq in sequences:\n        merged_path = f\"{dataset_path}/{seq}/{tracks_path}\"\n        # Load scene data\n        scene_data, is_fg = load_scene_data(seq, exp, s=1)\n        # scene_data = []\n        print(\"Scene data loaded.\")\n        depth_maps = load_depth_maps(dataset_path, seq, cam_ids)\n        preloaded_cameras = preload_camera_data(dataset_path, seq, cam_ids)\n\n        load_tapvid3d = np.load(merged_path)\n        query_points = load_tapvid3d[\"query_points_3d\"]\n        predictions_file = f\"./output/{exp}/{seq}/predictions.npz\"\n\n        if True:\n            THRESHOLD = 0.02\n            predictions = []\n            visibilities = []\n            for i, query_point in tqdm(enumerate(query_points), desc=\"Query points\"):\n                # print(\"Query point:\", query_point)\n                given_time = query_point[0]\n                # to int\n                # query_point = query_point.astype(int)\n                given_time = int(given_time)\n                qp = query_point[1:]\n                # convert it to Torch tensor\n                # torch.tensor([X, Y, Z], dtype=torch.float32).cuda()\n                qp = torch.tensor(qp, dtype=torch.float32).cuda()\n                trajectory, visiblity_d = track_query_point(\n                    scene_data,\n                    qp,\n                    cam_ids,\n                    given_time,\n                    depth_maps,\n                    preloaded_cameras,\n                    THRESHOLD,\n                )\n                # trajectory = trajectory.cpu().numpy()\n                predictions.append(torch.stack(trajectory).cpu().numpy())\n                visibilities.append(torch.stack(visiblity_d).cpu().numpy())\n\n            # pred shape is: n_queries, n_frames, 3\n            # convert it to n_frames, n_queries, 3\n            predictions = np.array(predictions)\n            predictions = np.transpose(predictions, (1, 0, 2))\n\n            visibilities = np.array(visibilities)\n            visibilities = np.transpose(visibilities, (2, 1, 0))\n\n            preds_file = f\"./output/{exp}/{seq}/predictions_threshold_{THRESHOLD}.npz\"\n            np.savez(\n                preds_file,\n                predictions=predictions,\n                visibilities=visibilities,\n            )\n            print(f\"Results saved for threshold {THRESHOLD} at: {preds_file}\")\n\n        # Load the ground truth\n        query_points = load_tapvid3d[\"query_points_3d\"]\n        query_points = query_points[None, ...]  # batch * num tracks * 4\n        gt_visibilities = load_tapvid3d[\"per_view_visibilities\"]\n        gt_visibilities = gt_visibilities[\n            None, ...\n        ]  # batch * view * num frames * num tracks\n        # convert all of them to false\n        gt_tracks = load_tapvid3d[\"trajectories\"]\n        gt_tracks = gt_tracks[None, ...]  # batch * num frames * num tracks * 3\n        # pred_visibilities = visibilities[None, ...]\n        # pred_visibilities_t = visibilities_i[None, ...]\n        pred_tracks = predictions[None, ...]\n        # print all dimensions for debugging\n        print(\"query_points:\", query_points.shape)\n        print(\"gt_occluded:\", gt_visibilities.shape)\n        print(\"gt_tracks:\", gt_tracks.shape)\n        print(\"pred_occluded:\", gt_visibilities.shape)\n        print(\"pred_tracks:\", pred_tracks.shape)\n\n        gt_visibilities_any_view = gt_visibilities.any(axis=1)\n\n        pred_visibilities = visibilities[None, ...]\n        pred_visibilities_any_view = pred_visibilities.any(axis=1)\n        print(\"EXP: \", exp)\n        print(\"SEQ: \", seq)\n        print(\"Evaluating ... \")\n        metrics_2 = evaluate_3dpt(\n            gt_tracks[0],\n            gt_visibilities_any_view[0],\n            pred_tracks[0],\n            pred_visibilities_any_view[0],\n            evaluation_setting=\"dex-ycb-multiview\",\n            query_points=query_points[0],\n            track_upscaling_factor=1,\n            verbose=True,\n        )\n\n        # Save evaluation results\n        results_file = f\"./output/{exp}/{seq}/results_threshold_{THRESHOLD}.txt\"\n        with open(results_file, \"w\") as f:\n            f.write(f\"Exp: {exp}\\n\")\n            f.write(f\"Seq: {seq}\\n\")\n            f.write(f\"Threshold: {THRESHOLD}\\n\")\n            f.write(str(metrics_2))\n            f.write(\"\\n\")\n        print(f\"Results saved at: {results_file}\")\n\n        print(\"Done.\")\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/train.py",
    "content": "import copy\nimport json\nimport os\nfrom random import randint\n\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\nfrom tqdm import tqdm\n\nfrom external import calc_ssim, calc_psnr, build_rotation, densify, update_params_and_optimizer\nfrom helpers import setup_camera, l1_loss_v1, l1_loss_v2, weighted_l2_loss_v1, weighted_l2_loss_v2, quat_mult, \\\n    o3d_knn, params2rendervar, params2cpu, save_params\n\n\ndef get_dataset(t, md, seq):\n    dataset = []\n    for c in range(len(md['fn'][t])):\n        w, h, k, w2c = md['w'], md['h'], md['k'][t][c], md['w2c'][t][c]\n        cam = setup_camera(w, h, k, w2c, near=1.0, far=100)\n        fn = md['fn'][t][c]\n        im = np.array(copy.deepcopy(Image.open(f\"./data/{seq}/ims/{fn}\")))\n        im = torch.tensor(im).float().cuda().permute(2, 0, 1) / 255\n        seg = np.array(copy.deepcopy(Image.open(f\"./data/{seq}/seg/{fn.replace('.jpg', '.png')}\"))).astype(np.float32)\n        seg = torch.tensor(seg).float().cuda()\n        seg_col = torch.stack((seg, torch.zeros_like(seg), 1 - seg))\n        dataset.append({'cam': cam, 'im': im, 'seg': seg_col, 'id': c})\n    return dataset\n\n\ndef get_batch(todo_dataset, dataset):\n    if not todo_dataset:\n        todo_dataset = dataset.copy()\n    curr_data = todo_dataset.pop(randint(0, len(todo_dataset) - 1))\n    return curr_data\n\n\ndef initialize_params(seq, md):\n    init_pt_cld = np.load(f\"./data/{seq}/init_pt_cld.npz\")[\"data\"]\n    seg = init_pt_cld[:, 6]\n    max_cams = 50\n    sq_dist, _ = o3d_knn(init_pt_cld[:, :3], 3)\n    mean3_sq_dist = sq_dist.mean(-1).clip(min=0.0000001)\n    params = {\n        'means3D': init_pt_cld[:, :3],\n        'rgb_colors': init_pt_cld[:, 3:6],\n        'seg_colors': np.stack((seg, np.zeros_like(seg), 1 - seg), -1),\n        'unnorm_rotations': np.tile([1, 0, 0, 0], (seg.shape[0], 1)),\n        'logit_opacities': np.zeros((seg.shape[0], 1)),\n        'log_scales': np.tile(np.log(np.sqrt(mean3_sq_dist))[..., None], (1, 3)),\n        'cam_m': np.zeros((max_cams, 3)),\n        'cam_c': np.zeros((max_cams, 3)),\n    }\n    params = {k: torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True)) for k, v in\n              params.items()}\n    cam_centers = np.linalg.inv(md['w2c'][0])[:, :3, 3]  # Get scene radius\n    scene_radius = 1.1 * np.max(np.linalg.norm(cam_centers - np.mean(cam_centers, 0)[None], axis=-1))\n    variables = {'max_2D_radius': torch.zeros(params['means3D'].shape[0]).cuda().float(),\n                 'scene_radius': scene_radius,\n                 'means2D_gradient_accum': torch.zeros(params['means3D'].shape[0]).cuda().float(),\n                 'denom': torch.zeros(params['means3D'].shape[0]).cuda().float()}\n    return params, variables\n\n\ndef initialize_optimizer(params, variables):\n    lrs = {\n        'means3D': 0.00016 * variables['scene_radius'],\n        'rgb_colors': 0.0025,\n        'seg_colors': 0.0,\n        'unnorm_rotations': 0.001,\n        'logit_opacities': 0.05,\n        'log_scales': 0.001,\n        'cam_m': 1e-4,\n        'cam_c': 1e-4,\n    }\n    param_groups = [{'params': [v], 'name': k, 'lr': lrs[k]} for k, v in params.items()]\n    return torch.optim.Adam(param_groups, lr=0.0, eps=1e-15)\n\n\ndef get_loss(params, curr_data, variables, is_initial_timestep):\n    losses = {}\n\n    rendervar = params2rendervar(params)\n    rendervar['means2D'].retain_grad()\n    im, radius, _, = Renderer(raster_settings=curr_data['cam'])(**rendervar)\n    curr_id = curr_data['id']\n    im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None]\n    losses['im'] = 0.8 * l1_loss_v1(im, curr_data['im']) + 0.2 * (1.0 - calc_ssim(im, curr_data['im']))\n    variables['means2D'] = rendervar['means2D']  # Gradient only accum from colour render for densification\n\n    segrendervar = params2rendervar(params)\n    segrendervar['colors_precomp'] = params['seg_colors']\n    seg, _, _, = Renderer(raster_settings=curr_data['cam'])(**segrendervar)\n    losses['seg'] = 0.8 * l1_loss_v1(seg, curr_data['seg']) + 0.2 * (1.0 - calc_ssim(seg, curr_data['seg']))\n\n    if not is_initial_timestep:\n        is_fg = (params['seg_colors'][:, 0] > 0.5).detach()\n        fg_pts = rendervar['means3D'][is_fg]\n        fg_rot = rendervar['rotations'][is_fg]\n\n        rel_rot = quat_mult(fg_rot, variables[\"prev_inv_rot_fg\"])\n        rot = build_rotation(rel_rot)\n        neighbor_pts = fg_pts[variables[\"neighbor_indices\"]]\n        curr_offset = neighbor_pts - fg_pts[:, None]\n        curr_offset_in_prev_coord = (rot.transpose(2, 1)[:, None] @ curr_offset[:, :, :, None]).squeeze(-1)\n        losses['rigid'] = weighted_l2_loss_v2(curr_offset_in_prev_coord, variables[\"prev_offset\"],\n                                              variables[\"neighbor_weight\"])\n\n        losses['rot'] = weighted_l2_loss_v2(rel_rot[variables[\"neighbor_indices\"]], rel_rot[:, None],\n                                            variables[\"neighbor_weight\"])\n\n        curr_offset_mag = torch.sqrt((curr_offset ** 2).sum(-1) + 1e-20)\n        losses['iso'] = weighted_l2_loss_v1(curr_offset_mag, variables[\"neighbor_dist\"], variables[\"neighbor_weight\"])\n\n        losses['floor'] = torch.clamp(fg_pts[:, 1], min=0).mean()\n\n        bg_pts = rendervar['means3D'][~is_fg]\n        bg_rot = rendervar['rotations'][~is_fg]\n        losses['bg'] = l1_loss_v2(bg_pts, variables[\"init_bg_pts\"]) + l1_loss_v2(bg_rot, variables[\"init_bg_rot\"])\n\n        losses['soft_col_cons'] = l1_loss_v2(params['rgb_colors'], variables[\"prev_col\"])\n\n    loss_weights = {'im': 1.0, 'seg': 3.0, 'rigid': 4.0, 'rot': 4.0, 'iso': 2.0, 'floor': 2.0, 'bg': 20.0,\n                    'soft_col_cons': 0.01}\n    loss = sum([loss_weights[k] * v for k, v in losses.items()])\n    seen = radius > 0\n    variables['max_2D_radius'][seen] = torch.max(radius[seen], variables['max_2D_radius'][seen])\n    variables['seen'] = seen\n    return loss, variables\n\n\ndef initialize_per_timestep(params, variables, optimizer):\n    pts = params['means3D']\n    rot = torch.nn.functional.normalize(params['unnorm_rotations'])\n    new_pts = pts + (pts - variables[\"prev_pts\"])\n    new_rot = torch.nn.functional.normalize(rot + (rot - variables[\"prev_rot\"]))\n\n    is_fg = params['seg_colors'][:, 0] > 0.5\n    prev_inv_rot_fg = rot[is_fg]\n    prev_inv_rot_fg[:, 1:] = -1 * prev_inv_rot_fg[:, 1:]\n    fg_pts = pts[is_fg]\n    prev_offset = fg_pts[variables[\"neighbor_indices\"]] - fg_pts[:, None]\n    variables['prev_inv_rot_fg'] = prev_inv_rot_fg.detach()\n    variables['prev_offset'] = prev_offset.detach()\n    variables[\"prev_col\"] = params['rgb_colors'].detach()\n    variables[\"prev_pts\"] = pts.detach()\n    variables[\"prev_rot\"] = rot.detach()\n\n    new_params = {'means3D': new_pts, 'unnorm_rotations': new_rot}\n    params = update_params_and_optimizer(new_params, params, optimizer)\n\n    return params, variables\n\n\ndef initialize_post_first_timestep(params, variables, optimizer, num_knn=20):\n    is_fg = params['seg_colors'][:, 0] > 0.5\n    init_fg_pts = params['means3D'][is_fg]\n    init_bg_pts = params['means3D'][~is_fg]\n    init_bg_rot = torch.nn.functional.normalize(params['unnorm_rotations'][~is_fg])\n    neighbor_sq_dist, neighbor_indices = o3d_knn(init_fg_pts.detach().cpu().numpy(), num_knn)\n    neighbor_weight = np.exp(-2000 * neighbor_sq_dist)\n    neighbor_dist = np.sqrt(neighbor_sq_dist)\n    variables[\"neighbor_indices\"] = torch.tensor(neighbor_indices).cuda().long().contiguous()\n    variables[\"neighbor_weight\"] = torch.tensor(neighbor_weight).cuda().float().contiguous()\n    variables[\"neighbor_dist\"] = torch.tensor(neighbor_dist).cuda().float().contiguous()\n\n    variables[\"init_bg_pts\"] = init_bg_pts.detach()\n    variables[\"init_bg_rot\"] = init_bg_rot.detach()\n    variables[\"prev_pts\"] = params['means3D'].detach()\n    variables[\"prev_rot\"] = torch.nn.functional.normalize(params['unnorm_rotations']).detach()\n    params_to_fix = ['logit_opacities', 'log_scales', 'cam_m', 'cam_c']\n    for param_group in optimizer.param_groups:\n        if param_group[\"name\"] in params_to_fix:\n            param_group['lr'] = 0.0\n    return variables\n\n\ndef report_progress(params, data, i, progress_bar, every_i=100):\n    if i % every_i == 0:\n        im, _, _, = Renderer(raster_settings=data['cam'])(**params2rendervar(params))\n        curr_id = data['id']\n        im = torch.exp(params['cam_m'][curr_id])[:, None, None] * im + params['cam_c'][curr_id][:, None, None]\n        psnr = calc_psnr(im, data['im']).mean()\n        progress_bar.set_postfix({\"train img 0 PSNR\": f\"{psnr:.{7}f}\"})\n        progress_bar.update(every_i)\n\n\ndef train(seq, exp):\n    if os.path.exists(f\"./output/{exp}/{seq}\"):\n        print(f\"Experiment '{exp}' for sequence '{seq}' already exists. Exiting.\")\n        return\n    md = json.load(open(f\"./data/{seq}/train_meta.json\", 'r'))  # metadata\n    num_timesteps = len(md['fn'])\n    params, variables = initialize_params(seq, md)\n    optimizer = initialize_optimizer(params, variables)\n    output_params = []\n    for t in range(num_timesteps):\n        dataset = get_dataset(t, md, seq)\n        todo_dataset = []\n        is_initial_timestep = (t == 0)\n        if not is_initial_timestep:\n            params, variables = initialize_per_timestep(params, variables, optimizer)\n        num_iter_per_timestep = 10000 if is_initial_timestep else 2000\n        progress_bar = tqdm(range(num_iter_per_timestep), desc=f\"timestep {t}\")\n        for i in range(num_iter_per_timestep):\n            curr_data = get_batch(todo_dataset, dataset)\n            loss, variables = get_loss(params, curr_data, variables, is_initial_timestep)\n            loss.backward()\n            with torch.no_grad():\n                report_progress(params, dataset[0], i, progress_bar)\n                if is_initial_timestep:\n                    params, variables = densify(params, variables, optimizer, i)\n                optimizer.step()\n                optimizer.zero_grad(set_to_none=True)\n        progress_bar.close()\n        output_params.append(params2cpu(params, is_initial_timestep))\n        if is_initial_timestep:\n            variables = initialize_post_first_timestep(params, variables, optimizer)\n    save_params(output_params, seq, exp)\n\n\nif __name__ == \"__main__\":\n    exp_name = \"exp1\"\n    for sequence in [\"basketball\", \"boxes\", \"football\", \"juggle\", \"softball\", \"tennis\"]:\n        train(sequence, exp_name)\n        torch.cuda.empty_cache()\n"
  },
  {
    "path": "mvtracker/models/core/dynamic3dgs/visualize.py",
    "content": "import json\nimport os\nfrom pathlib import Path\n\nimport matplotlib\nimport numpy as np\nimport rerun as rr\nimport torch\nfrom PIL import Image\nfrom diff_gaussian_rasterization import GaussianRasterizer as Renderer\n\nfrom .helpers import setup_camera\n\nRENDER_MODE = 'color'  # 'color', 'depth' or 'centers'\n# RENDER_MODE = 'depth'  # 'color', 'depth' or 'centers'\n# RENDER_MODE = 'centers'  # 'color', 'depth' or 'centers'\n\nREMOVE_BACKGROUND = False  # False or True\n# REMOVE_BACKGROUND = True  # False or True\n\nFORCE_LOOP = False  # False or True\n# FORCE_LOOP = True  # False or True\n\n\nw, h = 640, 360\nnear, far = 0.01, 100.0\ntraj_frac = 200  # 0.5% of points\n# VIEWS = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\nVIEWS = [1, 14]\nlog_rgb = True\nlog_d3dgs_rgb = False\nlog_d3dgs_depth = False\nlog_d3dgs_point_cloud = True\nlog_tracks = True\nlog_n_skip_view = 1\nlog_n_skip_t = 1\n\n\ndef load_scene_data(params_path, seg_as_col=False):\n    \"\"\"Load 3D scene data from file.\"\"\"\n    params = dict(np.load(params_path, allow_pickle=True))\n    params = {k: torch.tensor(v).cuda().float() for k, v in params.items()}\n    is_fg = params['seg_colors'][:, 0] > 0.5\n    scene_data = []\n    for t in range(len(params['means3D'])):\n        rendervar = {\n            'means3D': params['means3D'][t],\n            'colors_precomp': params['rgb_colors'][t] if not seg_as_col else params['seg_colors'],\n            'rotations': torch.nn.functional.normalize(params['unnorm_rotations'][t]),\n            'opacities': torch.sigmoid(params['logit_opacities']),\n            'scales': torch.exp(params['log_scales']),\n            'means2D': torch.zeros_like(params['means3D'][0], device=\"cuda\")\n        }\n        if REMOVE_BACKGROUND:\n            rendervar = {k: v[is_fg] for k, v in rendervar.items()}\n        scene_data.append(rendervar)\n    if REMOVE_BACKGROUND:\n        is_fg = is_fg[is_fg]\n    return scene_data, is_fg\n\n\ndef render(w2c, k, timestep_data):\n    \"\"\"Render scene using Gaussian Rasterization.\"\"\"\n    with torch.no_grad():\n        cam = setup_camera(w, h, k, w2c, near, far)\n        im, _, depth = Renderer(raster_settings=cam)(**timestep_data)\n        return im, depth\n\n\ndef log_tracks_to_rerun(\n        tracks: np.ndarray,\n        visibles: np.ndarray,\n        query_timestep: np.ndarray,\n        colors: np.ndarray,\n        track_names=None,\n\n        entity_format_str=\"{}\",\n\n        log_points=True,\n        points_radii=0.01,\n        invisible_color=[0., 0., 0.],\n\n        log_line_strips=True,\n        max_strip_length_past=30,\n        max_strip_length_future=1,\n        strips_radii=0.001,\n\n        log_error_lines=False,\n        error_lines_radii=0.0042,\n        error_lines_color=[1., 0., 0.],\n        gt_for_error_lines=None,\n\n        fps=30,\n) -> None:\n    \"\"\"\n    Log tracks to Rerun.\n\n    Parameters:\n        tracks: Shape (T, N, 3), the 3D trajectories of points.\n        visibles: Shape (T, N), boolean visibility mask for each point at each timestep.\n        query_timestep: Shape (T, N), the frame index after which the tracks start.\n        colors: Shape (N, 4), RGBA colors for each point.\n        entity_prefix: String prefix for entity hierarchy in Rerun.\n        entity_suffix: String suffix for entity hierarchy in Rerun.\n    \"\"\"\n\n    T, N, _ = tracks.shape\n    assert tracks.shape == (T, N, 3)\n    assert visibles.shape == (T, N)\n    assert query_timestep.shape == (N,)\n    assert query_timestep.min() >= 0\n    assert query_timestep.max() < T\n    assert colors.shape == (N, 4)\n\n    for n in range(N):\n        track_name = track_names[n] if track_names is not None else f\"track-{n}\"\n        rr.log(entity_format_str.format(track_name), rr.Clear(recursive=True))\n        for t in range(query_timestep[n], T):\n            rr.set_time_seconds(\"frame\", t / fps)\n\n            # Log the point (special handling for invisible points)\n            if log_points:\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/point\"),\n                    rr.Points3D(\n                        positions=[tracks[t, n]],\n                        colors=[colors[n, :3]] if visibles[t, n] else [invisible_color],\n                        radii=points_radii,\n                    ),\n                )\n\n            # Log line segments for visible tracks\n            if log_line_strips and t > query_timestep[n]:\n                strip_t_start = max(t - max_strip_length_past, query_timestep[n].item())\n                strip_t_end = min(t + max_strip_length_future, T - 1)\n\n                strips = np.stack([\n                    tracks[strip_t_start:strip_t_end, n],\n                    tracks[strip_t_start + 1:strip_t_end + 1, n],\n                ], axis=-2)\n                strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n]\n                strips_colors = np.where(\n                    strips_visibility[:, None],\n                    colors[None, n, :3],\n                    [invisible_color],\n                )\n\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/line\"),\n                    rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii),\n                )\n\n            if log_error_lines:\n                assert gt_for_error_lines is not None\n                strips = np.stack([\n                    tracks[t, n],\n                    gt_for_error_lines[t, n],\n                ], axis=-2)\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/error\"),\n                    rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii),\n                )\n\n\ndef visualize(seq, exp):\n    \"\"\"Visualize 3D Gaussian Splatting using Rerun.\"\"\"\n    scene_root = Path(f\"../datasets/panoptic_d3dgs/{seq}\")\n    output_root = Path(f\"./output/{exp}/{seq}\")\n    scene_data, is_fg = load_scene_data(os.path.join(output_root, \"params.npz\"))\n    md = json.load(open(os.path.join(scene_root, \"train_meta.json\"), \"r\"))\n\n    n_frames = len(md['fn'])\n    n_views = len(VIEWS)\n\n    # Check that the selected views are in the training set\n    view_paths = []\n    for view_idx in VIEWS:\n        view_path = scene_root / \"ims\" / f\"{view_idx}\"\n        assert view_idx in md[\"cam_id\"][0], f\"Camera {view_idx} is not in the training set\"\n        assert view_path.exists()\n        view_paths.append(view_path)\n    frame_paths = [sorted(view_path.glob(\"*.jpg\")) for view_path in view_paths]\n    assert all(len(frame_paths[v]) == n_frames for v in range(len(VIEWS)))\n    assert len(scene_data) == n_frames\n\n    # Create the output directory\n    views_selection_str = '-'.join(str(v) for v in VIEWS)\n    output_path = scene_root / f'dynamic3dgs-views-{views_selection_str}'\n    os.makedirs(output_path, exist_ok=True)\n\n    # Load the camera parameters\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_idx in VIEWS:\n        fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], []\n        for t in range(n_frames):\n            view_idx_in_array = md['cam_id'][t].index(view_idx)\n            k = md['k'][t][view_idx_in_array]\n            w2c = np.array(md['w2c'][t][view_idx_in_array])\n\n            fx_current.append(k[0][0])\n            fy_current.append(k[1][1])\n            cx_current.append(k[0][2])\n            cy_current.append(k[1][2])\n            extrinsics_current.append(w2c)\n\n        assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames))\n\n        fx.append(fx_current[0])\n        fy.append(fy_current[0])\n        cx.append(cx_current[0])\n        cy.append(cy_current[0])\n        extrinsics.append(extrinsics_current[0])\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    k = torch.eye(3).float()[None].repeat(n_views, 1, 1)\n    k[:, 0, 0] = fx\n    k[:, 1, 1] = fy\n    k[:, 0, 2] = cx\n    k[:, 1, 2] = cy\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n    k_inv = torch.inverse(k)\n    extrinsics_inv = torch.inverse(extrinsics)\n\n    # Render the depths\n    rgbs = np.stack([\n        np.stack([\n            np.array(Image.open(frame_paths[v][t]))\n            for t in range(n_frames)\n        ])\n        for v in range(n_views)\n    ])\n    h, w = rgbs.shape[2], rgbs.shape[3]\n    d3dgs_rgbs = []\n    d3dgs_depths = []\n    for v, view_idx in enumerate(VIEWS):\n        for t in range(n_frames):\n            im, depth = render(extrinsics[v].numpy(), k[v].numpy(), scene_data[t])\n            d3dgs_rgbs.append(im.cpu().numpy().transpose(1, 2, 0))\n            d3dgs_depths.append(depth.cpu().numpy()[0])\n    d3dgs_rgbs = np.stack(d3dgs_rgbs).reshape(n_views, n_frames, h, w, 3)\n    d3dgs_depths = np.stack(d3dgs_depths).reshape(n_views, n_frames, h, w)\n\n    assert rgbs.shape == (n_views, n_frames, h, w, 3)\n    assert d3dgs_rgbs.shape == (n_views, n_frames, h, w, 3)\n    assert d3dgs_depths.shape == (n_views, n_frames, h, w)\n\n    gt_tracks = np.stack([data['means3D'][is_fg][::traj_frac].contiguous().cpu().numpy() for data in scene_data])\n    n_tracks = gt_tracks.shape[1]\n    gt_vis = np.ones((n_frames, n_tracks), dtype=bool)\n    query_timestep = gt_vis.argmin(0)\n    assert gt_tracks.shape == (n_frames, n_tracks, 3)\n    assert gt_vis.shape == (n_frames, n_tracks)\n\n    cmap = matplotlib.colormaps[\"gist_rainbow\"]\n    norm = matplotlib.colors.Normalize(vmin=gt_tracks[..., 0].min(), vmax=gt_tracks[..., 0].max())\n    track_colors = cmap(norm(gt_tracks[-1, :, 0]))\n    assert track_colors.shape == (n_tracks, 4)\n\n    rr.init(\"reconstruction\", recording_id=\"v0.1\")\n    rr.connect_tcp()\n    rr.set_time_seconds(\"frame\", 0)\n    rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n    rr.log(\"world/xyz\", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                                    colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]]))\n    for t in range(0, n_frames, log_n_skip_t):\n        for v in range(0, n_views, log_n_skip_view):\n            rr.set_time_seconds(\"frame\", t / 30)\n            if log_rgb:\n                rr.log(f\"{seq}/rgb/view-{VIEWS[v]}/rgb\",\n                       rr.Image(rgbs[v, t]))\n                rr.log(f\"{seq}/rgb/view-{VIEWS[v]}\",\n                       rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h))\n                rr.log(f\"{seq}/rgb/view-{VIEWS[v]}\",\n                       rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(),\n                                      mat3x3=extrinsics_inv[v, :3, :3].numpy()))\n            if log_d3dgs_rgb:\n                rr.log(f\"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}/rgb\",\n                       rr.Image(d3dgs_rgbs[v, t]))\n                rr.log(f\"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}\",\n                       rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h))\n                rr.log(f\"{seq}/dyn-3dgs-rgb/view-{VIEWS[v]}\",\n                       rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(),\n                                      mat3x3=extrinsics_inv[v, :3, :3].numpy()))\n            if log_d3dgs_depth:\n                rr.log(f\"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}/depth\",\n                       rr.DepthImage(d3dgs_depths[v, t], point_fill_ratio=0.2))\n                rr.log(f\"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}\",\n                       rr.Pinhole(image_from_camera=k[v].numpy(), width=w, height=h))\n                rr.log(f\"{seq}/dyn-3dgs-depth/view-{VIEWS[v]}\",\n                       rr.Transform3D(translation=extrinsics_inv[v, :3, 3].numpy(),\n                                      mat3x3=extrinsics_inv[v, :3, :3].numpy()))\n            if log_d3dgs_point_cloud:\n                y, x = np.indices((h, w))\n                homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                depth_values = d3dgs_depths[v, t].ravel()\n                cam_coords = (k_inv[v] @ homo_pixel_coords) * depth_values\n                cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                world_coords = (extrinsics_inv[v] @ cam_coords)[:3].T\n                valid_mask = depth_values > 0\n                world_coords = world_coords[valid_mask]\n                rgb_colors = rgbs[v, t].reshape(-1, 3)[valid_mask].astype(np.uint8)\n                rr.log(f\"{seq}/dyn-3dgs-point-cloud/view-{v}\", rr.Points3D(world_coords, colors=rgb_colors, radii=0.01))\n    if log_tracks:\n        for tracks_batch_start in range(0, n_tracks, 100):\n            tracks_batch_end = min(tracks_batch_start + 100, n_tracks)\n            log_tracks_to_rerun(\n                tracks=gt_tracks[:, tracks_batch_start:tracks_batch_end],\n                visibles=gt_vis[:, tracks_batch_start:tracks_batch_end],\n                query_timestep=query_timestep[tracks_batch_start:tracks_batch_end],\n                colors=track_colors[tracks_batch_start:tracks_batch_end],\n                track_names=[f\"track-{i:02d}\" for i in range(tracks_batch_start, tracks_batch_end)],\n                entity_format_str=f\"{seq}/dyn-3dgs-tracks/{tracks_batch_start}-{tracks_batch_end}/{{}}\",\n                invisible_color=[0.3, 0.3, 0.3],\n            )\n\n    print(\"Done with visualization.\")\n\n\nif __name__ == \"__main__\":\n    exp_name = \"pretrained\"\n    for sequence in [\"basketball\", \"boxes\", \"football\", \"juggle\", \"softball\", \"tennis\"]:\n        visualize(sequence, exp_name)\n"
  },
  {
    "path": "mvtracker/models/core/embeddings.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport numpy as np\nimport torch\n\n\ndef get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    if isinstance(grid_size, tuple):\n        grid_size_h, grid_size_w = grid_size\n    else:\n        grid_size_h = grid_size_w = grid_size\n    grid_h = np.arange(grid_size_h, dtype=np.float32)\n    grid_w = np.arange(grid_size_w, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size_h, grid_size_w])\n    pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token and extra_tokens > 0:\n        pos_embed = np.concatenate(\n            [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0\n        )\n    return pos_embed\n\n\ndef get_3d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 3 == 0\n\n    # use half of dimensions to encode grid_h\n    B, S, N, _ = grid.shape\n    gridx = grid[..., 0].view(B * S * N).detach().cpu().numpy()\n    gridy = grid[..., 1].view(B * S * N).detach().cpu().numpy()\n    gridz = grid[..., 2].view(B * S * N).detach().cpu().numpy()\n\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridx)  # (N, D/3)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridy)  # (N, D/3)\n    emb_z = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, gridz)  # (N, D/3)\n\n    emb = np.concatenate([emb_h, emb_w, emb_z], axis=1)  # (N, D)\n    emb = torch.from_numpy(emb).to(grid.device)\n    return emb.view(B, S, N, embed_dim)\n\n\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    if isinstance(grid_size, tuple):\n        grid_size_h, grid_size_w = grid_size\n    else:\n        grid_size_h = grid_size_w = grid_size\n    grid_h = np.arange(grid_size_h, dtype=np.float32)\n    grid_w = np.arange(grid_size_w, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size_h, grid_size_w])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token and extra_tokens > 0:\n        pos_embed = np.concatenate(\n            [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0\n        )\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 2 == 0\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000 ** omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\ndef get_2d_embedding(xy, C, cat_coords=True):\n    B, N, D = xy.shape\n    assert D == 2\n\n    x = xy[:, :, 0:1]\n    y = xy[:, :, 1:2]\n    div_term = (\n            torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)\n    ).reshape(1, 1, int(C / 2))\n\n    pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)\n    pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)\n\n    pe_x[:, :, 0::2] = torch.sin(x * div_term)\n    pe_x[:, :, 1::2] = torch.cos(x * div_term)\n\n    pe_y[:, :, 0::2] = torch.sin(y * div_term)\n    pe_y[:, :, 1::2] = torch.cos(y * div_term)\n\n    pe = torch.cat([pe_x, pe_y], dim=2)  # B, N, C*3\n    if cat_coords:\n        pe = torch.cat([xy, pe], dim=2)  # B, N, C*3+3\n    return pe\n\n\ndef get_3d_embedding(xyz, C, cat_coords=True):\n    B, N, D = xyz.shape\n    assert D == 3\n\n    x = xyz[:, :, 0:1]\n    y = xyz[:, :, 1:2]\n    z = xyz[:, :, 2:3]\n    div_term = (\n            torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)\n    ).reshape(1, 1, int(C / 2))\n\n    pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)\n    pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)\n    pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)\n\n    pe_x[:, :, 0::2] = torch.sin(x * div_term)\n    pe_x[:, :, 1::2] = torch.cos(x * div_term)\n\n    pe_y[:, :, 0::2] = torch.sin(y * div_term)\n    pe_y[:, :, 1::2] = torch.cos(y * div_term)\n\n    pe_z[:, :, 0::2] = torch.sin(z * div_term)\n    pe_z[:, :, 1::2] = torch.cos(z * div_term)\n\n    pe = torch.cat([pe_x, pe_y, pe_z], dim=2)  # B, N, C*3\n    if cat_coords:\n        pe = torch.cat([pe, xyz], dim=2)  # B, N, C*3+3\n    return pe\n\n\ndef get_4d_embedding(xyzw, C, cat_coords=True):\n    B, N, D = xyzw.shape\n    assert D == 4\n\n    x = xyzw[:, :, 0:1]\n    y = xyzw[:, :, 1:2]\n    z = xyzw[:, :, 2:3]\n    w = xyzw[:, :, 3:4]\n    div_term = (\n            torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)\n    ).reshape(1, 1, int(C / 2))\n\n    pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)\n    pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)\n    pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)\n    pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)\n\n    pe_x[:, :, 0::2] = torch.sin(x * div_term)\n    pe_x[:, :, 1::2] = torch.cos(x * div_term)\n\n    pe_y[:, :, 0::2] = torch.sin(y * div_term)\n    pe_y[:, :, 1::2] = torch.cos(y * div_term)\n\n    pe_z[:, :, 0::2] = torch.sin(z * div_term)\n    pe_z[:, :, 1::2] = torch.cos(z * div_term)\n\n    pe_w[:, :, 0::2] = torch.sin(w * div_term)\n    pe_w[:, :, 1::2] = torch.cos(w * div_term)\n\n    pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2)  # B, N, C*3\n    if cat_coords:\n        pe = torch.cat([pe, xyzw], dim=2)  # B, N, C*3+3\n    return pe\n\n\nimport torch.nn as nn\n\n\nclass Embedder_Fourier(nn.Module):\n    def __init__(self, input_dim, max_freq_log2, N_freqs,\n                 log_sampling=True, include_input=True,\n                 periodic_fns=(torch.sin, torch.cos)):\n        '''\n        :param input_dim: dimension of input to be embedded\n        :param max_freq_log2: log2 of max freq; min freq is 1 by default\n        :param N_freqs: number of frequency bands\n        :param log_sampling: if True, frequency bands are linerly sampled in log-space\n        :param include_input: if True, raw input is included in the embedding\n        :param periodic_fns: periodic functions used to embed input\n        '''\n        super(Embedder_Fourier, self).__init__()\n\n        self.input_dim = input_dim\n        self.include_input = include_input\n        self.periodic_fns = periodic_fns\n\n        self.out_dim = 0\n        if self.include_input:\n            self.out_dim += self.input_dim\n\n        self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns)\n\n        if log_sampling:\n            self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)\n        else:\n            self.freq_bands = torch.linspace(\n                2. ** 0., 2. ** max_freq_log2, N_freqs)\n\n        self.freq_bands = self.freq_bands.numpy().tolist()\n\n    def forward(self,\n                input: torch.Tensor,\n                rescale: float = 1.0):\n        '''\n        :param input: tensor of shape [..., self.input_dim]\n        :return: tensor of shape [..., self.out_dim]\n        '''\n        assert (input.shape[-1] == self.input_dim)\n        out = []\n        if self.include_input:\n            out.append(input / rescale)\n\n        for i in range(len(self.freq_bands)):\n            freq = self.freq_bands[i]\n            for p_fn in self.periodic_fns:\n                out.append(p_fn(input.float() * freq).type_as(input))\n        out = torch.cat(out, dim=-1)\n\n        assert not input.isnan().any(), f\"Found NaN in input\"\n        assert not out.isnan().any(), f\"Found NaN in output\"\n\n        assert (out.shape[-1] == self.out_dim)\n        return out\n"
  },
  {
    "path": "mvtracker/models/core/loftr/__init__.py",
    "content": "from .transformer import LocalFeatureTransformer\n"
  },
  {
    "path": "mvtracker/models/core/loftr/linear_attention.py",
    "content": "\"\"\"\nLinear Transformer proposed in \"Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention\"\nModified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py\n\"\"\"\n\nimport torch\nfrom torch.nn import Module, Dropout\n\n\ndef elu_feature_map(x):\n    return torch.nn.functional.elu(x) + 1\n\n\nclass LinearAttention(Module):\n    def __init__(self, eps=1e-6):\n        super().__init__()\n        self.feature_map = elu_feature_map\n        self.eps = eps\n\n    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):\n        \"\"\" Multi-Head linear attention proposed in \"Transformers are RNNs\"\n        Args:\n            queries: [N, L, H, D]\n            keys: [N, S, H, D]\n            values: [N, S, H, D]\n            q_mask: [N, L]\n            kv_mask: [N, S]\n        Returns:\n            queried_values: (N, L, H, D)\n        \"\"\"\n        Q = self.feature_map(queries)\n        K = self.feature_map(keys)\n\n        # set padded position to zero\n        if q_mask is not None:\n            Q = Q * q_mask[:, :, None, None]\n        if kv_mask is not None:\n            K = K * kv_mask[:, :, None, None]\n            values = values * kv_mask[:, :, None, None]\n\n        v_length = values.size(1)\n        values = values / v_length  # prevent fp16 overflow\n        KV = torch.einsum(\"nshd,nshv->nhdv\", K, values)  # (S,D)' @ S,V\n        Z = 1 / (torch.einsum(\"nlhd,nhd->nlh\", Q, K.sum(dim=1)) + self.eps)\n        queried_values = torch.einsum(\"nlhd,nhdv,nlh->nlhv\", Q, KV, Z) * v_length\n\n        return queried_values.contiguous()\n\n\nclass FullAttention(Module):\n    def __init__(self, use_dropout=False, attention_dropout=0.1):\n        super().__init__()\n        self.use_dropout = use_dropout\n        self.dropout = Dropout(attention_dropout)\n\n    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):\n        \"\"\" Multi-head scaled dot-product attention, a.k.a full attention.\n        Args:\n            queries: [N, L, H, D]\n            keys: [N, S, H, D]\n            values: [N, S, H, D]\n            q_mask: [N, L]\n            kv_mask: [N, S]\n        Returns:\n            queried_values: (N, L, H, D)\n        \"\"\"\n\n        # Compute the unnormalized attention and apply the masks\n        QK = torch.einsum(\"nlhd,nshd->nlsh\", queries, keys)\n        if kv_mask is not None:\n            QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))\n\n        # Compute the attention and the weighted average\n        softmax_temp = 1. / queries.size(3) ** .5  # sqrt(D)\n        A = torch.softmax(softmax_temp * QK, dim=2)\n        if self.use_dropout:\n            A = self.dropout(A)\n\n        queried_values = torch.einsum(\"nlsh,nshd->nlhd\", A, values)\n\n        return queried_values.contiguous()\n"
  },
  {
    "path": "mvtracker/models/core/loftr/transformer.py",
    "content": "'''\nmodified from\nhttps://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py\n'''\nimport copy\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Module, Dropout\n\n\ndef elu_feature_map(x):\n    return torch.nn.functional.elu(x) + 1\n\n\nclass FullAttention(Module):\n    def __init__(self, use_dropout=False, attention_dropout=0.1):\n        super().__init__()\n        self.use_dropout = use_dropout\n        self.dropout = Dropout(attention_dropout)\n\n    def forward(self, queries, keys, values, q_mask=None, kv_mask=None):\n        \"\"\" Multi-head scaled dot-product attention, a.k.a full attention.\n        Args:\n            queries: [N, L, H, D]\n            keys: [N, S, H, D]\n            values: [N, S, H, D]\n            q_mask: [N, L]\n            kv_mask: [N, S]\n        Returns:\n            queried_values: (N, L, H, D)\n        \"\"\"\n\n        # Compute the unnormalized attention and apply the masks\n        # QK = torch.einsum(\"nlhd,nshd->nlsh\", queries, keys)\n        # if kv_mask is not None:\n        #     QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float(-1e12))\n        # softmax_temp = 1. / queries.size(3)**.5  # sqrt(D)\n        # A = torch.softmax(softmax_temp * QK, dim=2)\n        # if self.use_dropout:\n        #     A = self.dropout(A)\n        # queried_values_ = torch.einsum(\"nlsh,nshd->nlhd\", A, values)\n\n        # Compute the attention and the weighted average\n        input_args = [x.half().contiguous() for x in\n                      [queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3)]]\n        queried_values = F.scaled_dot_product_attention(*input_args).permute(0, 2, 1, 3).float()  # type: ignore\n\n        return queried_values.contiguous()\n\n\nclass TransformerEncoderLayer(nn.Module):\n    def __init__(self,\n                 d_model,\n                 nhead, ):\n        super(TransformerEncoderLayer, self).__init__()\n\n        self.dim = d_model // nhead\n        self.nhead = nhead\n\n        # multi-head attention\n        self.q_proj = nn.Linear(d_model, d_model, bias=False)\n        self.k_proj = nn.Linear(d_model, d_model, bias=False)\n        self.v_proj = nn.Linear(d_model, d_model, bias=False)\n        self.attention = FullAttention()\n        self.merge = nn.Linear(d_model, d_model, bias=False)\n\n        # feed-forward network\n        self.mlp = nn.Sequential(\n            nn.Linear(d_model * 2, d_model * 2, bias=False),\n            nn.ReLU(True),\n            nn.Linear(d_model * 2, d_model, bias=False),\n        )\n\n        # norm and dropout\n        self.norm1 = nn.LayerNorm(d_model)\n        self.norm2 = nn.LayerNorm(d_model)\n\n    def forward(self, x, source, x_mask=None, source_mask=None):\n        \"\"\"\n        Args:\n            x (torch.Tensor): [N, L, C]\n            source (torch.Tensor): [N, S, C]\n            x_mask (torch.Tensor): [N, L] (optional)\n            source_mask (torch.Tensor): [N, S] (optional)\n        \"\"\"\n        bs = x.size(0)\n        query, key, value = x, source, source\n\n        # multi-head attention\n        query = self.q_proj(query).view(bs, -1, self.nhead, self.dim)  # [N, L, (H, D)]\n        key = self.k_proj(key).view(bs, -1, self.nhead, self.dim)  # [N, S, (H, D)]\n        value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)\n        message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)  # [N, L, (H, D)]\n        message = self.merge(message.view(bs, -1, self.nhead * self.dim))  # [N, L, C]\n        message = self.norm1(message)\n\n        # feed-forward network\n        message = self.mlp(torch.cat([x, message], dim=2))\n        message = self.norm2(message)\n\n        return x + message\n\n\nclass LocalFeatureTransformer(nn.Module):\n    \"\"\"A Local Feature Transformer module.\"\"\"\n\n    def __init__(self, config):\n        super(LocalFeatureTransformer, self).__init__()\n\n        self.config = config\n        self.d_model = config['d_model']\n        self.nhead = config['nhead']\n        self.layer_names = config['layer_names']\n        encoder_layer = TransformerEncoderLayer(config['d_model'], config['nhead'])\n        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])\n        self._reset_parameters()\n\n    def _reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, feat0, feat1, mask0=None, mask1=None):\n        \"\"\"\n        Args:\n            feat0 (torch.Tensor): [N, L, C]\n            feat1 (torch.Tensor): [N, S, C]\n            mask0 (torch.Tensor): [N, L] (optional)\n            mask1 (torch.Tensor): [N, S] (optional)\n        \"\"\"\n\n        assert self.d_model == feat0.size(2), \"the feature number of src and transformer must be equal\"\n\n        for layer, name in zip(self.layers, self.layer_names):\n            if name == 'self':\n                feat0 = layer(feat0, feat0, mask0, mask0)\n                feat1 = layer(feat1, feat1, mask1, mask1)\n            elif name == 'cross':\n                feat0 = layer(feat0, feat1, mask0, mask1)\n                feat1 = layer(feat1, feat0, mask1, mask0)\n            else:\n                raise KeyError\n\n        return feat0, feat1\n"
  },
  {
    "path": "mvtracker/models/core/losses.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn.functional as F\n\nfrom mvtracker.models.core.model_utils import reduce_masked_mean\n\nEPS = 1e-6\n\nsigma = 3\nx_grid = torch.arange(-7, 8, 1)\ny_grid = torch.arange(-7, 8, 1)\nx_grid, y_grid = torch.meshgrid(x_grid, y_grid, indexing=\"ij\")\ngridxy = torch.stack([x_grid, y_grid], dim=-1).float()\ngs_kernel = torch.exp(-torch.sum(gridxy ** 2, dim=-1) / (2 * sigma ** 2))\n\n\ndef balanced_ce_loss(pred, gt, valid=None):\n    total_balanced_loss = 0.0\n    for j in range(len(gt)):\n        B, S, N = gt[j].shape\n        # pred and gt are the same shape\n        for (a, b) in zip(pred[j].size(), gt[j].size()):\n            assert a == b  # some shape mismatch!\n        # if valid is not None:\n        for (a, b) in zip(pred[j].size(), valid[j].size()):\n            assert a == b  # some shape mismatch!\n\n        pos = (gt[j] > 0.95).float()\n        neg = (gt[j] < 0.05).float()\n\n        label = pos * 2.0 - 1.0\n        a = -label * pred[j]\n        b = F.relu(a)\n        loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))\n\n        pos_loss = reduce_masked_mean(loss, pos * valid[j])\n        neg_loss = reduce_masked_mean(loss, neg * valid[j])\n\n        balanced_loss = pos_loss + neg_loss\n        total_balanced_loss += balanced_loss\n    return total_balanced_loss\n\n\ndef sequence_loss_3d(flow_preds, flow_gt, vis, valids, gamma=0.8, dmin=0.1, dmax=65, Dz=128):\n    \"\"\"Loss function defined over sequence of flow predictions with z component post-processing\"\"\"\n    total_flow_loss = 0.0\n    J = len(flow_gt)\n    for j in range(J):\n        B, S, N, D = flow_gt[j].shape\n        assert D == 3\n        B, S1, N = vis[j].shape\n        B, S2, N = valids[j].shape\n        assert S == S1\n        assert S == S2\n        n_predictions = len(flow_preds[j])\n        flow_loss = 0.0\n        for i in range(n_predictions):\n            i_weight = gamma ** (n_predictions - i - 1)\n            flow_pred = flow_preds[j][i]\n            flow_gt_j = flow_gt[j].clone()\n            flow_pred[..., 2] = (flow_pred[..., 2] - dmin) / (dmax - dmin) * Dz\n            flow_gt_j[..., 2] = (flow_gt_j[..., 2] - dmin) / (dmax - dmin) * Dz\n            i_loss = (flow_pred - flow_gt_j).abs()  # B, S, N, 3\n            i_loss = torch.mean(i_loss, dim=3)  # B, S, N\n            flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])\n        flow_loss = flow_loss / n_predictions\n        total_flow_loss += flow_loss / float(J)\n    return total_flow_loss\n"
  },
  {
    "path": "mvtracker/models/core/model_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport warnings\nfrom typing import Tuple, Optional\n\nimport torch\nfrom easydict import EasyDict as edict\nfrom torch.nn import functional as F\n\nfrom mvtracker.utils.basic import to_homogeneous, from_homogeneous\n\nEPS = 1e-6\n\n\ndef smart_cat(tensor1, tensor2, dim):\n    if tensor1 is None:\n        return tensor2\n    return torch.cat([tensor1, tensor2], dim=dim)\n\n\ndef normalize_single(d):\n    # d is a whatever shape torch tensor\n    dmin = torch.min(d)\n    dmax = torch.max(d)\n    d = (d - dmin) / (EPS + (dmax - dmin))\n    return d\n\n\ndef normalize(d):\n    # d is B x whatever. normalize within each element of the batch\n    out = torch.zeros(d.size())\n    if d.is_cuda:\n        out = out.cuda()\n    B = list(d.size())[0]\n    for b in list(range(B)):\n        out[b] = normalize_single(d[b])\n    return out\n\n\ndef meshgrid2d(B, Y, X, stack=False, norm=False, device=\"cuda\"):\n    # returns a meshgrid sized B x Y x X\n\n    grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))\n    grid_y = torch.reshape(grid_y, [1, Y, 1])\n    grid_y = grid_y.repeat(B, 1, X)\n\n    grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))\n    grid_x = torch.reshape(grid_x, [1, 1, X])\n    grid_x = grid_x.repeat(B, Y, 1)\n\n    if stack:\n        # note we stack in xy order\n        # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)\n        grid = torch.stack([grid_x, grid_y], dim=-1)\n        return grid\n    else:\n        return grid_y, grid_x\n\n\ndef reduce_masked_mean(x, mask, dim=None, keepdim=False):\n    # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting\n    # returns shape-1\n    # axis can be a list of axes\n    for (a, b) in zip(x.size(), mask.size()):\n        assert a == b  # some shape mismatch!\n    prod = x * mask\n    if dim is None:\n        numer = torch.sum(prod)\n        denom = EPS + torch.sum(mask)\n    else:\n        numer = torch.sum(prod, dim=dim, keepdim=keepdim)\n        denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)\n\n    mean = numer / denom\n    return mean\n\n\ndef bilinear_sample2d(im, x, y, return_inbounds=False):\n    # x and y are each B, N\n    # output is B, C, N\n    if len(im.shape) == 5:\n        B, N, C, H, W = list(im.shape)\n    else:\n        B, C, H, W = list(im.shape)\n    N = list(x.shape)[1]\n\n    x = x.float()\n    y = y.float()\n    H_f = torch.tensor(H, dtype=torch.float32)\n    W_f = torch.tensor(W, dtype=torch.float32)\n\n    # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()\n\n    max_y = (H_f - 1).int()\n    max_x = (W_f - 1).int()\n\n    x0 = torch.floor(x).int()\n    x1 = x0 + 1\n    y0 = torch.floor(y).int()\n    y1 = y0 + 1\n\n    x0_clip = torch.clamp(x0, 0, max_x)\n    x1_clip = torch.clamp(x1, 0, max_x)\n    y0_clip = torch.clamp(y0, 0, max_y)\n    y1_clip = torch.clamp(y1, 0, max_y)\n    dim2 = W\n    dim1 = W * H\n\n    base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1\n    base = torch.reshape(base, [B, 1]).repeat([1, N])\n\n    base_y0 = base + y0_clip * dim2\n    base_y1 = base + y1_clip * dim2\n\n    idx_y0_x0 = base_y0 + x0_clip\n    idx_y0_x1 = base_y0 + x1_clip\n    idx_y1_x0 = base_y1 + x0_clip\n    idx_y1_x1 = base_y1 + x1_clip\n\n    # use the indices to lookup pixels in the flat image\n    # im is B x C x H x W\n    # move C out to last dim\n    if len(im.shape) == 5:\n        im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)\n        i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(0, 2, 1)\n        i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(0, 2, 1)\n        i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(0, 2, 1)\n        i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(0, 2, 1)\n    else:\n        im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)\n        i_y0_x0 = im_flat[idx_y0_x0.long()]\n        i_y0_x1 = im_flat[idx_y0_x1.long()]\n        i_y1_x0 = im_flat[idx_y1_x0.long()]\n        i_y1_x1 = im_flat[idx_y1_x1.long()]\n\n    # Finally calculate interpolated values.\n    x0_f = x0.float()\n    x1_f = x1.float()\n    y0_f = y0.float()\n    y1_f = y1.float()\n\n    w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)\n    w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)\n    w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)\n    w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)\n\n    output = (w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1)\n    # output is B*N x C\n    output = output.view(B, -1, C)\n    output = output.permute(0, 2, 1)\n    # output is B x C x N\n\n    if return_inbounds:\n        x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()\n        y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()\n        inbounds = (x_valid & y_valid).float()\n        inbounds = inbounds.reshape(\n            B, N\n        )  # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)\n        return output, inbounds\n\n    return output  # B, C, N\n\n\ndef procrustes_analysis(X0, X1, Weight):  # [B,N,3]\n    # translation\n    t0 = X0.mean(dim=1, keepdim=True)\n    t1 = X1.mean(dim=1, keepdim=True)\n    X0c = X0 - t0\n    X1c = X1 - t1\n    # scale\n    # s0 = (X0c**2).sum(dim=-1).mean().sqrt()\n    # s1 = (X1c**2).sum(dim=-1).mean().sqrt()\n    # X0cs = X0c/s0\n    # X1cs = X1c/s1\n    # rotation (use double for SVD, float loses precision)\n    U, _, V = (X0c.t() @ X1c).double().svd(some=True)\n    R = (U @ V.t()).float()\n    if R.det() < 0: R[2] *= -1\n    # align X1 to X0: X1to0 = (X1-t1)/@R.t()+t0\n    se3 = edict(t0=t0[0], t1=t1[0], R=R)\n\n    return se3\n\n\ndef bilinear_sampler(input, coords, align_corners=True, padding_mode=\"border\"):\n    r\"\"\"Sample a tensor using bilinear interpolation\n\n    `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at\n    coordinates :attr:`coords` using bilinear interpolation. It is the same\n    as `torch.nn.functional.grid_sample()` but with a different coordinate\n    convention.\n\n    The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where\n    :math:`B` is the batch size, :math:`C` is the number of channels,\n    :math:`H` is the height of the image, and :math:`W` is the width of the\n    image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is\n    interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.\n\n    Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,\n    in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note\n    that in this case the order of the components is slightly different\n    from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.\n\n    If `align_corners` is `True`, the coordinate :math:`x` is assumed to be\n    in the range :math:`[0,W-1]`, with 0 corresponding to the center of the\n    left-most image pixel :math:`W-1` to the center of the right-most\n    pixel.\n\n    If `align_corners` is `False`, the coordinate :math:`x` is assumed to\n    be in the range :math:`[0,W]`, with 0 corresponding to the left edge of\n    the left-most pixel :math:`W` to the right edge of the right-most\n    pixel.\n\n    Similar conventions apply to the :math:`y` for the range\n    :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range\n    :math:`[0,T-1]` and :math:`[0,T]`.\n\n    Args:\n        input (Tensor): batch of input images.\n        coords (Tensor): batch of coordinates.\n        align_corners (bool, optional): Coordinate convention. Defaults to `True`.\n        padding_mode (str, optional): Padding mode. Defaults to `\"border\"`.\n\n    Returns:\n        Tensor: sampled points.\n    \"\"\"\n\n    sizes = input.shape[2:]\n\n    assert len(sizes) in [2, 3]\n\n    if len(sizes) == 3:\n        # t x y -> x y t to match dimensions T H W in grid_sample\n        coords = coords[..., [1, 2, 0]]\n\n    if align_corners:\n        coords = coords * torch.tensor(\n            [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device\n        )\n    else:\n        coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)\n\n    coords -= 1\n\n    return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)\n\n\ndef sample_features4d(input, coords):\n    r\"\"\"Sample spatial features\n\n    `sample_features4d(input, coords)` samples the spatial features\n    :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.\n\n    The field is sampled at coordinates :attr:`coords` using bilinear\n    interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,\n    3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the\n    same convention as :func:`bilinear_sampler` with `align_corners=True`.\n\n    The output tensor has one feature per point, and has shape :math:`(B,\n    R, C)`.\n\n    Args:\n        input (Tensor): spatial features.\n        coords (Tensor): points.\n\n    Returns:\n        Tensor: sampled features.\n    \"\"\"\n\n    B, _, _, _ = input.shape\n\n    # B R 2 -> B R 1 2\n    coords = coords.unsqueeze(2)\n\n    # B C R 1\n    feats = bilinear_sampler(input, coords)\n\n    return feats.permute(0, 2, 1, 3).view(\n        B, -1, feats.shape[1] * feats.shape[3]\n    )  # B C R 1 -> B R C\n\n\ndef sample_features5d(input, coords):\n    r\"\"\"Sample spatio-temporal features\n\n    `sample_features5d(input, coords)` works in the same way as\n    :func:`sample_features4d` but for spatio-temporal features and points:\n    :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is\n    a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,\n    x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.\n\n    Args:\n        input (Tensor): spatio-temporal features.\n        coords (Tensor): spatio-temporal points.\n\n    Returns:\n        Tensor: sampled features.\n    \"\"\"\n\n    B, T, _, _, _ = input.shape\n\n    # B T C H W -> B C T H W\n    input = input.permute(0, 2, 1, 3, 4)\n\n    # B R1 R2 3 -> B R1 R2 1 3\n    coords = coords.unsqueeze(3)\n\n    # B C R1 R2 1\n    feats = bilinear_sampler(input, coords)\n\n    return feats.permute(0, 2, 3, 1, 4).view(\n        B, feats.shape[2], feats.shape[3], feats.shape[1]\n    )  # B C R1 R2 1 -> B R1 R2 C\n\n\ndef pixel_xy_and_camera_z_to_world_space(pixel_xy, camera_z, intrs_inv, extrs_inv):\n    num_frames, num_points, _ = pixel_xy.shape\n    assert pixel_xy.shape == (num_frames, num_points, 2)\n    assert camera_z.shape == (num_frames, num_points, 1)\n    assert intrs_inv.shape == (num_frames, 3, 3)\n    assert extrs_inv.shape == (num_frames, 4, 4)\n\n    pixel_xy_homo = torch.cat([pixel_xy, pixel_xy.new_ones(pixel_xy[..., :1].shape)], -1)\n    camera_xyz = torch.einsum('Aij,ABj->ABi', intrs_inv, pixel_xy_homo) * camera_z\n    camera_xyz_homo = torch.cat([camera_xyz, camera_xyz.new_ones(camera_xyz[..., :1].shape)], -1)\n    world_xyz_homo = torch.einsum('Aij,ABj->ABi', extrs_inv, camera_xyz_homo)\n    if not torch.allclose(\n            world_xyz_homo[..., -1],\n            world_xyz_homo.new_ones(world_xyz_homo[..., -1].shape),\n            atol=0.1,\n    ):\n        warnings.warn(f\"pixel_xy_and_camera_z_to_world_space found some homo coordinates not close to 1: \"\n                      f\"the homo values are in {world_xyz_homo[..., -1].min()} – {world_xyz_homo[..., -1].max()}\")\n    world_xyz = world_xyz_homo[..., :-1]\n\n    assert world_xyz.shape == (num_frames, num_points, 3)\n    return world_xyz\n\n\ndef world_space_to_pixel_xy_and_camera_z(world_xyz, intrs, extrs):\n    num_frames, num_points, _ = world_xyz.shape\n    assert world_xyz.shape == (num_frames, num_points, 3)\n    assert intrs.shape == (num_frames, 3, 3)\n    assert extrs.shape == (num_frames, 3, 4)\n\n    world_xyz_homo = torch.cat([world_xyz, world_xyz.new_ones(world_xyz[..., :1].shape)], -1)\n    camera_xyz = torch.einsum('Aij,ABj->ABi', extrs, world_xyz_homo)\n    camera_z = camera_xyz[..., -1:]\n    pixel_xy_homo = torch.einsum('Aij,ABj->ABi', intrs, camera_xyz)\n    pixel_xy = pixel_xy_homo[..., :2] / pixel_xy_homo[..., -1:]\n\n    assert pixel_xy.shape == (num_frames, num_points, 2)\n    assert camera_z.shape == (num_frames, num_points, 1)\n    return pixel_xy, camera_z\n\n\ndef get_points_on_a_grid(\n        size: int,\n        extent: Tuple[float, ...],\n        center: Optional[Tuple[float, ...]] = None,\n        device: Optional[torch.device] = torch.device(\"cpu\"),\n):\n    r\"\"\"Get a grid of points covering a rectangular region\n\n    `get_points_on_a_grid(size, extent)` generates a :attr:`size` by\n    :attr:`size` grid fo points distributed to cover a rectangular area\n    specified by `extent`.\n\n    The `extent` is a pair of integer :math:`(H,W)` specifying the height\n    and width of the rectangle.\n\n    Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`\n    specifying the vertical and horizontal center coordinates. The center\n    defaults to the middle of the extent.\n\n    Points are distributed uniformly within the rectangle leaving a margin\n    :math:`m=W/64` from the border.\n\n    It returns a :math:`(1, \\text{size} \\times \\text{size}, 2)` tensor of\n    points :math:`P_{ij}=(x_i, y_i)` where\n\n    .. math::\n        P_{ij} = \\left(\n             c_x + m -\\frac{W}{2} + \\frac{W - 2m}{\\text{size} - 1}\\, j,~\n             c_y + m -\\frac{H}{2} + \\frac{H - 2m}{\\text{size} - 1}\\, i\n        \\right)\n\n    Points are returned in row-major order.\n\n    Args:\n        size (int): grid size.\n        extent (tuple): height and with of the grid extent.\n        center (tuple, optional): grid center.\n        device (str, optional): Defaults to `\"cpu\"`.\n\n    Returns:\n        Tensor: grid.\n    \"\"\"\n    if size == 1:\n        return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]\n\n    if center is None:\n        center = [extent[0] / 2, extent[1] / 2]\n\n    margin = extent[1] / 64\n    range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)\n    range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)\n    grid_y, grid_x = torch.meshgrid(\n        torch.linspace(*range_y, size, device=device),\n        torch.linspace(*range_x, size, device=device),\n        indexing=\"ij\",\n    )\n    return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)\n\n\ndef init_pointcloud_from_rgbd(\n        fmaps: torch.Tensor,\n        depths: torch.Tensor,\n        intrs: torch.Tensor,\n        extrs: torch.Tensor,\n        stride=4,\n        level=0,\n        depth_interp_mode='nearest',\n        return_validity_mask=False,\n):\n    B, V, S, C, H, W = fmaps.shape\n    assert fmaps.shape == (B, V, S, C, H, W)\n    assert depths.shape == (B, V, S, 1, H, W)\n    assert intrs.shape == (B, V, S, 3, 3)\n    assert extrs.shape == (B, V, S, 3, 4)\n\n    # Pool the fmaps and depths to the desired pyramid level\n    fmaps = fmaps.reshape(B * V * S, C, H, W)\n    depths = depths.reshape(B * V * S, 1, H, W)\n    for i in range(level):\n        fmaps = F.avg_pool2d(fmaps, 2, stride=2)\n        if depth_interp_mode == 'avg':\n            depths = F.avg_pool2d(depths, 2, stride=2)\n        elif depth_interp_mode == 'nearest':\n            depths = F.interpolate(depths, scale_factor=0.5, mode='nearest')\n        else:\n            raise NotImplementedError\n    H = H // 2 ** level\n    W = W // 2 ** level\n    fmaps = fmaps.reshape(B, V, S, C, H, W)\n    depths = depths.reshape(B, V, S, 1, H, W)\n    stride = stride * 2 ** level\n\n    # Invert intrinsics and extrinsics\n    intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n    extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, S, 1, 1)\n    extrs_square[:, :, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n    assert intrs_inv.shape == (B, V, S, 3, 3)\n    assert extrs_inv.shape == (B, V, S, 4, 4)\n\n    # Pixel --> Camera --> World\n    pixel_xy = torch.stack(torch.meshgrid(\n        (torch.arange(0, H) + 0.5) * stride - 0.5,\n        (torch.arange(0, W) + 0.5) * stride - 0.5,\n        indexing=\"ij\",\n    )[::-1], dim=-1)\n    pixel_xy = pixel_xy.to(device=fmaps.device, dtype=fmaps.dtype)\n    pixel_xy_homo = to_homogeneous(pixel_xy)\n    depthmap_camera_xyz = torch.einsum('BVSij,HWj->BVSHWi', intrs_inv, pixel_xy_homo)\n    depthmap_camera_xyz = depthmap_camera_xyz * depths[..., 0, :, :, None]\n    depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz)\n    depthmap_world_xyz_homo = torch.einsum('BVSij,BVSHWj->BVSHWi', extrs_inv, depthmap_camera_xyz_homo)\n    depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo)\n\n    pointcloud_xyz = depthmap_world_xyz.permute(0, 2, 1, 3, 4, 5).reshape(B * S, V * H * W, 3)\n    pointcloud_fvec = fmaps.permute(0, 2, 1, 4, 5, 3).reshape(B * S, V * H * W, C)\n\n    if return_validity_mask:\n        pointcloud_valid_mask = depths.permute(0, 2, 1, 3, 4, 5).reshape(B * S, V * H * W) > 0\n        return pointcloud_xyz, pointcloud_fvec, pointcloud_valid_mask\n\n    return pointcloud_xyz, pointcloud_fvec\n\n\ndef save_pointcloud_to_ply(filename, points, colors, edges=None):\n    with open(filename, 'w') as ply_file:\n        ply_file.write(\"ply\\nformat ascii 1.0\\n\")\n        ply_file.write(f\"element vertex {len(points)}\\n\")\n        ply_file.write(\"property float x\\nproperty float y\\nproperty float z\\n\")\n        ply_file.write(\"property uchar red\\nproperty uchar green\\nproperty uchar blue\\n\")\n\n        if edges is not None:\n            ply_file.write(f\"element edge {len(edges)}\\n\")\n            ply_file.write(\"property int vertex1\\nproperty int vertex2\\n\")\n\n        ply_file.write(\"end_header\\n\")\n\n        # Write vertices (points with colors)\n        for point, color in zip(points, colors):\n            ply_file.write(f\"{point[0]} {point[1]} {point[2]} {color[0]} {color[1]} {color[2]}\\n\")\n\n        # Write edges (if provided)\n        if edges is not None:\n            for edge in edges:\n                ply_file.write(f\"{edge[0]} {edge[1]}\\n\")\n"
  },
  {
    "path": "mvtracker/models/core/monocular_baselines.py",
    "content": "import logging\nimport sys\nimport warnings\nfrom typing import Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\nfrom mvtracker.datasets.utils import transform_scene\nfrom mvtracker.models.core.model_utils import bilinear_sample2d, pixel_xy_and_camera_z_to_world_space\nfrom mvtracker.utils.visualizer_mp4 import Visualizer\n\n\nclass CoTrackerOfflineWrapper(nn.Module):\n    def __init__(self, model_name=\"cotracker3_offline\", grid_size=10):\n        super(CoTrackerOfflineWrapper, self).__init__()\n        self.grid_size = grid_size\n        self.cotracker = torch.hub.load(\"facebookresearch/co-tracker\", model_name)\n\n    def forward(self, rgbs, queries, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert queries.shape == (N, 3)\n\n        # Forward pass: https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/predictor.py#L36\n        pred_tracks, pred_visibility = self.cotracker(\n            video=rgbs[None].float(),\n            queries=queries[None].float(),\n            grid_size=self.grid_size,\n        )\n\n        return {\"traj_2d\": pred_tracks[0], \"vis\": pred_visibility[0]}\n\n\nclass CoTrackerOnlineWrapper(nn.Module):\n    def __init__(self, model_name=\"cotracker3_online\", grid_size=10):\n        super(CoTrackerOnlineWrapper, self).__init__()\n        self.grid_size = grid_size\n        self.cotracker = torch.hub.load(\"facebookresearch/co-tracker\", model_name)\n\n    def forward(self, rgbs, queries, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert queries.shape == (N, 3)\n\n        # Forward pass: https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/predictor.py#L230\n        self.cotracker(\n            video_chunk=rgbs[None].float(),\n            queries=queries[None].float(),\n            grid_size=self.grid_size,\n            is_first_step=True,\n        )\n        for t in range(0, T - self.cotracker.step, self.cotracker.step):\n            pred_tracks, pred_visibility = self.cotracker(video_chunk=rgbs[None, t: t + self.cotracker.step * 2])\n\n        return {\"traj_2d\": pred_tracks[0], \"vis\": pred_visibility[0]}\n\n\nclass SpaTrackerV2Wrapper(nn.Module):\n    \"\"\"\n    Environment setup:\n    ```bash\n    git clone https://github.com/henry123-boy/SpaTrackerV2.git ../spatialtrackerv2\n    cd ../spatialtrackerv2\n    git checkout 1673230\n    git submodule update --init --recursive\n    pip install pycolmap==3.11.1\n    pip install git+https://github.com/EasternJournalist/utils3d.git#egg=utils3d\n    pip install pyceres==2.4\n\n    # Update the threshold for weighted_procrustes_torch from 1e-3 to 5e-3\n    sed -i 's/(torch.det(R) - 1).abs().max() < 1e-3/(torch.det(R) - 1).abs().max() < 5e-3/' ./models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py\n\n    # Verify the change: this should print a line with 5e-3\n    cat ./models/SpaTrackV2/models/tracker3D/spatrack_modules/utils.py | grep \"(torch.det(R) - 1).abs().max()\"\n    ```\n    \"\"\"\n\n    def __init__(\n            self,\n            model_type=\"offline\",\n            vo_points=756,\n    ):\n        super(SpaTrackerV2Wrapper, self).__init__()\n\n        sys.path.append(\"../spatialtrackerv2/\")\n        from models.SpaTrackV2.models.predictor import Predictor\n        if model_type == \"offline\":\n            self.model = Predictor.from_pretrained(\"Yuxihenry/SpatialTrackerV2-Offline\")\n        elif model_type == \"online\":\n            self.model = Predictor.from_pretrained(\"Yuxihenry/SpatialTrackerV2-Online\")\n        else:\n            raise ValueError(f\"Unknown model_type: {model_type}\")\n        self.model.spatrack.track_num = vo_points  # the track_num is the number of points in the grid\n        self.model.eval()\n        self.model.to(\"cuda\")\n\n    def forward(self, rgbs, depths, queries, queries_xyz_worldspace, intrs, extrs, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert depths.shape == (T, 1, H, W)\n        assert intrs.shape == (T, 3, 3)\n        assert extrs.shape == (T, 3, 4)\n        assert queries.shape == (N, 3)\n        assert queries_xyz_worldspace.shape == (N, 4)\n\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(T, 1, 1)\n        extrs_square[:, :3, :] = extrs\n\n        # Transform the extrinsics so that the camera is in the origin, and later revert.\n        transform = extrs_square[0]\n        transform_inv = torch.inverse(transform)\n        extrs, queries_xyz_worldspace = extrs.clone(), queries_xyz_worldspace.clone()\n        (\n            _, extrs, queries_xyz_worldspace, _, _\n        ) = transform_scene(1, transform[:3, :3], transform[:3, 3], None, extrs[None], queries_xyz_worldspace)\n        extrs = extrs[0]\n        extrs_square[:, :3, :] = extrs\n\n        # Check if the camera is fixed\n        extrs_delta = torch.linalg.norm(extrs - extrs[0], dim=(1, 2))\n        fixed_cam = (extrs_delta < 1e-3).all().item()\n\n        # Run inference\n        extrs_inv = torch.inverse(extrs_square)\n        with torch.amp.autocast(device_type=\"cuda\", dtype=torch.bfloat16):\n            (\n                c2w_traj, intrs, point_map, conf_depth,\n                track3d_pred, track2d_pred, vis_pred, conf_pred, video\n            ) = self.model.forward(rgbs.cpu(), depth=depths.squeeze(1).cpu().numpy(),\n                                   intrs=intrs.cpu(), extrs=extrs_inv.cpu().numpy(),\n                                   queries=queries.cpu().numpy(), queries_3d=queries_xyz_worldspace.cpu().numpy(),\n                                   fps=1, full_point=True, iters_track=4,\n                                   query_no_BA=True, fixed_cam=fixed_cam, stage=1, unc_metric=None,\n                                   support_frame=T - 1, replace_ratio=0.2)\n\n        trajectories_3d = (\n                torch.einsum(\"tij,tnj->tni\", c2w_traj[:, :3, :3].to(track3d_pred.device), track3d_pred[:, :, :3])\n                + c2w_traj[:, :3, 3][:, None, :].to(track3d_pred.device)\n        )\n        (\n            _, _, _, trajectories_3d, _\n        ) = transform_scene(1, transform_inv[:3, :3], transform_inv[:3, 3], None, None, None, trajectories_3d, None)\n        visibilities = vis_pred.squeeze(2)\n\n        assert trajectories_3d.shape == (T, N, 3)\n        assert visibilities.shape == (T, N)\n\n        return {\"traj_2d\": None, \"traj_3d_worldspace\": trajectories_3d, \"vis\": visibilities}\n\n\nclass LocoTrackWrapper(nn.Module):\n    \"\"\"\n    Environment setup:\n    ```sh\n    git clone https://github.com/cvlab-kaist/locotrack ../locotrack\n    cd ../locotrack\n    find ./locotrack_pytorch -type f -name \"*.py\" -exec sed -i 's/\\bimport models\\b/import locotrack_pytorch.models/g' {} \\;\n    find ./locotrack_pytorch -type f -name \"*.py\" -exec sed -i 's/\\bfrom models\\b/from locotrack_pytorch.models/g' {} \\;\n    cd ../spatialtracker\n    ```\n    \"\"\"\n\n    def __init__(self, model_size=\"base\"):\n        super(LocoTrackWrapper, self).__init__()\n        sys.path.append(\"../locotrack\")\n        from locotrack_pytorch.models.locotrack_model import load_model\n        self.model = load_model(model_size=model_size).cuda()\n        self.model.eval()\n\n    def forward(self, rgbs, queries, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries.shape\n\n        assert (H, W) == (256, 256), f\"LocoTrack only supports (256, 256) images, but got ({H}, {W})\"\n        assert rgbs.shape == (T, 3, H, W)\n        assert queries.shape == (N, 3)\n\n        # Forward pass: https://github.com/cvlab-kaist/locotrack/blob/6f3f9cad46b06c3de9c38fbf21006271056baf45/locotrack_pytorch/models/locotrack_model.py#L323\n        video = (rgbs.permute(0, 2, 3, 1)[None] / 255.0) * 2 - 1\n        queries_tyx = torch.stack([queries[:, 0], queries[:, 2], queries[:, 1]], dim=1)[None]\n        # queries_tyx = queries_tyx / torch.tensor([1, H, W], dtype=queries_tyx.dtype, device=queries_tyx.device)\n\n        with torch.no_grad():\n            output = self.model(video=video, query_points=queries_tyx)\n        pred_occ = torch.sigmoid(output['occlusion'])\n        if 'expected_dist' in output:\n            pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(output['expected_dist']))\n        pred_occ = (pred_occ > 0.5)[0]\n\n        trajectories_2d = output['tracks'][0].permute(1, 0, 2)\n        # trajectories_2d[..., 0] *= W\n        # trajectories_2d[..., 1] *= H\n        visibilities = ~pred_occ.permute(1, 0)\n\n        if torch.isnan(trajectories_2d).any():\n            warnings.warn(\n                f\"Found {torch.isnan(trajectories_2d).sum()}/{trajectories_2d.numel()} NaN values in trajectories_2d. Setting them to 0.\")\n            trajectories_2d[trajectories_2d.isnan()] = 0\n        if torch.isnan(visibilities).any():\n            warnings.warn(\n                f\"Found {torch.isnan(visibilities).sum()}/{visibilities.numel()} NaN values in visibilities. Setting them to 1.\")\n            visibilities[visibilities.isnan()] = 1\n\n        return {\"traj_2d\": trajectories_2d, \"vis\": visibilities}\n\n\nclass TAPTRWrapper(nn.Module):\n    pass\n\n\nclass TAPIRWrapper(nn.Module):\n    pass\n\n\nclass PIPSWrapper(nn.Module):\n    pass\n\n\nclass PIPSPlusPlusWrapper(nn.Module):\n    pass\n\n\nclass SceneTrackerWrapper(nn.Module):\n    \"\"\"\n    Environment setup:\n    ```sh\n    wget --directory-prefix=checkpoints https://huggingface.co/wwcreator/SceneTracker/resolve/main/scenetracker-odyssey-200k.pth\n    git clone https://github.com/wwsource/SceneTracker.git ../scenetracker\n\n    python eval.py experiment_path=logs/scenetracker model=scenetracker\n\n    ```\n    \"\"\"\n\n    def __init__(\n            self,\n            ckpt=\"checkpoints/scenetracker-odyssey-200k.pth\",\n            return_2d_track=False,\n    ):\n        super(SceneTrackerWrapper, self).__init__()\n\n        sys.path.append(\"../scenetracker/\")\n        from model.model_scenetracker import SceneTracker\n        model = SceneTracker()\n        pre_replace_list = [['module.', '']]\n        checkpoint = torch.load(ckpt)\n        for l in pre_replace_list:\n            checkpoint = {k.replace(l[0], l[1]): v for k, v in checkpoint.items()}\n        model.load_state_dict(checkpoint, strict=True)\n        model.eval().cuda()\n\n        self.return_2d_track = return_2d_track\n        self.model = model\n\n    def forward(self, rgbs, depths, queries_with_z, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries_with_z.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert depths.shape == (T, 1, H, W)\n        assert queries_with_z.shape == (N, 4)\n\n        trajs_uv_e, trajs_z_e, _, _ = self.model.infer(\n            self.model,\n            input_list=[\n                rgbs[None].float(),\n                depths[None].float(),\n                queries_with_z[None].float(),\n            ],\n            iters=4,\n            is_train=False,\n        )\n\n        trajectories_2d = trajs_uv_e[0].type(queries_with_z.dtype)\n        trajectories_z = trajs_z_e[0].type(queries_with_z.dtype)\n        visibilities = torch.zeros_like(trajectories_2d[..., 0], dtype=torch.bool)\n\n        if self.return_2d_track:\n            return {\"traj_2d\": trajectories_2d, \"vis\": visibilities}\n        else:\n            return {\"traj_2d\": trajectories_2d, \"traj_z\": trajectories_z, \"vis\": visibilities}\n\n\nclass DELTAWrapper(nn.Module):\n    \"\"\"\n    Environment setup:\n    ```sh\n    mkdir -p ./checkpoints/\n    gdown --fuzzy https://drive.google.com/file/d/18d5M3nl3AxbG4ZkT7wssvMXZXbmXrnjz/view?usp=sharing -O ./checkpoints/ # 3D ckpt\n    gdown --fuzzy https://drive.google.com/file/d/1S_T7DzqBXMtr0voRC_XUGn1VTnPk_7Rm/view?usp=sharing -O ./checkpoints/ # 2D ckpt\n    git clone --recursive https://github.com/snap-research/DELTA_densetrack3d ../delta\n    pip install jaxtyping\n\n    python eval.py experiment_path=logs/delta model=delta\n    ```\n    \"\"\"\n\n    def __init__(\n            self,\n            ckpt=\"checkpoints/densetrack3d.pth\",\n            upsample_factor=4,\n            grid_size=20,\n            return_2d_track=False,\n    ):\n        super(DELTAWrapper, self).__init__()\n\n        self.grid_size = grid_size\n        self.return_2d_track = return_2d_track\n\n        sys.path.append(\"../delta\")\n        from densetrack3d.models.densetrack3d.densetrack3d import DenseTrack3D\n        from densetrack3d.models.predictor.predictor import Predictor3D\n        model = DenseTrack3D(\n            stride=4,\n            window_len=16,\n            add_space_attn=True,\n            num_virtual_tracks=64,\n            model_resolution=(384, 512),\n            upsample_factor=upsample_factor\n        )\n        with open(ckpt, \"rb\") as f:\n            state_dict = torch.load(f, map_location=\"cpu\")\n            if \"model\" in state_dict:\n                state_dict = state_dict[\"model\"]\n        model.load_state_dict(state_dict, strict=False)\n        predictor = Predictor3D(model=model)\n        predictor = predictor.eval().cuda()\n        self.model = model\n        self.predictor = predictor\n\n    def forward(self, rgbs, depths, queries, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert depths.shape == (T, 1, H, W)\n        assert queries.shape == (N, 3)\n\n        out_dict = self.predictor(\n            rgbs[None],\n            depths[None],\n            queries=queries[None],\n            segm_mask=None,\n            grid_size=self.grid_size,\n            grid_query_frame=0,\n            backward_tracking=False,\n            predefined_intrs=None\n        )\n\n        trajectories_2d = out_dict[\"trajs_uv\"][0]\n        trajectories_z = out_dict[\"trajs_depth\"][0]\n        trajectories_3d = out_dict[\"trajs_3d_dict\"][\"coords\"][0]\n        visibilities = out_dict[\"vis\"][0]\n\n        if self.return_2d_track:\n            return {\"traj_2d\": trajectories_2d, \"vis\": visibilities}\n        else:\n            return {\"traj_2d\": trajectories_2d, \"traj_z\": trajectories_z, \"vis\": visibilities}\n\n\nclass TAPIP3DWrapper(nn.Module):\n    \"\"\"\n    Environment setup:\n    ```sh\n    wget --directory-prefix=checkpoints https://huggingface.co/zbww/tapip3d/resolve/main/tapip3d_final.pth\n    git clone git@github.com:zbw001/TAPIP3D.git ../tapip3d\n    cd ../tapip3d\n    git checkout 9359ae236f16a58a103dc1c55ad1919360dc6f8b\n    cd third_party/pointops2\n    LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH python setup.py install\n    cd ../..\n    \"\"\"\n\n    def __init__(\n            self,\n            ckpt=\"checkpoints/tapip3d_final.pth\",\n            num_iters=6,\n            grid_size=8,\n            resolution_factor=2,\n            transform_to_camera_space=False,\n    ):\n        super(TAPIP3DWrapper, self).__init__()\n\n        self.num_iters = num_iters\n        self.support_grid_size = grid_size\n        self.resolution_factor = resolution_factor\n        self.transform_to_camera_space = transform_to_camera_space\n\n        sys.path.append(\"../tapip3d\")\n        from utils.inference_utils import load_model\n        self.model = load_model(ckpt)\n        self.model.cuda()\n\n        inference_res = (\n            int(self.model.image_size[0] * np.sqrt(self.resolution_factor)),\n            int(self.model.image_size[1] * np.sqrt(self.resolution_factor)),\n        )\n        self.model.set_image_size(inference_res)\n\n    def forward(self, rgbs, depths, intrs, extrs, queries_xyz_worldspace, **kwargs):\n        T, _, H, W = rgbs.shape\n        N, _ = queries_xyz_worldspace.shape\n\n        assert rgbs.shape == (T, 3, H, W)\n        assert depths.shape == (T, 1, H, W)\n        assert intrs.shape == (T, 3, 3)\n        assert extrs.shape == (T, 3, 4)\n        assert queries_xyz_worldspace.shape == (N, 4)\n\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(T, 1, 1)\n        extrs_square[:, :3, :] = extrs\n\n        # Transform the extrinsics (and query points) so that\n        # the camera is in the origin, and later revert.\n        # But it's about the same performance either way.\n        if self.transform_to_camera_space:\n            T = extrs_square[0]\n            T_inv = torch.inverse(T)\n            extrs = extrs.clone()\n            (\n                _, extrs, queries_xyz_worldspace, _, _\n            ) = transform_scene(1, T[:3, :3], T[:3, 3], None, extrs[None], queries_xyz_worldspace, None, None)\n            extrs = extrs[0]\n            extrs_square[:, :3, :] = extrs\n\n        # Run inference\n        with torch.autocast(\"cuda\", dtype=torch.bfloat16):\n            trajectories_3d, visibilities = TAPIP3DWrapper.inference(\n                model=self.model,\n                video=rgbs / 255.0,\n                depths=depths.squeeze(1),\n                intrinsics=intrs,\n                extrinsics=extrs_square,\n                query_point=queries_xyz_worldspace,\n                num_iters=self.num_iters,\n                grid_size=self.support_grid_size,\n            )\n\n        if self.transform_to_camera_space:\n            (\n                _, _, _, trajectories_3d, _\n            ) = transform_scene(1, T_inv[:3, :3], T_inv[:3, 3], None, None, None, trajectories_3d, None)\n\n        if N == 1:\n            trajectories_3d = trajectories_3d.unsqueeze(1)\n            visibilities = visibilities.unsqueeze(1)\n        assert trajectories_3d.shape == (T, N, 3)\n        assert visibilities.shape == (T, N)\n\n        return {\"traj_2d\": None, \"traj_3d_worldspace\": trajectories_3d.clone(), \"vis\": visibilities.clone()}\n\n    @staticmethod\n    @torch.no_grad()\n    def inference(\n            *,\n            model: torch.nn.Module,\n            video: torch.Tensor,\n            depths: torch.Tensor,\n            intrinsics: torch.Tensor,\n            extrinsics: torch.Tensor,\n            query_point: torch.Tensor,\n            num_iters: int = 6,\n            grid_size: int = 8,\n            bidrectional: bool = True,\n            vis_threshold=None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        from utils.inference_utils import _inference_with_grid\n        from einops import repeat\n\n        _depths = depths.clone()\n        _depths = _depths[_depths > 0].reshape(-1)\n        q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values\n        q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values\n        iqr = q75 - q25\n        _depth_roi = torch.tensor(\n            [1e-7, (q75 + 1.5 * iqr).item()],\n            dtype=torch.float32,\n            device=video.device\n        )\n\n        T, C, H, W = video.shape\n        assert depths.shape == (T, H, W)\n        N = query_point.shape[0]\n\n        model.set_image_size((H, W))\n\n        preds, _ = _inference_with_grid(\n            model=model,\n            video=video[None],\n            depths=depths[None],\n            intrinsics=intrinsics[None],\n            extrinsics=extrinsics[None],\n            query_point=query_point[None],\n            num_iters=num_iters,\n            depth_roi=_depth_roi,\n            grid_size=grid_size\n        )\n\n        if bidrectional and not model.bidirectional and (query_point[..., 0] > 0).any():\n            preds_backward, _ = _inference_with_grid(\n                model=model,\n                video=video[None].flip(dims=(1,)),\n                depths=depths[None].flip(dims=(1,)),\n                intrinsics=intrinsics[None].flip(dims=(1,)),\n                extrinsics=extrinsics[None].flip(dims=(1,)),\n                query_point=torch.cat([T - 1 - query_point[..., :1], query_point[..., 1:]], dim=-1)[None],\n                num_iters=num_iters,\n                depth_roi=_depth_roi,\n                grid_size=grid_size,\n            )\n            preds.coords = torch.where(\n                repeat(torch.arange(T, device=video.device), 't -> b t n 3', b=1, n=N) < repeat(\n                    query_point[..., 0][None], 'b n -> b t n 3', t=T, n=N),\n                preds_backward.coords.flip(dims=(1,)),\n                preds.coords\n            )\n            preds.visibs = torch.where(\n                repeat(torch.arange(T, device=video.device), 't -> b t n', b=1, n=N) < repeat(\n                    query_point[..., 0][None], 'b n -> b t n', t=T, n=N),\n                preds_backward.visibs.flip(dims=(1,)),\n                preds.visibs\n            )\n\n        coords, visib_logits = preds.coords, preds.visibs\n        visibs = torch.sigmoid(visib_logits)\n        if vis_threshold is not None:\n            visibs = visibs >= vis_threshold\n        return coords.squeeze(), visibs.squeeze()\n\n\nclass MonocularToMultiViewAdapter(nn.Module):\n    def __init__(self, model, **kwargs):\n        super(MonocularToMultiViewAdapter, self).__init__()\n        self.model = model\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points,\n            intrs,\n            extrs,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            query_points_view=None,\n            **kwargs,\n    ):\n        batch_size, num_views, num_frames, _, height, width = rgbs.shape\n        _, num_points, _ = query_points.shape\n\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n        assert query_points.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        # Project the queries to each view\n        query_points_t = query_points[:, :, :1].long()\n        query_points_xyz_worldspace = query_points[:, :, 1:]\n\n        query_points_xy_pixelspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 2))\n        query_points_z_cameraspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 1))\n        for batch_idx in range(batch_size):\n            for t in query_points_t[batch_idx].unique():\n                query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t\n                point_3d_world = query_points_xyz_worldspace[batch_idx][query_points_t_mask]\n\n                # World to camera space\n                point_4d_world_homo = torch.cat(\n                    [point_3d_world, point_3d_world.new_ones(point_3d_world[..., :1].shape)], -1)\n                point_3d_camera = torch.einsum('Aij,Bj->ABi', extrs[batch_idx, :, t, :, :], point_4d_world_homo[:, :])\n\n                # Camera to pixel space\n                point_2d_pixel_homo = torch.einsum('Aij,ABj->ABi', intrs[batch_idx, :, t, :, :], point_3d_camera[:, :])\n                point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:]\n\n                query_points_xy_pixelspace_per_view[batch_idx, :, query_points_t_mask] = point_2d_pixel\n                query_points_z_cameraspace_per_view[batch_idx, :, query_points_t_mask] = point_3d_camera[..., -1:]\n\n        # Estimate occlusion mask in each view based on depth maps\n        query_points_depth_in_view = query_points.new_zeros((batch_size, num_views, num_points, 1))\n        for batch_idx in range(batch_size):\n            for view_idx in range(num_views):\n                for t in query_points_t[batch_idx].unique():\n                    query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t\n                    interpolated_depth = bilinear_sample2d(\n                        im=depths[batch_idx, view_idx, t][None],\n                        x=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 0][None],\n                        y=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 1][None],\n                    )[0].permute(1, 0).type(query_points.dtype)\n                    query_points_depth_in_view[batch_idx, view_idx, query_points_t_mask] = interpolated_depth\n\n        query_points_depth_in_view_masked = query_points_depth_in_view.clone()\n        query_points_outside_of_view_box = (\n                (query_points_xy_pixelspace_per_view[..., 0] < 0) |\n                (query_points_xy_pixelspace_per_view[..., 0] >= width) |\n                (query_points_xy_pixelspace_per_view[..., 1] < 0) |\n                (query_points_xy_pixelspace_per_view[..., 1] >= height) |\n                (query_points_z_cameraspace_per_view[..., 0] < 0)\n        )\n        if query_points_outside_of_view_box.all(1).any():\n            warnings.warn(f\"There are some query points that are outside of the frame of every view: \"\n                          f\"{query_points_xy_pixelspace_per_view[query_points_outside_of_view_box.all(1)[:, None, :].repeat(1, num_views, 1)].reshape(num_views, -1, 2).permute(1, 0, 2)}\")\n        query_points_depth_in_view_masked[query_points_outside_of_view_box] = -1e4\n        query_points_best_visibility_view = (\n                query_points_depth_in_view_masked - query_points_z_cameraspace_per_view).argmax(1)\n        query_points_best_visibility_view = query_points_best_visibility_view.squeeze(-1)\n\n        if query_points_view is not None:\n            query_points_best_visibility_view = query_points_view\n            logging.info(f\"Using the provided query_points_view instead of the estimated one\")\n\n        assert batch_size == 1, \"Batch size > 1 is not supported yet\"\n        batch_idx = 0\n\n        # Call the 2D tracker for each view\n        traj_e_per_view = {}\n        vis_e_per_view = {}\n        for view_idx in range(num_views):\n            track_mask = query_points_best_visibility_view[batch_idx] == view_idx\n            if track_mask.sum() == 0:\n                continue\n\n            view_rgbs = rgbs[batch_idx, view_idx]\n            view_depths = depths[batch_idx, view_idx]\n            view_intrs = intrs[batch_idx, view_idx]\n            view_extrs = extrs[batch_idx, view_idx]\n            view_query_points = torch.concat([\n                query_points_t[batch_idx, :, :][track_mask],\n                query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask],\n            ], dim=-1)\n            view_query_points_with_z = torch.concat([\n                query_points_t[batch_idx, :, :][track_mask],\n                query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask],\n                query_points_z_cameraspace_per_view[batch_idx, view_idx, :][track_mask],\n            ], dim=-1)\n            view_query_points_xyz_worldspace = torch.concat([\n                query_points_t[batch_idx, :, :][track_mask],\n                query_points_xyz_worldspace[batch_idx, :][track_mask],\n            ], dim=-1)\n\n            results = self.model(\n                rgbs=view_rgbs,\n                depths=view_depths,\n                intrs=view_intrs,\n                extrs=view_extrs,\n                queries=view_query_points,\n                queries_with_z=view_query_points_with_z,\n                queries_xyz_worldspace=view_query_points_xyz_worldspace,\n            )\n            view_traj_e = results[\"traj_2d\"]\n            view_vis_e = results[\"vis\"]\n\n            if save_debug_logs and view_traj_e is not None:\n                visualizer = Visualizer(\n                    save_dir=debug_logs_path,\n                    pad_value=16,\n                    fps=12,\n                    show_first_frame=0,\n                    tracks_leave_trace=3,\n                )\n                visualizer.visualize(\n                    video=view_rgbs[None].cpu(),\n                    tracks=view_traj_e[None].cpu(),\n                    visibility=view_vis_e[None].cpu(),\n                    filename=f\"view_{view_idx}.mp4\",\n                    query_frame=query_points_t[batch_idx, :, 0][track_mask][None],\n                    save_video=True,\n                )\n\n            # Project the trajectories to the world space\n            if \"traj_3d_worldspace\" in results:\n                view_traj_e = results[\"traj_3d_worldspace\"]\n            else:\n                if \"traj_z\" in results:\n                    view_camera_z = results[\"traj_z\"]\n                else:\n                    view_camera_z = bilinear_sampler(view_depths, view_traj_e.reshape(num_frames, -1, 1, 2))[:, 0, :, :]\n\n                view_intrs = intrs[batch_idx, view_idx]\n                view_extrs = extrs[batch_idx, view_idx]\n                intrs_inv = torch.inverse(view_intrs.float())\n                view_extrs_square = torch.eye(4).to(view_extrs.device)[None].repeat(num_frames, 1, 1)\n                view_extrs_square[:, :3, :] = view_extrs\n                extrs_inv = torch.inverse(view_extrs_square.float())\n                view_traj_e = pixel_xy_and_camera_z_to_world_space(\n                    pixel_xy=view_traj_e[..., :].float(),\n                    camera_z=view_camera_z.float(),\n                    intrs_inv=intrs_inv,\n                    extrs_inv=extrs_inv,\n                )\n\n            # Set the trajectory to (0,0,0) for the timesteps before the query timestep\n            for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)[track_mask]):\n                view_traj_e[:t, point_idx, :] = 0.0\n\n            traj_e_per_view[view_idx] = view_traj_e[None]\n            vis_e_per_view[view_idx] = view_vis_e[None]\n\n        # Merging the results from all views\n        views_to_keep = list(traj_e_per_view.keys())\n        traj_e = torch.cat([traj_e_per_view[view_idx] for view_idx in views_to_keep], dim=2)\n        vis_e = torch.cat([vis_e_per_view[view_idx] for view_idx in views_to_keep], dim=2)\n\n        # Sort the traj_e and vis_e based on the original indices, since concatenating the results from all views\n        # will first put the results from the first view, then the results from the second view, and so on.\n        # But we want to keep the trajectories order to match the original query points order.\n        sort_inds = []\n        for view_idx in views_to_keep:\n            track_mask = query_points_best_visibility_view[batch_idx] == view_idx\n            if track_mask.sum() == 0:\n                continue\n            global_indices = torch.nonzero(track_mask).squeeze(-1)\n            sort_inds += [global_indices]\n        sort_inds = torch.cat(sort_inds, dim=0)\n        inv_sort_inds = torch.argsort(sort_inds, dim=0)\n\n        # Use the inv_sort_inds to sort the traj_e and vis_e\n        traj_e = traj_e[:, :, inv_sort_inds]\n        vis_e = vis_e[:, :, inv_sort_inds]\n\n        # Save to results\n        results = {\"traj_e\": traj_e, \"vis_e\": vis_e}\n        return results\n\n\n# From https://github.com/facebookresearch/co-tracker/blob/82e02e8029753ad4ef13cf06be7f4fc5facdda4d/cotracker/models/core/model_utils.py#L286\ndef bilinear_sampler(input, coords, align_corners=True, padding_mode=\"border\"):\n    r\"\"\"Sample a tensor using bilinear interpolation\n\n    `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at\n    coordinates :attr:`coords` using bilinear interpolation. It is the same\n    as `torch.nn.functional.grid_sample()` but with a different coordinate\n    convention.\n\n    The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where\n    :math:`B` is the batch size, :math:`C` is the number of channels,\n    :math:`H` is the height of the image, and :math:`W` is the width of the\n    image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is\n    interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.\n\n    Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,\n    in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note\n    that in this case the order of the components is slightly different\n    from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.\n\n    If `align_corners` is `True`, the coordinate :math:`x` is assumed to be\n    in the range :math:`[0,W-1]`, with 0 corresponding to the center of the\n    left-most image pixel :math:`W-1` to the center of the right-most\n    pixel.\n\n    If `align_corners` is `False`, the coordinate :math:`x` is assumed to\n    be in the range :math:`[0,W]`, with 0 corresponding to the left edge of\n    the left-most pixel :math:`W` to the right edge of the right-most\n    pixel.\n\n    Similar conventions apply to the :math:`y` for the range\n    :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range\n    :math:`[0,T-1]` and :math:`[0,T]`.\n\n    Args:\n        input (Tensor): batch of input images.\n        coords (Tensor): batch of coordinates.\n        align_corners (bool, optional): Coordinate convention. Defaults to `True`.\n        padding_mode (str, optional): Padding mode. Defaults to `\"border\"`.\n\n    Returns:\n        Tensor: sampled points.\n    \"\"\"\n\n    sizes = input.shape[2:]\n\n    assert len(sizes) in [2, 3]\n\n    if len(sizes) == 3:\n        # t x y -> x y t to match dimensions T H W in grid_sample\n        coords = coords[..., [1, 2, 0]]\n\n    if align_corners:\n        coords = coords * torch.tensor(\n            [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device\n        )\n    else:\n        coords = coords * torch.tensor(\n            [2 / size for size in reversed(sizes)], device=coords.device\n        )\n\n    coords -= 1\n\n    return F.grid_sample(\n        input, coords, align_corners=align_corners, padding_mode=padding_mode\n    )\n"
  },
  {
    "path": "mvtracker/models/core/mvtracker/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/models/core/mvtracker/mvtracker.py",
    "content": "import logging\nimport os\nfrom collections import defaultdict\nfrom typing import Optional, Callable\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom einops import rearrange\nfrom torch import nn as nn\n\nfrom mvtracker.datasets.utils import transform_scene\nfrom mvtracker.models.core.cotracker2.blocks import Attention, FlashAttention\nfrom mvtracker.models.core.cotracker2.blocks import EfficientUpdateFormer\nfrom mvtracker.models.core.embeddings import (\n    get_3d_sincos_pos_embed_from_grid,\n    get_1d_sincos_pos_embed_from_grid,\n    get_3d_embedding,\n)\nfrom mvtracker.models.core.model_utils import smart_cat, init_pointcloud_from_rgbd, save_pointcloud_to_ply\nfrom mvtracker.models.core.spatracker.blocks import BasicEncoder\nfrom mvtracker.utils.basic import time_now\n\n\n# ---------- KNN backends ----------\ndef _knn_pointops(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor):\n    \"\"\"\n    Efficient batched KNN using pointops library.\n\n    This is slightly faster than torch.cdist + torch.topk and uses less memory:\n\n    Example::\n\n        Benchmarking KNN with different methods (HALF_PRECISION=True):\n        torch.cdist+torch.topk   | Avg Time: 0.008380 s | Peak Memory: 1151.19 MB (min: 1151.19, max: 1151.19)\n        pointops.knn_query       | Avg Time: 0.007477 s | Peak Memory:  47.22 MB (min:  47.22, max:  47.22)\n\n        Benchmarking KNN with different methods (HALF_PRECISION=False):\n        torch.cdist+torch.topk   | Avg Time: 0.014090 s | Peak Memory: 2249.88 MB (min: 2249.88, max: 2249.88)\n        pointops.knn_query       | Avg Time: 0.007368 s | Peak Memory:  43.62 MB (min:  43.62, max:  43.62)\n\n    Args:\n        xyz_ref (Tensor): (B, N, 3)\n        xyz_query (Tensor): (B, M, 3)\n\n    Returns:\n        Tuple[Tensor, Tensor]:\n            - dist (Tensor): (B, M, k)\n            - idx (Tensor): (B, M, k) int32 — indices into dimension N\n    \"\"\"\n    # Fallback if tensors are not on CUDA\n    if not xyz_ref.is_cuda:\n        return _knn_torch(k, xyz_ref, xyz_query)\n\n    from pointops import knn_query\n    B, N, _ = xyz_ref.shape\n    _, M, _ = xyz_query.shape\n    orig_dtype = xyz_ref.dtype\n\n    xyz_ref_flat = xyz_ref.contiguous().view(B * N, 3).to(torch.float32)\n    xyz_query_flat = xyz_query.contiguous().view(B * M, 3).to(torch.float32)\n\n    offset = torch.arange(1, B + 1, device=xyz_ref.device) * N\n    new_offset = torch.arange(1, B + 1, device=xyz_query.device) * M\n    idx, dists = knn_query(k, xyz_ref_flat, offset, xyz_query_flat, new_offset)\n\n    # Remap global indices to local per-batch\n    idx = idx.view(B, M, k)\n    idx = idx - (torch.arange(B, device=idx.device).view(B, 1, 1) * N)\n    dists = dists.view(B, M, k).to(orig_dtype)\n\n    return dists, idx\n\n\ndef _knn_torch(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor):\n    \"\"\"Fallback KNN using torch.cdist + topk.\"\"\"\n    dists = torch.cdist(xyz_query, xyz_ref, p=2)  # (B, M, N)\n    sorted_dists, indices = torch.topk(dists, k, dim=-1, largest=False, sorted=True)\n    return sorted_dists, indices\n\n\n# Select backend once (safe if pointops missing).\ntry:\n    import importlib\n\n    importlib.import_module(\"pointops\")\n    knn = _knn_pointops\nexcept Exception:\n    logging.warning(\"pointops not found, falling back to slower KNN implementation.\")\n    knn = _knn_torch\n\n\nclass MVTracker(nn.Module):\n    def __init__(\n            self,\n            sliding_window_len=12,\n            stride=4,\n            normalize_scene_in_fwd_pass=False,\n            fmaps_dim=128,\n            add_space_attn=True,\n            num_heads=6,\n            hidden_size=384,\n            space_depth=6,\n            time_depth=6,\n            num_virtual_tracks=64,\n            use_flash_attention=True,\n            corr_n_groups=1,\n            corr_n_levels=4,\n            corr_neighbors=16,\n            corr_add_neighbor_offset=True,\n            corr_add_neighbor_xyz=False,\n            corr_filter_invalid_depth=False,\n    ):\n        super().__init__()\n\n        self.S = sliding_window_len\n        self.stride = stride\n        self.normalize_scene_in_fwd_pass = normalize_scene_in_fwd_pass\n        self.latent_dim = fmaps_dim\n        self.flow_embed_dim = 64\n        self.b_latent_dim = self.latent_dim // 3\n        self.corr_n_groups = corr_n_groups\n        self.corr_n_levels = corr_n_levels\n        self.corr_neighbors = corr_neighbors\n        self.corr_pos_emb_size = 0\n        self.corr_add_neighbor_offset = corr_add_neighbor_offset\n        self.corr_add_neighbor_xyz = corr_add_neighbor_xyz\n        self.corr_filter_invalid_depth = corr_filter_invalid_depth\n        self.add_space_attn = add_space_attn\n        self.updateformer_input_dim = (\n            # The positional encoding of the 3D flow from t=i to t=0\n                + (self.flow_embed_dim + 1) * 3\n\n                # The correlation features (LRR) for the three planes (xy, yz, xz), concatenated\n                + self.corr_neighbors * self.corr_n_levels\n                * (self.corr_n_groups\n                   + 3 * self.corr_add_neighbor_offset\n                   + 3 * self.corr_add_neighbor_xyz\n                   + self.corr_pos_emb_size)\n\n                # The features of the tracked points, one for each of the three planes\n                + self.latent_dim\n\n                # The visibility mask\n                + 1\n\n                # The whether-the-point-is-tracked mask\n                + 1\n        )\n\n        # Feature encoder\n        self.fnet = BasicEncoder(\n            input_dim=3,\n            output_dim=self.latent_dim,\n            norm_fn=\"instance\",\n            dropout=0,\n            stride=self.stride,\n            Embed3D=False,\n        )\n\n        # Transformer for iterative updates\n        self.updateformer_hidden_size = hidden_size\n        self.updateformer = EfficientUpdateFormer(\n            space_depth=space_depth,\n            time_depth=time_depth,\n            input_dim=self.updateformer_input_dim,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            output_dim=3 + self.latent_dim,\n            mlp_ratio=4.0,\n            add_space_attn=add_space_attn,\n            num_virtual_tracks=num_virtual_tracks,\n            attn_class=FlashAttention if use_flash_attention else Attention,\n            linear_layer_for_vis_conf=False,\n        )\n\n        # Feature update + visibility\n        self.ffeats_norm = nn.GroupNorm(1, self.latent_dim)\n        self.ffeats_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())\n        self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))\n\n        self.stats_pyramid = None\n        self.stats_depth = None\n\n    def fnet_fwd(self, rgbs_normalized, image_features=None):\n        b, v, t, _, h, w = rgbs_normalized.shape\n        rgbs_normalized = rgbs_normalized.reshape(-1, 3, h, w)\n        return self.fnet(rgbs_normalized)\n\n    def init_stats(self):\n        self.stats_pyramid = defaultdict(list)\n        self.stats_depth = []\n\n    def consume_stats(self):\n        # Per-pyramid-level summary of neighbor distances\n        level_to_norms = defaultdict(list)\n        for (level, _), norm_lists in self.stats_pyramid.items():\n            level_to_norms[level].extend(norm_lists)\n        level_summary = []\n        for level, norm_lists in level_to_norms.items():\n            norms = np.concatenate(norm_lists).astype(float)\n            stats = pd.Series(norms).describe(percentiles=[.25, .5, .75])\n            level_summary.append({\n                \"level\": level,\n                \"count\": int(stats[\"count\"]),\n                \"mean\": round(float(stats[\"mean\"] * 100), 1),\n                \"std\": round(float(stats[\"std\"] * 100), 1),\n                \"min\": round(float(stats[\"min\"] * 100), 1),\n                \"25%\": round(float(stats[\"25%\"] * 100), 1),\n                \"50%\": round(float(stats[\"50%\"] * 100), 1),\n                \"75%\": round(float(stats[\"75%\"] * 100), 1),\n                \"max\": round(float(stats[\"max\"] * 100), 1),\n            })\n        df_level_summary = pd.DataFrame(level_summary).sort_values(\"level\")\n        logging.info(f\"Neighbor distances across pyramid levels:\\n{df_level_summary}\")\n\n        # Per-pyramid-level and per-iteration summary of neighbor distances\n        summary = []\n        for (level, it), norm_lists in self.stats_pyramid.items():\n            norms = np.concatenate(norm_lists).astype(float)\n            stats = pd.Series(norms).describe(percentiles=[.25, .5, .75])\n            summary.append({\n                \"level\": level,\n                \"iteration\": it,\n                \"count\": int(stats[\"count\"]),\n                \"mean\": round(float(stats[\"mean\"] * 100), 1),\n                \"std\": round(float(stats[\"std\"] * 100), 1),\n                \"min\": round(float(stats[\"min\"] * 100), 1),\n                \"25%\": round(float(stats[\"25%\"] * 100), 1),\n                \"50%\": round(float(stats[\"50%\"] * 100), 1),\n                \"75%\": round(float(stats[\"75%\"] * 100), 1),\n                \"max\": round(float(stats[\"max\"] * 100), 1),\n            })\n        df_summary = pd.DataFrame(summary).sort_values([\"level\", \"iteration\"])\n        logging.info(f\"Neighbor distances across pyramid levels and iterations (in cm):\\n{(df_summary)}\")\n\n        # Valid vs invalid depth stats\n        depth_stats = pd.Series(self.stats_depth).describe(percentiles=[.25, .5, .75]).astype(float).round(1)\n        logging.info(f\"Depth stats (valid vs invalid):\\n{depth_stats}\")\n\n        self.stats_pyramid = None\n        self.stats_depth = None\n\n    def forward_iteration(\n            self,\n            fmaps,\n            depths,\n            intrs,\n            extrs,\n            coords_init,\n            vis_init,\n            track_mask,\n            iters=4,\n            feat_init=None,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            debug_logs_prefix=\"\",\n            debug_logs_window_idx=None,\n            save_rerun_logs: bool = False,\n            rerun_fmap_coloring_fn: Optional[Callable] = None,\n    ):\n        B, V, S, D, H, W = fmaps.shape\n        N = coords_init.shape[2]\n        device = fmaps.device\n        if coords_init.shape[1] < S:\n            coords = torch.cat([coords_init, coords_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1)\n            vis_init = torch.cat([vis_init, vis_init[:, -1].repeat(1, S - vis_init.shape[1], 1, 1)], dim=1)\n        else:\n            coords = coords_init.clone()\n        if track_mask.shape[1] < S:\n            track_mask = torch.cat([\n                track_mask,\n                torch.zeros_like(track_mask[:, 0]).repeat(1, S - track_mask.shape[1], 1, 1),\n            ], dim=1)\n        assert B == 1\n        assert D == self.latent_dim\n        assert fmaps.shape == (B, V, S, D, H, W)\n        assert depths.shape == (B, V, S, 1, H, W)\n        assert intrs.shape == (B, V, S, 3, 3)\n        assert extrs.shape == (B, V, S, 3, 4)\n        assert coords.shape == (B, S, N, 3)\n        assert vis_init.shape == (B, S, N, 1)\n        assert track_mask.shape == (B, S, N, 1)\n        assert feat_init is None or feat_init.shape == (B, S, N, self.latent_dim)\n\n        assert track_mask.any(1).all(), \"All points should be requested to be tracked at least for one frame\"\n\n        intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, S, 1, 1)\n        extrs_square[:, :, :, :3, :] = extrs\n        extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n        assert intrs_inv.shape == (B, V, S, 3, 3)\n        assert extrs_square.shape == (B, V, S, 4, 4)\n        assert extrs_inv.shape == (B, V, S, 4, 4)\n\n        fcorr_fns = {}\n        for lvl in range(self.corr_n_levels):\n            pc = init_pointcloud_from_rgbd(\n                fmaps=fmaps,\n                depths=depths,\n                intrs=intrs,\n                extrs=extrs,\n                stride=self.stride,\n                level=lvl,\n                return_validity_mask=self.corr_filter_invalid_depth or save_rerun_logs,\n            )\n            if self.corr_filter_invalid_depth or save_rerun_logs:\n                pc_xyz, pc_fvec, pc_valid = pc\n            else:\n                pc_xyz, pc_fvec = pc\n                pc_valid = None\n            fcorr_fns[lvl] = PointcloudCorrBlock(\n                k=self.corr_neighbors,\n                groups=self.corr_n_groups,\n                xyz=pc_xyz,\n                fvec=pc_fvec,\n                filter_invalid=self.corr_filter_invalid_depth,\n                valid=pc_valid,\n                corr_add_neighbor_offset=self.corr_add_neighbor_offset,\n                corr_add_neighbor_xyz=self.corr_add_neighbor_xyz,\n                rerun_fmap_coloring_fn=rerun_fmap_coloring_fn,\n            )\n\n        # Positional/time embeddings (keep shapes identical to before)\n        embed_dim = self.updateformer_input_dim\n        if embed_dim % 6 != 0:\n            embed_dim += 6 - (embed_dim % 6)\n        pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, coords[:, 0:1]).float()[:, 0].permute(0, 2, 1)\n        if embed_dim > self.updateformer_input_dim:\n            pos_embed = pos_embed[:, :self.updateformer_input_dim, :]\n        pos_embed = rearrange(pos_embed, \"b e n -> (b n) e\").unsqueeze(1)\n\n        times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1) / S\n        embed_dim = self.updateformer_input_dim\n        if embed_dim % 2 != 0:\n            embed_dim += 2 - (embed_dim % 2)\n        times_embed = (\n            torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, times_[0]))[None]\n            .repeat(B, 1, 1)\n            .float()\n            .to(device)\n        )\n        if embed_dim > self.updateformer_input_dim:\n            times_embed = times_embed[:, :, :self.updateformer_input_dim]\n\n        coord_predictions = []\n\n        ffeats = feat_init.clone()\n        track_mask_and_vis = torch.cat([track_mask, vis_init], dim=3).permute(0, 2, 1, 3).reshape(B * N, S, 2)\n        for it in range(iters):\n            coords = coords.detach()\n\n            # Sample correlation features around each point\n            fcorrs = []\n            for lvl in range(self.corr_n_levels):\n                fcorr_fn = fcorr_fns[lvl]\n                fcorrs_level = (\n                    fcorr_fn\n                    .corr_sample(\n                        targets=ffeats.reshape(B * S, N, self.latent_dim),\n                        coords_world_xyz=coords.reshape(B * S, N, 3),\n                        save_debug_logs=False,\n                        debug_logs_path=debug_logs_path,\n                        debug_logs_prefix=debug_logs_prefix + f\"__iter_{it}__pyramid_level_{lvl}\",\n                        save_rerun_logs=save_rerun_logs,\n                    )\n                    .reshape(B, S, N, -1)\n                )\n                fcorrs.append(fcorrs_level)\n                if self.stats_pyramid is not None:\n                    self.stats_pyramid[(lvl, it)] += [\n                        np.linalg.norm(fcorrs_level.reshape(-1, 4)[:, 1:].detach().cpu().numpy(), axis=-1)\n                    ]\n            fcorrs = torch.cat(fcorrs, dim=-1)\n            LRR = fcorrs.shape[3]\n            fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)\n\n            # Flow embedding\n            flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)\n            flows_ = get_3d_embedding(flows_, self.flow_embed_dim, cat_coords=True)\n\n            ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)\n\n            transformer_input = torch.cat([flows_, fcorrs_, ffeats_, track_mask_and_vis], dim=2)\n            assert transformer_input.shape[-1] == pos_embed.shape[-1]\n            x = transformer_input + pos_embed + times_embed\n            x = rearrange(x, \"(b n) t d -> b n t d\", b=B)\n\n            delta = self.updateformer(x)\n            delta = rearrange(delta, \" b n t d -> (b n) t d\")\n\n            d_coord = delta[:, :, :3].reshape(B, N, S, 3).permute(0, 2, 1, 3)\n\n            d_feats = delta[:, :, 3:self.latent_dim + 3]\n            d_feats = self.ffeats_norm(d_feats.view(-1, self.latent_dim))\n            d_feats = self.ffeats_updater(d_feats).view(B, N, S, self.latent_dim).permute(0, 2, 1, 3)\n\n            coords = coords + d_coord\n            ffeats = ffeats + d_feats\n\n            if torch.isnan(coords).any():\n                logging.error(\"Got NaN values in coords, perhaps the training exploded\")\n                import ipdb\n                ipdb.set_trace()\n\n            coord_predictions.append(coords.clone())\n\n        vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)\n\n        return coord_predictions, vis_e, feat_init\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points,\n            intrs,\n            extrs,\n            iters=4,\n            image_features=None,\n            is_train=False,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            save_rerun_logs: bool = False,\n            save_rerun_logs_output_rrd_path: Optional[str] = None,\n            **kwargs,\n    ):\n        device = extrs.device\n        if save_debug_logs:\n            if kwargs:\n                logging.info(f\"Unused kwargs: {kwargs.keys()}\")\n\n        batch_size, num_views, num_frames, _, height, width = rgbs.shape\n        _, num_points, _ = query_points.shape\n        logging.info(f\"FWD pass: {num_views=} {num_frames=} {num_points=} \"\n                     f\"{height=} {width=} {iters=} {num_points=} {rgbs.dtype=}\")\n\n        # I made a video tutorial here if it is easier to follow: https://www.youtube.com/watch?v=dQw4w9WgXcQ\n\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n        assert query_points.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        if save_debug_logs:\n            os.makedirs(debug_logs_path, exist_ok=True)\n\n        if save_rerun_logs:\n            assert save_rerun_logs_output_rrd_path is not None\n            import rerun as rr\n            rr.init(\"3dpt\", recording_id=\"v0.16\")\n            rr.set_time_seconds(\"frame\", 0)\n\n        if self.stats_depth is not None:\n            self.stats_depth += [(depths == 0).float().mean().item() * 100]\n\n        # Scene normalization (optional): Rigid transformation to center first camera and rescale the scene like VGGT\n        qp_range_before = np.stack([\n            query_points[0, :, 1:].min(0).values.cpu().numpy().round(2),\n            query_points[0, :, 1:].max(0).values.cpu().numpy().round(2),\n        ])\n        if self.normalize_scene_in_fwd_pass:\n            assert batch_size == 1, \"VGGT normalization assumes batch size 1\"\n            max_depth = 24\n            _d = depths.clone()\n            _d[_d < max_depth] = max_depth\n            T_scale, T_rot, T_translation = compute_vggt_scene_normalization_transform(\n                _d[0], extrs[0].to(_d.device), intrs[0].to(_d.device)\n            )\n            T_scale_inv = 1 / T_scale\n            T_rot_inv = T_rot.transpose(0, 1)\n            T_translation_inv = -T_translation @ T_rot_inv\n\n            query_points, extrs = query_points[0], extrs[0]  # Remove batch dimension\n            extrs, query_points, _, _ = transform_scene(T, extrs, query_points, None, None)\n            query_points, extrs = query_points[None], extrs[None]  # Add batch dimension\n        qp_range_after = np.stack([\n            query_points[0, :, 1:].min(0).values.cpu().numpy().round(2),\n            query_points[0, :, 1:].max(0).values.cpu().numpy().round(2),\n        ])\n        if save_debug_logs:\n            logging.info(f\"Query points range before normalization:\\n{qp_range_before}\")\n            logging.info(f\"Query points range after normalization: \\n{qp_range_after}\")\n\n        self.is_train = is_train\n\n        # Unpack the query points\n        query_points_t = query_points[:, :, :1].long()\n        query_points_xyz_worldspace = query_points[:, :, 1:]\n\n        # Invert intrinsics and extrinsics\n        intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n        extrs_square[:, :, :, :3, :] = extrs\n        extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n\n        # Interpolate the rgbs and depthmaps to the stride of the SpaTracker\n        strided_height = height // self.stride\n        strided_width = width // self.stride\n\n        # Filter the points that never appear during 1 - T\n        assert batch_size == 1, \"Batch size > 1 is not supported yet\"\n        query_points_t = query_points_t.squeeze(0).squeeze(-1)  # BN1 --> N\n        ind_array = torch.arange(num_frames, device=query_points.device)\n        ind_array = ind_array[None, :, None].repeat(batch_size, 1, num_points)\n        track_mask = (ind_array >= query_points_t[None, None, :]).unsqueeze(-1)  # TODO: >= or >?\n\n        # Prepare the initial coordinates and visibility\n        coords_init = query_points_xyz_worldspace.unsqueeze(1).repeat(1, self.S, 1, 1)\n        vis_init = query_points.new_ones((batch_size, self.S, num_points, 1)) * 10\n\n        # Sort the queries via their first appeared time\n        _, sort_inds = torch.sort(query_points_t, dim=0, descending=False)\n        inv_sort_inds = torch.argsort(sort_inds, dim=0)\n        assert torch.allclose(query_points_t, query_points_t[sort_inds][inv_sort_inds])\n\n        query_points_t_ = query_points_t[sort_inds]\n        query_points_xyz_worldspace_ = query_points_xyz_worldspace[..., sort_inds, :]\n        coords_init_ = coords_init[..., sort_inds, :].clone()\n        vis_init_ = vis_init[:, :, sort_inds].clone()\n        track_mask_ = track_mask[:, :, sort_inds].clone()\n\n        # Delete the unsorted variables (for safety)\n        del coords_init, vis_init, query_points_t, query_points, query_points_xyz_worldspace, track_mask\n\n        # Placeholders for the results (for the sorted points)\n        traj_e_ = coords_init_.new_zeros((batch_size, num_frames, num_points, 3))\n        vis_e_ = coords_init_.new_zeros((batch_size, num_frames, num_points))\n\n        w_idx_start = query_points_t_.min()\n        p_idx_start = 0\n        vis_predictions = []\n        coord_predictions = []\n        p_idx_end_list = []\n        fmaps_seq, depths_seq, feat_init, rerun_fmap_coloring_fn = None, None, None, None\n        while w_idx_start < num_frames - self.S // 2:\n            curr_wind_points = torch.nonzero(query_points_t_ < w_idx_start + self.S)\n            assert curr_wind_points.shape[0] > 0\n            p_idx_end = curr_wind_points[-1].item() + 1\n            p_idx_end_list.append(p_idx_end)\n\n            intrs_seq = intrs[:, :, w_idx_start:w_idx_start + self.S]\n            extrs_seq = extrs[:, :, w_idx_start:w_idx_start + self.S]\n\n            # Compute fmaps and interpolated depth on a rolling basis\n            # to reduce peak GPU memory consumption, but don't recompute\n            # for the overlapping part of a window\n            if fmaps_seq is None:\n                assert depths_seq is None\n                new_seq_t0 = w_idx_start\n            else:\n                fmaps_seq = fmaps_seq[:, :, self.S // 2:]\n                depths_seq = depths_seq[:, :, self.S // 2:]\n                new_seq_t0 = w_idx_start + self.S // 2\n            new_seq_t1 = w_idx_start + self.S\n\n            _depths_seq_new = nn.functional.interpolate(\n                input=depths[:, :, new_seq_t0:new_seq_t1].to(device).reshape(-1, 1, height, width),\n                scale_factor=1.0 / self.stride,\n                mode=\"nearest\",\n            ).reshape(batch_size, num_views, -1, 1, strided_height, strided_width)\n            depths_seq = smart_cat(depths_seq, _depths_seq_new, dim=2)\n\n            _fmaps_seq_new = self.fnet_fwd(\n                (2 * (rgbs[:, :, new_seq_t0: new_seq_t1].to(device) / 255.0) - 1.0),\n                image_features,\n            )\n            _fmaps_seq_new = nn.functional.interpolate(\n                input=_fmaps_seq_new,\n                size=(strided_height, strided_width),\n                mode=\"bilinear\",\n            ).reshape(batch_size, num_views, -1, self.latent_dim, strided_height, strided_width)\n            fmaps_seq = smart_cat(fmaps_seq, _fmaps_seq_new, dim=2)\n\n            if save_rerun_logs and rerun_fmap_coloring_fn is None:\n                valid_depths_mask = depths_seq.detach().cpu().squeeze(3) > 0\n                fvec_flat = fmaps_seq.detach().cpu().permute(0, 1, 2, 4, 5, 3)[valid_depths_mask].numpy()\n                from sklearn.decomposition import PCA\n                reducer = PCA(n_components=3)\n                reducer.fit(fvec_flat)\n                fvec_reduced = reducer.transform(fvec_flat)\n                reducer_min = fvec_reduced.min(axis=0)\n                reducer_max = fvec_reduced.max(axis=0)\n\n                def fvec_to_rgb(fvec):\n                    input_shape = fvec.shape\n                    assert input_shape[-1] == self.latent_dim\n                    fvec_reduced = reducer.transform(fvec.reshape(-1, self.latent_dim))\n                    fvec_reduced = np.clip(fvec_reduced, reducer_min[None, :], reducer_max[None, :])\n                    fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min)\n                    fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int)\n                    fvec_reduced_rgb = fvec_reduced_rgb.reshape(input_shape[:-1] + (3,))\n                    return fvec_reduced_rgb\n\n                rerun_fmap_coloring_fn = fvec_to_rgb\n\n            S_local = fmaps_seq.shape[2]\n            if S_local < self.S:\n                diff = self.S - S_local\n                fmaps_seq = torch.cat([fmaps_seq, fmaps_seq[:, :, -1:].repeat(1, 1, diff, 1, 1, 1)], 2)\n                depths_seq = torch.cat([depths_seq, depths_seq[:, :, -1:].repeat(1, 1, diff, 1, 1, 1)], 2)\n                intrs_seq = torch.cat([intrs_seq, intrs_seq[:, :, -1:].repeat(1, 1, diff, 1, 1)], 2)\n                extrs_seq = torch.cat([extrs_seq, extrs_seq[:, :, -1:].repeat(1, 1, diff, 1, 1)], 2)\n\n            # Compute the feature vector initialization for the new query points\n            if p_idx_end - p_idx_start > 0:\n                rgbd_xyz, rgbd_fvec = init_pointcloud_from_rgbd(\n                    fmaps=_fmaps_seq_new,\n                    depths=_depths_seq_new,\n                    intrs=intrs[:, :, new_seq_t0:new_seq_t1],\n                    extrs=extrs[:, :, new_seq_t0:new_seq_t1],\n                    stride=self.stride,\n                )\n\n                new_num_frames = _fmaps_seq_new.shape[2]\n                rgbd_xyz = rgbd_xyz.reshape(batch_size, new_num_frames, num_views, strided_height * strided_width, 3)\n                rgbd_fvec = rgbd_fvec.reshape(batch_size, new_num_frames, num_views, strided_height * strided_width,\n                                              self.latent_dim)\n\n                _feat_init_new = torch.zeros(batch_size, p_idx_end - p_idx_start, self.latent_dim,\n                                             device=_fmaps_seq_new.device, dtype=_fmaps_seq_new.dtype)\n                assert batch_size == 1\n                assert ((query_points_t_[p_idx_start:p_idx_end] > new_seq_t0)\n                        | (query_points_t_[p_idx_start:p_idx_end] < new_seq_t1)).all()\n                batch_idx = 0\n                for t in range(new_seq_t0, new_seq_t1):\n                    query_mask = query_points_t_[p_idx_start:p_idx_end] == t\n                    if query_mask.sum() == 0:\n                        continue\n                    query_points_world = query_points_xyz_worldspace_[batch_idx, p_idx_start:p_idx_end][query_mask]\n\n                    rgbd_xyz_current = rgbd_xyz[batch_idx, t - new_seq_t0].reshape(-1, 3)  # Combine views for frame\n                    rgbd_fvec_current = rgbd_fvec[batch_idx, t - new_seq_t0].reshape(-1, self.latent_dim)\n\n                    k = 1\n                    neighbor_dists, neighbor_indices = knn(k, rgbd_xyz_current[None],\n                                                           query_points_world[None])\n                    assert k == 1, \"If k > 1, the code below should be modified to handle multiple neighbors -- how to combine the features of multiple neighbors?\"\n                    neighbor_xyz = rgbd_xyz_current[neighbor_indices[0, :, 0]]\n                    neighbor_fvec = rgbd_fvec_current[neighbor_indices[0, :, 0]]\n\n                    _feat_init_new[batch_idx, query_mask] = neighbor_fvec\n\n                feat_init = smart_cat(feat_init, _feat_init_new.repeat(1, self.S, 1, 1), dim=2)\n\n            # Update the initial coordinates and visibility for non-first windows\n            if p_idx_start > 0:\n                last_coords = coords[-1][:, self.S // 2:].clone()  # Take the predicted coords from the last window\n                coords_init_[:, : self.S // 2, :p_idx_start] = last_coords\n                coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[:, -1].repeat(1, self.S // 2, 1, 1)\n\n                last_vis = vis[:, self.S // 2:][..., None]\n                vis_init_[:, : self.S // 2, :p_idx_start] = last_vis\n                vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat(1, self.S // 2, 1, 1)\n\n            track_mask_current = track_mask_[:, w_idx_start: w_idx_start + self.S, :p_idx_end]\n            if S_local < self.S:\n                track_mask_current = torch.cat([\n                    track_mask_current,\n                    track_mask_current[:, -1:].repeat(1, self.S - S_local, 1, 1),\n                ], 1)\n\n            coords, vis, _ = self.forward_iteration(\n                fmaps=fmaps_seq,\n                depths=depths_seq,\n                intrs=intrs_seq,\n                extrs=extrs_seq,\n                coords_init=coords_init_[:, :, :p_idx_end],\n                feat_init=feat_init[:, :, :p_idx_end],\n                vis_init=vis_init_[:, :, :p_idx_end],\n                track_mask=track_mask_current,\n                iters=iters,\n                save_debug_logs=save_debug_logs,\n                debug_logs_path=debug_logs_path,\n                debug_logs_prefix=f\"__widx-{w_idx_start}_pidx-{p_idx_start}-{p_idx_end}\",\n                debug_logs_window_idx=w_idx_start,\n                save_rerun_logs=save_rerun_logs,\n                rerun_fmap_coloring_fn=rerun_fmap_coloring_fn,\n            )\n\n            if is_train:\n                coord_predictions.append([\n                    coord[:, :S_local]\n                    if not self.normalize_scene_in_fwd_pass\n                    else transform_scene(T_scale_inv, T_rot_inv, T_translation_inv,\n                                         None, None, None, coord[:, :S_local][0], None)[2][None]\n                    for coord in coords\n                ])\n                vis_predictions.append(vis[:, :S_local])\n\n            traj_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords[-1][:, :S_local]\n            vis_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = torch.sigmoid(vis[:, :S_local])\n\n            track_mask_[:, : w_idx_start + self.S, :p_idx_end] = 0.0\n            w_idx_start = w_idx_start + self.S // 2\n\n            p_idx_start = p_idx_end\n\n        if save_debug_logs:\n            import gpustat\n            torch.cuda.empty_cache()\n            logging.info(f\"Forward pass GPU usage: {gpustat.new_query()}\")\n\n        if save_rerun_logs:\n            import rerun as rr\n            rr.save(save_rerun_logs_output_rrd_path)\n            logging.info(f\"Saved Rerun recording to: {os.path.abspath(save_rerun_logs_output_rrd_path)}.\")\n\n        traj_e = traj_e_[:, :, inv_sort_inds]\n        vis_e = vis_e_[:, :, inv_sort_inds]\n\n        # Un-normalize the scene\n        if self.normalize_scene_in_fwd_pass:\n            traj_e = transform_scene(T_scale_inv, T_rot_inv, T_translation_inv,\n                                     None, None, None, traj_e[0], None)[2][None]\n\n        results = {\n            \"traj_e\": traj_e,\n            \"feat_init\": feat_init,\n            \"vis_e\": vis_e,\n        }\n        if self.is_train:\n            results[\"train_data\"] = {\n                \"vis_predictions\": vis_predictions,\n                \"coord_predictions\": coord_predictions,\n                \"attn_predictions\": None,\n                \"p_idx_end_list\": p_idx_end_list,\n                \"sort_inds\": sort_inds,\n                \"Rigid_ln_total\": None,\n            }\n        return results\n\n\ndef compute_vggt_scene_normalization_transform(depths, extrs, intrs):\n    V, T, _, H, W = depths.shape\n    device = depths.device\n\n    extrs_square = torch.eye(4, device=device)[None, None].repeat(V, T, 1, 1)\n    extrs_square[:, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n\n    intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n\n    y, x = torch.meshgrid(\n        torch.arange(H, device=device),\n        torch.arange(W, device=device),\n        indexing=\"ij\"\n    )\n    homog = torch.stack([x, y, torch.ones_like(x)], dim=-1).float().reshape(-1, 3)\n    homog = homog[None].expand(V, -1, -1).type(depths.dtype)\n\n    cam_points = torch.einsum(\"vij,vnj->vni\", intrs_inv[:, 0], homog) * depths[:, 0].reshape(V, -1, 1)\n    cam_points_h = torch.cat([cam_points, torch.ones_like(cam_points[..., :1])], dim=-1)\n    world_points_h = torch.einsum(\"vij,vnj->vni\", extrs_inv[:, 0], cam_points_h)\n\n    world_points_in_first = torch.einsum(\"ij,vnj->vni\", extrs[0, 0], world_points_h)\n\n    mask = (depths[:, 0] > 0).reshape(V, -1)\n    valid_points = world_points_in_first[mask]\n    avg_dist = valid_points.norm(dim=1).mean()\n    scale = 1.0 / avg_dist\n\n    rot = extrs[0, 0, :3, :3]\n    translation = extrs[0, 0, :3, 3] * scale\n    return scale, rot, translation\n\n\nclass PointcloudCorrBlock:\n    def __init__(\n            self,\n            k: int,\n            groups,\n            xyz: torch.Tensor,\n            fvec: torch.Tensor,\n            corr_add_neighbor_offset: bool,\n            corr_add_neighbor_xyz: bool,\n            filter_invalid: bool = False,\n            valid: Optional[torch.Tensor] = None,\n            rerun_fmap_coloring_fn: Optional[Callable] = None,\n    ):\n        self.B, self.N, self.C = fvec.shape\n        assert xyz.shape == (self.B, self.N, 3)\n        assert fvec.shape == (self.B, self.N, self.C)\n        assert k <= self.N, \"k should be less than or equal to N\"\n        assert groups <= self.C, \"number of correlation groups should not be larger than the number of channels\"\n        assert self.C % groups == 0, \"number of channels must be divisible by the number of groups (for convenience)\"\n        assert not filter_invalid or valid is not None\n\n        self.k = k\n        self.groups = groups\n        self.xyz = xyz\n        self.fvec = fvec\n        self.corr_add_neighbor_offset = corr_add_neighbor_offset\n        self.corr_add_neighbor_xyz = corr_add_neighbor_xyz\n        self.filter_invalid = filter_invalid\n        self.valid = valid\n        self.rerun_fmap_coloring_fn = rerun_fmap_coloring_fn\n\n    def corr_sample(\n            self,\n            targets: torch.Tensor,\n            coords_world_xyz: torch.Tensor,\n            save_debug_logs=False,\n            debug_logs_path=\".\",\n            debug_logs_prefix=\"corr\",\n            save_rerun_logs=False,\n    ):\n        # Check inputs\n        _, M, _ = targets.shape\n        assert targets.shape == (self.B, M, self.C)\n        assert coords_world_xyz.shape == (self.B, M, 3)\n\n        # Find neighbors for each of the N target points\n        if not self.filter_invalid:\n            neighbor_dists, neighbor_indices = knn(self.k, self.xyz, coords_world_xyz)\n        else:\n            neighbor_dists = []\n            neighbor_indices = []\n            for xyz_i, valid_i, coords_world_xyz_i in zip(self.xyz, self.valid, coords_world_xyz):\n                xyz_i = xyz_i[valid_i]\n                neighbor_dists_i, neighbor_indices_i = knn(self.k, xyz_i[None], coords_world_xyz_i[None])\n                neighbor_dists.append(neighbor_dists_i)\n                neighbor_indices.append(neighbor_indices_i)\n            neighbor_dists = torch.cat(neighbor_dists)\n            neighbor_indices = torch.cat(neighbor_indices)\n        batch_idx = torch.arange(self.B, device=self.xyz.device)[:, None, None]\n        neighbor_xyz = self.xyz[batch_idx, neighbor_indices]\n        neighbor_fvec = self.fvec[batch_idx, neighbor_indices]\n\n        # Compute the local correlations\n        targets_grouped = targets.view(self.B, M, self.groups, -1)\n        neighbor_fvec_grouped = neighbor_fvec.view(self.B, M, self.k, self.groups, -1)\n        corrs = torch.einsum('BMGc,BMKGc->BMKG', targets_grouped, neighbor_fvec_grouped)\n        corrs = corrs / ((self.C / self.groups) ** 0.5)\n\n        output = corrs\n\n        # Append the distance/direction features to the correlation\n        neighbor_offset_in_world_xyz = neighbor_xyz - coords_world_xyz[..., None, :]\n        if self.corr_add_neighbor_offset:\n            output = torch.cat([corrs, neighbor_offset_in_world_xyz], -1)\n\n        # Append the neighbor xyz to the correlation\n        if self.corr_add_neighbor_xyz:\n            output = torch.cat([output, neighbor_xyz], -1)\n\n        if save_debug_logs:\n\n            from sklearn.decomposition import PCA\n            fvec_flat = self.fvec.reshape(-1, self.C).detach().cpu().numpy()\n            reducer = PCA(n_components=3)\n            reducer.fit(fvec_flat)\n\n            fvec_reduced = reducer.transform(fvec_flat)\n            reducer_min = fvec_reduced.min(axis=0)\n            reducer_max = fvec_reduced.max(axis=0)\n\n            def fvec_to_rgb(fvec):\n                fvec_reduced = reducer.transform(fvec)\n                fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min)\n                fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int)\n                return fvec_reduced_rgb\n\n            for b in [0, self.B - 1]:\n                # Save all points\n                xyz = self.xyz[b].detach().cpu().numpy()\n                xyz_colors = fvec_to_rgb(self.fvec[b].detach().cpu().numpy())\n                save_pointcloud_to_ply(os.path.join(debug_logs_path, f\"{time_now()}{debug_logs_prefix}_all_b{b}.ply\"),\n                                       xyz, xyz_colors)\n\n                for n in range(3):\n                    neighbors = neighbor_xyz[b, n].detach().cpu().numpy()\n                    neighbors_colors = fvec_to_rgb(neighbor_fvec[b, n].detach().cpu().numpy())\n                    save_pointcloud_to_ply(\n                        os.path.join(debug_logs_path, f\"{time_now()}{debug_logs_prefix}_neighbors_b{b}_n{n}.ply\"),\n                        neighbors, neighbors_colors)\n\n                for n in range(3):\n                    neighbors = neighbor_xyz[b, n].detach().cpu().numpy()\n                    neighbors_colors = fvec_to_rgb(neighbor_fvec[b, n].detach().cpu().numpy())\n                    query_point = coords_world_xyz[b, n].detach().cpu().numpy()\n                    query_point_color = fvec_to_rgb(targets[b, n].detach().cpu().numpy().reshape(1, -1))\n                    combined_points = np.vstack([query_point, neighbors])\n                    combined_colors = np.vstack([query_point_color, neighbors_colors])\n                    query_point_index = 0\n                    neighbor_indices = np.arange(1, len(neighbors) + 1)\n                    edges = np.array([[query_point_index, i] for i in neighbor_indices])\n                    save_pointcloud_to_ply(os.path.join(debug_logs_path,\n                                                        f\"{time_now()}{debug_logs_prefix}_query_b{b}_n{n}_with_edges.ply\"),\n                                           combined_points, combined_colors, edges=edges)\n\n        # Visualize the results with rerun.io\n        if save_rerun_logs:\n            import rerun as rr\n            import re\n\n            assert self.C > 1\n            rerun_fps = 30\n            log_feature_maps = True\n            log_knn_neighbors = False\n            knn_line_coloring = \"static\"\n            knn_neighbors_to_log = 6\n\n            logging.info(f\"rerun for {debug_logs_prefix} started\")\n\n            ## Mask out target scene area\n            # xyz = self.xyz.detach().cpu().numpy()\n            # bbox = np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]]) # Softball bbox\n            # mask = (\n            #         (xyz[..., 0] > bbox[0, 0])\n            #         & (xyz[..., 0] < bbox[0, 1])\n            #         & (xyz[..., 1] > bbox[1, 0])\n            #         & (xyz[..., 1] < bbox[1, 1])\n            #         & (xyz[..., 2] > bbox[2, 0])\n            #         & (xyz[..., 2] < bbox[2, 1])\n            # )\n            xyz = self.xyz.detach().cpu().numpy()\n            mask = np.ones_like(xyz[..., 0]).astype(bool)\n            if self.valid is not None:\n                mask = self.valid.detach().cpu().numpy()\n\n            # PCA-based feature coloring\n            if self.rerun_fmap_coloring_fn is None:\n                fvec_flat = self.fvec.detach().cpu().numpy()[mask]\n                from sklearn.decomposition import PCA\n                reducer = PCA(n_components=3)\n                reducer.fit(fvec_flat)\n                fvec_reduced = reducer.transform(fvec_flat)\n                reducer_min = fvec_reduced.min(axis=0)\n                reducer_max = fvec_reduced.max(axis=0)\n\n                def fvec_to_rgb(fvec):\n                    input_shape = fvec.shape\n                    assert input_shape[-1] == self.C\n                    fvec_reduced = reducer.transform(fvec.reshape(-1, self.C))\n                    fvec_reduced = np.clip(fvec_reduced, reducer_min[None, :], reducer_max[None, :])\n                    fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min)\n                    fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int)\n                    fvec_reduced_rgb = fvec_reduced_rgb.reshape(input_shape[:-1] + (3,))\n                    return fvec_reduced_rgb\n\n                self.rerun_fmap_coloring_fn = fvec_to_rgb\n\n            fvec_colors = self.rerun_fmap_coloring_fn(self.fvec.detach().cpu().numpy())\n            targets_colors = self.rerun_fmap_coloring_fn(targets.detach().cpu().numpy())\n            neighbor_fvec_colors = self.rerun_fmap_coloring_fn(neighbor_fvec.detach().cpu().numpy())\n\n            import re\n            pattern = r'__widx-(\\d+)_pidx-(\\d+)-(\\d+)__iter_(\\d+)__pyramid_level_(\\d+)'\n            match = re.search(pattern, debug_logs_prefix)\n            assert match\n            t_start = int(match.group(1))\n            pidx_start = int(match.group(2))\n            pidx_end = int(match.group(3))\n            iteration = int(match.group(4))\n            pyramid_level = int(match.group(5))\n\n            # # Log fmaps as images for the pipeline figure\n            # import os\n            # from PIL import Image\n            # png_outdir = os.path.join(debug_logs_path, \"feature_maps_pngs_2\")\n            # os.makedirs(png_outdir, exist_ok=True)\n            # if pyramid_level == 0 and iteration == 0:\n            #     for b in range(self.B):\n            #         t = t_start + b\n            #         for v in range(8):\n            #             fvec_rgb_uint8 = fvec_colors[b].reshape(8, 96, 128, 3)[v].astype(np.uint8)\n            #             fname = f\"fmap__view{v:02d}__frame{t:05d}.png\"\n            #             fpath = os.path.join(png_outdir, fname)\n            #             Image.fromarray(fvec_rgb_uint8).save(fpath)\n\n            # Log feature map points\n            # if log_feature_maps and pyramid_level in [0, 1, 2, 3] and iteration == 0:\n            if log_feature_maps and pyramid_level in [0] and iteration == 0:\n                if t_start > 0:\n                    bs = range(self.B)\n                else:\n                    bs = range(self.B // 2, self.B)\n                for b in bs:\n                    rr.set_time_seconds(\"frame\", (t_start + b) / rerun_fps)\n                    rr.log(f\"fmaps/pyramid-{pyramid_level}\", rr.Points3D(\n                        xyz[b][mask[b]],\n                        colors=fvec_colors[b][mask[b]],\n                        radii=0.042,\n                        # radii=-2.53,\n                    ))\n\n            # Log neighbors\n            if log_knn_neighbors and pyramid_level in [0, 1, 2, 3] and iteration in [0]:\n                for b in range(self.B):\n                    rr.set_time_seconds(\"frame\", (t_start + b) / rerun_fps)\n                    for n in range(min(neighbor_xyz.shape[1], knn_neighbors_to_log)):  # Iterate over queries\n                        prefix = f\"knn/track-{n:03d}/iter-{iteration}/pyramid-{pyramid_level}\"\n                        rr.log(f\"{prefix}/queries\", rr.Points3D(\n                            coords_world_xyz[b, n].cpu().numpy(),\n                            colors=targets_colors[b, n],\n                            radii=0.072,\n                            # radii=-9.0,\n                        ))\n\n                        rr.log(f\"{prefix}/neighbors\", rr.Points3D(\n                            neighbor_xyz[b, n].cpu().numpy(),\n                            colors=neighbor_fvec_colors[b, n],\n                            radii=0.054,\n                            # radii=-5.0,\n                        ))\n\n                        if knn_line_coloring == \"correlation\":\n                            # Compute correlation strength for line coloring\n                            corr_strength = corrs[b, n,].squeeze(-1).cpu().numpy()\n                            corr_strength_normalized = (corr_strength / corr_strength.max()) * 1.0 + 0.0\n                            line_colors = (corr_strength_normalized[:, None] * np.array([9, 208, 239])).astype(int)\n                            line_colors = np.hstack([line_colors, np.full((line_colors.shape[0], 1), 204)])  # RGBA 80%\n\n                        elif knn_line_coloring == \"static\":\n                            # Make the lines sun flower yellow (241, 196, 15)\n                            line_colors = np.array([241, 196, 15])[None].repeat(self.k, 0).astype(int)\n\n                        # Draw edges between query and its neighbors\n                        strips = np.stack([\n                            coords_world_xyz[b, n].cpu().numpy()[None].repeat(neighbor_xyz.shape[2], axis=0),\n                            neighbor_xyz[b, n].cpu().numpy(),\n                        ], axis=-2)\n                        rr.log(f\"{prefix}/arrows\", rr.Arrows3D(\n                            origins=strips[:, 0],\n                            vectors=strips[:, 1] - strips[:, 0],\n                            colors=line_colors,\n                            radii=0.016,\n                            # radii=-1.2,\n                        ))\n            logging.info(f\"rerun for {debug_logs_prefix} done\")\n        return output\n"
  },
  {
    "path": "mvtracker/models/core/ptv3/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/ptv3/model.py",
    "content": "\"\"\"\r\nPoint Transformer - V3 Mode1\r\nPointcept detached version\r\n\r\nAuthor: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)\r\nPlease cite our work if the code is helpful to you.\r\n\"\"\"\r\n\r\nimport sys\r\nfrom collections import OrderedDict\r\nfrom functools import partial\r\n\r\nimport math\r\nimport spconv.pytorch as spconv\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch_scatter\r\nfrom addict import Dict\r\nfrom timm.models.layers import DropPath\r\n\r\ntry:\r\n    import flash_attn\r\nexcept ImportError:\r\n    flash_attn = None\r\n\r\nfrom .serialization import encode\r\n\r\n\r\n@torch.inference_mode()\r\ndef offset2bincount(offset):\r\n    return torch.diff(\r\n        offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)\r\n    )\r\n\r\n\r\n@torch.inference_mode()\r\ndef offset2batch(offset):\r\n    bincount = offset2bincount(offset)\r\n    return torch.arange(\r\n        len(bincount), device=offset.device, dtype=torch.long\r\n    ).repeat_interleave(bincount)\r\n\r\n\r\n@torch.inference_mode()\r\ndef batch2offset(batch):\r\n    return torch.cumsum(batch.bincount(), dim=0).long()\r\n\r\n\r\nclass Point(Dict):\r\n    \"\"\"\r\n    Point Structure of Pointcept\r\n\r\n    A Point (point cloud) in Pointcept is a dictionary that contains various properties of\r\n    a batched point cloud. The property with the following names have a specific definition\r\n    as follows:\r\n\r\n    - \"coord\": original coordinate of point cloud;\r\n    - \"grid_coord\": grid coordinate for specific grid size (related to GridSampling);\r\n    Point also support the following optional attributes:\r\n    - \"offset\": if not exist, initialized as batch size is 1;\r\n    - \"batch\": if not exist, initialized as batch size is 1;\r\n    - \"feat\": feature of point cloud, default input of model;\r\n    - \"grid_size\": Grid size of point cloud (related to GridSampling);\r\n    (related to Serialization)\r\n    - \"serialized_depth\": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;\r\n    - \"serialized_code\": a list of serialization codes;\r\n    - \"serialized_order\": a list of serialization order determined by code;\r\n    - \"serialized_inverse\": a list of inverse mapping determined by code;\r\n    (related to Sparsify: SpConv)\r\n    - \"sparse_shape\": Sparse shape for Sparse Conv Tensor;\r\n    - \"sparse_conv_feat\": SparseConvTensor init with information provide by Point;\r\n    \"\"\"\r\n\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n        # If one of \"offset\" or \"batch\" do not exist, generate by the existing one\r\n        if \"batch\" not in self.keys() and \"offset\" in self.keys():\r\n            self[\"batch\"] = offset2batch(self.offset)\r\n        elif \"offset\" not in self.keys() and \"batch\" in self.keys():\r\n            self[\"offset\"] = batch2offset(self.batch)\r\n\r\n    def serialization(self, order=\"z\", depth=None, shuffle_orders=False):\r\n        \"\"\"\r\n        Point Cloud Serialization\r\n\r\n        relay on [\"grid_coord\" or \"coord\" + \"grid_size\", \"batch\", \"feat\"]\r\n        \"\"\"\r\n        assert \"batch\" in self.keys()\r\n        if \"grid_coord\" not in self.keys():\r\n            # if you don't want to operate GridSampling in data augmentation,\r\n            # please add the following augmentation into your pipline:\r\n            # dict(type=\"Copy\", keys_dict={\"grid_size\": 0.01}),\r\n            # (adjust `grid_size` to what your want)\r\n            assert {\"grid_size\", \"coord\"}.issubset(self.keys())\r\n            self[\"grid_coord\"] = torch.div(\r\n                self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode=\"trunc\"\r\n            ).int()\r\n\r\n        if depth is None:\r\n            # Adaptive measure the depth of serialization cube (length = 2 ^ depth)\r\n            depth = int(self.grid_coord.max()).bit_length()\r\n        self[\"serialized_depth\"] = depth\r\n        # Maximum bit length for serialization code is 63 (int64)\r\n        assert depth * 3 + len(self.offset).bit_length() <= 63\r\n        # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.\r\n        # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3\r\n        # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.\r\n        # We can unlock the limitation by optimizing the z-order encoding function if necessary.\r\n        assert depth <= 16\r\n\r\n        # The serialization codes are arranged as following structures:\r\n        # [Order1 ([n]),\r\n        #  Order2 ([n]),\r\n        #   ...\r\n        #  OrderN ([n])] (k, n)\r\n        code = [\r\n            encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order\r\n        ]\r\n        code = torch.stack(code)\r\n        order = torch.argsort(code)\r\n        inverse = torch.zeros_like(order).scatter_(\r\n            dim=1,\r\n            index=order,\r\n            src=torch.arange(0, code.shape[1], device=order.device).repeat(\r\n                code.shape[0], 1\r\n            ),\r\n        )\r\n\r\n        if shuffle_orders:\r\n            perm = torch.randperm(code.shape[0])\r\n            code = code[perm]\r\n            order = order[perm]\r\n            inverse = inverse[perm]\r\n\r\n        self[\"serialized_code\"] = code\r\n        self[\"serialized_order\"] = order\r\n        self[\"serialized_inverse\"] = inverse\r\n\r\n    def sparsify(self, pad=96):\r\n        \"\"\"\r\n        Point Cloud Sparsification\r\n\r\n        Point cloud is sparse, here we use \"sparsify\" to specifically refer to\r\n        preparing \"spconv.SparseConvTensor\" for SpConv.\r\n\r\n        relay on [\"grid_coord\" or \"coord\" + \"grid_size\", \"batch\", \"feat\"]\r\n\r\n        pad: padding sparse for sparse shape.\r\n        \"\"\"\r\n        assert {\"feat\", \"batch\"}.issubset(self.keys())\r\n        if \"grid_coord\" not in self.keys():\r\n            # if you don't want to operate GridSampling in data augmentation,\r\n            # please add the following augmentation into your pipline:\r\n            # dict(type=\"Copy\", keys_dict={\"grid_size\": 0.01}),\r\n            # (adjust `grid_size` to what your want)\r\n            assert {\"grid_size\", \"coord\"}.issubset(self.keys())\r\n            self[\"grid_coord\"] = torch.div(\r\n                self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode=\"trunc\"\r\n            ).int()\r\n        if \"sparse_shape\" in self.keys():\r\n            sparse_shape = self.sparse_shape\r\n        else:\r\n            sparse_shape = torch.add(\r\n                torch.max(self.grid_coord, dim=0).values, pad\r\n            ).tolist()\r\n        sparse_conv_feat = spconv.SparseConvTensor(\r\n            features=self.feat,\r\n            indices=torch.cat(\r\n                [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1\r\n            ).contiguous(),\r\n            spatial_shape=sparse_shape,\r\n            batch_size=self.batch[-1].tolist() + 1,\r\n        )\r\n        self[\"sparse_shape\"] = sparse_shape\r\n        self[\"sparse_conv_feat\"] = sparse_conv_feat\r\n\r\n\r\nclass PointModule(nn.Module):\r\n    r\"\"\"PointModule\r\n    placeholder, all module subclass from this will take Point in PointSequential.\r\n    \"\"\"\r\n\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n\r\n\r\nclass PointSequential(PointModule):\r\n    r\"\"\"A sequential container.\r\n    Modules will be added to it in the order they are passed in the constructor.\r\n    Alternatively, an ordered dict of modules can also be passed in.\r\n    \"\"\"\r\n\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__()\r\n        if len(args) == 1 and isinstance(args[0], OrderedDict):\r\n            for key, module in args[0].items():\r\n                self.add_module(key, module)\r\n        else:\r\n            for idx, module in enumerate(args):\r\n                self.add_module(str(idx), module)\r\n        for name, module in kwargs.items():\r\n            if sys.version_info < (3, 6):\r\n                raise ValueError(\"kwargs only supported in py36+\")\r\n            if name in self._modules:\r\n                raise ValueError(\"name exists.\")\r\n            self.add_module(name, module)\r\n\r\n    def __getitem__(self, idx):\r\n        if not (-len(self) <= idx < len(self)):\r\n            raise IndexError(\"index {} is out of range\".format(idx))\r\n        if idx < 0:\r\n            idx += len(self)\r\n        it = iter(self._modules.values())\r\n        for i in range(idx):\r\n            next(it)\r\n        return next(it)\r\n\r\n    def __len__(self):\r\n        return len(self._modules)\r\n\r\n    def add(self, module, name=None):\r\n        if name is None:\r\n            name = str(len(self._modules))\r\n            if name in self._modules:\r\n                raise KeyError(\"name exists\")\r\n        self.add_module(name, module)\r\n\r\n    def forward(self, input):\r\n        for k, module in self._modules.items():\r\n            # Point module\r\n            if isinstance(module, PointModule):\r\n                input = module(input)\r\n            # Spconv module\r\n            elif spconv.modules.is_spconv_module(module):\r\n                if isinstance(input, Point):\r\n                    input.sparse_conv_feat = module(input.sparse_conv_feat)\r\n                    input.feat = input.sparse_conv_feat.features\r\n                else:\r\n                    input = module(input)\r\n            # PyTorch module\r\n            else:\r\n                if isinstance(input, Point):\r\n                    input.feat = module(input.feat)\r\n                    if \"sparse_conv_feat\" in input.keys():\r\n                        input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(\r\n                            input.feat\r\n                        )\r\n                elif isinstance(input, spconv.SparseConvTensor):\r\n                    if input.indices.shape[0] != 0:\r\n                        input = input.replace_feature(module(input.features))\r\n                else:\r\n                    input = module(input)\r\n        return input\r\n\r\n\r\nclass PDNorm(PointModule):\r\n    def __init__(\r\n            self,\r\n            num_features,\r\n            norm_layer,\r\n            context_channels=256,\r\n            conditions=(\"ScanNet\", \"S3DIS\", \"Structured3D\"),\r\n            decouple=True,\r\n            adaptive=False,\r\n    ):\r\n        super().__init__()\r\n        self.conditions = conditions\r\n        self.decouple = decouple\r\n        self.adaptive = adaptive\r\n        if self.decouple:\r\n            self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions])\r\n        else:\r\n            self.norm = norm_layer\r\n        if self.adaptive:\r\n            self.modulation = nn.Sequential(\r\n                nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True)\r\n            )\r\n\r\n    def forward(self, point):\r\n        assert {\"feat\", \"condition\"}.issubset(point.keys())\r\n        if isinstance(point.condition, str):\r\n            condition = point.condition\r\n        else:\r\n            condition = point.condition[0]\r\n        if self.decouple:\r\n            assert condition in self.conditions\r\n            norm = self.norm[self.conditions.index(condition)]\r\n        else:\r\n            norm = self.norm\r\n        point.feat = norm(point.feat)\r\n        if self.adaptive:\r\n            assert \"context\" in point.keys()\r\n            shift, scale = self.modulation(point.context).chunk(2, dim=1)\r\n            point.feat = point.feat * (1.0 + scale) + shift\r\n        return point\r\n\r\n\r\nclass RPE(torch.nn.Module):\r\n    def __init__(self, patch_size, num_heads):\r\n        super().__init__()\r\n        self.patch_size = patch_size\r\n        self.num_heads = num_heads\r\n        self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)\r\n        self.rpe_num = 2 * self.pos_bnd + 1\r\n        self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))\r\n        torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)\r\n\r\n    def forward(self, coord):\r\n        idx = (\r\n                coord.clamp(-self.pos_bnd, self.pos_bnd)  # clamp into bnd\r\n                + self.pos_bnd  # relative position to positive index\r\n                + torch.arange(3, device=coord.device) * self.rpe_num  # x, y, z stride\r\n        )\r\n        out = self.rpe_table.index_select(0, idx.reshape(-1))\r\n        out = out.view(idx.shape + (-1,)).sum(3)\r\n        out = out.permute(0, 3, 1, 2)  # (N, K, K, H) -> (N, H, K, K)\r\n        return out\r\n\r\n\r\nclass SerializedAttention(PointModule):\r\n    def __init__(\r\n            self,\r\n            channels,\r\n            num_heads,\r\n            patch_size,\r\n            qkv_bias=True,\r\n            qk_scale=None,\r\n            attn_drop=0.0,\r\n            proj_drop=0.0,\r\n            order_index=0,\r\n            enable_rpe=False,\r\n            enable_flash=True,\r\n            upcast_attention=True,\r\n            upcast_softmax=True,\r\n    ):\r\n        super().__init__()\r\n        assert channels % num_heads == 0\r\n        self.channels = channels\r\n        self.num_heads = num_heads\r\n        self.scale = qk_scale or (channels // num_heads) ** -0.5\r\n        self.order_index = order_index\r\n        self.upcast_attention = upcast_attention\r\n        self.upcast_softmax = upcast_softmax\r\n        self.enable_rpe = enable_rpe\r\n        self.enable_flash = enable_flash\r\n        if enable_flash:\r\n            assert (\r\n                    enable_rpe is False\r\n            ), \"Set enable_rpe to False when enable Flash Attention\"\r\n            assert (\r\n                    upcast_attention is False\r\n            ), \"Set upcast_attention to False when enable Flash Attention\"\r\n            assert (\r\n                    upcast_softmax is False\r\n            ), \"Set upcast_softmax to False when enable Flash Attention\"\r\n            assert flash_attn is not None, \"Make sure flash_attn is installed.\"\r\n            self.patch_size = patch_size\r\n            self.attn_drop = attn_drop\r\n        else:\r\n            # when disable flash attention, we still don't want to use mask\r\n            # consequently, patch size will auto set to the\r\n            # min number of patch_size_max and number of points\r\n            self.patch_size_max = patch_size\r\n            self.patch_size = 0\r\n            self.attn_drop = torch.nn.Dropout(attn_drop)\r\n\r\n        self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)\r\n        self.proj = torch.nn.Linear(channels, channels)\r\n        self.proj_drop = torch.nn.Dropout(proj_drop)\r\n        self.softmax = torch.nn.Softmax(dim=-1)\r\n        self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None\r\n\r\n    @torch.no_grad()\r\n    def get_rel_pos(self, point, order):\r\n        K = self.patch_size\r\n        rel_pos_key = f\"rel_pos_{self.order_index}\"\r\n        if rel_pos_key not in point.keys():\r\n            grid_coord = point.grid_coord[order]\r\n            grid_coord = grid_coord.reshape(-1, K, 3)\r\n            point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)\r\n        return point[rel_pos_key]\r\n\r\n    @torch.no_grad()\r\n    def get_padding_and_inverse(self, point):\r\n        pad_key = \"pad\"\r\n        unpad_key = \"unpad\"\r\n        cu_seqlens_key = \"cu_seqlens_key\"\r\n        if (\r\n                pad_key not in point.keys()\r\n                or unpad_key not in point.keys()\r\n                or cu_seqlens_key not in point.keys()\r\n        ):\r\n            offset = point.offset\r\n            bincount = offset2bincount(offset)\r\n            bincount_pad = (\r\n                    torch.div(\r\n                        bincount + self.patch_size - 1,\r\n                        self.patch_size,\r\n                        rounding_mode=\"trunc\",\r\n                    )\r\n                    * self.patch_size\r\n            )\r\n            # only pad point when num of points larger than patch_size\r\n            mask_pad = bincount > self.patch_size\r\n            bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad\r\n            _offset = nn.functional.pad(offset, (1, 0))\r\n            _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))\r\n            pad = torch.arange(_offset_pad[-1], device=offset.device)\r\n            unpad = torch.arange(_offset[-1], device=offset.device)\r\n            cu_seqlens = []\r\n            for i in range(len(offset)):\r\n                unpad[_offset[i]: _offset[i + 1]] += _offset_pad[i] - _offset[i]\r\n                if bincount[i] != bincount_pad[i]:\r\n                    pad[\r\n                    _offset_pad[i + 1]\r\n                    - self.patch_size\r\n                    + (bincount[i] % self.patch_size): _offset_pad[i + 1]\r\n                    ] = pad[\r\n                        _offset_pad[i + 1]\r\n                        - 2 * self.patch_size\r\n                        + (bincount[i] % self.patch_size): _offset_pad[i + 1]\r\n                                                           - self.patch_size\r\n                        ]\r\n                pad[_offset_pad[i]: _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]\r\n                cu_seqlens.append(\r\n                    torch.arange(\r\n                        _offset_pad[i],\r\n                        _offset_pad[i + 1],\r\n                        step=self.patch_size,\r\n                        dtype=torch.int32,\r\n                        device=offset.device,\r\n                    )\r\n                )\r\n            point[pad_key] = pad\r\n            point[unpad_key] = unpad\r\n            point[cu_seqlens_key] = nn.functional.pad(\r\n                torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]\r\n            )\r\n        return point[pad_key], point[unpad_key], point[cu_seqlens_key]\r\n\r\n    def forward(self, point):\r\n        if not self.enable_flash:\r\n            self.patch_size = min(\r\n                offset2bincount(point.offset).min().tolist(), self.patch_size_max\r\n            )\r\n\r\n        H = self.num_heads\r\n        K = self.patch_size\r\n        C = self.channels\r\n\r\n        pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)\r\n\r\n        order = point.serialized_order[self.order_index][pad]\r\n        inverse = unpad[point.serialized_inverse[self.order_index]]\r\n\r\n        # padding and reshape feat and batch for serialized point patch\r\n        qkv = self.qkv(point.feat)[order]\r\n\r\n        if not self.enable_flash:\r\n            # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')\r\n            q, k, v = (\r\n                qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)\r\n            )\r\n            # attn\r\n            if self.upcast_attention:\r\n                q = q.float()\r\n                k = k.float()\r\n            attn = (q * self.scale) @ k.transpose(-2, -1)  # (N', H, K, K)\r\n            if self.enable_rpe:\r\n                attn = attn + self.rpe(self.get_rel_pos(point, order))\r\n            if self.upcast_softmax:\r\n                attn = attn.float()\r\n            attn = self.softmax(attn)\r\n            attn = self.attn_drop(attn).to(qkv.dtype)\r\n            feat = (attn @ v).transpose(1, 2).reshape(-1, C)\r\n        else:\r\n            feat = flash_attn.flash_attn_varlen_qkvpacked_func(\r\n                qkv.half().reshape(-1, 3, H, C // H),\r\n                cu_seqlens,\r\n                max_seqlen=self.patch_size,\r\n                dropout_p=self.attn_drop if self.training else 0,\r\n                softmax_scale=self.scale,\r\n            ).reshape(-1, C)\r\n            feat = feat.to(qkv.dtype)\r\n        feat = feat[inverse]\r\n\r\n        # ffn\r\n        feat = self.proj(feat)\r\n        feat = self.proj_drop(feat)\r\n        point.feat = feat\r\n        return point\r\n\r\n\r\nclass MLP(nn.Module):\r\n    def __init__(\r\n            self,\r\n            in_channels,\r\n            hidden_channels=None,\r\n            out_channels=None,\r\n            act_layer=nn.GELU,\r\n            drop=0.0,\r\n    ):\r\n        super().__init__()\r\n        out_channels = out_channels or in_channels\r\n        hidden_channels = hidden_channels or in_channels\r\n        self.fc1 = nn.Linear(in_channels, hidden_channels)\r\n        self.act = act_layer()\r\n        self.fc2 = nn.Linear(hidden_channels, out_channels)\r\n        self.drop = nn.Dropout(drop)\r\n\r\n    def forward(self, x):\r\n        x = self.fc1(x)\r\n        x = self.act(x)\r\n        x = self.drop(x)\r\n        x = self.fc2(x)\r\n        x = self.drop(x)\r\n        return x\r\n\r\n\r\nclass Block(PointModule):\r\n    def __init__(\r\n            self,\r\n            channels,\r\n            num_heads,\r\n            patch_size=48,\r\n            mlp_ratio=4.0,\r\n            qkv_bias=True,\r\n            qk_scale=None,\r\n            attn_drop=0.0,\r\n            proj_drop=0.0,\r\n            drop_path=0.0,\r\n            norm_layer=nn.LayerNorm,\r\n            act_layer=nn.GELU,\r\n            pre_norm=True,\r\n            order_index=0,\r\n            cpe_indice_key=None,\r\n            enable_rpe=False,\r\n            enable_flash=True,\r\n            upcast_attention=True,\r\n            upcast_softmax=True,\r\n    ):\r\n        super().__init__()\r\n        self.channels = channels\r\n        self.pre_norm = pre_norm\r\n\r\n        self.cpe = PointSequential(\r\n            spconv.SubMConv3d(\r\n                channels,\r\n                channels,\r\n                kernel_size=3,\r\n                bias=True,\r\n                indice_key=cpe_indice_key,\r\n            ),\r\n            nn.Linear(channels, channels),\r\n            norm_layer(channels),\r\n        )\r\n\r\n        self.norm1 = PointSequential(norm_layer(channels))\r\n        self.attn = SerializedAttention(\r\n            channels=channels,\r\n            patch_size=patch_size,\r\n            num_heads=num_heads,\r\n            qkv_bias=qkv_bias,\r\n            qk_scale=qk_scale,\r\n            attn_drop=attn_drop,\r\n            proj_drop=proj_drop,\r\n            order_index=order_index,\r\n            enable_rpe=enable_rpe,\r\n            enable_flash=enable_flash,\r\n            upcast_attention=upcast_attention,\r\n            upcast_softmax=upcast_softmax,\r\n        )\r\n        self.norm2 = PointSequential(norm_layer(channels))\r\n        self.mlp = PointSequential(\r\n            MLP(\r\n                in_channels=channels,\r\n                hidden_channels=int(channels * mlp_ratio),\r\n                out_channels=channels,\r\n                act_layer=act_layer,\r\n                drop=proj_drop,\r\n            )\r\n        )\r\n        self.drop_path = PointSequential(\r\n            DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\r\n        )\r\n\r\n    def forward(self, point: Point):\r\n        shortcut = point.feat\r\n        point = self.cpe(point)\r\n        point.feat = shortcut + point.feat\r\n        shortcut = point.feat\r\n        if self.pre_norm:\r\n            point = self.norm1(point)\r\n        point = self.drop_path(self.attn(point))\r\n        point.feat = shortcut + point.feat\r\n        if not self.pre_norm:\r\n            point = self.norm1(point)\r\n\r\n        shortcut = point.feat\r\n        if self.pre_norm:\r\n            point = self.norm2(point)\r\n        point = self.drop_path(self.mlp(point))\r\n        point.feat = shortcut + point.feat\r\n        if not self.pre_norm:\r\n            point = self.norm2(point)\r\n        point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)\r\n        return point\r\n\r\n\r\nclass SerializedPooling(PointModule):\r\n    def __init__(\r\n            self,\r\n            in_channels,\r\n            out_channels,\r\n            stride=2,\r\n            norm_layer=None,\r\n            act_layer=None,\r\n            reduce=\"max\",\r\n            shuffle_orders=True,\r\n            traceable=True,  # record parent and cluster\r\n    ):\r\n        super().__init__()\r\n        self.in_channels = in_channels\r\n        self.out_channels = out_channels\r\n\r\n        assert stride == 2 ** (math.ceil(stride) - 1).bit_length()  # 2, 4, 8\r\n        # TODO: add support to grid pool (any stride)\r\n        self.stride = stride\r\n        assert reduce in [\"sum\", \"mean\", \"min\", \"max\"]\r\n        self.reduce = reduce\r\n        self.shuffle_orders = shuffle_orders\r\n        self.traceable = traceable\r\n\r\n        self.proj = nn.Linear(in_channels, out_channels)\r\n        if norm_layer is not None:\r\n            self.norm = PointSequential(norm_layer(out_channels))\r\n        if act_layer is not None:\r\n            self.act = PointSequential(act_layer())\r\n\r\n    def forward(self, point: Point):\r\n        pooling_depth = (math.ceil(self.stride) - 1).bit_length()\r\n        if pooling_depth > point.serialized_depth:\r\n            pooling_depth = 0\r\n        assert {\r\n            \"serialized_code\",\r\n            \"serialized_order\",\r\n            \"serialized_inverse\",\r\n            \"serialized_depth\",\r\n        }.issubset(\r\n            point.keys()\r\n        ), \"Run point.serialization() point cloud before SerializedPooling\"\r\n\r\n        code = point.serialized_code >> pooling_depth * 3\r\n        code_, cluster, counts = torch.unique(\r\n            code[0],\r\n            sorted=True,\r\n            return_inverse=True,\r\n            return_counts=True,\r\n        )\r\n        # indices of point sorted by cluster, for torch_scatter.segment_csr\r\n        _, indices = torch.sort(cluster)\r\n        # index pointer for sorted point, for torch_scatter.segment_csr\r\n        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])\r\n        # head_indices of each cluster, for reduce attr e.g. code, batch\r\n        head_indices = indices[idx_ptr[:-1]]\r\n        # generate down code, order, inverse\r\n        code = code[:, head_indices]\r\n        order = torch.argsort(code)\r\n        inverse = torch.zeros_like(order).scatter_(\r\n            dim=1,\r\n            index=order,\r\n            src=torch.arange(0, code.shape[1], device=order.device).repeat(\r\n                code.shape[0], 1\r\n            ),\r\n        )\r\n\r\n        if self.shuffle_orders:\r\n            perm = torch.randperm(code.shape[0])\r\n            code = code[perm]\r\n            order = order[perm]\r\n            inverse = inverse[perm]\r\n\r\n        # collect information\r\n        point_dict = Dict(\r\n            feat=torch_scatter.segment_csr(\r\n                self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce\r\n            ),\r\n            coord=torch_scatter.segment_csr(\r\n                point.coord[indices], idx_ptr, reduce=\"mean\"\r\n            ),\r\n            grid_coord=point.grid_coord[head_indices] >> pooling_depth,\r\n            serialized_code=code,\r\n            serialized_order=order,\r\n            serialized_inverse=inverse,\r\n            serialized_depth=point.serialized_depth - pooling_depth,\r\n            batch=point.batch[head_indices],\r\n        )\r\n\r\n        if \"condition\" in point.keys():\r\n            point_dict[\"condition\"] = point.condition\r\n        if \"context\" in point.keys():\r\n            point_dict[\"context\"] = point.context\r\n\r\n        if self.traceable:\r\n            point_dict[\"pooling_inverse\"] = cluster\r\n            point_dict[\"pooling_parent\"] = point\r\n        point = Point(point_dict)\r\n        if self.norm is not None:\r\n            point = self.norm(point)\r\n        if self.act is not None:\r\n            point = self.act(point)\r\n        point.sparsify()\r\n        return point\r\n\r\n\r\nclass SerializedUnpooling(PointModule):\r\n    def __init__(\r\n            self,\r\n            in_channels,\r\n            skip_channels,\r\n            out_channels,\r\n            norm_layer=None,\r\n            act_layer=None,\r\n            traceable=False,  # record parent and cluster\r\n    ):\r\n        super().__init__()\r\n        self.proj = PointSequential(nn.Linear(in_channels, out_channels))\r\n        self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))\r\n\r\n        if norm_layer is not None:\r\n            self.proj.add(norm_layer(out_channels))\r\n            self.proj_skip.add(norm_layer(out_channels))\r\n\r\n        if act_layer is not None:\r\n            self.proj.add(act_layer())\r\n            self.proj_skip.add(act_layer())\r\n\r\n        self.traceable = traceable\r\n\r\n    def forward(self, point):\r\n        assert \"pooling_parent\" in point.keys()\r\n        assert \"pooling_inverse\" in point.keys()\r\n        parent = point.pop(\"pooling_parent\")\r\n        inverse = point.pop(\"pooling_inverse\")\r\n        point = self.proj(point)\r\n        parent = self.proj_skip(parent)\r\n        parent.feat = parent.feat + point.feat[inverse]\r\n\r\n        if self.traceable:\r\n            parent[\"unpooling_parent\"] = point\r\n        return parent\r\n\r\n\r\nclass Embedding(PointModule):\r\n    def __init__(\r\n            self,\r\n            in_channels,\r\n            embed_channels,\r\n            norm_layer=None,\r\n            act_layer=None,\r\n    ):\r\n        super().__init__()\r\n        self.in_channels = in_channels\r\n        self.embed_channels = embed_channels\r\n\r\n        # TODO: check remove spconv\r\n        self.stem = PointSequential(\r\n            conv=spconv.SubMConv3d(\r\n                in_channels,\r\n                embed_channels,\r\n                kernel_size=5,\r\n                padding=1,\r\n                bias=False,\r\n                indice_key=\"stem\",\r\n            )\r\n        )\r\n        if norm_layer is not None:\r\n            self.stem.add(norm_layer(embed_channels), name=\"norm\")\r\n        if act_layer is not None:\r\n            self.stem.add(act_layer(), name=\"act\")\r\n\r\n    def forward(self, point: Point):\r\n        point = self.stem(point)\r\n        return point\r\n\r\n\r\nclass PointTransformerV3(PointModule):\r\n    def __init__(\r\n            self,\r\n            in_channels=6,\r\n            order=(\"z\", \"z-trans\", \"hilbert\", \"hilbert-trans\"),\r\n            stride=(2, 2, 2, 2),\r\n            enc_depths=(2, 2, 2, 6, 2),\r\n            enc_channels=(32, 64, 128, 256, 512),\r\n            enc_num_head=(2, 4, 8, 16, 32),\r\n            enc_patch_size=(1024, 1024, 1024, 1024, 1024),\r\n            dec_depths=(2, 2, 2, 2),\r\n            dec_channels=(64, 64, 128, 256),\r\n            dec_num_head=(4, 4, 8, 16),\r\n            dec_patch_size=(1024, 1024, 1024, 1024),\r\n            mlp_ratio=4,\r\n            qkv_bias=True,\r\n            qk_scale=None,\r\n            attn_drop=0.0,\r\n            proj_drop=0.0,\r\n            drop_path=0.3,\r\n            pre_norm=True,\r\n            shuffle_orders=True,\r\n            enable_rpe=False,\r\n            enable_flash=True,\r\n            upcast_attention=False,\r\n            upcast_softmax=False,\r\n            cls_mode=False,\r\n            pdnorm_bn=False,\r\n            pdnorm_ln=False,\r\n            pdnorm_decouple=True,\r\n            pdnorm_adaptive=False,\r\n            pdnorm_affine=True,\r\n            pdnorm_conditions=(\"ScanNet\", \"S3DIS\", \"Structured3D\"),\r\n    ):\r\n        super().__init__()\r\n        self.num_stages = len(enc_depths)\r\n        self.order = [order] if isinstance(order, str) else order\r\n        self.cls_mode = cls_mode\r\n        self.shuffle_orders = shuffle_orders\r\n\r\n        assert self.num_stages == len(stride) + 1\r\n        assert self.num_stages == len(enc_depths)\r\n        assert self.num_stages == len(enc_channels)\r\n        assert self.num_stages == len(enc_num_head)\r\n        assert self.num_stages == len(enc_patch_size)\r\n        assert self.cls_mode or self.num_stages == len(dec_depths) + 1\r\n        assert self.cls_mode or self.num_stages == len(dec_channels) + 1\r\n        assert self.cls_mode or self.num_stages == len(dec_num_head) + 1\r\n        assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1\r\n\r\n        # norm layers\r\n        if pdnorm_bn:\r\n            bn_layer = partial(\r\n                PDNorm,\r\n                norm_layer=partial(\r\n                    nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine\r\n                ),\r\n                conditions=pdnorm_conditions,\r\n                decouple=pdnorm_decouple,\r\n                adaptive=pdnorm_adaptive,\r\n            )\r\n        else:\r\n            bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)\r\n        if pdnorm_ln:\r\n            ln_layer = partial(\r\n                PDNorm,\r\n                norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine),\r\n                conditions=pdnorm_conditions,\r\n                decouple=pdnorm_decouple,\r\n                adaptive=pdnorm_adaptive,\r\n            )\r\n        else:\r\n            ln_layer = nn.LayerNorm\r\n        # activation layers\r\n        act_layer = nn.GELU\r\n\r\n        self.embedding = Embedding(\r\n            in_channels=in_channels,\r\n            embed_channels=enc_channels[0],\r\n            norm_layer=bn_layer,\r\n            act_layer=act_layer,\r\n        )\r\n\r\n        # encoder\r\n        enc_drop_path = [\r\n            x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))\r\n        ]\r\n        self.enc = PointSequential()\r\n        for s in range(self.num_stages):\r\n            enc_drop_path_ = enc_drop_path[\r\n                             sum(enc_depths[:s]): sum(enc_depths[: s + 1])\r\n                             ]\r\n            enc = PointSequential()\r\n            if s > 0:\r\n                enc.add(\r\n                    SerializedPooling(\r\n                        in_channels=enc_channels[s - 1],\r\n                        out_channels=enc_channels[s],\r\n                        stride=stride[s - 1],\r\n                        norm_layer=bn_layer,\r\n                        act_layer=act_layer,\r\n                    ),\r\n                    name=\"down\",\r\n                )\r\n            for i in range(enc_depths[s]):\r\n                enc.add(\r\n                    Block(\r\n                        channels=enc_channels[s],\r\n                        num_heads=enc_num_head[s],\r\n                        patch_size=enc_patch_size[s],\r\n                        mlp_ratio=mlp_ratio,\r\n                        qkv_bias=qkv_bias,\r\n                        qk_scale=qk_scale,\r\n                        attn_drop=attn_drop,\r\n                        proj_drop=proj_drop,\r\n                        drop_path=enc_drop_path_[i],\r\n                        norm_layer=ln_layer,\r\n                        act_layer=act_layer,\r\n                        pre_norm=pre_norm,\r\n                        order_index=i % len(self.order),\r\n                        cpe_indice_key=f\"stage{s}\",\r\n                        enable_rpe=enable_rpe,\r\n                        enable_flash=enable_flash,\r\n                        upcast_attention=upcast_attention,\r\n                        upcast_softmax=upcast_softmax,\r\n                    ),\r\n                    name=f\"block{i}\",\r\n                )\r\n            if len(enc) != 0:\r\n                self.enc.add(module=enc, name=f\"enc{s}\")\r\n\r\n        # decoder\r\n        if not self.cls_mode:\r\n            dec_drop_path = [\r\n                x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))\r\n            ]\r\n            self.dec = PointSequential()\r\n            dec_channels = list(dec_channels) + [enc_channels[-1]]\r\n            for s in reversed(range(self.num_stages - 1)):\r\n                dec_drop_path_ = dec_drop_path[\r\n                                 sum(dec_depths[:s]): sum(dec_depths[: s + 1])\r\n                                 ]\r\n                dec_drop_path_.reverse()\r\n                dec = PointSequential()\r\n                dec.add(\r\n                    SerializedUnpooling(\r\n                        in_channels=dec_channels[s + 1],\r\n                        skip_channels=enc_channels[s],\r\n                        out_channels=dec_channels[s],\r\n                        norm_layer=bn_layer,\r\n                        act_layer=act_layer,\r\n                    ),\r\n                    name=\"up\",\r\n                )\r\n                for i in range(dec_depths[s]):\r\n                    dec.add(\r\n                        Block(\r\n                            channels=dec_channels[s],\r\n                            num_heads=dec_num_head[s],\r\n                            patch_size=dec_patch_size[s],\r\n                            mlp_ratio=mlp_ratio,\r\n                            qkv_bias=qkv_bias,\r\n                            qk_scale=qk_scale,\r\n                            attn_drop=attn_drop,\r\n                            proj_drop=proj_drop,\r\n                            drop_path=dec_drop_path_[i],\r\n                            norm_layer=ln_layer,\r\n                            act_layer=act_layer,\r\n                            pre_norm=pre_norm,\r\n                            order_index=i % len(self.order),\r\n                            cpe_indice_key=f\"stage{s}\",\r\n                            enable_rpe=enable_rpe,\r\n                            enable_flash=enable_flash,\r\n                            upcast_attention=upcast_attention,\r\n                            upcast_softmax=upcast_softmax,\r\n                        ),\r\n                        name=f\"block{i}\",\r\n                    )\r\n                self.dec.add(module=dec, name=f\"dec{s}\")\r\n\r\n    def forward(self, data_dict):\r\n        \"\"\"\r\n        A data_dict is a dictionary containing properties of a batched point cloud.\r\n        It should contain the following properties for PTv3:\r\n        1. \"feat\": feature of point cloud\r\n        2. \"grid_coord\": discrete coordinate after grid sampling (voxelization) or \"coord\" + \"grid_size\"\r\n        3. \"offset\" or \"batch\": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset\r\n        \"\"\"\r\n        point = Point(data_dict)\r\n        point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)\r\n        point.sparsify()\r\n\r\n        point = self.embedding(point)\r\n        point = self.enc(point)\r\n        if not self.cls_mode:\r\n            point = self.dec(point)\r\n        return point\r\n"
  },
  {
    "path": "mvtracker/models/core/ptv3/serialization/__init__.py",
    "content": "from .default import (\r\n    encode,\r\n    decode,\r\n    z_order_encode,\r\n    z_order_decode,\r\n    hilbert_encode,\r\n    hilbert_decode,\r\n)\r\n"
  },
  {
    "path": "mvtracker/models/core/ptv3/serialization/default.py",
    "content": "import torch\r\n\r\nfrom .hilbert import decode as hilbert_decode_\r\nfrom .hilbert import encode as hilbert_encode_\r\nfrom .z_order import key2xyz as z_order_decode_\r\nfrom .z_order import xyz2key as z_order_encode_\r\n\r\n\r\n@torch.inference_mode()\r\ndef encode(grid_coord, batch=None, depth=16, order=\"z\"):\r\n    assert order in {\"z\", \"z-trans\", \"hilbert\", \"hilbert-trans\"}\r\n    if order == \"z\":\r\n        code = z_order_encode(grid_coord, depth=depth)\r\n    elif order == \"z-trans\":\r\n        code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)\r\n    elif order == \"hilbert\":\r\n        code = hilbert_encode(grid_coord, depth=depth)\r\n    elif order == \"hilbert-trans\":\r\n        code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)\r\n    else:\r\n        raise NotImplementedError\r\n    if batch is not None:\r\n        batch = batch.long()\r\n        code = batch << depth * 3 | code\r\n    return code\r\n\r\n\r\n@torch.inference_mode()\r\ndef decode(code, depth=16, order=\"z\"):\r\n    assert order in {\"z\", \"hilbert\"}\r\n    batch = code >> depth * 3\r\n    code = code & ((1 << depth * 3) - 1)\r\n    if order == \"z\":\r\n        grid_coord = z_order_decode(code, depth=depth)\r\n    elif order == \"hilbert\":\r\n        grid_coord = hilbert_decode(code, depth=depth)\r\n    else:\r\n        raise NotImplementedError\r\n    return grid_coord, batch\r\n\r\n\r\ndef z_order_encode(grid_coord: torch.Tensor, depth: int = 16):\r\n    x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()\r\n    # we block the support to batch, maintain batched code in Point class\r\n    code = z_order_encode_(x, y, z, b=None, depth=depth)\r\n    return code\r\n\r\n\r\ndef z_order_decode(code: torch.Tensor, depth):\r\n    x, y, z = z_order_decode_(code, depth=depth)\r\n    grid_coord = torch.stack([x, y, z], dim=-1)  # (N,  3)\r\n    return grid_coord\r\n\r\n\r\ndef hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):\r\n    return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)\r\n\r\n\r\ndef hilbert_decode(code: torch.Tensor, depth: int = 16):\r\n    return hilbert_decode_(code, num_dims=3, num_bits=depth)\r\n"
  },
  {
    "path": "mvtracker/models/core/ptv3/serialization/hilbert.py",
    "content": "\"\"\"\r\nHilbert Order\r\nModified from https://github.com/PrincetonLIPS/numpy-hilbert-curve\r\n\r\nAuthor: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu\r\nPlease cite our work if the code is helpful to you.\r\n\"\"\"\r\n\r\nimport torch\r\n\r\n\r\ndef right_shift(binary, k=1, axis=-1):\r\n    \"\"\"Right shift an array of binary values.\r\n\r\n    Parameters:\r\n    -----------\r\n     binary: An ndarray of binary values.\r\n\r\n     k: The number of bits to shift. Default 1.\r\n\r\n     axis: The axis along which to shift.  Default -1.\r\n\r\n    Returns:\r\n    --------\r\n     Returns an ndarray with zero prepended and the ends truncated, along\r\n     whatever axis was specified.\"\"\"\r\n\r\n    # If we're shifting the whole thing, just return zeros.\r\n    if binary.shape[axis] <= k:\r\n        return torch.zeros_like(binary)\r\n\r\n    # Determine the padding pattern.\r\n    # padding = [(0,0)] * len(binary.shape)\r\n    # padding[axis] = (k,0)\r\n\r\n    # Determine the slicing pattern to eliminate just the last one.\r\n    slicing = [slice(None)] * len(binary.shape)\r\n    slicing[axis] = slice(None, -k)\r\n    shifted = torch.nn.functional.pad(\r\n        binary[tuple(slicing)], (k, 0), mode=\"constant\", value=0\r\n    )\r\n\r\n    return shifted\r\n\r\n\r\ndef binary2gray(binary, axis=-1):\r\n    \"\"\"Convert an array of binary values into Gray codes.\r\n\r\n    This uses the classic X ^ (X >> 1) trick to compute the Gray code.\r\n\r\n    Parameters:\r\n    -----------\r\n     binary: An ndarray of binary values.\r\n\r\n     axis: The axis along which to compute the gray code. Default=-1.\r\n\r\n    Returns:\r\n    --------\r\n     Returns an ndarray of Gray codes.\r\n    \"\"\"\r\n    shifted = right_shift(binary, axis=axis)\r\n\r\n    # Do the X ^ (X >> 1) trick.\r\n    gray = torch.logical_xor(binary, shifted)\r\n\r\n    return gray\r\n\r\n\r\ndef gray2binary(gray, axis=-1):\r\n    \"\"\"Convert an array of Gray codes back into binary values.\r\n\r\n    Parameters:\r\n    -----------\r\n     gray: An ndarray of gray codes.\r\n\r\n     axis: The axis along which to perform Gray decoding. Default=-1.\r\n\r\n    Returns:\r\n    --------\r\n     Returns an ndarray of binary values.\r\n    \"\"\"\r\n\r\n    # Loop the log2(bits) number of times necessary, with shift and xor.\r\n    shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)\r\n    while shift > 0:\r\n        gray = torch.logical_xor(gray, right_shift(gray, shift))\r\n        shift = torch.div(shift, 2, rounding_mode=\"floor\")\r\n    return gray\r\n\r\n\r\ndef encode(locs, num_dims, num_bits):\r\n    \"\"\"Decode an array of locations in a hypercube into a Hilbert integer.\r\n\r\n    This is a vectorized-ish version of the Hilbert curve implementation by John\r\n    Skilling as described in:\r\n\r\n    Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference\r\n      Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.\r\n\r\n    Params:\r\n    -------\r\n     locs - An ndarray of locations in a hypercube of num_dims dimensions, in\r\n            which each dimension runs from 0 to 2**num_bits-1.  The shape can\r\n            be arbitrary, as long as the last dimension of the same has size\r\n            num_dims.\r\n\r\n     num_dims - The dimensionality of the hypercube. Integer.\r\n\r\n     num_bits - The number of bits for each dimension. Integer.\r\n\r\n    Returns:\r\n    --------\r\n     The output is an ndarray of uint64 integers with the same shape as the\r\n     input, excluding the last dimension, which needs to be num_dims.\r\n    \"\"\"\r\n\r\n    # Keep around the original shape for later.\r\n    orig_shape = locs.shape\r\n    bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)\r\n    bitpack_mask_rev = bitpack_mask.flip(-1)\r\n\r\n    if orig_shape[-1] != num_dims:\r\n        raise ValueError(\r\n            \"\"\"\r\n      The shape of locs was surprising in that the last dimension was of size\r\n      %d, but num_dims=%d.  These need to be equal.\r\n      \"\"\"\r\n            % (orig_shape[-1], num_dims)\r\n        )\r\n\r\n    if num_dims * num_bits > 63:\r\n        raise ValueError(\r\n            \"\"\"\r\n      num_dims=%d and num_bits=%d for %d bits total, which can't be encoded\r\n      into a int64.  Are you sure you need that many points on your Hilbert\r\n      curve?\r\n      \"\"\"\r\n            % (num_dims, num_bits, num_dims * num_bits)\r\n        )\r\n\r\n    # Treat the location integers as 64-bit unsigned and then split them up into\r\n    # a sequence of uint8s.  Preserve the association by dimension.\r\n    locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)\r\n\r\n    # Now turn these into bits and truncate to num_bits.\r\n    gray = (\r\n        locs_uint8.unsqueeze(-1)\r\n        .bitwise_and(bitpack_mask_rev)\r\n        .ne(0)\r\n        .byte()\r\n        .flatten(-2, -1)[..., -num_bits:]\r\n    )\r\n\r\n    # Run the decoding process the other way.\r\n    # Iterate forwards through the bits.\r\n    for bit in range(0, num_bits):\r\n        # Iterate forwards through the dimensions.\r\n        for dim in range(0, num_dims):\r\n            # Identify which ones have this bit active.\r\n            mask = gray[:, dim, bit]\r\n\r\n            # Where this bit is on, invert the 0 dimension for lower bits.\r\n            gray[:, 0, bit + 1:] = torch.logical_xor(\r\n                gray[:, 0, bit + 1:], mask[:, None]\r\n            )\r\n\r\n            # Where the bit is off, exchange the lower bits with the 0 dimension.\r\n            to_flip = torch.logical_and(\r\n                torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),\r\n                torch.logical_xor(gray[:, 0, bit + 1:], gray[:, dim, bit + 1:]),\r\n            )\r\n            gray[:, dim, bit + 1:] = torch.logical_xor(\r\n                gray[:, dim, bit + 1:], to_flip\r\n            )\r\n            gray[:, 0, bit + 1:] = torch.logical_xor(gray[:, 0, bit + 1:], to_flip)\r\n\r\n    # Now flatten out.\r\n    gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))\r\n\r\n    # Convert Gray back to binary.\r\n    hh_bin = gray2binary(gray)\r\n\r\n    # Pad back out to 64 bits.\r\n    extra_dims = 64 - num_bits * num_dims\r\n    padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), \"constant\", 0)\r\n\r\n    # Convert binary values into uint8s.\r\n    hh_uint8 = (\r\n        (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)\r\n        .sum(2)\r\n        .squeeze()\r\n        .type(torch.uint8)\r\n    )\r\n\r\n    # Convert uint8s into uint64s.\r\n    hh_uint64 = hh_uint8.view(torch.int64).squeeze()\r\n\r\n    return hh_uint64\r\n\r\n\r\ndef decode(hilberts, num_dims, num_bits):\r\n    \"\"\"Decode an array of Hilbert integers into locations in a hypercube.\r\n\r\n    This is a vectorized-ish version of the Hilbert curve implementation by John\r\n    Skilling as described in:\r\n\r\n    Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference\r\n      Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.\r\n\r\n    Params:\r\n    -------\r\n     hilberts - An ndarray of Hilbert integers.  Must be an integer dtype and\r\n                cannot have fewer bits than num_dims * num_bits.\r\n\r\n     num_dims - The dimensionality of the hypercube. Integer.\r\n\r\n     num_bits - The number of bits for each dimension. Integer.\r\n\r\n    Returns:\r\n    --------\r\n     The output is an ndarray of unsigned integers with the same shape as hilberts\r\n     but with an additional dimension of size num_dims.\r\n    \"\"\"\r\n\r\n    if num_dims * num_bits > 64:\r\n        raise ValueError(\r\n            \"\"\"\r\n      num_dims=%d and num_bits=%d for %d bits total, which can't be encoded\r\n      into a uint64.  Are you sure you need that many points on your Hilbert\r\n      curve?\r\n      \"\"\"\r\n            % (num_dims, num_bits)\r\n        )\r\n\r\n    # Handle the case where we got handed a naked integer.\r\n    hilberts = torch.atleast_1d(hilberts)\r\n\r\n    # Keep around the shape for later.\r\n    orig_shape = hilberts.shape\r\n    bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)\r\n    bitpack_mask_rev = bitpack_mask.flip(-1)\r\n\r\n    # Treat each of the hilberts as a s equence of eight uint8.\r\n    # This treats all of the inputs as uint64 and makes things uniform.\r\n    hh_uint8 = (\r\n        hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)\r\n    )\r\n\r\n    # Turn these lists of uints into lists of bits and then truncate to the size\r\n    # we actually need for using Skilling's procedure.\r\n    hh_bits = (\r\n        hh_uint8.unsqueeze(-1)\r\n        .bitwise_and(bitpack_mask_rev)\r\n        .ne(0)\r\n        .byte()\r\n        .flatten(-2, -1)[:, -num_dims * num_bits:]\r\n    )\r\n\r\n    # Take the sequence of bits and Gray-code it.\r\n    gray = binary2gray(hh_bits)\r\n\r\n    # There has got to be a better way to do this.\r\n    # I could index them differently, but the eventual packbits likes it this way.\r\n    gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)\r\n\r\n    # Iterate backwards through the bits.\r\n    for bit in range(num_bits - 1, -1, -1):\r\n        # Iterate backwards through the dimensions.\r\n        for dim in range(num_dims - 1, -1, -1):\r\n            # Identify which ones have this bit active.\r\n            mask = gray[:, dim, bit]\r\n\r\n            # Where this bit is on, invert the 0 dimension for lower bits.\r\n            gray[:, 0, bit + 1:] = torch.logical_xor(\r\n                gray[:, 0, bit + 1:], mask[:, None]\r\n            )\r\n\r\n            # Where the bit is off, exchange the lower bits with the 0 dimension.\r\n            to_flip = torch.logical_and(\r\n                torch.logical_not(mask[:, None]),\r\n                torch.logical_xor(gray[:, 0, bit + 1:], gray[:, dim, bit + 1:]),\r\n            )\r\n            gray[:, dim, bit + 1:] = torch.logical_xor(\r\n                gray[:, dim, bit + 1:], to_flip\r\n            )\r\n            gray[:, 0, bit + 1:] = torch.logical_xor(gray[:, 0, bit + 1:], to_flip)\r\n\r\n    # Pad back out to 64 bits.\r\n    extra_dims = 64 - num_bits\r\n    padded = torch.nn.functional.pad(gray, (extra_dims, 0), \"constant\", 0)\r\n\r\n    # Now chop these up into blocks of 8.\r\n    locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))\r\n\r\n    # Take those blocks and turn them unto uint8s.\r\n    # from IPython import embed; embed()\r\n    locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)\r\n\r\n    # Finally, treat these as uint64s.\r\n    flat_locs = locs_uint8.view(torch.int64)\r\n\r\n    # Return them in the expected shape.\r\n    return flat_locs.reshape((*orig_shape, num_dims))\r\n"
  },
  {
    "path": "mvtracker/models/core/ptv3/serialization/z_order.py",
    "content": "# --------------------------------------------------------\r\n# Octree-based Sparse Convolutional Neural Networks\r\n# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>\r\n# Licensed under The MIT License [see LICENSE for details]\r\n# Written by Peng-Shuai Wang\r\n# --------------------------------------------------------\r\n\r\nfrom typing import Optional, Union\r\n\r\nimport torch\r\n\r\n\r\nclass KeyLUT:\r\n    def __init__(self):\r\n        r256 = torch.arange(256, dtype=torch.int64)\r\n        r512 = torch.arange(512, dtype=torch.int64)\r\n        zero = torch.zeros(256, dtype=torch.int64)\r\n        device = torch.device(\"cpu\")\r\n\r\n        self._encode = {\r\n            device: (\r\n                self.xyz2key(r256, zero, zero, 8),\r\n                self.xyz2key(zero, r256, zero, 8),\r\n                self.xyz2key(zero, zero, r256, 8),\r\n            )\r\n        }\r\n        self._decode = {device: self.key2xyz(r512, 9)}\r\n\r\n    def encode_lut(self, device=torch.device(\"cpu\")):\r\n        if device not in self._encode:\r\n            cpu = torch.device(\"cpu\")\r\n            self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])\r\n        return self._encode[device]\r\n\r\n    def decode_lut(self, device=torch.device(\"cpu\")):\r\n        if device not in self._decode:\r\n            cpu = torch.device(\"cpu\")\r\n            self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])\r\n        return self._decode[device]\r\n\r\n    def xyz2key(self, x, y, z, depth):\r\n        key = torch.zeros_like(x)\r\n        for i in range(depth):\r\n            mask = 1 << i\r\n            key = (\r\n                    key\r\n                    | ((x & mask) << (2 * i + 2))\r\n                    | ((y & mask) << (2 * i + 1))\r\n                    | ((z & mask) << (2 * i + 0))\r\n            )\r\n        return key\r\n\r\n    def key2xyz(self, key, depth):\r\n        x = torch.zeros_like(key)\r\n        y = torch.zeros_like(key)\r\n        z = torch.zeros_like(key)\r\n        for i in range(depth):\r\n            x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))\r\n            y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))\r\n            z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))\r\n        return x, y, z\r\n\r\n\r\n_key_lut = KeyLUT()\r\n\r\n\r\ndef xyz2key(\r\n        x: torch.Tensor,\r\n        y: torch.Tensor,\r\n        z: torch.Tensor,\r\n        b: Optional[Union[torch.Tensor, int]] = None,\r\n        depth: int = 16,\r\n):\r\n    r\"\"\"Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys\r\n    based on pre-computed look up tables. The speed of this function is much\r\n    faster than the method based on for-loop.\r\n\r\n    Args:\r\n      x (torch.Tensor): The x coordinate.\r\n      y (torch.Tensor): The y coordinate.\r\n      z (torch.Tensor): The z coordinate.\r\n      b (torch.Tensor or int): The batch index of the coordinates, and should be\r\n          smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of\r\n          :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.\r\n      depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).\r\n    \"\"\"\r\n\r\n    EX, EY, EZ = _key_lut.encode_lut(x.device)\r\n    x, y, z = x.long(), y.long(), z.long()\r\n\r\n    mask = 255 if depth > 8 else (1 << depth) - 1\r\n    key = EX[x & mask] | EY[y & mask] | EZ[z & mask]\r\n    if depth > 8:\r\n        mask = (1 << (depth - 8)) - 1\r\n        key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]\r\n        key = key16 << 24 | key\r\n\r\n    if b is not None:\r\n        b = b.long()\r\n        key = b << 48 | key\r\n\r\n    return key\r\n\r\n\r\ndef key2xyz(key: torch.Tensor, depth: int = 16):\r\n    r\"\"\"Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates\r\n    and the batch index based on pre-computed look up tables.\r\n\r\n    Args:\r\n      key (torch.Tensor): The shuffled key.\r\n      depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).\r\n    \"\"\"\r\n\r\n    DX, DY, DZ = _key_lut.decode_lut(key.device)\r\n    x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)\r\n\r\n    b = key >> 48\r\n    key = key & ((1 << 48) - 1)\r\n\r\n    n = (depth + 2) // 3\r\n    for i in range(n):\r\n        k = key >> (i * 9) & 511\r\n        x = x | (DX[k] << (i * 3))\r\n        y = y | (DY[k] << (i * 3))\r\n        z = z | (DZ[k] << (i * 3))\r\n\r\n    return x, y, z, b\r\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/.gitignore",
    "content": "*.pth\n*.npy\n*.mp4\noutputs/\nwork_dirs/\n*__pycache__*\n.vscode/\n.envrc\n.bak/\ndatasets/\n\npreproc/checkpoints\npreproc/checkpoints/\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/.gitmodules",
    "content": "[submodule \"preproc/tapnet\"]\n\tpath = preproc/tapnet\n\turl = https://github.com/google-deepmind/tapnet.git\n[submodule \"preproc/DROID-SLAM\"]\n\tpath = preproc/DROID-SLAM\n\turl = https://github.com/princeton-vl/DROID-SLAM.git\n[submodule \"preproc/UniDepth\"]\n\tpath = preproc/UniDepth\n\turl = https://github.com/lpiccinelli-eth/UniDepth.git\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Vickie Ye\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/README.md",
    "content": "# Shape of Motion: 4D Reconstruction from a Single Video\n**[Project Page](https://shape-of-motion.github.io/) | [Arxiv](https://arxiv.org/abs/2407.13764)**\n\n[Qianqian Wang](https://qianqianwang68.github.io/)<sup>1,2</sup>*, [Vickie Ye](https://people.eecs.berkeley.edu/~vye/)<sup>1</sup>\\*, [Hang Gao](https://hangg7.com/)<sup>1</sup>\\*, [Jake Austin](https://www.linkedin.com/in/jakeaustin4701)<sup>1</sup>, [Zhengqi Li](https://zhengqili.github.io/)<sup>2</sup>, [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/)<sup>1</sup>\n\n<sup>1</sup>UC Berkeley   &nbsp;  <sup>2</sup>Google Research\n\n\\* Equal Contribution\n\n\n\n## Installation\n\n```\ngit clone --recurse-submodules https://github.com/vye16/shape-of-motion\ncd shape-of-motion/\nconda create -n som python=3.10\nconda activate som\n```\n\nUpdate `requirements.txt` with correct CUDA version for PyTorch and cuUML,\ni.e., replacing `cu122` and `cu12` with your CUDA version.\n```\n\npip install -r requirements.txt\npip install git+https://github.com/nerfstudio-project/gsplat.git\n```\n\n## Usage\n\n### Preprocessing\n\nWe depend on the third-party libraries in `preproc` to generate depth maps, object masks, camera estimates, and 2D tracks.\nPlease follow the guide in the [preprocessing README](./preproc/README.md).\n\n### Fitting to a Video\n\n```python\npython run_training.py \\\n  --work-dir <OUTPUT_DIR> \\\n  data:davis \\\n  --data.seq-name horsejump-low\n```\n\n## Evaluation on iPhone Dataset\nFirst, download our processed iPhone dataset from [this](https://drive.google.com/drive/folders/1xJaFS_3027crk7u36cue7BseAX80abRe?usp=sharing) link. To train on a sequence, e.g., *paper-windmill*, run:\n\n```python\npython run_training.py \\\n  --work-dir <OUTPUT_DIR> \\\n  --port <PORT> \\\n  data:iphone \\\n  --data.data-dir </path/to/paper-windmill/>\n```\n\nAfter optimization, the numerical result can be evaluated via:\n```\nPYTHONPATH='.' python scripts/evaluate_iphone.py \\\n  --data_dir </path/to/paper-windmill/> \\\n  --result_dir <OUTPUT_DIR> \\\n  --seq_names paper-windmill\n```\n\n\n## Citation\n```\n@inproceedings{som2024,\n  title     = {Shape of Motion: 4D Reconstruction from a Single Video},\n  author    = {Wang, Qianqian and Ye, Vickie and Gao, Hang and Austin, Jake and Li, Zhengqi and Kanazawa, Angjoo},\n  journal   = {arXiv preprint arXiv:2407.13764},\n  year      = {2024}\n}\n```\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/configs.py",
    "content": "from dataclasses import dataclass\n\n\n@dataclass\nclass FGLRConfig:\n    means: float = 1.6e-4\n    opacities: float = 1e-2\n    scales: float = 5e-3\n    quats: float = 1e-3\n    colors: float = 1e-2\n    motion_coefs: float = 1e-2\n\n\n@dataclass\nclass BGLRConfig:\n    means: float = 1.6e-4\n    opacities: float = 5e-2\n    scales: float = 5e-3\n    quats: float = 1e-3\n    colors: float = 1e-2\n\n\n@dataclass\nclass MotionLRConfig:\n    rots: float = 1.6e-4\n    transls: float = 1.6e-4\n\n\n@dataclass\nclass SceneLRConfig:\n    fg: FGLRConfig\n    bg: BGLRConfig\n    motion_bases: MotionLRConfig\n\n\n@dataclass\nclass LossesConfig:\n    w_rgb: float = 1.0\n    w_depth_reg: float = 0.5\n    w_depth_const: float = 0.1\n    w_depth_grad: float = 1\n    w_track: float = 2.0\n    w_mask: float = 1.0\n    w_smooth_bases: float = 0.1\n    w_smooth_tracks: float = 2.0\n    w_scale_var: float = 0.01\n    w_z_accel: float = 1.0\n\n\n@dataclass\nclass OptimizerConfig:\n    max_steps: int = 5000\n    ## Adaptive gaussian control\n    warmup_steps: int = 200\n    control_every: int = 100\n    reset_opacity_every_n_controls: int = 30\n    stop_control_by_screen_steps: int = 4000\n    stop_control_steps: int = 4000\n    ### Densify.\n    densify_xys_grad_threshold: float = 0.0002\n    densify_scale_threshold: float = 0.01\n    densify_screen_threshold: float = 0.05\n    stop_densify_steps: int = 15000\n    ### Cull.\n    cull_opacity_threshold: float = 0.1\n    cull_scale_threshold: float = 0.5\n    cull_screen_threshold: float = 0.15\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/__init__.py",
    "content": "from dataclasses import asdict, replace\n\nfrom torch.utils.data import Dataset\n\nfrom .base_dataset import BaseDataset\nfrom .panoptic_dataset import PanopticDataConfig, PanopticStudioDatasetSoM\nfrom .casual_dataset import CasualDataset, CustomDataConfig, DavisDataConfig\nfrom .iphone_dataset import (\n    iPhoneDataConfig,\n    iPhoneDataset,\n    iPhoneDatasetKeypointView,\n    iPhoneDatasetVideoView,\n)\n\n\ndef get_train_val_datasets(\n    data_cfg: iPhoneDataConfig | DavisDataConfig | CustomDataConfig, load_val: bool\n) -> tuple[BaseDataset, Dataset | None, Dataset | None, Dataset | None]:\n    train_video_view = None\n    val_img_dataset = None\n    val_kpt_dataset = None\n    if isinstance(data_cfg, iPhoneDataConfig):\n        train_dataset = iPhoneDataset(**asdict(data_cfg))\n        train_video_view = iPhoneDatasetVideoView(train_dataset)\n        if load_val:\n            val_img_dataset = (\n                iPhoneDataset(\n                    **asdict(replace(data_cfg, split=\"val\", load_from_cache=True))\n                )\n                if train_dataset.has_validation\n                else None\n            )\n            val_kpt_dataset = iPhoneDatasetKeypointView(train_dataset)\n    elif isinstance(data_cfg, DavisDataConfig) or isinstance(\n        data_cfg, CustomDataConfig\n    ):\n        train_dataset = CasualDataset(**asdict(data_cfg))\n\n    elif isinstance(data_cfg, PanopticDataConfig):\n        train_dataset = PanopticStudioDatasetSoM(**asdict(data_cfg))\n        print(\"PANOPTIC IS LOADED.\")\n\n    \n    else:\n        raise ValueError(f\"Unknown data config: {data_cfg}\")\n    return train_dataset, train_video_view, val_img_dataset, val_kpt_dataset\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/base_dataset.py",
    "content": "from abc import abstractmethod\n\nimport torch\nfrom torch.utils.data import Dataset, default_collate\n\n\nclass BaseDataset(Dataset):\n    @property\n    @abstractmethod\n    def num_frames(self) -> int: ...\n\n    @property\n    def keyframe_idcs(self) -> torch.Tensor:\n        return torch.arange(self.num_frames)\n\n    @abstractmethod\n    def get_w2cs(self) -> torch.Tensor: ...\n\n    @abstractmethod\n    def get_Ks(self) -> torch.Tensor: ...\n\n    @abstractmethod\n    def get_image(self, index: int) -> torch.Tensor: ...\n\n    @abstractmethod\n    def get_depth(self, index: int) -> torch.Tensor: ...\n\n    @abstractmethod\n    def get_mask(self, index: int) -> torch.Tensor: ...\n\n    def get_img_wh(self) -> tuple[int, int]: ...\n\n    @abstractmethod\n    def get_tracks_3d(\n        self, num_samples: int, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns 3D tracks:\n            coordinates (N, T, 3),\n            visibles (N, T),\n            invisibles (N, T),\n            confidences (N, T),\n            colors (N, 3)\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_bkgd_points(\n        self, num_samples: int, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns background points:\n            coordinates (N, 3),\n            normals (N, 3),\n            colors (N, 3)\n        \"\"\"\n        ...\n\n    # @staticmethod\n    # def train_collate_fn(batch):\n    #     collated = {}\n    #     for k in batch[0]:\n    #         if k not in [\n    #             \"query_tracks_2d\",\n    #             \"target_ts\",\n    #             \"target_w2cs\",\n    #             \"target_Ks\",\n    #             \"target_tracks_2d\",\n    #             \"target_visibles\",\n    #             \"target_track_depths\",\n    #             \"target_invisibles\",\n    #             \"target_confidences\",\n    #         ]:\n    #             collated[k] = default_collate([sample[k] for sample in batch])\n    #         else:\n    #             collated[k] = [sample[k] for sample in batch]\n    #     return collated\n\n    @staticmethod\n    def train_collate_fn(batch):\n        \"\"\"\n        Collate function that correctly batches data when each sample consists of multiple views.\n        \"\"\"\n\n        # Step 1: Transpose the batch to group by views\n        # If batch contains 4 views per sample, `batch` is a list of lists: [ [view_1, view_2, view_3, view_4], [view_1, view_2, view_3, view_4], ... ]\n        # We want to group all view_1's together, all view_2's together, etc.\n        num_views = len(batch[0])  # Assumes each sample has the same number of views\n        batch_per_view = list(zip(*batch))  # Transposes list-of-lists structure\n\n        collated_views = []\n        \n        # Step 2: Collate each view separately\n        for view_batch in batch_per_view:\n            collated = {}\n            for k in view_batch[0]:  # Iterate over keys in the dictionary\n                if k not in [\n                    \"query_tracks_2d\",\n                    \"target_ts\",\n                    \"target_w2cs\",\n                    \"target_Ks\",\n                    \"target_tracks_2d\",\n                    \"target_visibles\",\n                    \"target_track_depths\",\n                    \"target_invisibles\",\n                    \"target_confidences\",\n                ]:\n                    collated[k] = default_collate([sample[k] for sample in view_batch])\n                else:\n                    collated[k] = [sample[k] for sample in view_batch]  # Keep list format\n            collated_views.append(collated)\n\n        return collated_views  # List of collated dictionaries, one per view"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/casual_dataset.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Literal, cast\n\nimport cv2\nimport imageio\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport tyro\nfrom loguru import logger as guru\nfrom roma import roma\nfrom tqdm import tqdm\n\nfrom flow3d.data.base_dataset import BaseDataset\nfrom flow3d.data.utils import (\n    UINT16_MAX,\n    SceneNormDict,\n    get_tracks_3d_for_query_frame,\n    median_filter_2d,\n    normal_from_depth_image,\n    normalize_coords,\n    parse_tapir_track_info,\n)\nfrom flow3d.transforms import rt_to_mat4\n\n\n@dataclass\nclass DavisDataConfig:\n    seq_name: str\n    root_dir: str\n    start: int = 0\n    end: int = -1\n    res: str = \"480p\"\n    image_type: str = \"JPEGImages\"\n    mask_type: str = \"Annotations\"\n    depth_type: Literal[\n        \"aligned_depth_anything\",\n        \"aligned_depth_anything_v2\",\n        \"depth_anything\",\n        \"depth_anything_v2\",\n        \"unidepth_disp\",\n    ] = \"aligned_depth_anything\"\n    camera_type: Literal[\"droid_recon\"] = \"droid_recon\"\n    track_2d_type: Literal[\"bootstapir\", \"tapir\"] = \"bootstapir\"\n    mask_erosion_radius: int = 3\n    scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None\n    num_targets_per_frame: int = 4\n    load_from_cache: bool = False\n\n\n@dataclass\nclass CustomDataConfig:\n    seq_name: str\n    root_dir: str\n    start: int = 0\n    end: int = -1\n    res: str = \"\"\n    image_type: str = \"images\"\n    mask_type: str = \"masks\"\n    depth_type: Literal[\n        \"aligned_depth_anything\",\n        \"aligned_depth_anything_v2\",\n        \"depth_anything\",\n        \"depth_anything_v2\",\n        \"unidepth_disp\",\n    ] = \"aligned_depth_anything\"\n    camera_type: Literal[\"droid_recon\"] = \"droid_recon\"\n    track_2d_type: Literal[\"bootstapir\", \"tapir\"] = \"bootstapir\"\n    mask_erosion_radius: int = 7\n    scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None\n    num_targets_per_frame: int = 4\n    load_from_cache: bool = False\n\n\nclass CasualDataset(BaseDataset):\n    def __init__(\n        self,\n        seq_name: str,\n        root_dir: str,\n        start: int = 0,\n        end: int = -1,\n        res: str = \"480p\",\n        image_type: str = \"JPEGImages\",\n        mask_type: str = \"Annotations\",\n        depth_type: Literal[\n            \"aligned_depth_anything\",\n            \"aligned_depth_anything_v2\",\n            \"depth_anything\",\n            \"depth_anything_v2\",\n            \"unidepth_disp\",\n        ] = \"aligned_depth_anything\",\n        camera_type: Literal[\"droid_recon\"] = \"droid_recon\",\n        track_2d_type: Literal[\"bootstapir\", \"tapir\"] = \"bootstapir\",\n        mask_erosion_radius: int = 3,\n        scene_norm_dict: SceneNormDict | None = None,\n        num_targets_per_frame: int = 4,\n        load_from_cache: bool = False,\n        **_,\n    ):\n        super().__init__()\n\n        self.seq_name = seq_name\n        self.root_dir = root_dir\n        self.res = res\n        self.depth_type = depth_type\n        self.num_targets_per_frame = num_targets_per_frame\n        self.load_from_cache = load_from_cache\n        self.has_validation = False\n        self.mask_erosion_radius = mask_erosion_radius\n\n        self.img_dir = f\"{root_dir}/{image_type}/{res}/{seq_name}\"\n        self.img_ext = os.path.splitext(os.listdir(self.img_dir)[0])[1]\n        self.depth_dir = f\"{root_dir}/{depth_type}/{res}/{seq_name}\"\n        self.mask_dir = f\"{root_dir}/{mask_type}/{res}/{seq_name}\"\n        self.tracks_dir = f\"{root_dir}/{track_2d_type}/{res}/{seq_name}\"\n        self.cache_dir = f\"{root_dir}/flow3d_preprocessed/{res}/{seq_name}\"\n        #  self.cache_dir = f\"datasets/davis/flow3d_preprocessed/{res}/{seq_name}\"\n        frame_names = [os.path.splitext(p)[0] for p in sorted(os.listdir(self.img_dir))]\n\n        if end == -1:\n            end = len(frame_names)\n        self.start = start\n        self.end = end\n        self.frame_names = frame_names[start:end]\n\n        self.imgs: list[torch.Tensor | None] = [None for _ in self.frame_names]\n        self.depths: list[torch.Tensor | None] = [None for _ in self.frame_names]\n        self.masks: list[torch.Tensor | None] = [None for _ in self.frame_names]\n\n        # load cameras\n        if camera_type == \"droid_recon\":\n            img = self.get_image(0)\n            H, W = img.shape[:2]\n            w2cs, Ks, tstamps = load_cameras(\n                f\"{root_dir}/{camera_type}/{seq_name}.npy\", H, W\n            )\n        else:\n            raise ValueError(f\"Unknown camera type: {camera_type}\")\n        assert (\n            len(frame_names) == len(w2cs) == len(Ks)\n        ), f\"{len(frame_names)}, {len(w2cs)}, {len(Ks)}\"\n        self.w2cs = w2cs[start:end]\n        self.Ks = Ks[start:end]\n        tmask = (tstamps >= start) & (tstamps < end)\n        self._keyframe_idcs = tstamps[tmask] - start\n        self.scale = 1\n\n        if scene_norm_dict is None:\n            cached_scene_norm_dict_path = os.path.join(\n                self.cache_dir, \"scene_norm_dict.pth\"\n            )\n            if os.path.exists(cached_scene_norm_dict_path) and self.load_from_cache:\n                guru.info(\"loading cached scene norm dict...\")\n                scene_norm_dict = torch.load(\n                    os.path.join(self.cache_dir, \"scene_norm_dict.pth\")\n                )\n            else:\n                tracks_3d = self.get_tracks_3d(5000, step=self.num_frames // 10)[0]\n                scale, transfm = compute_scene_norm(tracks_3d, self.w2cs)\n                scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm)\n                os.makedirs(self.cache_dir, exist_ok=True)\n                torch.save(scene_norm_dict, cached_scene_norm_dict_path)\n\n        # transform cameras\n        self.scene_norm_dict = cast(SceneNormDict, scene_norm_dict)\n        self.scale = self.scene_norm_dict[\"scale\"]\n        transform = self.scene_norm_dict[\"transfm\"]\n        guru.info(f\"scene norm {self.scale=}, {transform=}\")\n        self.w2cs = torch.einsum(\"nij,jk->nik\", self.w2cs, torch.linalg.inv(transform))\n        self.w2cs[:, :3, 3] /= self.scale\n\n    @property\n    def num_frames(self) -> int:\n        return len(self.frame_names)\n\n    @property\n    def keyframe_idcs(self) -> torch.Tensor:\n        return self._keyframe_idcs\n\n    def __len__(self):\n        return len(self.frame_names)\n\n    def get_w2cs(self) -> torch.Tensor:\n        return self.w2cs\n\n    def get_Ks(self) -> torch.Tensor:\n        return self.Ks\n\n    def get_img_wh(self) -> tuple[int, int]:\n        return self.get_image(0).shape[1::-1]\n\n    def get_image(self, index) -> torch.Tensor:\n        if self.imgs[index] is None:\n            self.imgs[index] = self.load_image(index)\n        img = cast(torch.Tensor, self.imgs[index])\n        return img\n\n    def get_mask(self, index) -> torch.Tensor:\n        if self.masks[index] is None:\n            self.masks[index] = self.load_mask(index)\n        mask = cast(torch.Tensor, self.masks[index])\n        return mask\n\n    def get_depth(self, index) -> torch.Tensor:\n        if self.depths[index] is None:\n            self.depths[index] = self.load_depth(index)\n        return self.depths[index] / self.scale\n\n    def load_image(self, index) -> torch.Tensor:\n        path = f\"{self.img_dir}/{self.frame_names[index]}{self.img_ext}\"\n        return torch.from_numpy(imageio.imread(path)).float() / 255.0\n\n    def load_mask(self, index) -> torch.Tensor:\n        path = f\"{self.mask_dir}/{self.frame_names[index]}.png\"\n        r = self.mask_erosion_radius\n        mask = imageio.imread(path)\n        fg_mask = mask.reshape((*mask.shape[:2], -1)).max(axis=-1) > 0\n        bg_mask = ~fg_mask\n        fg_mask_erode = cv2.erode(\n            fg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1\n        )\n        bg_mask_erode = cv2.erode(\n            bg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1\n        )\n        out_mask = np.zeros_like(fg_mask, dtype=np.float32)\n        out_mask[bg_mask_erode > 0] = -1\n        out_mask[fg_mask_erode > 0] = 1\n        return torch.from_numpy(out_mask).float()\n\n    def load_depth(self, index) -> torch.Tensor:\n        path = f\"{self.depth_dir}/{self.frame_names[index]}.npy\"\n        disp = np.load(path)\n        depth = 1.0 / np.clip(disp, a_min=1e-6, a_max=1e6)\n        depth = torch.from_numpy(depth).float()\n        depth = median_filter_2d(depth[None, None], 11, 1)[0, 0]\n        return depth\n\n    def load_target_tracks(\n        self, query_index: int, target_indices: list[int], dim: int = 1\n    ):\n        \"\"\"\n        tracks are 2d, occs and uncertainties\n        :param dim (int), default 1: dimension to stack the time axis\n        return (N, T, 4) if dim=1, (T, N, 4) if dim=0\n        \"\"\"\n        q_name = self.frame_names[query_index]\n        all_tracks = []\n        for ti in target_indices:\n            t_name = self.frame_names[ti]\n            path = f\"{self.tracks_dir}/{q_name}_{t_name}.npy\"\n            tracks = np.load(path).astype(np.float32)\n            all_tracks.append(tracks)\n        return torch.from_numpy(np.stack(all_tracks, axis=dim))\n\n    def get_tracks_3d(\n        self, num_samples: int, start: int = 0, end: int = -1, step: int = 1, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        num_frames = self.num_frames\n        if end < 0:\n            end = num_frames + 1 + end\n        query_idcs = list(range(start, end, step))\n        target_idcs = list(range(start, end, step))\n        masks = torch.stack([self.get_mask(i) for i in target_idcs], dim=0)\n        fg_masks = (masks == 1).float()\n        depths = torch.stack([self.get_depth(i) for i in target_idcs], dim=0)\n        inv_Ks = torch.linalg.inv(self.Ks[target_idcs])\n        c2ws = torch.linalg.inv(self.w2cs[target_idcs])\n\n        num_per_query_frame = int(np.ceil(num_samples / len(query_idcs)))\n        cur_num = 0\n        tracks_all_queries = []\n        for q_idx in query_idcs:\n            # (N, T, 4)\n            tracks_2d = self.load_target_tracks(q_idx, target_idcs)\n            num_sel = int(\n                min(num_per_query_frame, num_samples - cur_num, len(tracks_2d))\n            )\n            if num_sel < len(tracks_2d):\n                sel_idcs = np.random.choice(len(tracks_2d), num_sel, replace=False)\n                tracks_2d = tracks_2d[sel_idcs]\n            cur_num += tracks_2d.shape[0]\n            img = self.get_image(q_idx)\n            tidx = target_idcs.index(q_idx)\n            tracks_tuple = get_tracks_3d_for_query_frame(\n                tidx, img, tracks_2d, depths, fg_masks, inv_Ks, c2ws\n            )\n            tracks_all_queries.append(tracks_tuple)\n        tracks_3d, colors, visibles, invisibles, confidences = map(\n            partial(torch.cat, dim=0), zip(*tracks_all_queries)\n        )\n        return tracks_3d, visibles, invisibles, confidences, colors\n\n    def get_bkgd_points(\n        self,\n        num_samples: int,\n        use_kf_tstamps: bool = True,\n        stride: int = 8,\n        down_rate: int = 8,\n        min_per_frame: int = 64,\n        **kwargs,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        start = 0\n        end = self.num_frames\n        H, W = self.get_image(0).shape[:2]\n        grid = torch.stack(\n            torch.meshgrid(\n                torch.arange(0, W, dtype=torch.float32),\n                torch.arange(0, H, dtype=torch.float32),\n                indexing=\"xy\",\n            ),\n            dim=-1,\n        )\n\n        if use_kf_tstamps:\n            query_idcs = self.keyframe_idcs.tolist()\n        else:\n            num_query_frames = self.num_frames // stride\n            query_endpts = torch.linspace(start, end, num_query_frames + 1)\n            query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist()\n\n        bg_geometry = []\n        print(f\"{query_idcs=}\")\n        for query_idx in tqdm(query_idcs, desc=\"Loading bkgd points\", leave=False):\n            img = self.get_image(query_idx)\n            depth = self.get_depth(query_idx)\n            bg_mask = self.get_mask(query_idx) < 0\n            bool_mask = (bg_mask * (depth > 0)).to(torch.bool)\n            w2c = self.w2cs[query_idx]\n            K = self.Ks[query_idx]\n\n            # get the bounding box of previous points that reproject into frame\n            # inefficient but works for now\n            bmax_x, bmax_y, bmin_x, bmin_y = 0, 0, W, H\n            for p3d, _, _ in bg_geometry:\n                if len(p3d) < 1:\n                    continue\n                # reproject into current frame\n                p2d = torch.einsum(\n                    \"ij,jk,pk->pi\", K, w2c[:3], F.pad(p3d, (0, 1), value=1.0)\n                )\n                p2d = p2d[:, :2] / p2d[:, 2:].clamp(min=1e-6)\n                xmin, xmax = p2d[:, 0].min().item(), p2d[:, 0].max().item()\n                ymin, ymax = p2d[:, 1].min().item(), p2d[:, 1].max().item()\n\n                bmin_x = min(bmin_x, int(xmin))\n                bmin_y = min(bmin_y, int(ymin))\n                bmax_x = max(bmax_x, int(xmax))\n                bmax_y = max(bmax_y, int(ymax))\n\n            # don't include points that are covered by previous points\n            bmin_x = max(0, bmin_x)\n            bmin_y = max(0, bmin_y)\n            bmax_x = min(W, bmax_x)\n            bmax_y = min(H, bmax_y)\n            overlap_mask = torch.ones_like(bool_mask)\n            overlap_mask[bmin_y:bmax_y, bmin_x:bmax_x] = 0\n\n            bool_mask &= overlap_mask\n            if bool_mask.sum() < min_per_frame:\n                guru.debug(f\"skipping {query_idx=}\")\n                continue\n\n            points = (\n                torch.einsum(\n                    \"ij,pj->pi\",\n                    torch.linalg.inv(K),\n                    F.pad(grid[bool_mask], (0, 1), value=1.0),\n                )\n                * depth[bool_mask][:, None]\n            )\n            points = torch.einsum(\n                \"ij,pj->pi\", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0)\n            )\n            point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask]\n            point_colors = img[bool_mask]\n\n            num_sel = max(len(points) // down_rate, min_per_frame)\n            sel_idcs = np.random.choice(len(points), num_sel, replace=False)\n            points = points[sel_idcs]\n            point_normals = point_normals[sel_idcs]\n            point_colors = point_colors[sel_idcs]\n            guru.debug(f\"{query_idx=} {points.shape=}\")\n            bg_geometry.append((points, point_normals, point_colors))\n\n        bg_points, bg_normals, bg_colors = map(\n            partial(torch.cat, dim=0), zip(*bg_geometry)\n        )\n        if len(bg_points) > num_samples:\n            sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False)\n            bg_points = bg_points[sel_idcs]\n            bg_normals = bg_normals[sel_idcs]\n            bg_colors = bg_colors[sel_idcs]\n\n        return bg_points, bg_normals, bg_colors\n\n    def __getitem__(self, index: int):\n        index = np.random.randint(0, self.num_frames)\n        data = {\n            # ().\n            \"frame_names\": self.frame_names[index],\n            # ().\n            \"ts\": torch.tensor(index),\n            # (4, 4).\n            \"w2cs\": self.w2cs[index],\n            # (3, 3).\n            \"Ks\": self.Ks[index],\n            # (H, W, 3).\n            \"imgs\": self.get_image(index),\n            \"depths\": self.get_depth(index),\n        }\n        tri_mask = self.get_mask(index)\n        valid_mask = tri_mask != 0  # not fg or bg\n        mask = tri_mask == 1  # fg mask\n        data[\"masks\"] = mask.float()\n        data[\"valid_masks\"] = valid_mask.float()\n\n        # (P, 2)\n        query_tracks = self.load_target_tracks(index, [index])[:, 0, :2]\n        target_inds = torch.from_numpy(\n            np.random.choice(\n                self.num_frames, (self.num_targets_per_frame,), replace=False\n            )\n        )\n        # (N, P, 4)\n        target_tracks = self.load_target_tracks(index, target_inds.tolist(), dim=0)\n        data[\"query_tracks_2d\"] = query_tracks\n        data[\"target_ts\"] = target_inds\n        data[\"target_w2cs\"] = self.w2cs[target_inds]\n        data[\"target_Ks\"] = self.Ks[target_inds]\n        data[\"target_tracks_2d\"] = target_tracks[..., :2]\n        # (N, P).\n        (\n            data[\"target_visibles\"],\n            data[\"target_invisibles\"],\n            data[\"target_confidences\"],\n        ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3])\n        # (N, H, W)\n        target_depths = torch.stack([self.get_depth(i) for i in target_inds], dim=0)\n        H, W = target_depths.shape[-2:]\n        data[\"target_track_depths\"] = F.grid_sample(\n            target_depths[:, None],\n            normalize_coords(target_tracks[..., None, :2], H, W),\n            align_corners=True,\n            padding_mode=\"border\",\n        )[:, 0, :, 0]\n        return data\n\n\ndef load_cameras(\n    path: str, H: int, W: int\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    assert os.path.exists(path), f\"Camera file {path} does not exist.\"\n    recon = np.load(path, allow_pickle=True).item()\n    guru.debug(f\"{recon.keys()=}\")\n    traj_c2w = recon[\"traj_c2w\"]  # (N, 4, 4)\n    h, w = recon[\"img_shape\"]\n    sy, sx = H / h, W / w\n    traj_w2c = np.linalg.inv(traj_c2w)\n    fx, fy, cx, cy = recon[\"intrinsics\"]  # (4,)\n    K = np.array([[fx * sx, 0, cx * sx], [0, fy * sy, cy * sy], [0, 0, 1]])  # (3, 3)\n    Ks = np.tile(K[None, ...], (len(traj_c2w), 1, 1))  # (N, 3, 3)\n    kf_tstamps = recon[\"tstamps\"].astype(\"int\")\n    return (\n        torch.from_numpy(traj_w2c).float(),\n        torch.from_numpy(Ks).float(),\n        torch.from_numpy(kf_tstamps),\n    )\n\n\ndef compute_scene_norm(\n    X: torch.Tensor, w2cs: torch.Tensor\n) -> tuple[float, torch.Tensor]:\n    \"\"\"\n    :param X: [N*T, 3]\n    :param w2cs: [N, 4, 4]\n    \"\"\"\n    X = X.reshape(-1, 3)\n    scene_center = X.mean(dim=0)\n    X = X - scene_center[None]\n    min_scale = X.quantile(0.05, dim=0)\n    max_scale = X.quantile(0.95, dim=0)\n    scale = (max_scale - min_scale).max().item() / 2.0\n    original_up = -F.normalize(w2cs[:, 1, :3].mean(0), dim=-1)\n    target_up = original_up.new_tensor([0.0, 0.0, 1.0])\n    R = roma.rotvec_to_rotmat(\n        F.normalize(original_up.cross(target_up), dim=-1)\n        * original_up.dot(target_up).acos_()\n    )\n    transfm = rt_to_mat4(R, torch.einsum(\"ij,j->i\", -R, scene_center))\n    return scale, transfm\n\n\nif __name__ == \"__main__\":\n    d = CasualDataset(\"bear\", \"/shared/vye/datasets/DAVIS\", camera_type=\"droid_recon\")\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/colmap.py",
    "content": "import os\nimport struct\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Dict, Union\n\nimport numpy as np\n\n\ndef get_colmap_camera_params(colmap_dir, img_files):\n    cameras = read_cameras_binary(colmap_dir + \"/cameras.bin\")\n    images = read_images_binary(colmap_dir + \"/images.bin\")\n    colmap_image_idcs = {v.name: k for k, v in images.items()}\n    img_names = [os.path.basename(img_file) for img_file in img_files]\n    num_imgs = len(img_names)\n    K_all = np.zeros((num_imgs, 4, 4))\n    extrinsics_all = np.zeros((num_imgs, 4, 4))\n    for idx, name in enumerate(img_names):\n        key = colmap_image_idcs[name]\n        image = images[key]\n        assert image.name == name\n        K, extrinsics = get_intrinsics_extrinsics(image, cameras)\n        K_all[idx] = K\n        extrinsics_all[idx] = extrinsics\n\n    return K_all, extrinsics_all\n\n\n@dataclass(frozen=True)\nclass CameraModel:\n    model_id: int\n    model_name: str\n    num_params: int\n\n\n@dataclass(frozen=True)\nclass Camera:\n    id: int\n    model: str\n    width: int\n    height: int\n    params: np.ndarray\n\n\n@dataclass(frozen=True)\nclass BaseImage:\n    id: int\n    qvec: np.ndarray\n    tvec: np.ndarray\n    camera_id: int\n    name: str\n    xys: np.ndarray\n    point3D_ids: np.ndarray\n\n\n@dataclass(frozen=True)\nclass Point3D:\n    id: int\n    xyz: np.ndarray\n    rgb: np.ndarray\n    error: Union[float, np.ndarray]\n    image_ids: np.ndarray\n    point2D_idxs: np.ndarray\n\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\n\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12),\n}\nCAMERA_MODEL_IDS = dict(\n    [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]\n)\n\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\n\ndef read_cameras_text(path: Union[str, Path]) -> Dict[int, Camera]:\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasText(const std::string& path)\n        void Reconstruction::ReadCamerasText(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(\n                    id=camera_id, model=model, width=width, height=height, params=params\n                )\n    return cameras\n\n\ndef read_cameras_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Camera]:\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for camera_line_index in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\"\n            )\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(\n                fid, num_bytes=8 * num_params, format_char_sequence=\"d\" * num_params\n            )\n            cameras[camera_id] = Camera(\n                id=camera_id,\n                model=model_name,\n                width=width,\n                height=height,\n                params=np.array(params),\n            )\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_images_text(path: Union[str, Path]) -> Dict[int, Image]:\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesText(const std::string& path)\n        void Reconstruction::WriteImagesText(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack(\n                    [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]\n                )\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id,\n                    qvec=qvec,\n                    tvec=tvec,\n                    camera_id=camera_id,\n                    name=image_name,\n                    xys=xys,\n                    point3D_ids=point3D_ids,\n                )\n    return images\n\n\ndef read_images_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Image]:\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for image_index in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\"\n            )\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":  # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence=\"Q\")[\n                0\n            ]\n            x_y_id_s = read_next_bytes(\n                fid,\n                num_bytes=24 * num_points2D,\n                format_char_sequence=\"ddq\" * num_points2D,\n            )\n            xys = np.column_stack(\n                [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]\n            )\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id,\n                qvec=qvec,\n                tvec=tvec,\n                camera_id=camera_id,\n                name=image_name,\n                xys=xys,\n                point3D_ids=point3D_ids,\n            )\n    return images\n\n\ndef read_points3D_text(path: Union[str, Path]):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    points3D = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                point3D_id = int(elems[0])\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = float(elems[7])\n                image_ids = np.array(tuple(map(int, elems[8::2])))\n                point2D_idxs = np.array(tuple(map(int, elems[9::2])))\n                points3D[point3D_id] = Point3D(\n                    id=point3D_id,\n                    xyz=xyz,\n                    rgb=rgb,\n                    error=error,\n                    image_ids=image_ids,\n                    point2D_idxs=point2D_idxs,\n                )\n    return points3D\n\n\ndef read_points3d_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Point3D]:\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n    points3D = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n        for point_line_index in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\"\n            )\n            point3D_id = binary_point_line_properties[0]\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence=\"Q\")[\n                0\n            ]\n            track_elems = read_next_bytes(\n                fid,\n                num_bytes=8 * track_length,\n                format_char_sequence=\"ii\" * track_length,\n            )\n            image_ids = np.array(tuple(map(int, track_elems[0::2])))\n            point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))\n            points3D[point3D_id] = Point3D(\n                id=point3D_id,\n                xyz=xyz,\n                rgb=rgb,\n                error=error,\n                image_ids=image_ids,\n                point2D_idxs=point2D_idxs,\n            )\n    return points3D\n\n\ndef qvec2rotmat(qvec):\n    return np.array(\n        [\n            [\n                1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,\n                2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n                2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],\n            ],\n            [\n                2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n                1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,\n                2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],\n            ],\n            [\n                2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n                2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n                1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,\n            ],\n        ]\n    )\n\n\ndef get_intrinsics_extrinsics(img, cameras):\n    # world to cam transformation\n    R = qvec2rotmat(img.qvec)\n    # translation\n    t = img.tvec\n    cam = cameras[img.camera_id]\n\n    if cam.model in (\"SIMPLE_PINHOLE\", \"SIMPLE_RADIAL\", \"RADIAL\"):\n        fx = fy = cam.params[0]\n        cx = cam.params[1]\n        cy = cam.params[2]\n    elif cam.model in (\n        \"PINHOLE\",\n        \"OPENCV\",\n        \"OPENCV_FISHEYE\",\n        \"FULL_OPENCV\",\n    ):\n        fx = cam.params[0]\n        fy = cam.params[1]\n        cx = cam.params[2]\n        cy = cam.params[3]\n    else:\n        raise Exception(\"Camera model not supported\")\n\n    # intrinsics\n    K = np.identity(4)\n    K[0, 0] = fx\n    K[1, 1] = fy\n    K[0, 2] = cx\n    K[1, 2] = cy\n\n    extrinsics = np.eye(4)\n    extrinsics[:3, :3] = R\n    extrinsics[:3, 3] = t\n    return K, extrinsics\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/iphone_dataset.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom dataclasses import dataclass\nfrom glob import glob\nfrom itertools import product\nfrom typing import Literal\n\nimport imageio.v3 as iio\nimport numpy as np\nimport roma\nimport torch\nimport torch.nn.functional as F\nimport tyro\nfrom loguru import logger as guru\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\n\nfrom flow3d.data.base_dataset import BaseDataset\nfrom flow3d.data.colmap import get_colmap_camera_params\nfrom flow3d.data.utils import (\n    SceneNormDict,\n    masked_median_blur,\n    normal_from_depth_image,\n    normalize_coords,\n    parse_tapir_track_info,\n)\nfrom flow3d.transforms import rt_to_mat4\n\n\n@dataclass\nclass iPhoneDataConfig:\n    data_dir: str\n    start: int = 0\n    end: int = -1\n    split: Literal[\"train\", \"val\"] = \"train\"\n    depth_type: Literal[\n        \"midas\",\n        \"depth_anything\",\n        \"lidar\",\n        \"depth_anything_colmap\",\n    ] = \"depth_anything_colmap\"\n    camera_type: Literal[\"original\", \"refined\"] = \"refined\"\n    use_median_filter: bool = False\n    num_targets_per_frame: int = 4\n    scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None\n    load_from_cache: bool = False\n    skip_load_imgs: bool = False\n\n\nclass iPhoneDataset(BaseDataset):\n    def __init__(\n        self,\n        data_dir: str,\n        start: int = 0,\n        end: int = -1,\n        factor: int = 1,\n        split: Literal[\"train\", \"val\"] = \"train\",\n        depth_type: Literal[\n            \"midas\",\n            \"depth_anything\",\n            \"lidar\",\n            \"depth_anything_colmap\",\n        ] = \"depth_anything_colmap\",\n        camera_type: Literal[\"original\", \"refined\"] = \"refined\",\n        use_median_filter: bool = False,\n        num_targets_per_frame: int = 1,\n        scene_norm_dict: SceneNormDict | None = None,\n        load_from_cache: bool = False,\n        skip_load_imgs: bool = False,\n        **_,\n    ):\n        super().__init__()\n        print(skip_load_imgs)\n        self.data_dir = data_dir\n        self.training = split == \"train\"\n        self.split = split\n        self.factor = factor\n        self.start = start\n        self.end = end\n        self.depth_type = depth_type\n        self.camera_type = camera_type\n        self.use_median_filter = use_median_filter\n        self.num_targets_per_frame = num_targets_per_frame\n        self.scene_norm_dict = scene_norm_dict\n        self.load_from_cache = load_from_cache\n        self.cache_dir = osp.join(data_dir, \"flow3d_preprocessed\", \"cache\")\n        os.makedirs(self.cache_dir, exist_ok=True)\n\n        # Test if the current data has validation set.\n        with open(osp.join(data_dir, \"splits\", \"val.json\")) as f:\n            split_dict = json.load(f)\n        self.has_validation = len(split_dict[\"frame_names\"]) > 0\n\n        # Load metadata.\n        with open(osp.join(data_dir, \"splits\", f\"{split}.json\")) as f:\n            split_dict = json.load(f)\n        full_len = len(split_dict[\"frame_names\"])\n        end = min(end, full_len) if end > 0 else full_len\n        self.end = end\n        self.frame_names = split_dict[\"frame_names\"][start:end]\n        time_ids = [t for t in split_dict[\"time_ids\"] if t >= start and t < end]\n        self.time_ids = torch.tensor(time_ids) - start\n        guru.info(f\"{self.time_ids.min()=} {self.time_ids.max()=}\")\n        # with open(osp.join(data_dir, \"dataset.json\")) as f:\n        #     dataset_dict = json.load(f)\n        # self.num_frames = dataset_dict[\"num_exemplars\"]\n        guru.info(f\"{self.num_frames=}\")\n        with open(osp.join(data_dir, \"extra.json\")) as f:\n            extra_dict = json.load(f)\n        self.fps = float(extra_dict[\"fps\"])\n\n        # Load cameras.\n        if self.camera_type == \"original\":\n            Ks, w2cs = [], []\n            for frame_name in self.frame_names:\n                with open(osp.join(data_dir, \"camera\", f\"{frame_name}.json\")) as f:\n                    camera_dict = json.load(f)\n                focal_length = camera_dict[\"focal_length\"]\n                principal_point = camera_dict[\"principal_point\"]\n                Ks.append(\n                    [\n                        [focal_length, 0.0, principal_point[0]],\n                        [0.0, focal_length, principal_point[1]],\n                        [0.0, 0.0, 1.0],\n                    ]\n                )\n                orientation = np.array(camera_dict[\"orientation\"])\n                position = np.array(camera_dict[\"position\"])\n                w2cs.append(\n                    np.block(\n                        [\n                            [orientation, -orientation @ position[:, None]],\n                            [np.zeros((1, 3)), np.ones((1, 1))],\n                        ]\n                    ).astype(np.float32)\n                )\n            self.Ks = torch.tensor(Ks)\n            self.Ks[:, :2] /= factor\n            self.w2cs = torch.from_numpy(np.array(w2cs))\n        elif self.camera_type == \"refined\":\n            Ks, w2cs = get_colmap_camera_params(\n                osp.join(data_dir, \"flow3d_preprocessed/colmap/sparse/\"),\n                [frame_name + \".png\" for frame_name in self.frame_names],\n            )\n            self.Ks = torch.from_numpy(Ks[:, :3, :3].astype(np.float32))\n            self.Ks[:, :2] /= factor\n            self.w2cs = torch.from_numpy(w2cs.astype(np.float32))\n        if not skip_load_imgs:\n            # Load images.\n            imgs = torch.from_numpy(\n                np.array(\n                    [\n                        iio.imread(\n                            osp.join(self.data_dir, f\"rgb/{factor}x/{frame_name}.png\")\n                        )\n                        for frame_name in tqdm(\n                            self.frame_names,\n                            desc=f\"Loading {self.split} images\",\n                            leave=False,\n                        )\n                    ],\n                )\n            )\n            self.imgs = imgs[..., :3] / 255.0\n            self.valid_masks = imgs[..., 3] / 255.0\n            # Load masks.\n            self.masks = (\n                torch.from_numpy(\n                    np.array(\n                        [\n                            iio.imread(\n                                osp.join(\n                                    self.data_dir,\n                                    \"flow3d_preprocessed/track_anything/\",\n                                    f\"{factor}x/{frame_name}.png\",\n                                )\n                            )\n                            for frame_name in tqdm(\n                                self.frame_names,\n                                desc=f\"Loading {self.split} masks\",\n                                leave=False,\n                            )\n                        ],\n                    )\n                )\n                / 255.0\n            )\n            if self.training:\n                # Load depths.\n                def load_depth(frame_name):\n                    if self.depth_type == \"lidar\":\n                        depth = np.load(\n                            osp.join(\n                                self.data_dir,\n                                f\"depth/{factor}x/{frame_name}.npy\",\n                            )\n                        )[..., 0]\n                    else:\n                        depth = np.load(\n                            osp.join(\n                                self.data_dir,\n                                f\"flow3d_preprocessed/aligned_{self.depth_type}/\",\n                                f\"{factor}x/{frame_name}.npy\",\n                            )\n                        )\n                        depth[depth < 1e-3] = 1e-3\n                        depth = 1.0 / depth\n                    return depth\n\n                self.depths = torch.from_numpy(\n                    np.array(\n                        [\n                            load_depth(frame_name)\n                            for frame_name in tqdm(\n                                self.frame_names,\n                                desc=f\"Loading {self.split} depths\",\n                                leave=False,\n                            )\n                        ],\n                        np.float32,\n                    )\n                )\n                max_depth_values_per_frame = self.depths.reshape(\n                    self.num_frames, -1\n                ).max(1)[0]\n                max_depth_value = max_depth_values_per_frame.median() * 2.5\n                print(\"max_depth_value\", max_depth_value)\n                self.depths = torch.clamp(self.depths, 0, max_depth_value)\n                # Median filter depths.\n                # NOTE(hangg): This operator is very expensive.\n                if self.use_median_filter:\n                    for i in tqdm(\n                        range(self.num_frames), desc=\"Processing depths\", leave=False\n                    ):\n                        depth = masked_median_blur(\n                            self.depths[[i]].unsqueeze(1).to(\"cuda\"),\n                            (\n                                self.masks[[i]]\n                                * self.valid_masks[[i]]\n                                * (self.depths[[i]] > 0)\n                            )\n                            .unsqueeze(1)\n                            .to(\"cuda\"),\n                        )[0, 0].cpu()\n                        self.depths[i] = depth * self.masks[i] + self.depths[i] * (\n                            1 - self.masks[i]\n                        )\n                # Load the query pixels from 2D tracks.\n                self.query_tracks_2d = [\n                    torch.from_numpy(\n                        np.load(\n                            osp.join(\n                                self.data_dir,\n                                \"flow3d_preprocessed/2d_tracks/\",\n                                f\"{factor}x/{frame_name}_{frame_name}.npy\",\n                            )\n                        ).astype(np.float32)\n                    )\n                    for frame_name in self.frame_names\n                ]\n                guru.info(\n                    f\"{len(self.query_tracks_2d)=} {self.query_tracks_2d[0].shape=}\"\n                )\n\n                # Load sam features.\n                # sam_feat_dir = osp.join(\n                #     data_dir, f\"flow3d_preprocessed/sam_features/{factor}x\"\n                # )\n                # assert osp.exists(sam_feat_dir), f\"SAM features not exist!\"\n                # sam_features, original_size, input_size = load_sam_features(\n                #     sam_feat_dir, self.frame_names\n                # )\n                # guru.info(f\"{sam_features.shape=} {original_size=} {input_size=}\")\n                # self.sam_features = sam_features\n                # self.sam_original_size = original_size\n                # self.sam_input_size = input_size\n            else:\n                # Load covisible masks.\n                self.covisible_masks = (\n                    torch.from_numpy(\n                        np.array(\n                            [\n                                iio.imread(\n                                    osp.join(\n                                        self.data_dir,\n                                        \"flow3d_preprocessed/covisible/\",\n                                        f\"{factor}x/{split}/{frame_name}.png\",\n                                    )\n                                )\n                                for frame_name in tqdm(\n                                    self.frame_names,\n                                    desc=f\"Loading {self.split} covisible masks\",\n                                    leave=False,\n                                )\n                            ],\n                        )\n                    )\n                    / 255.0\n                )\n\n        if self.scene_norm_dict is None:\n            cached_scene_norm_dict_path = osp.join(\n                self.cache_dir, \"scene_norm_dict.pth\"\n            )\n            if osp.exists(cached_scene_norm_dict_path) and self.load_from_cache:\n                print(\"loading cached scene norm dict...\")\n                self.scene_norm_dict = torch.load(\n                    osp.join(self.cache_dir, \"scene_norm_dict.pth\")\n                )\n            elif self.training:\n                # Compute the scene scale and transform for normalization.\n                # Normalize the scene based on the foreground 3D tracks.\n                subsampled_tracks_3d = self.get_tracks_3d(\n                    num_samples=10000, step=self.num_frames // 10, show_pbar=False\n                )[0]\n                scene_center = subsampled_tracks_3d.mean((0, 1))\n                tracks_3d_centered = subsampled_tracks_3d - scene_center\n                min_scale = tracks_3d_centered.quantile(0.05, dim=0)\n                max_scale = tracks_3d_centered.quantile(0.95, dim=0)\n                scale = torch.max(max_scale - min_scale).item() / 2.0\n                original_up = -F.normalize(self.w2cs[:, 1, :3].mean(0), dim=-1)\n                target_up = original_up.new_tensor([0.0, 0.0, 1.0])\n                R = roma.rotvec_to_rotmat(\n                    F.normalize(original_up.cross(target_up, dim=-1), dim=-1)\n                    * original_up.dot(target_up).acos_()\n                )\n                transfm = rt_to_mat4(R, torch.einsum(\"ij,j->i\", -R, scene_center))\n                self.scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm)\n                torch.save(self.scene_norm_dict, cached_scene_norm_dict_path)\n            else:\n                raise ValueError(\"scene_norm_dict must be provided for validation.\")\n\n        # Normalize the scene.\n        scale = self.scene_norm_dict[\"scale\"]\n        transfm = self.scene_norm_dict[\"transfm\"]\n        self.w2cs = self.w2cs @ torch.linalg.inv(transfm)\n        self.w2cs[:, :3, 3] /= scale\n        if self.training and not skip_load_imgs:\n            self.depths /= scale\n\n        if not skip_load_imgs:\n            guru.info(\n                f\"{self.imgs.shape=} {self.valid_masks.shape=} {self.masks.shape=}\"\n            )\n\n    @property\n    def num_frames(self) -> int:\n        return len(self.frame_names)\n\n    def __len__(self):\n        return self.imgs.shape[0]\n\n    def get_w2cs(self) -> torch.Tensor:\n        return self.w2cs\n\n    def get_Ks(self) -> torch.Tensor:\n        return self.Ks\n\n    def get_image(self, index: int) -> torch.Tensor:\n        return self.imgs[index]\n\n    def get_depth(self, index: int) -> torch.Tensor:\n        return self.depths[index]\n\n    def get_mask(self, index: int) -> torch.Tensor:\n        return self.masks[index]\n\n    def get_img_wh(self) -> tuple[int, int]:\n        return iio.imread(\n            osp.join(self.data_dir, f\"rgb/{self.factor}x/{self.frame_names[0]}.png\")\n        ).shape[1::-1]\n\n    # def get_sam_features(self) -> list[torch.Tensor, tuple[int, int], tuple[int, int]]:\n    #     return self.sam_features, self.sam_original_size, self.sam_input_size\n\n    def get_tracks_3d(\n        self, num_samples: int, step: int = 1, show_pbar: bool = True, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        \"\"\"Get 3D tracks from the dataset.\n\n        Args:\n            num_samples (int | None): The number of samples to fetch. If None,\n                fetch all samples. If not None, fetch roughly a same number of\n                samples across each frame. Note that this might result in\n                number of samples less than what is specified.\n            step (int): The step to temporally subsample the track.\n        \"\"\"\n        assert (\n            self.split == \"train\"\n        ), \"fetch_tracks_3d is only available for the training split.\"\n        cached_track_3d_path = osp.join(self.cache_dir, f\"tracks_3d_{num_samples}.pth\")\n        if osp.exists(cached_track_3d_path) and step == 1 and self.load_from_cache:\n            print(\"loading cached 3d tracks data...\")\n            start, end = self.start, self.end\n            cached_track_3d_data = torch.load(cached_track_3d_path)\n            tracks_3d, visibles, invisibles, confidences, track_colors = (\n                cached_track_3d_data[\"tracks_3d\"][:, start:end],\n                cached_track_3d_data[\"visibles\"][:, start:end],\n                cached_track_3d_data[\"invisibles\"][:, start:end],\n                cached_track_3d_data[\"confidences\"][:, start:end],\n                cached_track_3d_data[\"track_colors\"],\n            )\n            return tracks_3d, visibles, invisibles, confidences, track_colors\n\n        # Load 2D tracks.\n        raw_tracks_2d = []\n        candidate_frames = list(range(0, self.num_frames, step))\n        num_sampled_frames = len(candidate_frames)\n        for i in (\n            tqdm(candidate_frames, desc=\"Loading 2D tracks\", leave=False)\n            if show_pbar\n            else candidate_frames\n        ):\n            curr_num_samples = self.query_tracks_2d[i].shape[0]\n            num_samples_per_frame = (\n                int(np.floor(num_samples / num_sampled_frames))\n                if i != candidate_frames[-1]\n                else num_samples\n                - (num_sampled_frames - 1)\n                * int(np.floor(num_samples / num_sampled_frames))\n            )\n            if num_samples_per_frame < curr_num_samples:\n                track_sels = np.random.choice(\n                    curr_num_samples, (num_samples_per_frame,), replace=False\n                )\n            else:\n                track_sels = np.arange(0, curr_num_samples)\n            curr_tracks_2d = []\n            for j in range(0, self.num_frames, step):\n                if i == j:\n                    target_tracks_2d = self.query_tracks_2d[i]\n                else:\n                    target_tracks_2d = torch.from_numpy(\n                        np.load(\n                            osp.join(\n                                self.data_dir,\n                                \"flow3d_preprocessed/2d_tracks/\",\n                                f\"{self.factor}x/\"\n                                f\"{self.frame_names[i]}_\"\n                                f\"{self.frame_names[j]}.npy\",\n                            )\n                        ).astype(np.float32)\n                    )\n                curr_tracks_2d.append(target_tracks_2d[track_sels])\n            raw_tracks_2d.append(torch.stack(curr_tracks_2d, dim=1))\n        guru.info(f\"{step=} {len(raw_tracks_2d)=} {raw_tracks_2d[0].shape=}\")\n\n        # Process 3D tracks.\n        inv_Ks = torch.linalg.inv(self.Ks)[::step]\n        c2ws = torch.linalg.inv(self.w2cs)[::step]\n        H, W = self.imgs.shape[1:3]\n        filtered_tracks_3d, filtered_visibles, filtered_track_colors = [], [], []\n        filtered_invisibles, filtered_confidences = [], []\n        masks = self.masks * self.valid_masks * (self.depths > 0)\n        masks = (masks > 0.5).float()\n        for i, tracks_2d in enumerate(raw_tracks_2d):\n            tracks_2d = tracks_2d.swapdims(0, 1)\n            tracks_2d, occs, dists = (\n                tracks_2d[..., :2],\n                tracks_2d[..., 2],\n                tracks_2d[..., 3],\n            )\n            # visibles = postprocess_occlusions(occs, dists)\n            visibles, invisibles, confidences = parse_tapir_track_info(occs, dists)\n            # Unproject 2D tracks to 3D.\n            track_depths = F.grid_sample(\n                self.depths[::step, None],\n                normalize_coords(tracks_2d[..., None, :], H, W),\n                align_corners=True,\n                padding_mode=\"border\",\n            )[:, 0]\n            tracks_3d = (\n                torch.einsum(\n                    \"nij,npj->npi\",\n                    inv_Ks,\n                    F.pad(tracks_2d, (0, 1), value=1.0),\n                )\n                * track_depths\n            )\n            tracks_3d = torch.einsum(\n                \"nij,npj->npi\", c2ws, F.pad(tracks_3d, (0, 1), value=1.0)\n            )[..., :3]\n            # Filter out out-of-mask tracks.\n            is_in_masks = (\n                F.grid_sample(\n                    masks[::step, None],\n                    normalize_coords(tracks_2d[..., None, :], H, W),\n                    align_corners=True,\n                ).squeeze()\n                == 1\n            )\n            visibles *= is_in_masks\n            invisibles *= is_in_masks\n            confidences *= is_in_masks.float()\n            # Get track's color from the query frame.\n            track_colors = (\n                F.grid_sample(\n                    self.imgs[i * step : i * step + 1].permute(0, 3, 1, 2),\n                    normalize_coords(tracks_2d[i : i + 1, None, :], H, W),\n                    align_corners=True,\n                    padding_mode=\"border\",\n                )\n                .squeeze()\n                .T\n            )\n            # at least visible 5% of the time, otherwise discard\n            visible_counts = visibles.sum(0)\n            valid = visible_counts >= min(\n                int(0.05 * self.num_frames),\n                visible_counts.float().quantile(0.1).item(),\n            )\n\n            filtered_tracks_3d.append(tracks_3d[:, valid])\n            filtered_visibles.append(visibles[:, valid])\n            filtered_invisibles.append(invisibles[:, valid])\n            filtered_confidences.append(confidences[:, valid])\n            filtered_track_colors.append(track_colors[valid])\n\n        filtered_tracks_3d = torch.cat(filtered_tracks_3d, dim=1).swapdims(0, 1)\n        filtered_visibles = torch.cat(filtered_visibles, dim=1).swapdims(0, 1)\n        filtered_invisibles = torch.cat(filtered_invisibles, dim=1).swapdims(0, 1)\n        filtered_confidences = torch.cat(filtered_confidences, dim=1).swapdims(0, 1)\n        filtered_track_colors = torch.cat(filtered_track_colors, dim=0)\n        if step == 1:\n            torch.save(\n                {\n                    \"tracks_3d\": filtered_tracks_3d,\n                    \"visibles\": filtered_visibles,\n                    \"invisibles\": filtered_invisibles,\n                    \"confidences\": filtered_confidences,\n                    \"track_colors\": filtered_track_colors,\n                },\n                cached_track_3d_path,\n            )\n        return (\n            filtered_tracks_3d,\n            filtered_visibles,\n            filtered_invisibles,\n            filtered_confidences,\n            filtered_track_colors,\n        )\n\n    def get_bkgd_points(\n        self, num_samples: int, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        H, W = self.imgs.shape[1:3]\n        grid = torch.stack(\n            torch.meshgrid(\n                torch.arange(W, dtype=torch.float32),\n                torch.arange(H, dtype=torch.float32),\n                indexing=\"xy\",\n            ),\n            dim=-1,\n        )\n        candidate_frames = list(range(self.num_frames))\n        num_sampled_frames = len(candidate_frames)\n        bkgd_points, bkgd_point_normals, bkgd_point_colors = [], [], []\n        for i in tqdm(candidate_frames, desc=\"Loading bkgd points\", leave=False):\n            img = self.imgs[i]\n            depth = self.depths[i]\n            bool_mask = ((1.0 - self.masks[i]) * self.valid_masks[i] * (depth > 0)).to(\n                torch.bool\n            )\n            w2c = self.w2cs[i]\n            K = self.Ks[i]\n            points = (\n                torch.einsum(\n                    \"ij,pj->pi\",\n                    torch.linalg.inv(K),\n                    F.pad(grid[bool_mask], (0, 1), value=1.0),\n                )\n                * depth[bool_mask][:, None]\n            )\n            points = torch.einsum(\n                \"ij,pj->pi\", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0)\n            )\n            point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask]\n            point_colors = img[bool_mask]\n            curr_num_samples = points.shape[0]\n            num_samples_per_frame = (\n                int(np.floor(num_samples / num_sampled_frames))\n                if i != candidate_frames[-1]\n                else num_samples\n                - (num_sampled_frames - 1)\n                * int(np.floor(num_samples / num_sampled_frames))\n            )\n            if num_samples_per_frame < curr_num_samples:\n                point_sels = np.random.choice(\n                    curr_num_samples, (num_samples_per_frame,), replace=False\n                )\n            else:\n                point_sels = np.arange(0, curr_num_samples)\n            bkgd_points.append(points[point_sels])\n            bkgd_point_normals.append(point_normals[point_sels])\n            bkgd_point_colors.append(point_colors[point_sels])\n        bkgd_points = torch.cat(bkgd_points, dim=0)\n        bkgd_point_normals = torch.cat(bkgd_point_normals, dim=0)\n        bkgd_point_colors = torch.cat(bkgd_point_colors, dim=0)\n        return bkgd_points, bkgd_point_normals, bkgd_point_colors\n\n    def get_video_dataset(self) -> Dataset:\n        return iPhoneDatasetVideoView(self)\n\n    def __getitem__(self, index: int):\n        if self.training:\n            index = np.random.randint(0, self.num_frames)\n        data = {\n            # ().\n            \"frame_names\": self.frame_names[index],\n            # ().\n            \"ts\": self.time_ids[index],\n            # (4, 4).\n            \"w2cs\": self.w2cs[index],\n            # (3, 3).\n            \"Ks\": self.Ks[index],\n            # (H, W, 3).\n            \"imgs\": self.imgs[index],\n            # (H, W).\n            \"valid_masks\": self.valid_masks[index],\n            # (H, W).\n            \"masks\": self.masks[index],\n        }\n        if self.training:\n            # (H, W).\n            data[\"depths\"] = self.depths[index]\n            # (P, 2).\n            data[\"query_tracks_2d\"] = self.query_tracks_2d[index][:, :2]\n            target_inds = torch.from_numpy(\n                np.random.choice(\n                    self.num_frames, (self.num_targets_per_frame,), replace=False\n                )\n            )\n            # (N, P, 4).\n            target_tracks_2d = torch.stack(\n                [\n                    torch.from_numpy(\n                        np.load(\n                            osp.join(\n                                self.data_dir,\n                                \"flow3d_preprocessed/2d_tracks/\",\n                                f\"{self.factor}x/\"\n                                f\"{self.frame_names[index]}_\"\n                                f\"{self.frame_names[target_index.item()]}.npy\",\n                            )\n                        ).astype(np.float32)\n                    )\n                    for target_index in target_inds\n                ],\n                dim=0,\n            )\n            # (N,).\n            target_ts = self.time_ids[target_inds]\n            data[\"target_ts\"] = target_ts\n            # (N, 4, 4).\n            data[\"target_w2cs\"] = self.w2cs[target_ts]\n            # (N, 3, 3).\n            data[\"target_Ks\"] = self.Ks[target_ts]\n            # (N, P, 2).\n            data[\"target_tracks_2d\"] = target_tracks_2d[..., :2]\n            # (N, P).\n            (\n                data[\"target_visibles\"],\n                data[\"target_invisibles\"],\n                data[\"target_confidences\"],\n            ) = parse_tapir_track_info(\n                target_tracks_2d[..., 2], target_tracks_2d[..., 3]\n            )\n            # (N, P).\n            data[\"target_track_depths\"] = F.grid_sample(\n                self.depths[target_inds, None],\n                normalize_coords(\n                    target_tracks_2d[..., None, :2],\n                    self.imgs.shape[1],\n                    self.imgs.shape[2],\n                ),\n                align_corners=True,\n                padding_mode=\"border\",\n            )[:, 0, :, 0]\n        else:\n            # (H, W).\n            data[\"covisible_masks\"] = self.covisible_masks[index]\n        return data\n\n    def preprocess(self, data):\n        return data\n\n\nclass iPhoneDatasetKeypointView(Dataset):\n    \"\"\"Return a dataset view of the annotated keypoints.\"\"\"\n\n    def __init__(self, dataset: iPhoneDataset):\n        super().__init__()\n        self.dataset = dataset\n        assert self.dataset.split == \"train\"\n        # Load 2D keypoints.\n        keypoint_paths = sorted(\n            glob(osp.join(self.dataset.data_dir, \"keypoint/2x/train/0_*.json\"))\n        )\n        keypoints = []\n        for keypoint_path in keypoint_paths:\n            with open(keypoint_path) as f:\n                keypoints.append(json.load(f))\n        time_ids = [\n            int(osp.basename(p).split(\"_\")[1].split(\".\")[0]) for p in keypoint_paths\n        ]\n        # only use time ids that are in the dataset.\n        start = self.dataset.start\n        time_ids = [t - start for t in time_ids if t - start in self.dataset.time_ids]\n        self.time_ids = torch.tensor(time_ids)\n        self.time_pairs = torch.tensor(list(product(self.time_ids, repeat=2)))\n        self.index_pairs = torch.tensor(\n            list(product(range(len(self.time_ids)), repeat=2))\n        )\n        self.keypoints = torch.tensor(keypoints, dtype=torch.float32)\n        self.keypoints[..., :2] *= 2.0 / self.dataset.factor\n\n    def __len__(self):\n        return len(self.time_pairs)\n\n    def __getitem__(self, index: int):\n        ts = self.time_pairs[index]\n        return {\n            \"ts\": ts,\n            \"w2cs\": self.dataset.w2cs[ts],\n            \"Ks\": self.dataset.Ks[ts],\n            \"imgs\": self.dataset.imgs[ts],\n            \"keypoints\": self.keypoints[self.index_pairs[index]],\n        }\n\n\nclass iPhoneDatasetVideoView(Dataset):\n    \"\"\"Return a dataset view of the video trajectory.\"\"\"\n\n    def __init__(self, dataset: iPhoneDataset):\n        super().__init__()\n        self.dataset = dataset\n        self.fps = self.dataset.fps\n        assert self.dataset.split == \"train\"\n\n    def __len__(self):\n        return self.dataset.num_frames\n\n    def __getitem__(self, index):\n        return {\n            \"frame_names\": self.dataset.frame_names[index],\n            \"ts\": index,\n            \"w2cs\": self.dataset.w2cs[index],\n            \"Ks\": self.dataset.Ks[index],\n            \"imgs\": self.dataset.imgs[index],\n            \"depths\": self.dataset.depths[index],\n            \"masks\": self.dataset.masks[index],\n        }\n\n\n\"\"\"\nclass iPhoneDataModule(BaseDataModule[iPhoneDataset]):\n    def __init__(\n        self,\n        data_dir: str,\n        factor: int = 1,\n        start: int = 0,\n        end: int = -1,\n        depth_type: Literal[\n            \"midas\",\n            \"depth_anything\",\n            \"lidar\",\n            \"depth_anything_colmap\",\n        ] = \"depth_anything_colmap\",\n        camera_type: Literal[\"original\", \"refined\"] = \"refined\",\n        use_median_filter: bool = False,\n        num_targets_per_frame: int = 1,\n        load_from_cache: bool = False,\n        **kwargs,\n    ):\n        super().__init__(dataset_cls=iPhoneDataset, **kwargs)\n        self.data_dir = data_dir\n        self.start = start\n        self.end = end\n        self.factor = factor\n        self.depth_type = depth_type\n        self.camera_type = camera_type\n        self.use_median_filter = use_median_filter\n        self.num_targets_per_frame = num_targets_per_frame\n        self.load_from_cache = load_from_cache\n\n        self.val_loader_tasks = [\"img\", \"keypoint\"]\n\n    def setup(self, *_, **__) -> None:\n        guru.info(\"Loading train dataset...\")\n        self.train_dataset = self.dataset_cls(\n            data_dir=self.data_dir,\n            training=True,\n            split=\"train\",\n            start=self.start,\n            end=self.end,\n            factor=self.factor,\n            depth_type=self.depth_type,  # type: ignore\n            camera_type=self.camera_type,  # type: ignore\n            use_median_filter=self.use_median_filter,\n            num_targets_per_frame=self.num_targets_per_frame,\n            max_steps=self.max_steps * self.batch_size,\n            load_from_cache=self.load_from_cache,\n        )\n        if self.train_dataset.has_validation:\n            guru.info(\"Loading val dataset...\")\n            self.val_dataset = self.dataset_cls(\n                data_dir=self.data_dir,\n                training=False,\n                split=\"val\",\n                start=self.start,\n                end=self.end,\n                factor=self.factor,\n                depth_type=self.depth_type,  # type: ignore\n                camera_type=self.camera_type,  # type: ignore\n                use_median_filter=self.use_median_filter,\n                scene_norm_dict=self.train_dataset.scene_norm_dict,\n                load_from_cache=self.load_from_cache,\n            )\n        else:\n            # Dummy validation set.\n            self.val_dataset = TensorDataset(torch.zeros(0))  # type: ignore\n        self.keypoint_dataset = iPhoneDatasetKeypointView(self.train_dataset)\n        self.video_dataset = self.train_dataset.get_video_dataset()\n        guru.success(\"Loading finished!\")\n\n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            collate_fn=iPhoneDataset.train_collate_fn,\n        )\n\n    def val_dataloader(self) -> list[DataLoader]:\n        return [DataLoader(self.val_dataset), DataLoader(self.keypoint_dataset)]\n        \"\"\"\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/panoptic_dataset.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Literal, cast\n\nimport cv2\nimport imageio\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport tyro\nfrom loguru import logger as guru\nfrom roma import roma\nfrom tqdm import tqdm\n\nfrom flow3d.data.base_dataset import BaseDataset\nfrom flow3d.data.utils import (\n    UINT16_MAX,\n    SceneNormDict,\n    get_tracks_3d_for_query_frame,\n    median_filter_2d,\n    normal_from_depth_image,\n    normalize_coords,\n    parse_tapir_track_info,\n)\nfrom flow3d.transforms import rt_to_mat4\n\nimport sys\nsys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), \"../../../\")))\n\nimport models.spatracker.datasets.utils as dataset_utils\nfrom models.spatracker.datasets.panoptic_studio_multiview_dataset import PanopticStudioMultiViewDataset\n\nfrom torch.utils.data import default_collate\n\n@dataclass\nclass PanopticDataConfig:\n    seq_name: str\n    root_dir: str\n    start: int = 0\n    end: int = -1\n    res: str = \"\"\n    image_type: str = \"images\"\n    mask_type: str = \"masks\"\n    depth_type: Literal[\n        \"aligned_depth_anything\",\n        \"aligned_depth_anything_v2\",\n        \"depth_anything\",\n        \"depth_anything_v2\",\n        \"unidepth_disp\",\n    ] = \"aligned_depth_anything\"\n    camera_type: Literal[\"droid_recon\"] = \"droid_recon\"\n    track_2d_type: Literal[\"bootstapir\", \"tapir\"] = \"bootstapir\"\n    mask_erosion_radius: int = 7\n    scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None\n    num_targets_per_frame: int = 4\n    load_from_cache: bool = False\n\nclass PanopticStudioDatasetSoM(BaseDataset):\n    def __init__(\n        self,\n        seq_name: str,\n        root_dir: str,\n        res: str = \"480p\",\n        depth_type: Literal[\n            \"aligned_depth_anything\",\n            \"aligned_depth_anything_v2\",\n            \"depth_anything\",\n            \"depth_anything_v2\",\n            \"unidepth_disp\",\n        ] = \"aligned_depth_anything\",\n        mask_erosion_radius: int = 0,\n        scene_norm_dict: SceneNormDict | None = None,\n        num_targets_per_frame: int = 4,\n        load_from_cache: bool = False,\n        **_,\n    ):\n        super().__init__()\n\n        self.seq_name = seq_name\n        self.root_dir = root_dir\n        self.res = res\n        self.depth_type = depth_type\n        self.num_targets_per_frame = num_targets_per_frame\n        self.load_from_cache = load_from_cache\n        self.has_validation = False\n        self.mask_erosion_radius = mask_erosion_radius\n\n        #######################################################################\n        self.views_to_return = [1, 7, 14, 20]\n\n        datasets_root = \"/cluster/scratch/egundogdu/datasets/\"\n        panoptic_kwargs = {\n            \"data_root\": os.path.join(datasets_root, \"panoptic_d3dgs\"), \"traj_per_sample\": 384, \"seed\": 72,\n            \"max_videos\": 1, \"perform_sanity_checks\": False, \"views_to_return\": [1, 7, 14, 20],\n            \"use_duster_depths\": False, \"clean_duster_depths\": False,\n        }\n        self.panoptic_spatial_dataset = PanopticStudioMultiViewDataset(**panoptic_kwargs)\n        \n        datapoint = self.panoptic_spatial_dataset.__getitem__(0)\n\n        if isinstance(datapoint, tuple):\n            datapoint, gotit = datapoint\n            assert gotit\n        if torch.cuda.is_available():\n            dataset_utils.dataclass_to_cuda_(datapoint)\n            device = torch.device(\"cuda\")\n        else:\n            device = torch.device(\"cpu\")\n\n        self.img_dir_view_1 = os.path.join(datasets_root, \"panoptic_d3dgs\", \"basketball\", \"ims\", \"1\")\n        self.frame_names = [os.path.splitext(p)[0] for p in sorted(os.listdir(self.img_dir_view_1))]\n\n        # Per view data\n        self.rgbs = datapoint.video\n        self.depths = datapoint.videodepth\n        self.image_features = datapoint.feats\n        self.intrs = datapoint.intrs\n        self.extrs = datapoint.extrs\n        self.gt_trajectories_2d_pixelspace_w_z_cameraspace = datapoint.trajectory\n        self.gt_visibilities_per_view = datapoint.visibility\n        self.query_points_2d = (datapoint.query_points.clone().float().to(device)\n                           if datapoint.query_points is not None else None)\n        self.query_points_3d = datapoint.query_points_3d.clone().float().to(device)\n\n        # Non-per-view data\n        self.gt_trajectories_3d_worldspace = datapoint.trajectory_3d\n        self.valid_tracks_per_frame = datapoint.valid\n        self.track_upscaling_factor = datapoint.track_upscaling_factor\n\n        print(self.rgbs.shape)\n        num_views, num_frames, _, height, width = self.rgbs.shape\n        num_points = self.gt_trajectories_2d_pixelspace_w_z_cameraspace.shape[2]\n\n        self.rgbs = self.rgbs.permute(0, 1, 3, 4, 2).cpu()\n        self.depths = self.depths.permute(0, 1, 3, 4, 2).cpu()\n\n        # Assert shapes of per-view data\n        assert self.depths is not None, \"Depth is required for evaluation.\"\n        assert self.rgbs.shape == (num_views, num_frames, height, width, 3)\n        assert self.depths.shape == (num_views, num_frames, height, width, 1)\n        assert self.intrs.shape == (num_views, num_frames, 3, 3)\n        assert self.extrs.shape == (num_views, num_frames, 3, 4)\n        assert self.gt_trajectories_2d_pixelspace_w_z_cameraspace.shape == (\n            num_views, num_frames, num_points, 3)\n        assert self.gt_visibilities_per_view.shape == (num_views, num_frames, num_points)\n\n        # Assert shapes of non-per-view data\n        assert self.query_points_3d.shape == (num_points, 4)\n        assert self.gt_trajectories_3d_worldspace.shape == (num_frames, num_points, 3)\n        assert self.valid_tracks_per_frame.shape == (num_frames, num_points)\n\n        self.w2cs = torch.eye(4).expand(num_views, num_frames, 4, 4).clone()\n        self.w2cs[:, :, :3, :] = self.extrs.squeeze(0).cpu()  # (n_views, n_frames, 4, 4)\n        self.Ks = self.intrs.squeeze(0).cpu()                 # (n_views, n_frames, 3, 3)\n\n\n\n        \n        ###### normalization...\n        self.scale = 1\n\n        tracks_3d = self.get_tracks_3d(5000, step=num_frames // 10)[0]\n        scale, transfm = compute_scene_norm(tracks_3d, self.w2cs)\n        scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm)\n\n        # transform cameras\n        self.scene_norm_dict = cast(SceneNormDict, scene_norm_dict)\n        self.scale = self.scene_norm_dict[\"scale\"]\n        transform = self.scene_norm_dict[\"transfm\"]\n        guru.info(f\"scene norm {self.scale=}, {transform=}\")\n        for v in range(num_views):\n            self.w2cs[v] = torch.einsum(\"nij,jk->nik\", self.w2cs[v], torch.linalg.inv(transform))\n            self.w2cs[v, :, :3, 3] /= self.scale\n\n\n        \n\n    @property\n    def num_frames(self) -> int:\n        return len(self.frame_names)\n\n    @property\n    def keyframe_idcs(self) -> torch.Tensor:\n        # return self._keyframe_idcs\n        return np.array(range(10,140,10))\n\n    def __len__(self):\n        return len(self.frame_names)\n\n    def get_w2cs(self, view_index=0) -> torch.Tensor:\n        return self.w2cs[view_index].cpu().to(torch.float32)\n\n    def get_Ks(self, view_index=0) -> torch.Tensor:\n        return self.Ks[view_index].cpu().to(torch.float32)\n\n    def get_img_wh(self) -> tuple[int, int]:\n        return self.get_image(0).shape[1::-1]\n\n    def get_image(self, index, view_index=0) -> torch.Tensor:\n        return self.rgbs[view_index][index].cpu().to(torch.float32) / 255.0\n\n    def get_mask(self, index, view_index=0) -> torch.Tensor:\n        view = self.views_to_return[view_index]\n        mask = self.load_mask(index, view)\n        mask = cast(torch.Tensor, mask)\n        return mask.cpu().to(torch.float32)\n\n    def get_depth(self, index, view=0) -> torch.Tensor:\n        # return self.load_depth(index, view) / self.scales[view]\n        return self.load_depth(index, view).cpu().to(torch.float32) / self.scale\n\n    def load_mask(self, index, view=0) -> torch.Tensor:\n        # self.mask_dir = \"/cluster/scratch/egundogdu/datasets/panoptic_d3dgs/basketball/seg\"\n        self.mask_dir = \"/cluster/home/egundogdu/projects/vlg-lab/spatialtracker/shape-of-motion/panoptic_masks\"\n        path = f\"{self.mask_dir}/{view}/{self.frame_names[index]}.png\"\n        r = self.mask_erosion_radius\n        mask = imageio.imread(path)\n        fg_mask = mask.reshape((*mask.shape[:2], -1)).max(axis=-1) > 0\n        bg_mask = ~fg_mask\n        fg_mask_erode = cv2.erode(\n            fg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1\n        )\n        bg_mask_erode = cv2.erode(\n            bg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1\n        )\n        out_mask = np.zeros_like(fg_mask, dtype=np.float32)\n        out_mask[bg_mask_erode > 0] = -1\n        out_mask[fg_mask_erode > 0] = 1\n        return torch.from_numpy(out_mask).float()\n\n    def load_depth(self, index, view=0) -> torch.Tensor:\n        depth = self.depths[view][index]\n        depth = depth.permute(2, 0, 1).unsqueeze(0)\n        depth = median_filter_2d(depth, 11, 1)[0, 0]\n        return depth.squeeze(0)\n\n\n    #####################################\n    def get_foreground_points(\n        self,\n        num_samples: int,\n        use_kf_tstamps: bool = False,\n        stride: int = 4,\n        down_rate: int = 8,\n        min_per_frame: int = 64,\n        **kwargs,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        start = 0\n        end = self.num_frames\n        H, W = self.rgbs.shape[2:4]  # Get height & width from rgbs shape\n\n        # Create pixel grid\n        grid = torch.stack(\n            torch.meshgrid(\n                torch.arange(0, W, dtype=torch.float32),\n                torch.arange(0, H, dtype=torch.float32),\n                indexing=\"xy\",\n            ),\n            dim=-1,\n        )  # Shape: (H, W, 2)\n\n        if use_kf_tstamps:\n            query_idcs = self.keyframe_idcs.tolist()\n        else:\n            num_query_frames = self.num_frames // stride\n            query_endpts = torch.linspace(start, end, num_query_frames + 1)\n            query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist()\n\n        bg_geometry = []\n        print(f\"{query_idcs=}\")\n        \n        # for v in range(self.rgbs.shape[0]):  # Iterate over views\n        for query_idx in tqdm(query_idcs, desc=f\"Loading foreground points (view)\", leave=False):\n            for v in [0, 1, 2, 3]:\n\n                img = self.get_image(query_idx, v).cpu().numpy()  # Shape: (H, W, 3)\n                height, width = img.shape[0], img.shape[1]\n\n                depth = self.get_depth(query_idx, v).cpu().numpy()\n                mask = self.get_mask(query_idx, v).cpu().numpy() < 0  # Shape: (H, W)\n                valid_mask = (~mask * (depth > 0)).ravel()\n\n                w2c = self.w2cs[v, query_idx].cpu().numpy()\n                c2w = np.linalg.inv(w2c)\n                k = self.Ks[v, query_idx].cpu().numpy()\n                k_inv = np.linalg.inv(k)\n            \n\n                y, x = np.indices((height, width))\n                homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                cam_coords = (k_inv @ homo_pixel_coords) * depth.ravel()\n                cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                world_coords = (c2w @ cam_coords)[:3].T\n                world_coords = world_coords[valid_mask]\n                rgb_colors = img.reshape(-1, 3)[valid_mask].astype(np.uint8)\n\n\n                bg_geometry.append((torch.from_numpy(world_coords), torch.from_numpy(world_coords), torch.from_numpy(rgb_colors)))\n\n                rr.set_time_seconds(\"frame\", query_idx / 30)\n                rr.log(f\"world/points/view_{v}_foreground\", rr.Points3D(positions=world_coords, colors=rgb_colors * 255.0))\n\n\n\n                # tmp_img = img.clone()\n                # tmp_img[~bool_mask] = 1\n                # img_8bit = (tmp_img.reshape(self.rgbs[v, query_idx].shape).cpu().numpy() * 255).astype(np.uint8)\n                # datasets_root = f\"/cluster/scratch/egundogdu/datasets/view{v}_frame{query_idx}.png\"\n                # cv2.imwrite(datasets_root, img_8bit[..., ::-1])\n                # print(f\"Saved {datasets_root}\")\n\n                # img_8bit = (depth.cpu().numpy() * 255).astype(np.uint8)\n                # datasets_root = f\"/cluster/scratch/egundogdu/datasets/depth_view{v}_frame{query_idx}.png\"\n                # cv2.imwrite(datasets_root, img_8bit[..., ::-1])\n                # print(f\"Saved {datasets_root}\")\n\n                # img_8bit = (bool_mask.cpu().numpy() * 255).astype(np.uint8)\n                # datasets_root = f\"/cluster/scratch/egundogdu/datasets/bool_mask_view{v}_frame{query_idx}.png\"\n                # cv2.imwrite(datasets_root, img_8bit[..., ::-1])\n                # print(f\"Saved {datasets_root}\")\n\n\n\n        bg_points, bg_normals, bg_colors = map(\n            partial(torch.cat, dim=0), zip(*bg_geometry)\n        )\n\n        # Final downsampling\n        # doesnt use texture-based prob sampling\n        # TODO: add texture information to sample from a probability\n        if len(bg_points) > num_samples:\n            sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False)\n            bg_points = bg_points[sel_idcs]\n            bg_normals = bg_normals[sel_idcs]\n            bg_colors = bg_colors[sel_idcs]\n\n        return bg_points, bg_normals, bg_colors\n    \n\n    def get_bkgd_points(\n        self,\n        num_samples: int,\n        use_kf_tstamps: bool = False,\n        stride: int = 8,\n        down_rate: int = 8,\n        min_per_frame: int = 64,\n        **kwargs,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        start = 0\n        end = self.num_frames\n        H, W = self.rgbs.shape[2:4]  # Get height & width from rgbs shape\n\n        # Create pixel grid\n        grid = torch.stack(\n            torch.meshgrid(\n                torch.arange(0, W, dtype=torch.float32),\n                torch.arange(0, H, dtype=torch.float32),\n                indexing=\"xy\",\n            ),\n            dim=-1,\n        )  # Shape: (H, W, 2)\n\n        if use_kf_tstamps:\n            query_idcs = self.keyframe_idcs.tolist()\n        else:\n            num_query_frames = self.num_frames // stride\n            query_endpts = torch.linspace(start, end, num_query_frames + 1)\n            query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist()\n\n        bg_geometry = []\n        print(f\"{query_idcs=}\")\n        \n        view_index_list = [0, 1, 2, 3]\n\n        # for v in range(self.rgbs.shape[0]):  # Iterate over views\n        for query_idx in tqdm(query_idcs, desc=f\"Loading bkgd points (view)\", leave=False):\n            for v in view_index_list:\n\n                img = self.get_image(query_idx, v).cpu().numpy()  # Shape: (H, W, 3)\n                height, width = img.shape[0], img.shape[1]\n\n                depth = self.get_depth(query_idx, v).cpu().numpy()\n                mask = self.get_mask(query_idx, v).cpu().numpy() < 0  # Shape: (H, W)\n                valid_mask = (mask * (depth > 0)).ravel()\n                # valid_mask = depth.ravel() > 0\n\n                w2c = self.w2cs[v, query_idx].cpu().numpy()\n                c2w = np.linalg.inv(w2c)\n                k = self.Ks[v, query_idx].cpu().numpy()\n                k_inv = np.linalg.inv(k)\n            \n\n                y, x = np.indices((height, width))\n                homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                cam_coords = (k_inv @ homo_pixel_coords) * depth.ravel()\n                cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                world_coords = (c2w @ cam_coords)[:3].T\n                world_coords = world_coords[valid_mask]\n                rgb_colors = img.reshape(-1, 3)[valid_mask]\n\n\n                bg_geometry.append((torch.from_numpy(world_coords).to(torch.float32), torch.from_numpy(world_coords).to(torch.float32), torch.from_numpy(rgb_colors).to(torch.float32)))\n\n\n        bg_points, bg_normals, bg_colors = map(\n            partial(torch.cat, dim=0), zip(*bg_geometry)\n        )\n\n        # Final downsampling\n        # doesnt use texture-based prob sampling\n        # TODO: add texture information to sample from a probability\n        if len(bg_points) > num_samples:\n            sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False)\n            bg_points = bg_points[sel_idcs]\n            bg_normals = bg_normals[sel_idcs]\n            bg_colors = bg_colors[sel_idcs]\n\n        return bg_points, bg_normals, bg_colors\n\n    #####################################\n    def load_target_tracks(\n        self, query_index: int, target_indices: list[int], view_index=0, dim: int = 1\n    ):\n        \"\"\"\n        tracks are 2d, occs and uncertainties\n        :param dim (int), default 1: dimension to stack the time axis\n        return (N, T, 4) if dim=1, (T, N, 4) if dim=0\n        \"\"\"\n        view = self.views_to_return[view_index]\n\n        q_name = self.frame_names[query_index]\n        all_tracks = []\n        for ti in target_indices:\n            t_name = self.frame_names[ti]\n            # path = f\"/cluster/scratch/egundogdu/datasets/panoptic_d3dgs/basketball/tracks_tapvid_som/{view}/{q_name}_{t_name}.npy\"\n            path = f\"/cluster/home/egundogdu/projects/vlg-lab/spatialtracker/shape-of-motion/panoptic_tracks/{view}/{q_name}_{t_name}.npy\"\n            tracks = np.load(path).astype(np.float32)\n            all_tracks.append(tracks)\n        return torch.from_numpy(np.stack(all_tracks, axis=dim))\n    \n    def get_tracks_3d(\n        self, num_samples: int, start: int = 0, end: int = -1, step: int = 1, **kwargs\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n        num_frames = self.num_frames\n        if end < 0:\n            end = num_frames + 1 + end\n        query_idcs = list(range(start, end, step))\n        target_idcs = list(range(start, end, step))\n        \n        num_per_query_frame = int(np.ceil(num_samples / len(query_idcs) / 8))\n        cur_num = 0\n        tracks_all_queries = []\n\n\n        view_index_list = [0, 1, 2, 3]\n\n        precomputed_data = {}\n        for v in view_index_list:\n            masks = torch.stack([self.get_mask(i, v).cpu() for i in target_idcs], dim=0)\n            fg_masks = (masks == 1).float()\n            depths = torch.stack([self.get_depth(i, v).cpu() for i in target_idcs], dim=0)\n            inv_Ks = torch.linalg.inv(self.Ks[v][target_idcs].cpu())\n            c2ws = torch.linalg.inv(self.w2cs[v][target_idcs].cpu())\n            \n            precomputed_data[v] = (fg_masks, depths, inv_Ks, c2ws)\n\n        for q_idx in tqdm(query_idcs, desc=f\"Loading 3d tracks points\", leave=False):\n            for v in view_index_list:\n                # # masks = torch.stack([self.get_mask(i, v) for i in target_idcs], dim=0)\n                # # fg_masks = (masks == 1).float()\n                # # depths = torch.stack([self.get_depth(i, v) for i in target_idcs], dim=0)\n                # inv_Ks = torch.linalg.inv(self.Ks[v][target_idcs])\n                # c2ws = torch.linalg.inv(self.w2cs[v][target_idcs])\n                fg_masks, depths, inv_Ks, c2ws = precomputed_data[v]\n\n                # (N, T, 4)\n                # print(q_idx, len(query_idcs), \"cur: \", cur_num)\n                tracks_2d = self.load_target_tracks(q_idx, target_idcs, v).cpu()\n                num_sel = int(\n                    min(num_per_query_frame, num_samples - cur_num, len(tracks_2d))\n                )\n                if num_sel < len(tracks_2d):\n                    sel_idcs = np.random.choice(len(tracks_2d), num_sel, replace=False)\n                    tracks_2d = tracks_2d[sel_idcs]\n                cur_num += tracks_2d.shape[0]\n\n                img = self.get_image(q_idx, v).cpu()\n                tidx = target_idcs.index(q_idx)\n                tracks_tuple = get_tracks_3d_for_query_frame(\n                    tidx, img, tracks_2d, depths, fg_masks, inv_Ks, c2ws\n                )\n                tracks_all_queries.append(tracks_tuple)\n\n        tracks_3d, colors, visibles, invisibles, confidences = map(\n            partial(torch.cat, dim=0), zip(*tracks_all_queries)\n        )\n        return tracks_3d, visibles, invisibles, confidences, colors\n\n\n    \n    def train_collate_fn(self, batch):\n        \"\"\"\n        Collate function that correctly batches data when each sample consists of multiple views.\n        \"\"\"\n\n        # Step 1: Transpose the batch to group by views\n        # If batch contains 4 views per sample, `batch` is a list of lists: [ [view_1, view_2, view_3, view_4], [view_1, view_2, view_3, view_4], ... ]\n        # We want to group all view_1's together, all view_2's together, etc.\n        num_views = len(batch[0])  # Assumes each sample has the same number of views\n        batch_per_view = list(zip(*batch))  # Transposes list-of-lists structure\n\n        collated_views = []\n        \n        # Step 2: Collate each view separately\n        for view_batch in batch_per_view:\n            collated = {}\n            for k in view_batch[0]:  # Iterate over keys in the dictionary\n                if k not in [\n                    \"query_tracks_2d\",\n                    \"target_ts\",\n                    \"target_w2cs\",\n                    \"target_Ks\",\n                    \"target_tracks_2d\",\n                    \"target_visibles\",\n                    \"target_track_depths\",\n                    \"target_invisibles\",\n                    \"target_confidences\",\n                ]:\n                    collated[k] = default_collate([sample[k] for sample in view_batch])\n                else:\n                    collated[k] = [sample[k] for sample in view_batch]  # Keep list format\n            collated_views.append(collated)\n\n        return collated_views  # List of collated dictionaries, one per view\n    \n\n    # def __getitem__(self, index: int, view=0):\n    #     index = np.random.randint(0, self.num_frames)\n    #     data = {\n    #         # ().\n    #         \"frame_names\": self.frame_names[index],\n    #         # ().\n    #         \"ts\": torch.tensor(index),\n    #         # (4, 4).\n    #         \"w2cs\": self.w2cs[view][index],\n    #         # (3, 3).\n    #         \"Ks\": self.Ks[view][index],\n    #         # (H, W, 3).\n    #         \"imgs\": self.get_image(index, view),\n    #         \"depths\": self.get_depth(index, view),\n    #     }\n    #     tri_mask = self.get_mask(index, view)\n    #     valid_mask = tri_mask != 0  # not fg or bg\n    #     mask = tri_mask == 1  # fg mask\n    #     data[\"masks\"] = mask.float()\n    #     data[\"valid_masks\"] = valid_mask.float()\n\n    #     # (P, 2)\n    #     query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2]\n    #     target_inds = torch.from_numpy(\n    #         np.random.choice(\n    #             self.num_frames, (self.num_targets_per_frame,), replace=False\n    #         )\n    #     )\n    #     # (N, P, 4)\n    #     target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0)\n    #     data[\"query_tracks_2d\"] = query_tracks\n    #     data[\"target_ts\"] = target_inds\n    #     data[\"target_w2cs\"] = self.w2cs[view][target_inds]\n    #     data[\"target_Ks\"] = self.Ks[view][target_inds]\n    #     data[\"target_tracks_2d\"] = target_tracks[..., :2]\n    #     # (N, P).\n    #     (\n    #         data[\"target_visibles\"],\n    #         data[\"target_invisibles\"],\n    #         data[\"target_confidences\"],\n    #     ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3])\n    #     # (N, H, W)\n    #     target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0)\n    #     H, W = target_depths.shape[-2:]\n    #     data[\"target_track_depths\"] = F.grid_sample(\n    #         target_depths[:, None],\n    #         normalize_coords(target_tracks[..., None, :2], H, W),\n    #         align_corners=True,\n    #         padding_mode=\"border\",\n    #     )[:, 0, :, 0]\n    #     return data\n    \n    def get_batches(self, batch_size):\n        num_batches = self.num_frames // batch_size  # Determine number of batches\n        train_collated_merged_data = []\n        \n        for _ in range(num_batches):\n            train_collated_merged_data.append(self.__getitem_as_batch__(batch_size))\n        \n        return train_collated_merged_data\n\n    def __getitem_as_batch__(self, batch_size):\n        # index = np.random.randint(0, self.num_frames)\n        if batch_size > self.num_frames:\n            index = np.random.choice(self.num_frames, batch_size, replace=True)  # Sample with replacement\n        else:\n            index = np.random.choice(self.num_frames, batch_size, replace=False)  # Sample without replacement\n        \n        merged_data = []\n        for i in tqdm(index):\n            view_data = []\n            for view in [0, 1, 2, 3]:\n                view_data.append(self.__getitem_single_view__(i, view))\n            merged_data.append(view_data)\n        \n        return self.train_collate_fn(merged_data)\n\n    def __getitem_single_view__(self, index: int, view: int):\n        index = np.random.randint(0, self.num_frames)\n    \n        data = {\n            # ().\n            \"frame_names\": self.frame_names[index],\n            # ().\n            \"ts\": torch.tensor(index),\n            # (4, 4).\n            \"w2cs\": self.w2cs[view][index],\n            # (3, 3).\n            \"Ks\": self.Ks[view][index],\n            # (H, W, 3).\n            \"imgs\": self.get_image(index, view),\n            \"depths\": self.get_depth(index, view),\n        }\n        tri_mask = self.get_mask(index, view)\n        valid_mask = tri_mask != 0  # not fg or bg\n        mask = tri_mask == 1  # fg mask\n        data[\"masks\"] = mask.float()\n        data[\"valid_masks\"] = valid_mask.float()\n\n        # (P, 2)\n        query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2]\n        target_inds = torch.from_numpy(\n            np.random.choice(\n                self.num_frames, (self.num_targets_per_frame,), replace=False\n            )\n        )\n        # (N, P, 4)\n        target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0)\n        data[\"query_tracks_2d\"] = query_tracks\n        data[\"target_ts\"] = target_inds\n        data[\"target_w2cs\"] = self.w2cs[view][target_inds]\n        data[\"target_Ks\"] = self.Ks[view][target_inds]\n        data[\"target_tracks_2d\"] = target_tracks[..., :2]\n        # (N, P).\n        (\n            data[\"target_visibles\"],\n            data[\"target_invisibles\"],\n            data[\"target_confidences\"],\n        ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3])\n        # (N, H, W)\n        target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0)\n        H, W = target_depths.shape[-2:]\n        data[\"target_track_depths\"] = F.grid_sample(\n            target_depths[:, None],\n            normalize_coords(target_tracks[..., None, :2], H, W),\n            align_corners=True,\n            padding_mode=\"border\",\n        )[:, 0, :, 0] \n\n        return data\n\n\n    def __getitem__(self, index: int):\n        index = np.random.randint(0, self.num_frames)\n        merged_data = []\n        for view in [0, 1, 2, 3]:\n            data = {\n                # ().\n                \"frame_names\": self.frame_names[index],\n                # ().\n                \"ts\": torch.tensor(index),\n                # (4, 4).\n                \"w2cs\": self.w2cs[view][index],\n                # (3, 3).\n                \"Ks\": self.Ks[view][index],\n                # (H, W, 3).\n                \"imgs\": self.get_image(index, view),\n                \"depths\": self.get_depth(index, view),\n            }\n            tri_mask = self.get_mask(index, view)\n            valid_mask = tri_mask != 0  # not fg or bg\n            mask = tri_mask == 1  # fg mask\n            data[\"masks\"] = mask.float()\n            data[\"valid_masks\"] = valid_mask.float()\n\n            # (P, 2)\n            query_tracks = self.load_target_tracks(index, [index], view_index=view)[:, 0, :2]\n            target_inds = torch.from_numpy(\n                np.random.choice(\n                    self.num_frames, (self.num_targets_per_frame,), replace=False\n                )\n            )\n            # (N, P, 4)\n            target_tracks = self.load_target_tracks(index, target_inds.tolist(), view_index=view, dim=0)\n            data[\"query_tracks_2d\"] = query_tracks\n            data[\"target_ts\"] = target_inds\n            data[\"target_w2cs\"] = self.w2cs[view][target_inds]\n            data[\"target_Ks\"] = self.Ks[view][target_inds]\n            data[\"target_tracks_2d\"] = target_tracks[..., :2]\n            # (N, P).\n            (\n                data[\"target_visibles\"],\n                data[\"target_invisibles\"],\n                data[\"target_confidences\"],\n            ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3])\n            # (N, H, W)\n            target_depths = torch.stack([self.get_depth(i, view) for i in target_inds], dim=0)\n            H, W = target_depths.shape[-2:]\n            data[\"target_track_depths\"] = F.grid_sample(\n                target_depths[:, None],\n                normalize_coords(target_tracks[..., None, :2], H, W),\n                align_corners=True,\n                padding_mode=\"border\",\n            )[:, 0, :, 0] \n\n            merged_data.append(data)\n            \n        return merged_data\n\n\ndef compute_scene_norm(\n    X: torch.Tensor, w2cs: torch.Tensor\n) -> tuple[float, torch.Tensor]:\n    \"\"\"\n    :param X: [N*T, 3]\n    # :param w2cs: [N, 4, 4]\n    :param w2cs: [n_views, N, 4, 4]\n    \"\"\"\n    X = X.reshape(-1, 3)\n    scene_center = X.mean(dim=0)\n    X = X - scene_center[None]\n    min_scale = X.quantile(0.05, dim=0)\n    max_scale = X.quantile(0.95, dim=0)\n    scale = (max_scale - min_scale).max().item() / 2.0\n    \n    original_up = -F.normalize(w2cs[:, :, 1, :3].mean(dim=(0,1)), dim=-1)\n    target_up = original_up.new_tensor([0.0, 0.0, 1.0])\n\n    R = roma.rotvec_to_rotmat(\n        F.normalize(original_up.cross(target_up), dim=-1)\n        * original_up.dot(target_up).acos_()\n    )\n    transfm = rt_to_mat4(R, torch.einsum(\"ij,j->i\", -R, scene_center))\n    return scale, transfm\n\n\n\n\n\n\n# import rerun as rr\nif __name__ == \"__main__\":\n#     rr.init(\"3dpt\", recording_id=\"v0.1\")\n#     rr.connect_tcp(\"0.0.0.0:9876\")\n#     rr.set_time_seconds(\"frame\", 0)\n#     rr.log(\"world/xyz\", rr.Arrows3D(vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n#                                             colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]]))\n\n    d = PanopticStudioDatasetSoM(\"\", \"\", camera_type=\"\")\n    batch = d.__getitem_as_batch__(150)\n    import ipdb\n    ipdb.set_trace()\n\n\n    # print(d[\"imgs\"])\n\n#     # Get background points\n#     points, normals, colors = d.get_bkgd_points(num_samples=100_000)\n#     print(points.dtype)\n\n#     rr.set_time_seconds(\"frame\", 0)\n#     rr.log(f\"world/points/final_background\", rr.Points3D(positions=points, colors=colors * 255.0))\n#     print(\"Done.\")\n\n#     # # Get foreground points\n#     points, normals, colors = d.get_foreground_points(num_samples=40_000)\n#     rr.set_time_seconds(\"frame\", 0)\n#     rr.log(f\"world/points/final_foreground\", rr.Points3D(positions=points, colors=colors * 255.0))\n#     print(\"Done.\")\n\n#     # tracks_2d = d.load_target_tracks(0, [0,1,2,3,4], 1)    \n#     # print(tracks_2d.dtype)\n#     # # tracks_3d, visibles, invisibles, confidences, colors = d.get_tracks_3d(40000)\n#     # # colors = (colors * 255.0)\n#     # # print(\n#     # #     f\"{tracks_3d.shape=} {visibles.shape=} \"\n#     # #     f\"{invisibles.shape=} {confidences.shape=} \"\n#     # #     f\"{colors.shape=}\"\n#     # # )\n\n#     # # # Loop through 150 frames and log the corresponding points\n#     # # num_frames = tracks_3d.shape[1]  # 150 frames\n#     # # for frame_idx in range(num_frames):\n#     # #     rr.set_time_seconds(\"frame\", frame_idx)\n\n#     # #     # Get the 3D positions for the current frame\n#     # #     frame_tracks = tracks_3d[:, frame_idx, :]  # Shape: (35418, 3)\n#     # #     frame_visibles = visibles[:, frame_idx]  # Visibility mask\n\n#     # #     # Filter only visible points\n#     # #     visible_tracks = frame_tracks[frame_visibles > 0]\n#     # #     visible_colors = colors[frame_visibles > 0]\n\n#     # #     rr.set_time_seconds(\"frame\", frame_idx / 30)\n#     # #     rr.log(f\"world/tracks_3d\", rr.Points3D(positions=visible_tracks, colors=visible_colors))\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/data/utils.py",
    "content": "from typing import List, Optional, Tuple, TypedDict\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.modules.utils import _pair, _quadruple\n\nUINT16_MAX = 65535\n\n\nclass SceneNormDict(TypedDict):\n    scale: float\n    transfm: torch.Tensor\n\n\ndef to_device(batch, device):\n    if isinstance(batch, dict):\n        return {k: to_device(v, device) for k, v in batch.items()}\n    if isinstance(batch, (list, tuple)):\n        return [to_device(v, device) for v in batch]\n    if isinstance(batch, torch.Tensor):\n        return batch.to(device)\n    return batch\n\n\ndef normalize_coords(coords, h, w):\n    assert coords.shape[-1] == 2\n    return coords / torch.tensor([w - 1.0, h - 1.0], device=coords.device) * 2 - 1.0\n\n\ndef postprocess_occlusions(occlusions, expected_dist):\n    \"\"\"Postprocess occlusions to boolean visible flag.\n\n    Args:\n      occlusions: [-inf, inf], np.float32\n      expected_dist:, [-inf, inf], np.float32\n\n    Returns:\n      visibles: bool\n    \"\"\"\n\n    def sigmoid(x):\n        if x.dtype == np.ndarray:\n            return 1 / (1 + np.exp(-x))\n        else:\n            return torch.sigmoid(x)\n\n    visibles = (1 - sigmoid(occlusions)) * (1 - sigmoid(expected_dist)) > 0.5\n    return visibles\n\n\ndef parse_tapir_track_info(occlusions, expected_dist):\n    \"\"\"\n    return:\n        valid_visible: mask of visible & confident points\n        valid_invisible: mask of invisible & confident points\n        confidence: clamped confidence scores (all < 0.5 -> 0)\n    \"\"\"\n    visiblility = 1 - F.sigmoid(occlusions)\n    confidence = 1 - F.sigmoid(expected_dist)\n    valid_visible = visiblility * confidence > 0.5\n    valid_invisible = (1 - visiblility) * confidence > 0.5\n    # set all confidence < 0.5 to 0\n    confidence = confidence * (valid_visible | valid_invisible).float()\n    return valid_visible, valid_invisible, confidence\n\n\ndef get_tracks_3d_for_query_frame(\n    query_index: int,\n    query_img: torch.Tensor,\n    tracks_2d: torch.Tensor,\n    depths: torch.Tensor,\n    masks: torch.Tensor,\n    inv_Ks: torch.Tensor,\n    c2ws: torch.Tensor,\n):\n    \"\"\"\n    :param query_index (int)\n    :param query_img [H, W, 3]\n    :param tracks_2d [N, T, 4]\n    :param depths [T, H, W]\n    :param masks [T, H, W]\n    :param inv_Ks [T, 3, 3]\n    :param c2ws [T, 4, 4]\n    returns (\n        tracks_3d [N, T, 3]\n        track_colors [N, 3]\n        visibles [N, T]\n        invisibles [N, T]\n        confidences [N, T]\n    )\n    \"\"\"\n    T, H, W = depths.shape\n    query_img = query_img[None].permute(0, 3, 1, 2)  # (1, 3, H, W)\n    tracks_2d = tracks_2d.swapaxes(0, 1)  # (T, N, 4)\n    tracks_2d, occs, dists = (\n        tracks_2d[..., :2],\n        tracks_2d[..., 2],\n        tracks_2d[..., 3],\n    )\n    # visibles = postprocess_occlusions(occs, dists)\n    # (T, N), (T, N), (T, N)\n    visibles, invisibles, confidences = parse_tapir_track_info(occs, dists)\n    # Unproject 2D tracks to 3D.\n    # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N)\n    track_depths = F.grid_sample(\n        depths[:, None],\n        normalize_coords(tracks_2d[:, None], H, W),\n        align_corners=True,\n        padding_mode=\"border\",\n    )[:, 0, 0]\n    tracks_3d = (\n        torch.einsum(\n            \"nij,npj->npi\",\n            inv_Ks,\n            F.pad(tracks_2d, (0, 1), value=1.0),\n        )\n        * track_depths[..., None]\n    )\n    tracks_3d = torch.einsum(\"nij,npj->npi\", c2ws, F.pad(tracks_3d, (0, 1), value=1.0))[\n        ..., :3\n    ]\n    # Filter out out-of-mask tracks.\n    # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N)\n    is_in_masks = (\n        F.grid_sample(\n            masks[:, None],\n            normalize_coords(tracks_2d[:, None], H, W),\n            align_corners=True,\n        )[:, 0, 0]\n        == 1\n    )\n    visibles *= is_in_masks\n    invisibles *= is_in_masks\n    confidences *= is_in_masks.float()\n\n    # valid if in the fg mask at least 40% of the time\n    # in_mask_counts = is_in_masks.sum(0)\n    # t = 0.25\n    # thresh = min(t * T, in_mask_counts.float().quantile(t).item())\n    # valid = in_mask_counts > thresh\n    valid = is_in_masks[query_index]\n    # valid if visible 5% of the time\n    visible_counts = visibles.sum(0)\n    valid = valid & (\n        visible_counts\n        >= min(\n            int(0.05 * T),\n            visible_counts.float().quantile(0.1).item(),\n        )\n    )\n\n    # Get track's color from the query frame.\n    # (1, 3, H, W), (1, 1, N, 2) -> (1, 3, 1, N) -> (N, 3)\n    track_colors = F.grid_sample(\n        query_img,\n        normalize_coords(tracks_2d[query_index : query_index + 1, None], H, W),\n        align_corners=True,\n        padding_mode=\"border\",\n    )[0, :, 0].T\n    return (\n        tracks_3d[:, valid].swapdims(0, 1),\n        track_colors[valid],\n        visibles[:, valid].swapdims(0, 1),\n        invisibles[:, valid].swapdims(0, 1),\n        confidences[:, valid].swapdims(0, 1),\n    )\n\n\ndef _get_padding(x, k, stride, padding, same: bool):\n    if same:\n        ih, iw = x.size()[2:]\n        if ih % stride[0] == 0:\n            ph = max(k[0] - stride[0], 0)\n        else:\n            ph = max(k[0] - (ih % stride[0]), 0)\n        if iw % stride[1] == 0:\n            pw = max(k[1] - stride[1], 0)\n        else:\n            pw = max(k[1] - (iw % stride[1]), 0)\n        pl = pw // 2\n        pr = pw - pl\n        pt = ph // 2\n        pb = ph - pt\n        padding = (pl, pr, pt, pb)\n    else:\n        padding = padding\n    return padding\n\n\ndef median_filter_2d(x, kernel_size=3, stride=1, padding=1, same: bool = True):\n    \"\"\"\n    :param x [B, C, H, W]\n    \"\"\"\n    k = _pair(kernel_size)\n    stride = _pair(stride)  # convert to tuple\n    padding = _quadruple(padding)  # convert to l, r, t, b\n    # using existing pytorch functions and tensor ops so that we get autograd,\n    # would likely be more efficient to implement from scratch at C/Cuda level\n    x = F.pad(x, _get_padding(x, k, stride, padding, same), mode=\"reflect\")\n    x = x.unfold(2, k[0], stride[0]).unfold(3, k[1], stride[1])\n    x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]\n    return x\n\n\ndef masked_median_blur(image, mask, kernel_size=11):\n    \"\"\"\n    Args:\n        image: [B, C, H, W]\n        mask: [B, C, H, W]\n        kernel_size: int\n    \"\"\"\n    assert image.shape == mask.shape\n    if not isinstance(image, torch.Tensor):\n        raise TypeError(f\"Input type is not a torch.Tensor. Got {type(image)}\")\n\n    if not len(image.shape) == 4:\n        raise ValueError(f\"Invalid input shape, we expect BxCxHxW. Got: {image.shape}\")\n\n    padding: Tuple[int, int] = _compute_zero_padding((kernel_size, kernel_size))\n\n    # prepare kernel\n    kernel: torch.Tensor = get_binary_kernel2d((kernel_size, kernel_size)).to(image)\n    b, c, h, w = image.shape\n\n    # map the local window to single vector\n    features: torch.Tensor = F.conv2d(\n        image.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1\n    )\n    masks: torch.Tensor = F.conv2d(\n        mask.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1\n    )\n    features = features.view(b, c, -1, h, w).permute(\n        0, 1, 3, 4, 2\n    )  # BxCxxHxWx(K_h * K_w)\n    min_value, max_value = features.min(), features.max()\n    masks = masks.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2)  # BxCxHxWx(K_h * K_w)\n    index_invalid = (1 - masks).nonzero(as_tuple=True)\n    index_b, index_c, index_h, index_w, index_k = index_invalid\n    features[(index_b[::2], index_c[::2], index_h[::2], index_w[::2], index_k[::2])] = (\n        min_value\n    )\n    features[\n        (index_b[1::2], index_c[1::2], index_h[1::2], index_w[1::2], index_k[1::2])\n    ] = max_value\n    # compute the median along the feature axis\n    median: torch.Tensor = torch.median(features, dim=-1)[0]\n\n    return median\n\n\ndef _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:\n    r\"\"\"Utility function that computes zero padding tuple.\"\"\"\n    computed: List[int] = [(k - 1) // 2 for k in kernel_size]\n    return computed[0], computed[1]\n\n\ndef get_binary_kernel2d(\n    window_size: tuple[int, int] | int,\n    *,\n    device: Optional[torch.device] = None,\n    dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    \"\"\"\n    from kornia\n    Create a binary kernel to extract the patches.\n    If the window size is HxW will create a (H*W)x1xHxW kernel.\n    \"\"\"\n    ky, kx = _unpack_2d_ks(window_size)\n\n    window_range = kx * ky\n\n    kernel = torch.zeros((window_range, window_range), device=device, dtype=dtype)\n    idx = torch.arange(window_range, device=device)\n    kernel[idx, idx] += 1.0\n    return kernel.view(window_range, 1, ky, kx)\n\n\ndef _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]:\n    if isinstance(kernel_size, int):\n        ky = kx = kernel_size\n    else:\n        assert len(kernel_size) == 2, \"2D Kernel size should have a length of 2.\"\n        ky, kx = kernel_size\n\n    ky = int(ky)\n    kx = int(kx)\n\n    return (ky, kx)\n\n\n## Functions from GaussianShader.\ndef ndc_2_cam(ndc_xyz, intrinsic, W, H):\n    inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device)\n    cam_z = ndc_xyz[..., 2:3]\n    cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z\n    cam_xyz = torch.cat([cam_xy, cam_z], dim=-1)\n    cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t())\n    return cam_xyz\n\n\ndef depth2point_cam(sampled_depth, ref_intrinsic):\n    B, N, C, H, W = sampled_depth.shape\n    valid_z = sampled_depth\n    valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device) / (\n        W - 1\n    )\n    valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device) / (\n        H - 1\n    )\n    valid_y, valid_x = torch.meshgrid(valid_y, valid_x, indexing=\"ij\")\n    # B,N,H,W\n    valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1)\n    valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1)\n    ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view(\n        B, N, C, H, W, 3\n    )  # 1, 1, 5, 512, 640, 3\n    cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H)  # 1, 1, 5, 512, 640, 3\n    return ndc_xyz, cam_xyz\n\n\ndef depth2point_world(depth_image, intrinsic_matrix, extrinsic_matrix):\n    # depth_image: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)\n    _, xyz_cam = depth2point_cam(\n        depth_image[None, None, None, ...], intrinsic_matrix[None, ...]\n    )\n    xyz_cam = xyz_cam.reshape(-1, 3)\n    xyz_world = torch.cat(\n        [xyz_cam, torch.ones_like(xyz_cam[..., 0:1])], dim=-1\n    ) @ torch.inverse(extrinsic_matrix).transpose(0, 1)\n    xyz_world = xyz_world[..., :3]\n\n    return xyz_world\n\n\ndef depth_pcd2normal(xyz):\n    hd, wd, _ = xyz.shape\n    bottom_point = xyz[..., 2:hd, 1 : wd - 1, :]\n    top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :]\n    right_point = xyz[..., 1 : hd - 1, 2:wd, :]\n    left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :]\n    left_to_right = right_point - left_point\n    bottom_to_top = top_point - bottom_point\n    xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)\n    xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)\n    xyz_normal = torch.nn.functional.pad(\n        xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode=\"constant\"\n    ).permute(1, 2, 0)\n    return xyz_normal\n\n\ndef normal_from_depth_image(depth, intrinsic_matrix, extrinsic_matrix):\n    # depth: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)\n    # xyz_normal: (H, W, 3)\n    xyz_world = depth2point_world(depth, intrinsic_matrix, extrinsic_matrix)  # (HxW, 3)\n    xyz_world = xyz_world.reshape(*depth.shape, 3)\n    xyz_normal = depth_pcd2normal(xyz_world)\n\n    return xyz_normal\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/init_utils.py",
    "content": "import time\nfrom typing import Literal\n\nimport cupy as cp\nimport imageio.v3 as iio\nimport numpy as np\n\n# from pytorch3d.ops import sample_farthest_points\nimport roma\nimport torch\nimport torch.nn.functional as F\nfrom cuml import HDBSCAN, KMeans\nfrom loguru import logger as guru\nfrom matplotlib.pyplot import get_cmap\nfrom tqdm import tqdm\nfrom viser import ViserServer\n\nfrom flow3d.loss_utils import (\n    compute_accel_loss,\n    compute_se3_smoothness_loss,\n    compute_z_acc_loss,\n    get_weights_for_procrustes,\n    knn,\n    masked_l1_loss,\n)\nfrom flow3d.params import GaussianParams, MotionBases\nfrom flow3d.tensor_dataclass import StaticObservations, TrackObservations\nfrom flow3d.transforms import cont_6d_to_rmat, rt_to_mat4, solve_procrustes\nfrom flow3d.vis.utils import draw_keypoints_video, get_server, project_2d_tracks\n\n\ndef init_fg_from_tracks_3d(\n    cano_t: int, tracks_3d: TrackObservations, motion_coefs: torch.Tensor\n) -> GaussianParams:\n    \"\"\"\n    using dataclasses individual tensors so we know they're consistent\n    and are always masked/filtered together\n    \"\"\"\n    num_fg = tracks_3d.xyz.shape[0]\n\n    # Initialize gaussian colors.\n    colors = torch.logit(tracks_3d.colors)\n    # Initialize gaussian scales: find the average of the three nearest\n    # neighbors in the first frame for each point and use that as the\n    # scale.\n    dists, _ = knn(tracks_3d.xyz[:, cano_t], 3)\n    dists = torch.from_numpy(dists)\n    scales = dists.mean(dim=-1, keepdim=True)\n    scales = scales.clamp(torch.quantile(scales, 0.05), torch.quantile(scales, 0.95))\n    scales = torch.log(scales.repeat(1, 3))\n    # Initialize gaussian means.\n    means = tracks_3d.xyz[:, cano_t]\n    # Initialize gaussian orientations as random.\n    quats = torch.rand(num_fg, 4)\n    # Initialize gaussian opacities.\n    opacities = torch.logit(torch.full((num_fg,), 0.7))\n    gaussians = GaussianParams(means, quats, scales, colors, opacities, motion_coefs)\n    return gaussians\n\n\ndef init_bg(\n    points: StaticObservations,\n) -> GaussianParams:\n    \"\"\"\n    using dataclasses instead of individual tensors so we know they're consistent\n    and are always masked/filtered together\n    \"\"\"\n    num_init_bg_gaussians = points.xyz.shape[0]\n    bg_scene_center = points.xyz.mean(0)\n    bg_points_centered = points.xyz - bg_scene_center\n    bg_min_scale = bg_points_centered.quantile(0.05, dim=0)\n    bg_max_scale = bg_points_centered.quantile(0.95, dim=0)\n    bg_scene_scale = torch.max(bg_max_scale - bg_min_scale).item() / 2.0\n    bkdg_colors = torch.logit(points.colors)\n\n    # Initialize gaussian scales: find the average of the three nearest\n    # neighbors in the first frame for each point and use that as the\n    # scale.\n    dists, _ = knn(points.xyz, 3)\n    dists = torch.from_numpy(dists)\n    bg_scales = dists.mean(dim=-1, keepdim=True)\n    bkdg_scales = torch.log(bg_scales.repeat(1, 3))\n\n    bg_means = points.xyz\n\n    # Initialize gaussian orientations by normals.\n    local_normals = points.normals.new_tensor([[0.0, 0.0, 1.0]]).expand_as(\n        points.normals\n    )\n\n    angles = torch.clamp((local_normals * points.normals).sum(-1, keepdim=True), -1.0, 1.0).acos_()\n\n    # bg_quats = roma.rotvec_to_unitquat(\n    #     F.normalize(local_normals.cross(points.normals), dim=-1)\n    #     * (local_normals * points.normals).sum(-1, keepdim=True).acos_()\n    # ).roll(1, dims=-1)\n    bg_quats = roma.rotvec_to_unitquat(\n        F.normalize(local_normals.cross(points.normals), dim=-1)\n        * angles\n    ).roll(1, dims=-1)\n    \n    bg_opacities = torch.logit(torch.full((num_init_bg_gaussians,), 0.7))\n    gaussians = GaussianParams(\n        bg_means,\n        bg_quats,\n        bkdg_scales,\n        bkdg_colors,\n        bg_opacities,\n        scene_center=bg_scene_center,\n        scene_scale=bg_scene_scale,\n    )\n    return gaussians\n\n\ndef init_motion_params_with_procrustes(\n    tracks_3d: TrackObservations,\n    num_bases: int,\n    rot_type: Literal[\"quat\", \"6d\"],\n    cano_t: int,\n    cluster_init_method: str = \"kmeans\",\n    min_mean_weight: float = 0.1,\n    vis: bool = False,\n    port: int | None = None,\n) -> tuple[MotionBases, torch.Tensor, TrackObservations]:\n    device = tracks_3d.xyz.device\n    num_frames = tracks_3d.xyz.shape[1]\n    # sample centers and get initial se3 motion bases by solving procrustes\n    means_cano = tracks_3d.xyz[:, cano_t].clone()  # [num_gaussians, 3]\n\n    # remove outliers\n    scene_center = means_cano.median(dim=0).values\n    print(f\"{scene_center=}\")\n    dists = torch.norm(means_cano - scene_center, dim=-1)\n    dists_th = torch.quantile(dists, 0.95)\n    valid_mask = dists < dists_th\n\n    # remove tracks that are not visible in any frame\n    valid_mask = valid_mask & tracks_3d.visibles.any(dim=1)\n    print(f\"{valid_mask.sum()=}\")\n\n    tracks_3d = tracks_3d.filter_valid(valid_mask)\n\n    if vis and port is not None:\n        server = get_server(port)\n        try:\n            pts = tracks_3d.xyz.cpu().numpy()\n            clrs = tracks_3d.colors.cpu().numpy()\n            while True:\n                for t in range(num_frames):\n                    server.scene.add_point_cloud(\"points\", pts[:, t], clrs)\n                    time.sleep(0.3)\n        except KeyboardInterrupt:\n            pass\n\n    means_cano = means_cano[valid_mask]\n\n    sampled_centers, num_bases, labels = sample_initial_bases_centers(\n        cluster_init_method, cano_t, tracks_3d, num_bases\n    )\n\n    # assign each point to the label to compute the cluster weight\n    ids, counts = labels.unique(return_counts=True)\n    ids = ids[counts > 100]\n    num_bases = len(ids)\n    sampled_centers = sampled_centers[:, ids]\n    print(f\"{num_bases=} {sampled_centers.shape=}\")\n\n    # compute basis weights from the distance to the cluster centers\n    dists2centers = torch.norm(means_cano[:, None] - sampled_centers, dim=-1)\n    motion_coefs = 10 * torch.exp(-dists2centers)\n\n    init_rots, init_ts = [], []\n\n    if rot_type == \"quat\":\n        id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)\n        rot_dim = 4\n    else:\n        id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], device=device)\n        rot_dim = 6\n\n    init_rots = id_rot.reshape(1, 1, rot_dim).repeat(num_bases, num_frames, 1)\n    init_ts = torch.zeros(num_bases, num_frames, 3, device=device)\n    errs_before = np.full((num_bases, num_frames), -1.0)\n    errs_after = np.full((num_bases, num_frames), -1.0)\n\n    tgt_ts = list(range(cano_t - 1, -1, -1)) + list(range(cano_t, num_frames))\n    print(f\"{tgt_ts=}\")\n    skipped_ts = {}\n    for n, cluster_id in enumerate(ids):\n        mask_in_cluster = labels == cluster_id\n        cluster = tracks_3d.xyz[mask_in_cluster].transpose(\n            0, 1\n        )  # [num_frames, n_pts, 3]\n        visibilities = tracks_3d.visibles[mask_in_cluster].swapaxes(\n            0, 1\n        )  # [num_frames, n_pts]\n        confidences = tracks_3d.confidences[mask_in_cluster].swapaxes(\n            0, 1\n        )  # [num_frames, n_pts]\n        weights = get_weights_for_procrustes(cluster, visibilities)\n        prev_t = cano_t\n        cluster_skip_ts = []\n        for cur_t in tgt_ts:\n            # compute pairwise transform from cano_t\n            procrustes_weights = (\n                weights[cano_t]\n                * weights[cur_t]\n                * (confidences[cano_t] + confidences[cur_t])\n                / 2\n            )\n            if procrustes_weights.sum() < min_mean_weight * num_frames:\n                init_rots[n, cur_t] = init_rots[n, prev_t]\n                init_ts[n, cur_t] = init_ts[n, prev_t]\n                cluster_skip_ts.append(cur_t)\n            else:\n                se3, (err, err_before) = solve_procrustes(\n                    cluster[cano_t],\n                    cluster[cur_t],\n                    weights=procrustes_weights,\n                    enforce_se3=True,\n                    rot_type=rot_type,\n                )\n                init_rot, init_t, _ = se3\n                assert init_rot.shape[-1] == rot_dim\n                # double cover\n                if rot_type == \"quat\" and torch.linalg.norm(\n                    init_rot - init_rots[n][prev_t]\n                ) > torch.linalg.norm(-init_rot - init_rots[n][prev_t]):\n                    init_rot = -init_rot\n                init_rots[n, cur_t] = init_rot\n                init_ts[n, cur_t] = init_t\n                if err == np.nan:\n                    print(f\"{cur_t=} {err=}\")\n                    print(f\"{procrustes_weights.isnan().sum()=}\")\n                if err_before == np.nan:\n                    print(f\"{cur_t=} {err_before=}\")\n                    print(f\"{procrustes_weights.isnan().sum()=}\")\n                errs_after[n, cur_t] = err\n                errs_before[n, cur_t] = err_before\n            prev_t = cur_t\n        skipped_ts[cluster_id.item()] = cluster_skip_ts\n\n    guru.info(f\"{skipped_ts=}\")\n    guru.info(\n        \"procrustes init median error: {:.5f} => {:.5f}\".format(\n            np.median(errs_before[errs_before > 0]),\n            np.median(errs_after[errs_after > 0]),\n        )\n    )\n    guru.info(\n        \"procrustes init mean error: {:.5f} => {:.5f}\".format(\n            np.mean(errs_before[errs_before > 0]), np.mean(errs_after[errs_after > 0])\n        )\n    )\n    guru.info(f\"{init_rots.shape=}, {init_ts.shape=}, {motion_coefs.shape=}\")\n\n    if vis:\n        server = get_server(port)\n        center_idcs = torch.argmin(dists2centers, dim=0)\n        print(f\"{dists2centers.shape=} {center_idcs.shape=}\")\n        vis_se3_init_3d(server, init_rots, init_ts, means_cano[center_idcs])\n        vis_tracks_3d(server, tracks_3d.xyz[center_idcs].numpy(), name=\"center_tracks\")\n        import ipdb\n\n        ipdb.set_trace()\n\n    bases = MotionBases(init_rots, init_ts)\n    return bases, motion_coefs, tracks_3d\n\n\ndef run_initial_optim(\n    fg: GaussianParams,\n    bases: MotionBases,\n    tracks_3d: TrackObservations,\n    Ks: torch.Tensor,\n    w2cs: torch.Tensor,\n    num_iters: int = 1000,\n    use_depth_range_loss: bool = False,\n):\n    \"\"\"\n    :param motion_rots: [num_bases, num_frames, 4|6]\n    :param motion_transls: [num_bases, num_frames, 3]\n    :param motion_coefs: [num_bases, num_frames]\n    :param means: [num_gaussians, 3]\n    \"\"\"\n    optimizer = torch.optim.Adam(\n        [\n            {\"params\": bases.params[\"rots\"], \"lr\": 1e-2},\n            {\"params\": bases.params[\"transls\"], \"lr\": 3e-2},\n            {\"params\": fg.params[\"motion_coefs\"], \"lr\": 1e-2},\n            {\"params\": fg.params[\"means\"], \"lr\": 1e-3},\n        ],\n    )\n    scheduler = torch.optim.lr_scheduler.ExponentialLR(\n        optimizer, gamma=0.1 ** (1 / num_iters)\n    )\n    G = fg.params.means.shape[0]\n    num_frames = bases.num_frames\n    device = bases.params[\"rots\"].device\n\n    w_smooth_func = lambda i, min_v, max_v, th: (\n        min_v if i <= th else (max_v - min_v) * (i - th) / (num_iters - th) + min_v\n    )\n\n    gt_2d, gt_depth = project_2d_tracks(\n        tracks_3d.xyz.swapaxes(0, 1), Ks, w2cs, return_depth=True\n    )\n    # (G, T, 2)\n    gt_2d = gt_2d.swapaxes(0, 1)\n    # (G, T)\n    gt_depth = gt_depth.swapaxes(0, 1)\n\n    ts = torch.arange(0, num_frames, device=device)\n    ts_clamped = torch.clamp(ts, min=1, max=num_frames - 2)\n    ts_neighbors = torch.cat((ts_clamped - 1, ts_clamped, ts_clamped + 1))  # i (3B,)\n\n    pbar = tqdm(range(0, num_iters))\n    for i in pbar:\n        coefs = fg.get_coefs()\n        transfms = bases.compute_transforms(ts, coefs)\n        positions = torch.einsum(\n            \"pnij,pj->pni\",\n            transfms,\n            F.pad(fg.params[\"means\"], (0, 1), value=1.0),\n        )\n\n        loss = 0.0\n        track_3d_loss = masked_l1_loss(\n            positions,\n            tracks_3d.xyz,\n            (tracks_3d.visibles.float() * tracks_3d.confidences)[..., None],\n        )\n        loss += track_3d_loss * 1.0\n\n        pred_2d, pred_depth = project_2d_tracks(\n            positions.swapaxes(0, 1), Ks, w2cs, return_depth=True\n        )\n        pred_2d = pred_2d.swapaxes(0, 1)\n        pred_depth = pred_depth.swapaxes(0, 1)\n\n        loss_2d = (\n            masked_l1_loss(\n                pred_2d,\n                gt_2d,\n                (tracks_3d.invisibles.float() * tracks_3d.confidences)[..., None],\n                quantile=0.95,\n            )\n            / Ks[0, 0, 0]\n        )\n        loss += 0.5 * loss_2d\n\n        if use_depth_range_loss:\n            near_depths = torch.quantile(gt_depth, 0.0, dim=0, keepdim=True)\n            far_depths = torch.quantile(gt_depth, 0.98, dim=0, keepdim=True)\n            loss_depth_in_range = 0\n            if (pred_depth < near_depths).any():\n                loss_depth_in_range += (near_depths - pred_depth)[\n                    pred_depth < near_depths\n                ].mean()\n            if (pred_depth > far_depths).any():\n                loss_depth_in_range += (pred_depth - far_depths)[\n                    pred_depth > far_depths\n                ].mean()\n\n            loss += loss_depth_in_range * w_smooth_func(i, 0.05, 0.5, 400)\n\n        motion_coef_sparse_loss = 1 - (coefs**2).sum(dim=-1).mean()\n        loss += motion_coef_sparse_loss * 0.01\n\n        # motion basis should be smooth.\n        w_smooth = w_smooth_func(i, 0.01, 0.1, 400)\n        small_acc_loss = compute_se3_smoothness_loss(\n            bases.params[\"rots\"], bases.params[\"transls\"]\n        )\n        loss += small_acc_loss * w_smooth\n\n        small_acc_loss_tracks = compute_accel_loss(positions)\n        loss += small_acc_loss_tracks * w_smooth * 0.5\n\n        transfms_nbs = bases.compute_transforms(ts_neighbors, coefs)\n        means_nbs = torch.einsum(\n            \"pnij,pj->pni\", transfms_nbs, F.pad(fg.params[\"means\"], (0, 1), value=1.0)\n        )  # (G, 3n, 3)\n        means_nbs = means_nbs.reshape(means_nbs.shape[0], 3, -1, 3)  # [G, 3, n, 3]\n        z_accel_loss = compute_z_acc_loss(means_nbs, w2cs)\n        loss += z_accel_loss * 0.1\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        scheduler.step()\n\n        pbar.set_description(\n            f\"{loss.item():.3f} \"\n            f\"{track_3d_loss.item():.3f} \"\n            f\"{motion_coef_sparse_loss.item():.3f} \"\n            f\"{small_acc_loss.item():.3f} \"\n            f\"{small_acc_loss_tracks.item():.3f} \"\n            f\"{z_accel_loss.item():.3f} \"\n        )\n\n\ndef random_quats(N: int) -> torch.Tensor:\n    u = torch.rand(N, 1)\n    v = torch.rand(N, 1)\n    w = torch.rand(N, 1)\n    quats = torch.cat(\n        [\n            torch.sqrt(1.0 - u) * torch.sin(2.0 * np.pi * v),\n            torch.sqrt(1.0 - u) * torch.cos(2.0 * np.pi * v),\n            torch.sqrt(u) * torch.sin(2.0 * np.pi * w),\n            torch.sqrt(u) * torch.cos(2.0 * np.pi * w),\n        ],\n        -1,\n    )\n    return quats\n\n\ndef compute_means(ts, fg: GaussianParams, bases: MotionBases):\n    transfms = bases.compute_transforms(ts, fg.get_coefs())\n    means = torch.einsum(\n        \"pnij,pj->pni\",\n        transfms,\n        F.pad(fg.params[\"means\"], (0, 1), value=1.0),\n    )\n    return means\n\n\ndef vis_init_params(\n    server,\n    fg: GaussianParams,\n    bases: MotionBases,\n    name=\"init_params\",\n    num_vis: int = 100,\n):\n    idcs = np.random.choice(fg.num_gaussians, num_vis)\n    labels = np.linspace(0, 1, num_vis)\n    ts = torch.arange(bases.num_frames, device=bases.params[\"rots\"].device)\n    with torch.no_grad():\n        pred_means = compute_means(ts, fg, bases)\n        vis_means = pred_means[idcs].detach().cpu().numpy()\n    vis_tracks_3d(server, vis_means, labels, name=name)\n\n\n@torch.no_grad()\ndef vis_se3_init_3d(server, init_rots, init_ts, basis_centers):\n    \"\"\"\n    :param init_rots: [num_bases, num_frames, 4|6]\n    :param init_ts: [num_bases, num_frames, 3]\n    :param basis_centers: [num_bases, 3]\n    \"\"\"\n    # visualize the initial centers across time\n    rot_dim = init_rots.shape[-1]\n    assert rot_dim in [4, 6]\n    num_bases = init_rots.shape[0]\n    assert init_ts.shape[0] == num_bases\n    assert basis_centers.shape[0] == num_bases\n    labels = np.linspace(0, 1, num_bases)\n    if rot_dim == 4:\n        quats = F.normalize(init_rots, dim=-1, p=2)\n        rmats = roma.unitquat_to_rotmat(quats.roll(-1, dims=-1))\n    else:\n        rmats = cont_6d_to_rmat(init_rots)\n    transls = init_ts\n    transfms = rt_to_mat4(rmats, transls)\n    center_tracks3d = torch.einsum(\n        \"bnij,bj->bni\", transfms, F.pad(basis_centers, (0, 1), value=1.0)\n    )[..., :3]\n    vis_tracks_3d(server, center_tracks3d.cpu().numpy(), labels, name=\"se3_centers\")\n\n\n@torch.no_grad()\ndef vis_tracks_2d_video(\n    path,\n    imgs: np.ndarray,\n    tracks_3d: np.ndarray,\n    Ks: np.ndarray,\n    w2cs: np.ndarray,\n    occs=None,\n    radius: int = 3,\n):\n    num_tracks = tracks_3d.shape[0]\n    labels = np.linspace(0, 1, num_tracks)\n    cmap = get_cmap(\"gist_rainbow\")\n    colors = cmap(labels)[:, :3]\n    tracks_2d = (\n        project_2d_tracks(tracks_3d.swapaxes(0, 1), Ks, w2cs).cpu().numpy()  # type: ignore\n    )\n    frames = np.asarray(\n        draw_keypoints_video(imgs, tracks_2d, colors, occs, radius=radius)\n    )\n    iio.imwrite(path, frames, fps=15)\n\n\ndef vis_tracks_3d(\n    server: ViserServer,\n    vis_tracks: np.ndarray,\n    vis_label: np.ndarray | None = None,\n    name: str = \"tracks\",\n):\n    \"\"\"\n    :param vis_tracks (np.ndarray): (N, T, 3)\n    :param vis_label (np.ndarray): (N)\n    \"\"\"\n    cmap = get_cmap(\"gist_rainbow\")\n    if vis_label is None:\n        vis_label = np.linspace(0, 1, len(vis_tracks))\n    colors = cmap(np.asarray(vis_label))[:, :3]\n    guru.info(f\"{colors.shape=}, {vis_tracks.shape=}\")\n    N, T = vis_tracks.shape[:2]\n    vis_tracks = np.asarray(vis_tracks)\n    for i in range(N):\n        server.scene.add_spline_catmull_rom(\n            f\"/{name}/{i}/spline\", vis_tracks[i], color=colors[i], segments=T - 1\n        )\n        server.scene.add_point_cloud(\n            f\"/{name}/{i}/start\",\n            vis_tracks[i, [0]],\n            colors=colors[i : i + 1],\n            point_size=0.05,\n            point_shape=\"circle\",\n        )\n        server.scene.add_point_cloud(\n            f\"/{name}/{i}/end\",\n            vis_tracks[i, [-1]],\n            colors=colors[i : i + 1],\n            point_size=0.05,\n            point_shape=\"diamond\",\n        )\n\n\ndef sample_initial_bases_centers(\n    mode: str, cano_t: int, tracks_3d: TrackObservations, num_bases: int\n):\n    \"\"\"\n    :param mode: \"farthest\" | \"hdbscan\" | \"kmeans\"\n    :param tracks_3d: [G, T, 3]\n    :param cano_t: canonical index\n    :param num_bases: number of SE3 bases\n    \"\"\"\n    assert mode in [\"farthest\", \"hdbscan\", \"kmeans\"]\n    means_canonical = tracks_3d.xyz[:, cano_t].clone()\n    # if mode == \"farthest\":\n    #     vis_mask = tracks_3d.visibles[:, cano_t]\n    #     sampled_centers, _ = sample_farthest_points(\n    #         means_canonical[vis_mask][None],\n    #         K=num_bases,\n    #         random_start_point=True,\n    #     )  # [1, num_bases, 3]\n    #     dists2centers = torch.norm(means_canonical[:, None] - sampled_centers, dim=-1).T\n    #     return sampled_centers, num_bases, dists2centers\n\n    # linearly interpolate missing 3d points\n    xyz = cp.asarray(tracks_3d.xyz)\n    print(f\"{xyz.shape=}\")\n    visibles = cp.asarray(tracks_3d.visibles)\n\n    num_tracks = xyz.shape[0]\n    xyz_interp = batched_interp_masked(xyz, visibles)\n\n    # num_vis = 50\n    # server = get_server(port=8890)\n    # idcs = np.random.choice(num_tracks, num_vis)\n    # labels = np.linspace(0, 1, num_vis)\n    # vis_tracks_3d(server, tracks_3d.xyz[idcs].get(), labels, name=\"raw_tracks\")\n    # vis_tracks_3d(server, xyz_interp[idcs].get(), labels, name=\"interp_tracks\")\n    # import ipdb; ipdb.set_trace()\n\n    velocities = xyz_interp[:, 1:] - xyz_interp[:, :-1]\n    vel_dirs = (\n        velocities / (cp.linalg.norm(velocities, axis=-1, keepdims=True) + 1e-5)\n    ).reshape((num_tracks, -1))\n\n    # [num_bases, num_gaussians]\n    if mode == \"kmeans\":\n        model = KMeans(n_clusters=num_bases)\n    else:\n        model = HDBSCAN(min_cluster_size=20, max_cluster_size=num_tracks // 4)\n    model.fit(vel_dirs)\n    labels = model.labels_\n    num_bases = labels.max().item() + 1\n    sampled_centers = torch.stack(\n        [\n            means_canonical[torch.tensor(labels == i)].median(dim=0).values\n            for i in range(num_bases)\n        ]\n    )[None]\n    print(\"number of {} clusters: \".format(mode), num_bases)\n    return sampled_centers, num_bases, torch.tensor(labels)\n\n\ndef interp_masked(vals: cp.ndarray, mask: cp.ndarray, pad: int = 1) -> cp.ndarray:\n    \"\"\"\n    hacky way to interpolate batched with cupy\n    by concatenating the batches and pad with dummy values\n    :param vals: [B, M, *]\n    :param mask: [B, M]\n    \"\"\"\n    assert mask.ndim == 2\n    assert vals.shape[:2] == mask.shape\n\n    B, M = mask.shape\n\n    # get the first and last valid values for each track\n    sh = vals.shape[2:]\n    vals = vals.reshape((B, M, -1))\n    D = vals.shape[-1]\n    first_val_idcs = cp.argmax(mask, axis=-1)\n    last_val_idcs = M - 1 - cp.argmax(cp.flip(mask, axis=-1), axis=-1)\n    bidcs = cp.arange(B)\n\n    v0 = vals[bidcs, first_val_idcs][:, None]\n    v1 = vals[bidcs, last_val_idcs][:, None]\n    m0 = mask[bidcs, first_val_idcs][:, None]\n    m1 = mask[bidcs, last_val_idcs][:, None]\n    if pad > 1:\n        v0 = cp.tile(v0, [1, pad, 1])\n        v1 = cp.tile(v1, [1, pad, 1])\n        m0 = cp.tile(m0, [1, pad])\n        m1 = cp.tile(m1, [1, pad])\n\n    vals_pad = cp.concatenate([v0, vals, v1], axis=1)\n    mask_pad = cp.concatenate([m0, mask, m1], axis=1)\n\n    M_pad = vals_pad.shape[1]\n    vals_flat = vals_pad.reshape((B * M_pad, -1))\n    mask_flat = mask_pad.reshape((B * M_pad,))\n    idcs = cp.where(mask_flat)[0]\n\n    cx = cp.arange(B * M_pad)\n    out = cp.zeros((B * M_pad, D), dtype=vals_flat.dtype)\n    for d in range(D):\n        out[:, d] = cp.interp(cx, idcs, vals_flat[idcs, d])\n\n    out = out.reshape((B, M_pad, *sh))[:, pad:-pad]\n    return out\n\n\ndef batched_interp_masked(\n    vals: cp.ndarray, mask: cp.ndarray, batch_num: int = 4096, batch_time: int = 64\n):\n    assert mask.ndim == 2\n    B, M = mask.shape\n    out = cp.zeros_like(vals)\n    for b in tqdm(range(0, B, batch_num), leave=False):\n        for m in tqdm(range(0, M, batch_time), leave=False):\n            x = interp_masked(\n                vals[b : b + batch_num, m : m + batch_time],\n                mask[b : b + batch_num, m : m + batch_time],\n            )  # (batch_num, batch_time, *)\n            out[b : b + batch_num, m : m + batch_time] = x\n    return out\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/loss_utils.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.neighbors import NearestNeighbors\n\n\ndef masked_mse_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):\n    if mask is None:\n        return trimmed_mse_loss(pred, gt, quantile)\n    else:\n        sum_loss = F.mse_loss(pred, gt, reduction=\"none\").mean(dim=-1, keepdim=True)\n        quantile_mask = (\n            (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)\n            if quantile < 1\n            else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)\n        )\n        ndim = sum_loss.shape[-1]\n        if normalize:\n            return torch.sum((sum_loss * mask)[quantile_mask]) / (\n                ndim * torch.sum(mask[quantile_mask]) + 1e-8\n            )\n        else:\n            return torch.mean((sum_loss * mask)[quantile_mask])\n\n\ndef masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):\n    if mask is None:\n        return trimmed_l1_loss(pred, gt, quantile)\n    else:\n        sum_loss = F.l1_loss(pred, gt, reduction=\"none\").mean(dim=-1, keepdim=True)\n        quantile_mask = (\n            (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)\n            if quantile < 1\n            else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)\n        )\n        ndim = sum_loss.shape[-1]\n        if normalize:\n            return torch.sum((sum_loss * mask)[quantile_mask]) / (\n                ndim * torch.sum(mask[quantile_mask]) + 1e-8\n            )\n        else:\n            return torch.mean((sum_loss * mask)[quantile_mask])\n\n\ndef masked_huber_loss(pred, gt, delta, mask=None, normalize=True):\n    if mask is None:\n        return F.huber_loss(pred, gt, delta=delta)\n    else:\n        sum_loss = F.huber_loss(pred, gt, delta=delta, reduction=\"none\")\n        ndim = sum_loss.shape[-1]\n        if normalize:\n            return torch.sum(sum_loss * mask) / (ndim * torch.sum(mask) + 1e-8)\n        else:\n            return torch.mean(sum_loss * mask)\n\n\ndef trimmed_mse_loss(pred, gt, quantile=0.9):\n    loss = F.mse_loss(pred, gt, reduction=\"none\").mean(dim=-1)\n    loss_at_quantile = torch.quantile(loss, quantile)\n    trimmed_loss = loss[loss < loss_at_quantile].mean()\n    return trimmed_loss\n\n\ndef trimmed_l1_loss(pred, gt, quantile=0.9):\n    loss = F.l1_loss(pred, gt, reduction=\"none\").mean(dim=-1)\n    loss_at_quantile = torch.quantile(loss, quantile)\n    trimmed_loss = loss[loss < loss_at_quantile].mean()\n    return trimmed_loss\n\n\ndef compute_gradient_loss(pred, gt, mask, quantile=0.98):\n    \"\"\"\n    Compute gradient loss\n    pred: (batch_size, H, W, D) or (batch_size, H, W)\n    gt: (batch_size, H, W, D) or (batch_size, H, W)\n    mask: (batch_size, H, W), bool or float\n    \"\"\"\n    # NOTE: messy need to be cleaned up\n    mask_x = mask[:, :, 1:] * mask[:, :, :-1]\n    mask_y = mask[:, 1:, :] * mask[:, :-1, :]\n    pred_grad_x = pred[:, :, 1:] - pred[:, :, :-1]\n    pred_grad_y = pred[:, 1:, :] - pred[:, :-1, :]\n    gt_grad_x = gt[:, :, 1:] - gt[:, :, :-1]\n    gt_grad_y = gt[:, 1:, :] - gt[:, :-1, :]\n    loss = masked_l1_loss(\n        pred_grad_x[mask_x][..., None], gt_grad_x[mask_x][..., None], quantile=quantile\n    ) + masked_l1_loss(\n        pred_grad_y[mask_y][..., None], gt_grad_y[mask_y][..., None], quantile=quantile\n    )\n    return loss\n\n\ndef knn(x: torch.Tensor, k: int) -> tuple[np.ndarray, np.ndarray]:\n    x = x.cpu().numpy()\n    knn_model = NearestNeighbors(\n        n_neighbors=k + 1, algorithm=\"auto\", metric=\"euclidean\"\n    ).fit(x)\n    distances, indices = knn_model.kneighbors(x)\n    return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)\n\n\ndef get_weights_for_procrustes(clusters, visibilities=None):\n    clusters_median = clusters.median(dim=-2, keepdim=True)[0]\n    dists2clusters_center = torch.norm(clusters - clusters_median, dim=-1)\n    dists2clusters_center /= dists2clusters_center.median(dim=-1, keepdim=True)[0]\n    weights = torch.exp(-dists2clusters_center)\n    weights /= weights.mean(dim=-1, keepdim=True) + 1e-6\n    if visibilities is not None:\n        weights *= visibilities.float() + 1e-6\n    invalid = dists2clusters_center > np.quantile(\n        dists2clusters_center.cpu().numpy(), 0.9\n    )\n    invalid |= torch.isnan(weights)\n    weights[invalid] = 0\n    return weights\n\n\ndef compute_z_acc_loss(means_ts_nb: torch.Tensor, w2cs: torch.Tensor):\n    \"\"\"\n    :param means_ts (G, 3, B, 3)\n    :param w2cs (B, 4, 4)\n    return (float)\n    \"\"\"\n    camera_center_t = torch.linalg.inv(w2cs)[:, :3, 3]  # (B, 3)\n    ray_dir = F.normalize(\n        means_ts_nb[:, 1] - camera_center_t, p=2.0, dim=-1\n    )  # [G, B, 3]\n    # acc = 2 * means[:, 1] - means[:, 0] - means[:, 2]  # [G, B, 3]\n    # acc_loss = (acc * ray_dir).sum(dim=-1).abs().mean()\n    acc_loss = (\n        ((means_ts_nb[:, 1] - means_ts_nb[:, 0]) * ray_dir).sum(dim=-1) ** 2\n    ).mean() + (\n        ((means_ts_nb[:, 2] - means_ts_nb[:, 1]) * ray_dir).sum(dim=-1) ** 2\n    ).mean()\n    return acc_loss\n\n\ndef compute_se3_smoothness_loss(\n    rots: torch.Tensor,\n    transls: torch.Tensor,\n    weight_rot: float = 1.0,\n    weight_transl: float = 2.0,\n):\n    \"\"\"\n    central differences\n    :param motion_transls (K, T, 3)\n    :param motion_rots (K, T, 6)\n    \"\"\"\n    r_accel_loss = compute_accel_loss(rots)\n    t_accel_loss = compute_accel_loss(transls)\n    return r_accel_loss * weight_rot + t_accel_loss * weight_transl\n\n\ndef compute_accel_loss(transls):\n    accel = 2 * transls[:, 1:-1] - transls[:, :-2] - transls[:, 2:]\n    loss = accel.norm(dim=-1).mean()\n    return loss\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/metrics.py",
    "content": "from typing import Literal\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torchmetrics.functional.image.lpips import _NoTrainLpips\nfrom torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure\nfrom torchmetrics.metric import Metric\nfrom torchmetrics.utilities import dim_zero_cat\nfrom torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE\n\n\ndef compute_psnr(\n    preds: torch.Tensor,\n    targets: torch.Tensor,\n    masks: torch.Tensor | None = None,\n) -> float:\n    \"\"\"\n    Args:\n        preds (torch.Tensor): (..., 3) predicted images in [0, 1].\n        targets (torch.Tensor): (..., 3) target images in [0, 1].\n        masks (torch.Tensor | None): (...,) optional binary masks where the\n            1-regions will be taken into account.\n\n    Returns:\n        psnr (float): Peak signal-to-noise ratio.\n    \"\"\"\n    if masks is None:\n        masks = torch.ones_like(preds[..., 0])\n    return (\n        -10.0\n        * torch.log(\n            F.mse_loss(\n                preds * masks[..., None],\n                targets * masks[..., None],\n                reduction=\"sum\",\n            )\n            / masks.sum().clamp(min=1.0)\n            / 3.0\n        )\n        / np.log(10.0)\n    ).item()\n\n\ndef compute_pose_errors(\n    preds: torch.Tensor, targets: torch.Tensor\n) -> tuple[float, float, float]:\n    \"\"\"\n    Args:\n        preds: (N, 4, 4) predicted camera poses.\n        targets: (N, 4, 4) target camera poses.\n\n    Returns:\n        ate (float): Absolute trajectory error.\n        rpe_t (float): Relative pose error in translation.\n        rpe_r (float): Relative pose error in rotation (degree).\n    \"\"\"\n    # Compute ATE.\n    ate = torch.linalg.norm(preds[:, :3, -1] - targets[:, :3, -1], dim=-1).mean().item()\n    # Compute RPE_t and RPE_r.\n    # NOTE(hangg): It's important to use numpy here for the accuracy of RPE_r.\n    # torch has numerical issues for acos when the value is close to 1.0, i.e.\n    # RPE_r is supposed to be very small, and will result in artificially large\n    # error.\n    preds = preds.detach().cpu().numpy()\n    targets = targets.detach().cpu().numpy()\n    pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:]\n    pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:]\n    target_rels = np.linalg.inv(targets[:-1]) @ targets[1:]\n    error_rels = np.linalg.inv(target_rels) @ pred_rels\n    traces = error_rels[:, :3, :3].trace(axis1=-2, axis2=-1)\n    rpe_t = np.linalg.norm(error_rels[:, :3, -1], axis=-1).mean().item()\n    rpe_r = (\n        np.arccos(np.clip((traces - 1.0) / 2.0, -1.0, 1.0)).mean().item()\n        / np.pi\n        * 180.0\n    )\n    return ate, rpe_t, rpe_r\n\n\nclass mPSNR(PeakSignalNoiseRatio):\n    sum_squared_error: list[torch.Tensor]\n    total: list[torch.Tensor]\n\n    def __init__(self, **kwargs) -> None:\n        super().__init__(\n            data_range=1.0,\n            base=10.0,\n            dim=None,\n            reduction=\"elementwise_mean\",\n            **kwargs,\n        )\n        self.add_state(\"sum_squared_error\", default=[], dist_reduce_fx=\"cat\")\n        self.add_state(\"total\", default=[], dist_reduce_fx=\"cat\")\n\n    def __len__(self) -> int:\n        return len(self.total)\n\n    def update(\n        self,\n        preds: torch.Tensor,\n        targets: torch.Tensor,\n        masks: torch.Tensor | None = None,\n    ):\n        \"\"\"Update state with predictions and targets.\n\n        Args:\n            preds (torch.Tensor): (..., 3) float32 predicted images.\n            targets (torch.Tensor): (..., 3) float32 target images.\n            masks (torch.Tensor | None): (...,) optional binary masks where the\n                1-regions will be taken into account.\n        \"\"\"\n        if masks is None:\n            masks = torch.ones_like(preds[..., 0])\n        self.sum_squared_error.append(\n            torch.sum(torch.pow((preds - targets) * masks[..., None], 2))\n        )\n        self.total.append(masks.sum().to(torch.int64) * 3)\n\n    def compute(self) -> torch.Tensor:\n        \"\"\"Compute peak signal-to-noise ratio over state.\"\"\"\n        sum_squared_error = dim_zero_cat(self.sum_squared_error)\n        total = dim_zero_cat(self.total)\n        return -10.0 * torch.log(sum_squared_error / total).mean() / np.log(10.0)\n\n\nclass mSSIM(StructuralSimilarityIndexMeasure):\n    similarity: list\n\n    def __init__(self, **kwargs) -> None:\n        super().__init__(\n            reduction=None,\n            data_range=1.0,\n            return_full_image=False,\n            **kwargs,\n        )\n        assert isinstance(self.sigma, float)\n\n    def __len__(self) -> int:\n        return sum([s.shape[0] for s in self.similarity])\n\n    def update(\n        self,\n        preds: torch.Tensor,\n        targets: torch.Tensor,\n        masks: torch.Tensor | None = None,\n    ):\n        \"\"\"Update state with predictions and targets.\n\n        Args:\n            preds (torch.Tensor): (B, H, W, 3) float32 predicted images.\n            targets (torch.Tensor): (B, H, W, 3) float32 target images.\n            masks (torch.Tensor | None): (B, H, W) optional binary masks where\n                the 1-regions will be taken into account.\n        \"\"\"\n        if masks is None:\n            masks = torch.ones_like(preds[..., 0])\n\n        # Construct a 1D Gaussian blur filter.\n        assert isinstance(self.kernel_size, int)\n        hw = self.kernel_size // 2\n        shift = (2 * hw - self.kernel_size + 1) / 2\n        assert isinstance(self.sigma, float)\n        f_i = (\n            (torch.arange(self.kernel_size, device=preds.device) - hw + shift)\n            / self.sigma\n        ) ** 2\n        filt = torch.exp(-0.5 * f_i)\n        filt /= torch.sum(filt)\n\n        # Blur in x and y (faster than the 2D convolution).\n        def convolve2d(z, m, f):\n            # z: (B, H, W, C), m: (B, H, W), f: (Hf, Wf).\n            z = z.permute(0, 3, 1, 2)\n            m = m[:, None]\n            f = f[None, None].expand(z.shape[1], -1, -1, -1)\n            z_ = torch.nn.functional.conv2d(\n                z * m, f, padding=\"valid\", groups=z.shape[1]\n            )\n            m_ = torch.nn.functional.conv2d(m, torch.ones_like(f[:1]), padding=\"valid\")\n            return torch.where(\n                m_ != 0, z_ * torch.ones_like(f).sum() / (m_ * z.shape[1]), 0\n            ).permute(0, 2, 3, 1), (m_ != 0)[:, 0].to(z.dtype)\n\n        filt_fn1 = lambda z, m: convolve2d(z, m, filt[:, None])\n        filt_fn2 = lambda z, m: convolve2d(z, m, filt[None, :])\n        filt_fn = lambda z, m: filt_fn1(*filt_fn2(z, m))\n\n        mu0 = filt_fn(preds, masks)[0]\n        mu1 = filt_fn(targets, masks)[0]\n        mu00 = mu0 * mu0\n        mu11 = mu1 * mu1\n        mu01 = mu0 * mu1\n        sigma00 = filt_fn(preds**2, masks)[0] - mu00\n        sigma11 = filt_fn(targets**2, masks)[0] - mu11\n        sigma01 = filt_fn(preds * targets, masks)[0] - mu01\n\n        # Clip the variances and covariances to valid values.\n        # Variance must be non-negative:\n        sigma00 = sigma00.clamp(min=0.0)\n        sigma11 = sigma11.clamp(min=0.0)\n        sigma01 = torch.sign(sigma01) * torch.minimum(\n            torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)\n        )\n\n        assert isinstance(self.data_range, float)\n        c1 = (self.k1 * self.data_range) ** 2\n        c2 = (self.k2 * self.data_range) ** 2\n        numer = (2 * mu01 + c1) * (2 * sigma01 + c2)\n        denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)\n        ssim_map = numer / denom\n\n        self.similarity.append(ssim_map.mean(dim=(1, 2, 3)))\n\n    def compute(self) -> torch.Tensor:\n        \"\"\"Compute final SSIM metric.\"\"\"\n        return torch.cat(self.similarity).mean()\n\n\nclass mLPIPS(Metric):\n    sum_scores: list[torch.Tensor]\n    total: list[torch.Tensor]\n\n    def __init__(\n        self,\n        net_type: Literal[\"vgg\", \"alex\", \"squeeze\"] = \"alex\",\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        if not _TORCHVISION_AVAILABLE:\n            raise ModuleNotFoundError(\n                \"LPIPS metric requires that torchvision is installed.\"\n                \" Either install as `pip install torchmetrics[image]` or `pip install torchvision`.\"\n            )\n\n        valid_net_type = (\"vgg\", \"alex\", \"squeeze\")\n        if net_type not in valid_net_type:\n            raise ValueError(\n                f\"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.\"\n            )\n        self.net = _NoTrainLpips(net=net_type, spatial=True)\n\n        self.add_state(\"sum_scores\", [], dist_reduce_fx=\"cat\")\n        self.add_state(\"total\", [], dist_reduce_fx=\"cat\")\n\n    def __len__(self) -> int:\n        return len(self.total)\n\n    def update(\n        self,\n        preds: torch.Tensor,\n        targets: torch.Tensor,\n        masks: torch.Tensor | None = None,\n    ):\n        \"\"\"Update internal states with lpips scores.\n\n        Args:\n            preds (torch.Tensor): (B, H, W, 3) float32 predicted images.\n            targets (torch.Tensor): (B, H, W, 3) float32 target images.\n            masks (torch.Tensor | None): (B, H, W) optional float32 binary\n                masks where the 1-regions will be taken into account.\n        \"\"\"\n        if masks is None:\n            masks = torch.ones_like(preds[..., 0])\n        scores = self.net(\n            (preds * masks[..., None]).permute(0, 3, 1, 2),\n            (targets * masks[..., None]).permute(0, 3, 1, 2),\n            normalize=True,\n        )\n        self.sum_scores.append((scores * masks[:, None]).sum())\n        self.total.append(masks.sum().to(torch.int64))\n\n    def compute(self) -> torch.Tensor:\n        \"\"\"Compute final perceptual similarity metric.\"\"\"\n        return (\n            torch.tensor(self.sum_scores, device=self.device)\n            / torch.tensor(self.total, device=self.device)\n        ).mean()\n\n\nclass PCK(Metric):\n    correct: list[torch.Tensor]\n    total: list[int]\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.add_state(\"correct\", default=[], dist_reduce_fx=\"cat\")\n        self.add_state(\"total\", default=[], dist_reduce_fx=\"cat\")\n\n    def __len__(self) -> int:\n        return len(self.total)\n\n    def update(self, preds: torch.Tensor, targets: torch.Tensor, threshold: float):\n        \"\"\"Update internal states with PCK scores.\n\n        Args:\n            preds (torch.Tensor): (N, 2) predicted 2D keypoints.\n            targets (torch.Tensor): (N, 2) targets 2D keypoints.\n            threshold (float): PCK threshold.\n        \"\"\"\n\n        self.correct.append(\n            (torch.linalg.norm(preds - targets, dim=-1) < threshold).sum()\n        )\n        self.total.append(preds.shape[0])\n\n    def compute(self) -> torch.Tensor:\n        \"\"\"Compute PCK over state.\"\"\"\n        return (\n            torch.tensor(self.correct, device=self.device)\n            / torch.clamp(torch.tensor(self.total, device=self.device), min=1e-8)\n        ).mean()\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/params.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom flow3d.transforms import cont_6d_to_rmat\n\n\nclass GaussianParams(nn.Module):\n    def __init__(\n        self,\n        means: torch.Tensor,\n        quats: torch.Tensor,\n        scales: torch.Tensor,\n        colors: torch.Tensor,\n        opacities: torch.Tensor,\n        motion_coefs: torch.Tensor | None = None,\n        scene_center: torch.Tensor | None = None,\n        scene_scale: torch.Tensor | float = 1.0,\n    ):\n        super().__init__()\n        if not check_gaussian_sizes(\n            means, quats, scales, colors, opacities, motion_coefs\n        ):\n            import ipdb\n\n            ipdb.set_trace()\n        params_dict = {\n            \"means\": nn.Parameter(means),\n            \"quats\": nn.Parameter(quats),\n            \"scales\": nn.Parameter(scales),\n            \"colors\": nn.Parameter(colors),\n            \"opacities\": nn.Parameter(opacities),\n        }\n        if motion_coefs is not None:\n            params_dict[\"motion_coefs\"] = nn.Parameter(motion_coefs)\n        self.params = nn.ParameterDict(params_dict)\n        self.quat_activation = lambda x: F.normalize(x, dim=-1, p=2)\n        self.color_activation = torch.sigmoid\n        self.scale_activation = torch.exp\n        self.opacity_activation = torch.sigmoid\n        self.motion_coef_activation = lambda x: F.softmax(x, dim=-1)\n\n        if scene_center is None:\n            scene_center = torch.zeros(3, device=means.device)\n        self.register_buffer(\"scene_center\", scene_center)\n        self.register_buffer(\"scene_scale\", torch.as_tensor(scene_scale))\n\n    @staticmethod\n    def init_from_state_dict(state_dict, prefix=\"params.\"):\n        req_keys = [\"means\", \"quats\", \"scales\", \"colors\", \"opacities\"]\n        assert all(f\"{prefix}{k}\" in state_dict for k in req_keys)\n        args = {\n            \"motion_coefs\": None,\n            \"scene_center\": torch.zeros(3),\n            \"scene_scale\": torch.tensor(1.0),\n        }\n        for k in req_keys + list(args.keys()):\n            if f\"{prefix}{k}\" in state_dict:\n                args[k] = state_dict[f\"{prefix}{k}\"]\n        return GaussianParams(**args)\n\n    @property\n    def num_gaussians(self) -> int:\n        return self.params[\"means\"].shape[0]\n\n    def get_colors(self) -> torch.Tensor:\n        return self.color_activation(self.params[\"colors\"])\n\n    def get_scales(self) -> torch.Tensor:\n        return self.scale_activation(self.params[\"scales\"])\n\n    def get_opacities(self) -> torch.Tensor:\n        return self.opacity_activation(self.params[\"opacities\"])\n\n    def get_quats(self) -> torch.Tensor:\n        return self.quat_activation(self.params[\"quats\"])\n\n    def get_coefs(self) -> torch.Tensor:\n        assert \"motion_coefs\" in self.params\n        return self.motion_coef_activation(self.params[\"motion_coefs\"])\n\n    def densify_params(self, should_split, should_dup):\n        \"\"\"\n        densify gaussians\n        \"\"\"\n        updated_params = {}\n        for name, x in self.params.items():\n            x_dup = x[should_dup]\n            x_split = x[should_split].repeat([2] + [1] * (x.ndim - 1))\n            if name == \"scales\":\n                x_split -= math.log(1.6)\n            x_new = nn.Parameter(torch.cat([x[~should_split], x_dup, x_split], dim=0))\n            updated_params[name] = x_new\n            self.params[name] = x_new\n        return updated_params\n\n    def cull_params(self, should_cull):\n        \"\"\"\n        cull gaussians\n        \"\"\"\n        updated_params = {}\n        for name, x in self.params.items():\n            x_new = nn.Parameter(x[~should_cull])\n            updated_params[name] = x_new\n            self.params[name] = x_new\n        return updated_params\n\n    def reset_opacities(self, new_val):\n        \"\"\"\n        reset all opacities to new_val\n        \"\"\"\n        self.params[\"opacities\"].data.fill_(new_val)\n        updated_params = {\"opacities\": self.params[\"opacities\"]}\n        return updated_params\n\n\nclass MotionBases(nn.Module):\n    def __init__(self, rots, transls):\n        super().__init__()\n        self.num_frames = rots.shape[1]\n        self.num_bases = rots.shape[0]\n        assert check_bases_sizes(rots, transls)\n        self.params = nn.ParameterDict(\n            {\n                \"rots\": nn.Parameter(rots),\n                \"transls\": nn.Parameter(transls),\n            }\n        )\n\n    @staticmethod\n    def init_from_state_dict(state_dict, prefix=\"params.\"):\n        param_keys = [\"rots\", \"transls\"]\n        assert all(f\"{prefix}{k}\" in state_dict for k in param_keys)\n        args = {k: state_dict[f\"{prefix}{k}\"] for k in param_keys}\n        return MotionBases(**args)\n\n    def compute_transforms(self, ts: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        :param ts (B)\n        :param coefs (G, K)\n        returns transforms (G, B, 3, 4)\n        \"\"\"\n        transls = self.params[\"transls\"][:, ts]  # (K, B, 3)\n        rots = self.params[\"rots\"][:, ts]  # (K, B, 6)\n        transls = torch.einsum(\"pk,kni->pni\", coefs, transls)\n        rots = torch.einsum(\"pk,kni->pni\", coefs, rots)  # (G, B, 6)\n        rotmats = cont_6d_to_rmat(rots)  # (K, B, 3, 3)\n        return torch.cat([rotmats, transls[..., None]], dim=-1)\n\n\ndef check_gaussian_sizes(\n    means: torch.Tensor,\n    quats: torch.Tensor,\n    scales: torch.Tensor,\n    colors: torch.Tensor,\n    opacities: torch.Tensor,\n    motion_coefs: torch.Tensor | None = None,\n) -> bool:\n    dims = means.shape[:-1]\n    leading_dims_match = (\n        quats.shape[:-1] == dims\n        and scales.shape[:-1] == dims\n        and colors.shape[:-1] == dims\n        and opacities.shape == dims\n    )\n    if motion_coefs is not None and motion_coefs.numel() > 0:\n        leading_dims_match &= motion_coefs.shape[:-1] == dims\n    dims_correct = (\n        means.shape[-1] == 3\n        and (quats.shape[-1] == 4)\n        and (scales.shape[-1] == 3)\n        and (colors.shape[-1] == 3)\n    )\n    return leading_dims_match and dims_correct\n\n\ndef check_bases_sizes(motion_rots: torch.Tensor, motion_transls: torch.Tensor) -> bool:\n    return (\n        motion_rots.shape[-1] == 6\n        and motion_transls.shape[-1] == 3\n        and motion_rots.shape[:-2] == motion_transls.shape[:-2]\n    )\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/renderer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger as guru\nfrom nerfview import CameraState\n\nfrom flow3d.scene_model import SceneModel\nfrom flow3d.vis.utils import draw_tracks_2d_th, get_server\nfrom flow3d.vis.viewer import DynamicViewer\n\n\nclass Renderer:\n    def __init__(\n        self,\n        model: SceneModel,\n        device: torch.device,\n        # Logging.\n        work_dir: str,\n        port: int | None = None,\n    ):\n        self.device = device\n\n        self.model = model\n        self.num_frames = model.num_frames\n\n        self.work_dir = work_dir\n        self.global_step = 0\n        self.epoch = 0\n\n        self.viewer = None\n        if port is not None:\n            server = get_server(port=port)\n            self.viewer = DynamicViewer(\n                server, self.render_fn, model.num_frames, work_dir, mode=\"rendering\"\n            )\n\n        self.tracks_3d = self.model.compute_poses_fg(\n            #  torch.arange(max(0, t - 20), max(1, t), device=self.device),\n            torch.arange(self.num_frames, device=self.device),\n            inds=torch.arange(10, device=self.device),\n        )[0]\n\n    @staticmethod\n    def init_from_checkpoint(\n        path: str, device: torch.device, *args, **kwargs\n    ) -> \"Renderer\":\n        guru.info(f\"Loading checkpoint from {path}\")\n        ckpt = torch.load(path)\n        state_dict = ckpt[\"model\"]\n        model = SceneModel.init_from_state_dict(state_dict)\n        model = model.to(device)\n        renderer = Renderer(model, device, *args, **kwargs)\n        renderer.global_step = ckpt.get(\"global_step\", 0)\n        renderer.epoch = ckpt.get(\"epoch\", 0)\n        return renderer\n\n    @torch.inference_mode()\n    def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]):\n        if self.viewer is None:\n            return np.full((img_wh[1], img_wh[0], 3), 255, dtype=np.uint8)\n\n        W, H = img_wh\n\n        focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item()\n        K = torch.tensor(\n            [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]],\n            device=self.device,\n        )\n        w2c = torch.linalg.inv(\n            torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device)\n        )\n        t = (\n            int(self.viewer._playback_guis[0].value)\n            if not self.viewer._canonical_checkbox.value\n            else None\n        )\n        self.model.training = False\n        img = self.model.render(t, w2c[None], K[None], img_wh)[\"img\"][0]\n        if not self.viewer._render_track_checkbox.value:\n            img = (img.cpu().numpy() * 255.0).astype(np.uint8)\n        else:\n            assert t is not None\n            tracks_3d = self.tracks_3d[:, max(0, t - 20) : max(1, t)]\n            tracks_2d = torch.einsum(\n                \"ij,jk,nbk->nbi\", K, w2c[:3], F.pad(tracks_3d, (0, 1), value=1.0)\n            )\n            tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:]\n            img = draw_tracks_2d_th(img, tracks_2d)\n        return img\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/scene_model.py",
    "content": "import roma\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom gsplat.rendering import rasterization\nfrom torch import Tensor\n\nfrom flow3d.params import GaussianParams, MotionBases\n\n\nclass SceneModel(nn.Module):\n    def __init__(\n        self,\n        Ks: Tensor,\n        w2cs: Tensor,\n        fg_params: GaussianParams,\n        motion_bases: MotionBases,\n        bg_params: GaussianParams | None = None,\n    ):\n        super().__init__()\n        self.num_frames = motion_bases.num_frames\n        self.fg = fg_params\n        self.motion_bases = motion_bases\n        self.bg = bg_params\n        scene_scale = 1.0 if bg_params is None else bg_params.scene_scale\n        self.register_buffer(\"bg_scene_scale\", torch.as_tensor(scene_scale))\n        self.register_buffer(\"Ks\", Ks)\n        self.register_buffer(\"w2cs\", w2cs)\n\n        self._current_xys = None\n        self._current_radii = None\n        self._current_img_wh = None\n\n    @property\n    def num_gaussians(self) -> int:\n        return self.num_bg_gaussians + self.num_fg_gaussians\n\n    @property\n    def num_bg_gaussians(self) -> int:\n        return self.bg.num_gaussians if self.bg is not None else 0\n\n    @property\n    def num_fg_gaussians(self) -> int:\n        return self.fg.num_gaussians\n\n    @property\n    def num_motion_bases(self) -> int:\n        return self.motion_bases.num_bases\n\n    @property\n    def has_bg(self) -> bool:\n        return self.bg is not None\n\n    def compute_poses_bg(self) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Returns:\n            means: (G, B, 3)\n            quats: (G, B, 4)\n        \"\"\"\n        assert self.bg is not None\n        return self.bg.params[\"means\"], self.bg.get_quats()\n\n    def compute_transforms(\n        self, ts: torch.Tensor, inds: torch.Tensor | None = None\n    ) -> torch.Tensor:\n        coefs = self.fg.get_coefs()  # (G, K)\n        if inds is not None:\n            coefs = coefs[inds]\n        transfms = self.motion_bases.compute_transforms(ts, coefs)  # (G, B, 3, 4)\n        return transfms\n\n    def compute_poses_fg(\n        self, ts: torch.Tensor | None, inds: torch.Tensor | None = None\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        :returns means: (G, B, 3), quats: (G, B, 4)\n        \"\"\"\n        means = self.fg.params[\"means\"]  # (G, 3)\n        quats = self.fg.get_quats()  # (G, 4)\n        if inds is not None:\n            means = means[inds]\n            quats = quats[inds]\n        if ts is not None:\n            transfms = self.compute_transforms(ts, inds)  # (G, B, 3, 4)\n            means = torch.einsum(\n                \"pnij,pj->pni\",\n                transfms,\n                F.pad(means, (0, 1), value=1.0),\n            )\n            quats = roma.quat_xyzw_to_wxyz(\n                (\n                    roma.quat_product(\n                        roma.rotmat_to_unitquat(transfms[..., :3, :3]),\n                        roma.quat_wxyz_to_xyzw(quats[:, None]),\n                    )\n                )\n            )\n            quats = F.normalize(quats, p=2, dim=-1)\n        else:\n            means = means[:, None]\n            quats = quats[:, None]\n        return means, quats\n\n    def compute_poses_all(\n        self, ts: torch.Tensor | None\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        means, quats = self.compute_poses_fg(ts)\n        if self.has_bg:\n            bg_means, bg_quats = self.compute_poses_bg()\n            means = torch.cat(\n                [means, bg_means[:, None].expand(-1, means.shape[1], -1)], dim=0\n            ).contiguous()\n            quats = torch.cat(\n                [quats, bg_quats[:, None].expand(-1, means.shape[1], -1)], dim=0\n            ).contiguous()\n        return means, quats\n\n    def get_colors_all(self) -> torch.Tensor:\n        colors = self.fg.get_colors()\n        if self.bg is not None:\n            colors = torch.cat([colors, self.bg.get_colors()], dim=0).contiguous()\n        return colors\n\n    def get_scales_all(self) -> torch.Tensor:\n        scales = self.fg.get_scales()\n        if self.bg is not None:\n            scales = torch.cat([scales, self.bg.get_scales()], dim=0).contiguous()\n        return scales\n\n    def get_opacities_all(self) -> torch.Tensor:\n        \"\"\"\n        :returns colors: (G, 3), scales: (G, 3), opacities: (G, 1)\n        \"\"\"\n        opacities = self.fg.get_opacities()\n        if self.bg is not None:\n            opacities = torch.cat(\n                [opacities, self.bg.get_opacities()], dim=0\n            ).contiguous()\n        return opacities\n\n    @staticmethod\n    def init_from_state_dict(state_dict, prefix=\"\"):\n        fg = GaussianParams.init_from_state_dict(\n            state_dict, prefix=f\"{prefix}fg.params.\"\n        )\n        bg = None\n        if any(\"bg.\" in k for k in state_dict):\n            bg = GaussianParams.init_from_state_dict(\n                state_dict, prefix=f\"{prefix}bg.params.\"\n            )\n        motion_bases = MotionBases.init_from_state_dict(\n            state_dict, prefix=f\"{prefix}motion_bases.params.\"\n        )\n        Ks = state_dict[f\"{prefix}Ks\"]\n        w2cs = state_dict[f\"{prefix}w2cs\"]\n        return SceneModel(Ks, w2cs, fg, motion_bases, bg)\n\n    def render(\n        self,\n        # A single time instance for view rendering.\n        t: int | None,\n        w2cs: torch.Tensor,  # (C, 4, 4)\n        Ks: torch.Tensor,  # (C, 3, 3)\n        img_wh: tuple[int, int],\n        # Multiple time instances for track rendering: (B,).\n        target_ts: torch.Tensor | None = None,  # (B)\n        target_w2cs: torch.Tensor | None = None,  # (B, 4, 4)\n        bg_color: torch.Tensor | float = 1.0,\n        colors_override: torch.Tensor | None = None,\n        means: torch.Tensor | None = None,\n        quats: torch.Tensor | None = None,\n        target_means: torch.Tensor | None = None,\n        return_color: bool = True,\n        return_depth: bool = False,\n        return_mask: bool = False,\n        fg_only: bool = False,\n        filter_mask: torch.Tensor | None = None,\n    ) -> dict:\n        device = w2cs.device\n        C = w2cs.shape[0]\n\n        W, H = img_wh\n        pose_fnc = self.compute_poses_fg if fg_only else self.compute_poses_all\n        N = self.num_fg_gaussians if fg_only else self.num_gaussians\n\n        if means is None or quats is None:\n            means, quats = pose_fnc(\n                torch.tensor([t], device=device) if t is not None else None\n            )\n            means = means[:, 0]\n            quats = quats[:, 0]\n\n        if colors_override is None:\n            if return_color:\n                colors_override = (\n                    self.fg.get_colors() if fg_only else self.get_colors_all()\n                )\n            else:\n                colors_override = torch.zeros(N, 0, device=device)\n\n        D = colors_override.shape[-1]\n\n        scales = self.fg.get_scales() if fg_only else self.get_scales_all()\n        opacities = self.fg.get_opacities() if fg_only else self.get_opacities_all()\n\n        if isinstance(bg_color, float):\n            bg_color = torch.full((C, D), bg_color, device=device)\n        assert isinstance(bg_color, torch.Tensor)\n\n        mode = \"RGB\"\n        ds_expected = {\"img\": D}\n\n        if return_mask:\n            if self.has_bg and not fg_only:\n                mask_values = torch.zeros((self.num_gaussians, 1), device=device)\n                mask_values[: self.num_fg_gaussians] = 1.0\n            else:\n                mask_values = torch.ones((self.num_fg_gaussians, 1), device=device)\n            colors_override = torch.cat([colors_override, mask_values], dim=-1)\n            bg_color = torch.cat([bg_color, torch.zeros(C, 1, device=device)], dim=-1)\n            ds_expected[\"mask\"] = 1\n\n        B = 0\n        if target_ts is not None:\n            B = target_ts.shape[0]\n            if target_means is None:\n                target_means, _ = pose_fnc(target_ts)  # [G, B, 3]\n            if target_w2cs is not None:\n                target_means = torch.einsum(\n                    \"bij,pbj->pbi\",\n                    target_w2cs[:, :3],\n                    F.pad(target_means, (0, 1), value=1.0),\n                )\n            track_3d_vals = target_means.flatten(-2)  # (G, B * 3)\n            d_track = track_3d_vals.shape[-1]\n            colors_override = torch.cat([colors_override, track_3d_vals], dim=-1)\n            bg_color = torch.cat(\n                [bg_color, torch.zeros(C, track_3d_vals.shape[-1], device=device)],\n                dim=-1,\n            )\n            ds_expected[\"tracks_3d\"] = d_track\n\n        assert colors_override.shape[-1] == sum(ds_expected.values())\n        assert bg_color.shape[-1] == sum(ds_expected.values())\n\n        if return_depth:\n            mode = \"RGB+ED\"\n            ds_expected[\"depth\"] = 1\n\n        if filter_mask is not None:\n            assert filter_mask.shape == (N,)\n            means = means[filter_mask]\n            quats = quats[filter_mask]\n            scales = scales[filter_mask]\n            opacities = opacities[filter_mask]\n            colors_override = colors_override[filter_mask]\n\n        render_colors, alphas, info = rasterization(\n            means=means,\n            quats=quats,\n            scales=scales,\n            opacities=opacities,\n            colors=colors_override,\n            backgrounds=bg_color,\n            viewmats=w2cs,  # [C, 4, 4]\n            Ks=Ks,  # [C, 3, 3]\n            width=W,\n            height=H,\n            packed=False,\n            render_mode=mode,\n        )\n\n        # Populate the current data for adaptive gaussian control.\n        if self.training and info[\"means2d\"].requires_grad:\n            self._current_xys = info[\"means2d\"]\n            self._current_radii = info[\"radii\"]\n            self._current_img_wh = img_wh\n            # We want to be able to access to xys' gradients later in a\n            # torch.no_grad context.\n            self._current_xys.retain_grad()\n\n        assert render_colors.shape[-1] == sum(ds_expected.values())\n        outputs = torch.split(render_colors, list(ds_expected.values()), dim=-1)\n        out_dict = {}\n        for i, (name, dim) in enumerate(ds_expected.items()):\n            x = outputs[i]\n            assert x.shape[-1] == dim, f\"{x.shape[-1]=} != {dim=}\"\n            if name == \"tracks_3d\":\n                x = x.reshape(C, H, W, B, 3)\n            out_dict[name] = x\n        out_dict[\"acc\"] = alphas\n        return out_dict\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/tensor_dataclass.py",
    "content": "from dataclasses import dataclass\nfrom typing import Callable, TypeVar\n\nimport torch\nfrom typing_extensions import Self\n\nTensorDataclassT = TypeVar(\"T\", bound=\"TensorDataclass\")\n\n\nclass TensorDataclass:\n    \"\"\"A lighter version of nerfstudio's TensorDataclass:\n    https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py\n    \"\"\"\n\n    def __getitem__(self, key) -> Self:\n        return self.map(lambda x: x[key])\n\n    def to(self, device: torch.device | str) -> Self:\n        \"\"\"Move the tensors in the dataclass to the given device.\n\n        Args:\n            device: The device to move to.\n\n        Returns:\n            A new dataclass.\n        \"\"\"\n        return self.map(lambda x: x.to(device))\n\n    def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self:\n        \"\"\"Apply a function to all tensors in the dataclass.\n\n        Also recurses into lists, tuples, and dictionaries.\n\n        Args:\n            fn: The function to apply to each tensor.\n\n        Returns:\n            A new dataclass.\n        \"\"\"\n\n        MapT = TypeVar(\"MapT\")\n\n        def _map_impl(\n            fn: Callable[[torch.Tensor], torch.Tensor],\n            val: MapT,\n        ) -> MapT:\n            if isinstance(val, torch.Tensor):\n                return fn(val)\n            elif isinstance(val, TensorDataclass):\n                return type(val)(**_map_impl(fn, vars(val)))\n            elif isinstance(val, (list, tuple)):\n                return type(val)(_map_impl(fn, v) for v in val)\n            elif isinstance(val, dict):\n                assert type(val) is dict  # No subclass support.\n                return {k: _map_impl(fn, v) for k, v in val.items()}  # type: ignore\n            else:\n                return val\n\n        return _map_impl(fn, self)\n\n\n@dataclass\nclass TrackObservations(TensorDataclass):\n    xyz: torch.Tensor\n    visibles: torch.Tensor\n    invisibles: torch.Tensor\n    confidences: torch.Tensor\n    colors: torch.Tensor\n\n    def check_sizes(self) -> bool:\n        dims = self.xyz.shape[:-1]\n        return (\n            self.visibles.shape == dims\n            and self.invisibles.shape == dims\n            and self.confidences.shape == dims\n            and self.colors.shape[:-1] == dims[:-1]\n            and self.xyz.shape[-1] == 3\n            and self.colors.shape[-1] == 3\n        )\n\n    def filter_valid(self, valid_mask: torch.Tensor) -> Self:\n        return self.map(lambda x: x[valid_mask])\n\n\n@dataclass\nclass StaticObservations(TensorDataclass):\n    xyz: torch.Tensor\n    normals: torch.Tensor\n    colors: torch.Tensor\n\n    def check_sizes(self) -> bool:\n        dims = self.xyz.shape\n        return self.normals.shape == dims and self.colors.shape == dims\n\n    def filter_valid(self, valid_mask: torch.Tensor) -> Self:\n        return self.map(lambda x: x[valid_mask])\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/trainer.py",
    "content": "import functools\nimport time\nfrom dataclasses import asdict\nfrom typing import cast\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger as guru\nfrom nerfview import CameraState\nfrom pytorch_msssim import SSIM\nfrom torch.utils.tensorboard import SummaryWriter  # type: ignore\n\nfrom flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig\nfrom flow3d.loss_utils import (\n    compute_gradient_loss,\n    compute_se3_smoothness_loss,\n    compute_z_acc_loss,\n    masked_l1_loss,\n)\nfrom flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM\nfrom flow3d.scene_model import SceneModel\nfrom flow3d.vis.utils import get_server\nfrom flow3d.vis.viewer import DynamicViewer\n\n\nclass Trainer:\n    def __init__(\n        self,\n        model: SceneModel,\n        device: torch.device,\n        lr_cfg: SceneLRConfig,\n        losses_cfg: LossesConfig,\n        optim_cfg: OptimizerConfig,\n        # Logging.\n        work_dir: str,\n        port: int | None = None,\n        log_every: int = 10,\n        checkpoint_every: int = 200,\n        validate_every: int = 500,\n        validate_video_every: int = 1000,\n        validate_viewer_assets_every: int = 100,\n    ):\n        self.device = device\n        self.log_every = log_every\n        self.checkpoint_every = checkpoint_every\n        self.validate_every = validate_every\n        self.validate_video_every = validate_video_every\n        self.validate_viewer_assets_every = validate_viewer_assets_every\n\n        self.model = model\n        self.num_frames = model.num_frames\n\n        self.lr_cfg = lr_cfg\n        self.losses_cfg = losses_cfg\n        self.optim_cfg = optim_cfg\n\n        self.reset_opacity_every = (\n            self.optim_cfg.reset_opacity_every_n_controls * self.optim_cfg.control_every\n        )\n        self.optimizers, self.scheduler = self.configure_optimizers()\n\n        # running stats for adaptive density control\n        self.running_stats = {\n            \"xys_grad_norm_acc\": torch.zeros(self.model.num_gaussians, device=device),\n            \"vis_count\": torch.zeros(\n                self.model.num_gaussians, device=device, dtype=torch.int64\n            ),\n            \"max_radii\": torch.zeros(self.model.num_gaussians, device=device),\n        }\n\n        self.work_dir = work_dir\n        self.writer = SummaryWriter(log_dir=work_dir)\n        self.global_step = 0\n        self.epoch = 0\n\n        self.viewer = None\n        if port is not None:\n            server = get_server(port=port)\n            self.viewer = DynamicViewer(\n                server, self.render_fn, model.num_frames, work_dir, mode=\"training\"\n            )\n\n        # metrics\n        self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)\n        self.psnr_metric = mPSNR()\n        self.ssim_metric = mSSIM()\n        self.lpips_metric = mLPIPS()\n        self.pck_metric = PCK()\n        self.bg_psnr_metric = mPSNR()\n        self.fg_psnr_metric = mPSNR()\n        self.bg_ssim_metric = mSSIM()\n        self.fg_ssim_metric = mSSIM()\n        self.bg_lpips_metric = mLPIPS()\n        self.fg_lpips_metric = mLPIPS()\n\n    def set_epoch(self, epoch: int):\n        self.epoch = epoch\n\n    def save_checkpoint(self, path: str):\n        model_dict = self.model.state_dict()\n        optimizer_dict = {k: v.state_dict() for k, v in self.optimizers.items()}\n        scheduler_dict = {k: v.state_dict() for k, v in self.scheduler.items()}\n        ckpt = {\n            \"model\": model_dict,\n            \"optimizers\": optimizer_dict,\n            \"schedulers\": scheduler_dict,\n            \"global_step\": self.global_step,\n            \"epoch\": self.epoch,\n        }\n        torch.save(ckpt, path)\n        guru.info(f\"Saved checkpoint at {self.global_step=} to {path}\")\n\n    @staticmethod\n    def init_from_checkpoint(\n        path: str, device: torch.device, *args, **kwargs\n    ) -> tuple[\"Trainer\", int]:\n        guru.info(f\"Loading checkpoint from {path}\")\n        ckpt = torch.load(path)\n        state_dict = ckpt[\"model\"]\n        model = SceneModel.init_from_state_dict(state_dict)\n        model = model.to(device)\n        trainer = Trainer(model, device, *args, **kwargs)\n        if \"optimizers\" in ckpt:\n            trainer.load_checkpoint_optimizers(ckpt[\"optimizers\"])\n        if \"schedulers\" in ckpt:\n            trainer.load_checkpoint_schedulers(ckpt[\"schedulers\"])\n        trainer.global_step = ckpt.get(\"global_step\", 0)\n        start_epoch = ckpt.get(\"epoch\", 0)\n        trainer.set_epoch(start_epoch)\n        return trainer, start_epoch\n\n    def load_checkpoint_optimizers(self, opt_ckpt):\n        for k, v in self.optimizers.items():\n            v.load_state_dict(opt_ckpt[k])\n\n    def load_checkpoint_schedulers(self, sched_ckpt):\n        for k, v in self.scheduler.items():\n            v.load_state_dict(sched_ckpt[k])\n\n    @torch.inference_mode()\n    def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]):\n        W, H = img_wh\n\n        focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item()\n        K = torch.tensor(\n            [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]],\n            device=self.device,\n        )\n        w2c = torch.linalg.inv(\n            torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device)\n        )\n        t = 0\n        if self.viewer is not None:\n            t = (\n                int(self.viewer._playback_guis[0].value)\n                if not self.viewer._canonical_checkbox.value\n                else None\n            )\n        self.model.training = False\n        img = self.model.render(t, w2c[None], K[None], img_wh)[\"img\"][0]\n        return (img.cpu().numpy() * 255.0).astype(np.uint8)\n\n    def train_step(self, batch):\n        if self.viewer is not None:\n            while self.viewer.state.status == \"paused\":\n                time.sleep(0.1)\n            self.viewer.lock.acquire()\n\n        multi_loss = 0.0\n\n        # import ipdb\n        # ipdb.set_trace()\n\n        for view_index in [0, 1, 2, 3]:\n            view_data = batch[view_index]\n\n            loss, stats, num_rays_per_step, num_rays_per_sec = self.compute_losses(view_data)\n            if loss.isnan():\n                guru.info(f\"Loss is NaN at step {self.global_step}!!\")\n                import ipdb\n                ipdb.set_trace()\n\n            multi_loss += loss / 4\n\n        multi_loss.backward()\n        # loss.backward()\n\n        for opt in self.optimizers.values():\n            opt.step()\n            opt.zero_grad(set_to_none=True)\n        for sched in self.scheduler.values():\n            sched.step()\n\n        self.log_dict(stats)\n        self.global_step += 1\n        self.run_control_steps()\n\n        if self.viewer is not None:\n            self.viewer.lock.release()\n            self.viewer.state.num_train_rays_per_sec = num_rays_per_sec\n            if self.viewer.mode == \"training\":\n                self.viewer.update(self.global_step, num_rays_per_step)\n\n        if self.global_step % self.checkpoint_every == 0:\n            self.save_checkpoint(f\"{self.work_dir}/checkpoints/last.ckpt\")\n\n        # return loss.item()\n        return multi_loss.item()\n\n    def compute_losses(self, batch):\n        self.model.training = True\n        B = batch[\"imgs\"].shape[0]\n        W, H = img_wh = batch[\"imgs\"].shape[2:0:-1]\n        N = batch[\"target_ts\"][0].shape[0]\n\n        # (B,).\n        ts = batch[\"ts\"]\n        # (B, 4, 4).\n        w2cs = batch[\"w2cs\"]\n        # (B, 3, 3).\n        Ks = batch[\"Ks\"]\n        # (B, H, W, 3).\n        imgs = batch[\"imgs\"]\n        # (B, H, W).\n        valid_masks = batch.get(\"valid_masks\", torch.ones_like(batch[\"imgs\"][..., 0]))\n        # (B, H, W).\n        masks = batch[\"masks\"]\n        masks *= valid_masks\n        # (B, H, W).\n        depths = batch[\"depths\"]\n        # [(P, 2), ...].\n        query_tracks_2d = batch[\"query_tracks_2d\"]\n        # [(N,), ...].\n        target_ts = batch[\"target_ts\"]\n        # [(N, 4, 4), ...].\n        target_w2cs = batch[\"target_w2cs\"]\n        # [(N, 3, 3), ...].\n        target_Ks = batch[\"target_Ks\"]\n        # [(N, P, 2), ...].\n        target_tracks_2d = batch[\"target_tracks_2d\"]\n        # [(N, P), ...].\n        target_visibles = batch[\"target_visibles\"]\n        # [(N, P), ...].\n        target_invisibles = batch[\"target_invisibles\"]\n        # [(N, P), ...].\n        target_confidences = batch[\"target_confidences\"]\n        # [(N, P), ...].\n        target_track_depths = batch[\"target_track_depths\"]\n\n        _tic = time.time()\n        # (B, G, 3).\n        means, quats = self.model.compute_poses_all(ts)  # (G, B, 3), (G, B, 4)\n        device = means.device\n        means = means.transpose(0, 1)\n        quats = quats.transpose(0, 1)\n        # [(N, G, 3), ...].\n        target_ts_vec = torch.cat(target_ts)\n        # (B * N, G, 3).\n        target_means, _ = self.model.compute_poses_all(target_ts_vec)\n        target_means = target_means.transpose(0, 1)\n        target_mean_list = target_means.split(N)\n        num_frames = self.model.num_frames\n\n        loss = 0.0\n\n        bg_colors = []\n        rendered_all = []\n        self._batched_xys = []\n        self._batched_radii = []\n        self._batched_img_wh = []\n        for i in range(B):\n            bg_color = torch.ones(1, 3, device=device)\n\n\n\n            # import ipdb\n            # ipdb.set_trace()\n\n            \n            rendered = self.model.render(\n                ts[i].item(),\n                w2cs[None, i],\n                Ks[None, i],\n                img_wh,\n                target_ts=target_ts[i],\n                target_w2cs=target_w2cs[i],\n                bg_color=bg_color,\n                means=means[i],\n                quats=quats[i],\n                target_means=target_mean_list[i].transpose(0, 1),\n                return_depth=True,\n                return_mask=self.model.has_bg,\n            )\n            rendered_all.append(rendered)\n            bg_colors.append(bg_color)\n            if (\n                self.model._current_xys is not None\n                and self.model._current_radii is not None\n                and self.model._current_img_wh is not None\n            ):\n                self._batched_xys.append(self.model._current_xys)\n                self._batched_radii.append(self.model._current_radii)\n                self._batched_img_wh.append(self.model._current_img_wh)\n\n        # Necessary to make viewer work.\n        num_rays_per_step = H * W * B\n        num_rays_per_sec = num_rays_per_step / (time.time() - _tic)\n\n        # (B, H, W, N, *).\n        rendered_all = {\n            key: (\n                torch.cat([out_dict[key] for out_dict in rendered_all], dim=0)\n                if rendered_all[0][key] is not None\n                else None\n            )\n            for key in rendered_all[0]\n        }\n        bg_colors = torch.cat(bg_colors, dim=0)\n\n        # Compute losses.\n        # (B * N).\n        frame_intervals = (ts.repeat_interleave(N) - target_ts_vec).abs()\n        if not self.model.has_bg:\n            imgs = (\n                imgs * masks[..., None]\n                + (1.0 - masks[..., None]) * bg_colors[:, None, None]\n            )\n        else:\n            imgs = (\n                imgs * valid_masks[..., None]\n                + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None]\n            )\n        # (P_all, 2).\n        tracks_2d = torch.cat([x.reshape(-1, 2) for x in target_tracks_2d], dim=0)\n        # (P_all,)\n        visibles = torch.cat([x.reshape(-1) for x in target_visibles], dim=0)\n        # (P_all,)\n        confidences = torch.cat([x.reshape(-1) for x in target_confidences], dim=0)\n\n        # RGB loss.\n        rendered_imgs = cast(torch.Tensor, rendered_all[\"img\"])\n        if self.model.has_bg:\n            rendered_imgs = (\n                rendered_imgs * valid_masks[..., None]\n                + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None]\n            )\n\n\n\n        # import cv2\n        # print(imgs[0].shape)\n        # print(imgs[0].max())\n        # cv2.imwrite(\"/cluster/scratch/egundogdu/rendered_image.jpg\", ((rendered_imgs[0]*255.0).cpu().detach().numpy()).astype(np.uint8))\n\n        # if True:\n        #     import ipdb\n        #     ipdb.set_trace()\n\n\n\n        rgb_loss = 0.8 * F.l1_loss(rendered_imgs, imgs) + 0.2 * (\n            1 - self.ssim(rendered_imgs.permute(0, 3, 1, 2), imgs.permute(0, 3, 1, 2))\n        )\n        loss += rgb_loss * self.losses_cfg.w_rgb\n\n        # Mask loss.\n        if not self.model.has_bg:\n            mask_loss = F.mse_loss(rendered_all[\"acc\"], masks[..., None])  # type: ignore\n        else:\n            mask_loss = F.mse_loss(\n                rendered_all[\"acc\"], torch.ones_like(rendered_all[\"acc\"])  # type: ignore\n            ) + masked_l1_loss(\n                rendered_all[\"mask\"],\n                masks[..., None],\n                quantile=0.98,  # type: ignore\n            )\n        loss += mask_loss * self.losses_cfg.w_mask\n\n        # (B * N, H * W, 3).\n        pred_tracks_3d = (\n            rendered_all[\"tracks_3d\"].permute(0, 3, 1, 2, 4).reshape(-1, H * W, 3)  # type: ignore\n        )\n        pred_tracks_2d = torch.einsum(\n            \"bij,bpj->bpi\", torch.cat(target_Ks), pred_tracks_3d\n        )\n        # (B * N, H * W, 1).\n        mapped_depth = torch.clamp(pred_tracks_2d[..., 2:], min=1e-6)\n        # (B * N, H * W, 2).\n        pred_tracks_2d = pred_tracks_2d[..., :2] / mapped_depth\n\n        # (B * N).\n        w_interval = torch.exp(-2 * frame_intervals / num_frames)\n        # w_track_loss = min(1, (self.max_steps - self.global_step) / 6000)\n        track_weights = confidences[..., None] * w_interval\n\n        # (B, H, W).\n        masks_flatten = torch.zeros_like(masks)\n        for i in range(B):\n            # This takes advantage of the fact that the query 2D tracks are\n            # always on the grid.\n            query_pixels = query_tracks_2d[i].to(torch.int64)\n            masks_flatten[i, query_pixels[:, 1], query_pixels[:, 0]] = 1.0\n        # (B * N, H * W).\n        masks_flatten = (\n            masks_flatten.reshape(-1, H * W).tile(1, N).reshape(-1, H * W) > 0.5\n        )\n\n        track_2d_loss = masked_l1_loss(\n            pred_tracks_2d[masks_flatten][visibles],\n            tracks_2d[visibles],\n            mask=track_weights[visibles],\n            quantile=0.98,\n        ) / max(H, W)\n        loss += track_2d_loss * self.losses_cfg.w_track\n\n        depth_masks = (\n            masks[..., None] if not self.model.has_bg else valid_masks[..., None]\n        )\n\n        pred_depth = cast(torch.Tensor, rendered_all[\"depth\"])\n        pred_disp = 1.0 / (pred_depth + 1e-5)\n        tgt_disp = 1.0 / (depths[..., None] + 1e-5)\n        depth_loss = masked_l1_loss(\n            pred_disp,\n            tgt_disp,\n            mask=depth_masks,\n            quantile=0.98,\n        )\n        # depth_loss = cauchy_loss_with_uncertainty(\n        #     pred_disp.squeeze(-1),\n        #     tgt_disp.squeeze(-1),\n        #     depth_masks.squeeze(-1),\n        #     self.depth_uncertainty_activation(self.depth_uncertainties)[ts],\n        #     bias=1e-3,\n        # )\n        loss += depth_loss * self.losses_cfg.w_depth_reg\n\n        # mapped depth loss (using cached depth with EMA)\n        #  mapped_depth_loss = 0.0\n        mapped_depth_gt = torch.cat([x.reshape(-1) for x in target_track_depths], dim=0)\n        mapped_depth_loss = masked_l1_loss(\n            1 / (mapped_depth[masks_flatten][visibles] + 1e-5),\n            1 / (mapped_depth_gt[visibles, None] + 1e-5),\n            track_weights[visibles],\n        )\n\n        loss += mapped_depth_loss * self.losses_cfg.w_depth_const\n\n        #  depth_gradient_loss = 0.0\n        depth_gradient_loss = compute_gradient_loss(\n            pred_disp,\n            tgt_disp,\n            mask=depth_masks > 0.5,\n            quantile=0.95,\n        )\n        # depth_gradient_loss = compute_gradient_loss(\n        #     pred_disps,\n        #     ref_disps,\n        #     mask=depth_masks.squeeze(-1) > 0.5,\n        #     c=depth_uncertainty.detach(),\n        #     mode=\"l1\",\n        #     bias=1e-3,\n        # )\n        loss += depth_gradient_loss * self.losses_cfg.w_depth_grad\n\n        # bases should be smooth.\n        small_accel_loss = compute_se3_smoothness_loss(\n            self.model.motion_bases.params[\"rots\"],\n            self.model.motion_bases.params[\"transls\"],\n        )\n        loss += small_accel_loss * self.losses_cfg.w_smooth_bases\n\n        # tracks should be smooth\n        ts = torch.clamp(ts, min=1, max=num_frames - 2)\n        ts_neighbors = torch.cat((ts - 1, ts, ts + 1))\n        transfms_nbs = self.model.compute_transforms(ts_neighbors)  # (G, 3n, 3, 4)\n        means_fg_nbs = torch.einsum(\n            \"pnij,pj->pni\",\n            transfms_nbs,\n            F.pad(self.model.fg.params[\"means\"], (0, 1), value=1.0),\n        )\n        means_fg_nbs = means_fg_nbs.reshape(\n            means_fg_nbs.shape[0], 3, -1, 3\n        )  # [G, 3, n, 3]\n        if self.losses_cfg.w_smooth_tracks > 0:\n            small_accel_loss_tracks = 0.5 * (\n                (2 * means_fg_nbs[:, 1:-1] - means_fg_nbs[:, :-2] - means_fg_nbs[:, 2:])\n                .norm(dim=-1)\n                .mean()\n            )\n            loss += small_accel_loss_tracks * self.losses_cfg.w_smooth_tracks\n\n        # Constrain the std of scales.\n        # TODO: do we want to penalize before or after exp?\n        loss += (\n            self.losses_cfg.w_scale_var\n            * torch.var(self.model.fg.params[\"scales\"], dim=-1).mean()\n        )\n        if self.model.bg is not None:\n            loss += (\n                self.losses_cfg.w_scale_var\n                * torch.var(self.model.bg.params[\"scales\"], dim=-1).mean()\n            )\n\n        # # sparsity loss\n        # loss += 0.01 * self.opacity_activation(self.opacities).abs().mean()\n\n        # Acceleration along ray direction should be small.\n        z_accel_loss = compute_z_acc_loss(means_fg_nbs, w2cs)\n        loss += self.losses_cfg.w_z_accel * z_accel_loss\n\n        # Prepare stats for logging.\n        stats = {\n            \"train/loss\": loss.item(),\n            \"train/rgb_loss\": rgb_loss.item(),\n            \"train/mask_loss\": mask_loss.item(),\n            \"train/depth_loss\": depth_loss.item(),\n            \"train/depth_gradient_loss\": depth_gradient_loss.item(),\n            \"train/mapped_depth_loss\": mapped_depth_loss.item(),\n            \"train/track_2d_loss\": track_2d_loss.item(),\n            \"train/small_accel_loss\": small_accel_loss.item(),\n            \"train/z_acc_loss\": z_accel_loss.item(),\n            \"train/num_gaussians\": self.model.num_gaussians,\n            \"train/num_fg_gaussians\": self.model.num_fg_gaussians,\n            \"train/num_bg_gaussians\": self.model.num_bg_gaussians,\n        }\n\n        # Compute metrics.\n        with torch.no_grad():\n            psnr = self.psnr_metric(\n                rendered_imgs, imgs, masks if not self.model.has_bg else valid_masks\n            )\n            self.psnr_metric.reset()\n            stats[\"train/psnr\"] = psnr\n            if self.model.has_bg:\n                bg_psnr = self.bg_psnr_metric(rendered_imgs, imgs, 1.0 - masks)\n                fg_psnr = self.fg_psnr_metric(rendered_imgs, imgs, masks)\n                self.bg_psnr_metric.reset()\n                self.fg_psnr_metric.reset()\n                stats[\"train/bg_psnr\"] = bg_psnr\n                stats[\"train/fg_psnr\"] = fg_psnr\n\n        stats.update(\n            **{\n                \"train/num_rays_per_sec\": num_rays_per_sec,\n                \"train/num_rays_per_step\": float(num_rays_per_step),\n            }\n        )\n\n        # print(stats)\n\n        return loss, stats, num_rays_per_step, num_rays_per_sec\n\n    def log_dict(self, stats: dict):\n        for k, v in stats.items():\n            self.writer.add_scalar(k, v, self.global_step)\n\n    def run_control_steps(self):\n        global_step = self.global_step\n        # Adaptive gaussian control.\n        cfg = self.optim_cfg\n        num_frames = self.model.num_frames\n        ready = self._prepare_control_step()\n        if (\n            ready\n            and global_step > cfg.warmup_steps\n            and global_step % cfg.control_every == 0\n            and global_step < cfg.stop_control_steps\n        ):\n            if (\n                global_step < cfg.stop_densify_steps\n                and global_step % self.reset_opacity_every > num_frames\n            ):\n                self._densify_control_step(global_step)\n            if global_step % self.reset_opacity_every > min(3 * num_frames, 1000):\n                self._cull_control_step(global_step)\n            if global_step % self.reset_opacity_every == 0:\n                self._reset_opacity_control_step()\n\n            # Reset stats after every control.\n            for k in self.running_stats:\n                self.running_stats[k].zero_()\n\n    @torch.no_grad()\n    def _prepare_control_step(self) -> bool:\n        # Prepare for adaptive gaussian control based on the current stats.\n        if not (\n            self.model._current_radii is not None\n            and self.model._current_xys is not None\n        ):\n            guru.warning(\"Model not training, skipping control step preparation\")\n            return False\n\n        batch_size = len(self._batched_xys)\n        # these quantities are for each rendered view and have shapes (C, G, *)\n        # must be aggregated over all views\n        for _current_xys, _current_radii, _current_img_wh in zip(\n            self._batched_xys, self._batched_radii, self._batched_img_wh\n        ):\n            sel = _current_radii > 0\n            gidcs = torch.where(sel)[1]\n            # normalize grads to [-1, 1] screen space\n            xys_grad = _current_xys.grad.clone()\n            xys_grad[..., 0] *= _current_img_wh[0] / 2.0 * batch_size\n            xys_grad[..., 1] *= _current_img_wh[1] / 2.0 * batch_size\n            self.running_stats[\"xys_grad_norm_acc\"].index_add_(\n                0, gidcs, xys_grad[sel].norm(dim=-1)\n            )\n            self.running_stats[\"vis_count\"].index_add_(\n                0, gidcs, torch.ones_like(gidcs, dtype=torch.int64)\n            )\n            max_radii = torch.maximum(\n                self.running_stats[\"max_radii\"].index_select(0, gidcs),\n                _current_radii[sel] / max(_current_img_wh),\n            )\n            self.running_stats[\"max_radii\"].index_put((gidcs,), max_radii)\n        return True\n\n    @torch.no_grad()\n    def _densify_control_step(self, global_step):\n        assert (self.running_stats[\"vis_count\"] > 0).any()\n\n        cfg = self.optim_cfg\n        xys_grad_avg = self.running_stats[\"xys_grad_norm_acc\"] / self.running_stats[\n            \"vis_count\"\n        ].clamp_min(1)\n        is_grad_too_high = xys_grad_avg > cfg.densify_xys_grad_threshold\n        # Split gaussians.\n        scales = self.model.get_scales_all()\n        is_scale_too_big = scales.amax(dim=-1) > cfg.densify_scale_threshold\n        if global_step < cfg.stop_control_by_screen_steps:\n            is_radius_too_big = (\n                self.running_stats[\"max_radii\"] > cfg.densify_screen_threshold\n            )\n        else:\n            is_radius_too_big = torch.zeros_like(is_grad_too_high, dtype=torch.bool)\n\n        should_split = is_grad_too_high & (is_scale_too_big | is_radius_too_big)\n        should_dup = is_grad_too_high & ~is_scale_too_big\n\n        num_fg = self.model.num_fg_gaussians\n        should_fg_split = should_split[:num_fg]\n        num_fg_splits = int(should_fg_split.sum().item())\n        should_fg_dup = should_dup[:num_fg]\n        num_fg_dups = int(should_fg_dup.sum().item())\n\n        should_bg_split = should_split[num_fg:]\n        num_bg_splits = int(should_bg_split.sum().item())\n        should_bg_dup = should_dup[num_fg:]\n        num_bg_dups = int(should_bg_dup.sum().item())\n\n        fg_param_map = self.model.fg.densify_params(should_fg_split, should_fg_dup)\n        for param_name, new_params in fg_param_map.items():\n            full_param_name = f\"fg.params.{param_name}\"\n            optimizer = self.optimizers[full_param_name]\n            dup_in_optim(\n                optimizer,\n                [new_params],\n                should_fg_split,\n                num_fg_splits * 2 + num_fg_dups,\n            )\n\n        if self.model.bg is not None:\n            bg_param_map = self.model.bg.densify_params(should_bg_split, should_bg_dup)\n            for param_name, new_params in bg_param_map.items():\n                full_param_name = f\"bg.params.{param_name}\"\n                optimizer = self.optimizers[full_param_name]\n                dup_in_optim(\n                    optimizer,\n                    [new_params],\n                    should_bg_split,\n                    num_bg_splits * 2 + num_bg_dups,\n                )\n\n        # update running stats\n        for k, v in self.running_stats.items():\n            v_fg, v_bg = v[:num_fg], v[num_fg:]\n            new_v = torch.cat(\n                [\n                    v_fg[~should_fg_split],\n                    v_fg[should_fg_dup],\n                    v_fg[should_fg_split].repeat(2),\n                    v_bg[~should_bg_split],\n                    v_bg[should_bg_dup],\n                    v_bg[should_bg_split].repeat(2),\n                ],\n                dim=0,\n            )\n            self.running_stats[k] = new_v\n        guru.info(\n            f\"Split {should_split.sum().item()} gaussians, \"\n            f\"Duplicated {should_dup.sum().item()} gaussians, \"\n            f\"{self.model.num_gaussians} gaussians left\"\n        )\n\n    @torch.no_grad()\n    def _cull_control_step(self, global_step):\n        # Cull gaussians.\n        cfg = self.optim_cfg\n        opacities = self.model.get_opacities_all()\n        device = opacities.device\n        is_opacity_too_small = opacities < cfg.cull_opacity_threshold\n        is_radius_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool)\n        is_scale_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool)\n        cull_scale_threshold = (\n            torch.ones(len(is_scale_too_big), device=device) * cfg.cull_scale_threshold\n        )\n        num_fg = self.model.num_fg_gaussians\n        cull_scale_threshold[num_fg:] *= self.model.bg_scene_scale\n        if global_step > self.reset_opacity_every:\n            scales = self.model.get_scales_all()\n            is_scale_too_big = scales.amax(dim=-1) > cull_scale_threshold\n            if global_step < cfg.stop_control_by_screen_steps:\n                is_radius_too_big = (\n                    self.running_stats[\"max_radii\"] > cfg.cull_screen_threshold\n                )\n        should_cull = is_opacity_too_small | is_radius_too_big | is_scale_too_big\n        should_fg_cull = should_cull[:num_fg]\n        should_bg_cull = should_cull[num_fg:]\n\n        fg_param_map = self.model.fg.cull_params(should_fg_cull)\n        for param_name, new_params in fg_param_map.items():\n            full_param_name = f\"fg.params.{param_name}\"\n            optimizer = self.optimizers[full_param_name]\n            remove_from_optim(optimizer, [new_params], should_fg_cull)\n\n        if self.model.bg is not None:\n            bg_param_map = self.model.bg.cull_params(should_bg_cull)\n            for param_name, new_params in bg_param_map.items():\n                full_param_name = f\"bg.params.{param_name}\"\n                optimizer = self.optimizers[full_param_name]\n                remove_from_optim(optimizer, [new_params], should_bg_cull)\n\n        # update running stats\n        for k, v in self.running_stats.items():\n            self.running_stats[k] = v[~should_cull]\n\n        guru.info(\n            f\"Culled {should_cull.sum().item()} gaussians, \"\n            f\"{self.model.num_gaussians} gaussians left\"\n        )\n\n    @torch.no_grad()\n    def _reset_opacity_control_step(self):\n        # Reset gaussian opacities.\n        new_val = torch.logit(torch.tensor(0.8 * self.optim_cfg.cull_opacity_threshold))\n        for part in [\"fg\", \"bg\"]:\n            part_params = getattr(self.model, part).reset_opacities(new_val)\n            # Modify optimizer states by new assignment.\n            for param_name, new_params in part_params.items():\n                full_param_name = f\"{part}.params.{param_name}\"\n                optimizer = self.optimizers[full_param_name]\n                reset_in_optim(optimizer, [new_params])\n        guru.info(\"Reset opacities\")\n\n    def configure_optimizers(self):\n        def _exponential_decay(step, *, lr_init, lr_final):\n            t = np.clip(step / self.optim_cfg.max_steps, 0.0, 1.0)\n            lr = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n            return lr / lr_init\n\n        lr_dict = asdict(self.lr_cfg)\n        optimizers = {}\n        schedulers = {}\n        # named parameters will be [part].params.[field]\n        # e.g. fg.params.means\n        # lr config is a nested dict for each fg/bg part\n        for name, params in self.model.named_parameters():\n            part, _, field = name.split(\".\")\n            lr = lr_dict[part][field]\n            optim = torch.optim.Adam([{\"params\": params, \"lr\": lr, \"name\": name}])\n\n            if \"scales\" in name:\n                fnc = functools.partial(_exponential_decay, lr_final=0.1 * lr)\n            else:\n                fnc = lambda _, **__: 1.0\n\n            optimizers[name] = optim\n            schedulers[name] = torch.optim.lr_scheduler.LambdaLR(\n                optim, functools.partial(fnc, lr_init=lr)\n            )\n        return optimizers, schedulers\n\n\ndef dup_in_optim(optimizer, new_params: list, should_dup: torch.Tensor, num_dups: int):\n    assert len(optimizer.param_groups) == len(new_params)\n    for i, p_new in enumerate(new_params):\n        old_params = optimizer.param_groups[i][\"params\"][0]\n        param_state = optimizer.state[old_params]\n        if len(param_state) == 0:\n            return\n        for key in param_state:\n            if key == \"step\":\n                continue\n            p = param_state[key]\n            param_state[key] = torch.cat(\n                [p[~should_dup], p.new_zeros(num_dups, *p.shape[1:])],\n                dim=0,\n            )\n        del optimizer.state[old_params]\n        optimizer.state[p_new] = param_state\n        optimizer.param_groups[i][\"params\"] = [p_new]\n        del old_params\n        torch.cuda.empty_cache()\n\n\ndef remove_from_optim(optimizer, new_params: list, _should_cull: torch.Tensor):\n    assert len(optimizer.param_groups) == len(new_params)\n    for i, p_new in enumerate(new_params):\n        old_params = optimizer.param_groups[i][\"params\"][0]\n        param_state = optimizer.state[old_params]\n        if len(param_state) == 0:\n            return\n        for key in param_state:\n            if key == \"step\":\n                continue\n            param_state[key] = param_state[key][~_should_cull]\n        del optimizer.state[old_params]\n        optimizer.state[p_new] = param_state\n        optimizer.param_groups[i][\"params\"] = [p_new]\n        del old_params\n        torch.cuda.empty_cache()\n\n\ndef reset_in_optim(optimizer, new_params: list):\n    assert len(optimizer.param_groups) == len(new_params)\n    for i, p_new in enumerate(new_params):\n        old_params = optimizer.param_groups[i][\"params\"][0]\n        param_state = optimizer.state[old_params]\n        if len(param_state) == 0:\n            return\n        for key in param_state:\n            param_state[key] = torch.zeros_like(param_state[key])\n        del optimizer.state[old_params]\n        optimizer.state[p_new] = param_state\n        optimizer.param_groups[i][\"params\"] = [p_new]\n        del old_params\n        torch.cuda.empty_cache()\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/trajectories.py",
    "content": "import numpy as np\nimport roma\nimport torch\nimport torch.nn.functional as F\n\nfrom .transforms import rt_to_mat4\n\n\ndef get_avg_w2c(w2cs: torch.Tensor):\n    c2ws = torch.linalg.inv(w2cs)\n    # 1. Compute the center\n    center = c2ws[:, :3, -1].mean(0)\n    # 2. Compute the z axis\n    z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1)\n    # 3. Compute axis y' (no need to normalize as it's not the final output)\n    y_ = c2ws[:, :3, 1].mean(0)  # (3)\n    # 4. Compute the x axis\n    x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1)  # (3)\n    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)\n    y = torch.cross(z, x, dim=-1)  # (3)\n    avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center)\n    avg_w2c = torch.linalg.inv(avg_c2w)\n    return avg_w2c\n\n\ndef get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:\n    \"\"\"Triangulate a set of rays to find a single lookat point.\n\n    Args:\n        origins (torch.Tensor): A (N, 3) array of ray origins.\n        viewdirs (torch.Tensor): A (N, 3) array of ray view directions.\n\n    Returns:\n        torch.Tensor: A (3,) lookat point.\n    \"\"\"\n\n    viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)\n    eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]\n    # Calculate projection matrix I - rr^T\n    I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])\n    # Compute sum of projections\n    sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)\n    # Solve for the intersection point using least squares\n    lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]\n    # Check NaNs.\n    assert not torch.any(torch.isnan(lookat))\n    return lookat\n\n\ndef get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor):\n    \"\"\"\n    Args:\n        positions: (N, 3) tensor of camera positions\n        lookat: (3,) tensor of lookat point\n        up: (3,) tensor of up vector\n\n    Returns:\n        w2cs: (N, 3, 3) tensor of world to camera rotation matrices\n    \"\"\"\n    forward_vectors = F.normalize(lookat - positions, dim=-1)\n    right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1)\n    down_vectors = F.normalize(\n        torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1\n    )\n    Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)\n    w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))\n    return w2cs\n\n\ndef get_arc_w2cs(\n    ref_w2c: torch.Tensor,\n    lookat: torch.Tensor,\n    up: torch.Tensor,\n    num_frames: int,\n    degree: float,\n    **_,\n) -> torch.Tensor:\n    ref_position = torch.linalg.inv(ref_w2c)[:3, 3]\n    thetas = (\n        torch.sin(\n            torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[\n                :-1\n            ]\n        )\n        * (degree / 2.0)\n        / 180.0\n        * torch.pi\n    )\n    positions = torch.einsum(\n        \"nij,j->ni\",\n        roma.rotvec_to_rotmat(thetas[:, None] * up[None]),\n        ref_position - lookat,\n    )\n    return get_lookat_w2cs(positions, lookat, up)\n\n\ndef get_lemniscate_w2cs(\n    ref_w2c: torch.Tensor,\n    lookat: torch.Tensor,\n    up: torch.Tensor,\n    num_frames: int,\n    degree: float,\n    **_,\n) -> torch.Tensor:\n    ref_c2w = torch.linalg.inv(ref_w2c)\n    a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)\n    # Lemniscate curve in camera space. Starting at the origin.\n    thetas = (\n        torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]\n        + torch.pi / 2\n    )\n    positions = torch.stack(\n        [\n            a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),\n            a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),\n            torch.zeros(num_frames, device=ref_w2c.device),\n        ],\n        dim=-1,\n    )\n    # Transform to world space.\n    positions = torch.einsum(\n        \"ij,nj->ni\", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)\n    )\n    return get_lookat_w2cs(positions, lookat, up)\n\n\ndef get_spiral_w2cs(\n    ref_w2c: torch.Tensor,\n    lookat: torch.Tensor,\n    up: torch.Tensor,\n    num_frames: int,\n    rads: float | torch.Tensor,\n    zrate: float,\n    rots: int,\n    **_,\n) -> torch.Tensor:\n    ref_c2w = torch.linalg.inv(ref_w2c)\n    thetas = torch.linspace(\n        0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device\n    )[:-1]\n    # Spiral curve in camera space. Starting at the origin.\n    if isinstance(rads, torch.Tensor):\n        rads = rads.reshape(-1, 3).to(ref_w2c.device)\n    positions = (\n        torch.stack(\n            [\n                torch.cos(thetas),\n                -torch.sin(thetas),\n                -torch.sin(thetas * zrate),\n            ],\n            dim=-1,\n        )\n        * rads\n    )\n    # Transform to world space.\n    positions = torch.einsum(\n        \"ij,nj->ni\", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)\n    )\n    return get_lookat_w2cs(positions, lookat, up)\n\n\ndef get_wander_w2cs(ref_w2c, focal_length, num_frames, **_):\n    device = ref_w2c.device\n    c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy())\n    max_disp = 48.0\n\n    max_trans = max_disp / focal_length\n    output_poses = []\n\n    for i in range(num_frames):\n        x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))\n        y_trans = 0.0\n        z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0\n\n        i_pose = np.concatenate(\n            [\n                np.concatenate(\n                    [\n                        np.eye(3),\n                        np.array([x_trans, y_trans, z_trans])[:, np.newaxis],\n                    ],\n                    axis=1,\n                ),\n                np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],\n            ],\n            axis=0,\n        )\n\n        i_pose = np.linalg.inv(i_pose)\n\n        ref_pose = np.concatenate(\n            [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0\n        )\n\n        render_pose = np.dot(ref_pose, i_pose)\n        output_poses.append(render_pose)\n    output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device)\n    w2cs = torch.linalg.inv(output_poses)\n\n    return w2cs\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/transforms.py",
    "content": "from typing import Literal\n\nimport roma\nimport torch\nimport torch.nn.functional as F\n\n\ndef rt_to_mat4(\n    R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None\n) -> torch.Tensor:\n    \"\"\"\n    Args:\n        R (torch.Tensor): (..., 3, 3).\n        t (torch.Tensor): (..., 3).\n        s (torch.Tensor): (...,).\n\n    Returns:\n        torch.Tensor: (..., 4, 4)\n    \"\"\"\n    mat34 = torch.cat([R, t[..., None]], dim=-1)\n    if s is None:\n        bottom = (\n            mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])\n            .reshape((1,) * (mat34.dim() - 2) + (1, 4))\n            .expand(mat34.shape[:-2] + (1, 4))\n        )\n    else:\n        bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)\n    mat4 = torch.cat([mat34, bottom], dim=-2)\n    return mat4\n\n\ndef rmat_to_cont_6d(matrix):\n    \"\"\"\n    :param matrix (*, 3, 3)\n    :returns 6d vector (*, 6)\n    \"\"\"\n    return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1)\n\n\ndef cont_6d_to_rmat(cont_6d):\n    \"\"\"\n    :param 6d vector (*, 6)\n    :returns matrix (*, 3, 3)\n    \"\"\"\n    x1 = cont_6d[..., 0:3]\n    y1 = cont_6d[..., 3:6]\n\n    x = F.normalize(x1, dim=-1)\n    y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1)\n    z = torch.linalg.cross(x, y, dim=-1)\n\n    return torch.stack([x, y, z], dim=-1)\n\n\ndef solve_procrustes(\n    src: torch.Tensor,\n    dst: torch.Tensor,\n    weights: torch.Tensor | None = None,\n    enforce_se3: bool = False,\n    rot_type: Literal[\"quat\", \"mat\", \"6d\"] = \"quat\",\n):\n    \"\"\"\n    Solve the Procrustes problem to align two point clouds, by solving the\n    following problem:\n\n    min_{s, R, t} || s * (src @ R.T + t) - dst ||_2, s.t. R.T @ R = I and det(R) = 1.\n\n    Args:\n        src (torch.Tensor): (N, 3).\n        dst (torch.Tensor): (N, 3).\n        weights (torch.Tensor | None): (N,), optional weights for alignment.\n        enforce_se3 (bool): Whether to enforce the transfm to be SE3.\n\n    Returns:\n        sim3 (tuple[torch.Tensor, torch.Tensor, torch.Tensor]):\n            q (torch.Tensor): (4,), rotation component in quaternion of WXYZ\n                format.\n            t (torch.Tensor): (3,), translation component.\n            s (torch.Tensor): (), scale component.\n        error (torch.Tensor): (), average L2 distance after alignment.\n    \"\"\"\n    # Compute weights.\n    if weights is None:\n        weights = src.new_ones(src.shape[0])\n    weights = weights[:, None] / weights.sum()\n    # Normalize point positions.\n    src_mean = (src * weights).sum(dim=0)\n    dst_mean = (dst * weights).sum(dim=0)\n    src_cent = src - src_mean\n    dst_cent = dst - dst_mean\n    # Normalize point scales.\n    if not enforce_se3:\n        src_scale = (src_cent**2 * weights).sum(dim=-1).mean().sqrt()\n        dst_scale = (dst_cent**2 * weights).sum(dim=-1).mean().sqrt()\n    else:\n        src_scale = dst_scale = src.new_tensor(1.0)\n    src_scaled = src_cent / src_scale\n    dst_scaled = dst_cent / dst_scale\n    # Compute the matrix for the singular value decomposition (SVD).\n    matrix = (weights * dst_scaled).T @ src_scaled\n    U, _, Vh = torch.linalg.svd(matrix)\n    # Special reflection case.\n    S = torch.eye(3, device=src.device)\n    if torch.det(U) * torch.det(Vh) < 0:\n        S[2, 2] = -1\n    R = U @ S @ Vh\n    # Compute the transformation.\n    if rot_type == \"quat\":\n        rot = roma.rotmat_to_unitquat(R).roll(1, dims=-1)\n    elif rot_type == \"6d\":\n        rot = rmat_to_cont_6d(R)\n    else:\n        rot = R\n    s = dst_scale / src_scale\n    t = dst_mean / s - src_mean @ R.T\n    sim3 = rot, t, s\n    # Debug: error.\n    procrustes_dst = torch.einsum(\n        \"ij,nj->ni\", rt_to_mat4(R, t, s), F.pad(src, (0, 1), value=1.0)\n    )\n    procrustes_dst = procrustes_dst[:, :3] / procrustes_dst[:, 3:]\n    error_before = (torch.linalg.norm(dst - src, dim=-1) * weights[:, 0]).sum()\n    error = (torch.linalg.norm(dst - procrustes_dst, dim=-1) * weights[:, 0]).sum()\n    # print(f\"Procrustes error: {error_before} -> {error}\")\n    # if error_before < error:\n    #     print(\"Something is wrong.\")\n    #     __import__(\"ipdb\").set_trace()\n    return sim3, (error.item(), error_before.item())\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/validator.py",
    "content": "import functools\nimport os\nimport os.path as osp\nimport time\nfrom dataclasses import asdict\nfrom typing import cast\n\nimport imageio as iio\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom loguru import logger as guru\nfrom nerfview import CameraState, Viewer\nfrom pytorch_msssim import SSIM\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nfrom flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig\nfrom flow3d.data.utils import normalize_coords, to_device\nfrom flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM\nfrom flow3d.scene_model import SceneModel\nfrom flow3d.vis.utils import (\n    apply_depth_colormap,\n    make_video_divisble,\n    plot_correspondences,\n)\n\n\nclass Validator:\n    def __init__(\n        self,\n        model: SceneModel,\n        device: torch.device,\n        train_loader: DataLoader | None,\n        val_img_loader: DataLoader | None,\n        val_kpt_loader: DataLoader | None,\n        save_dir: str,\n    ):\n        self.model = model\n        self.device = device\n        self.train_loader = train_loader\n        self.val_img_loader = val_img_loader\n        self.val_kpt_loader = val_kpt_loader\n        self.save_dir = save_dir\n        self.has_bg = self.model.has_bg\n\n        # metrics\n        self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)\n        self.psnr_metric = mPSNR()\n        self.ssim_metric = mSSIM()\n        self.lpips_metric = mLPIPS().to(device)\n        self.fg_psnr_metric = mPSNR()\n        self.fg_ssim_metric = mSSIM()\n        self.fg_lpips_metric = mLPIPS().to(device)\n        self.bg_psnr_metric = mPSNR()\n        self.bg_ssim_metric = mSSIM()\n        self.bg_lpips_metric = mLPIPS().to(device)\n        self.pck_metric = PCK()\n\n    def reset_metrics(self):\n        self.psnr_metric.reset()\n        self.ssim_metric.reset()\n        self.lpips_metric.reset()\n        self.fg_psnr_metric.reset()\n        self.fg_ssim_metric.reset()\n        self.fg_lpips_metric.reset()\n        self.bg_psnr_metric.reset()\n        self.bg_ssim_metric.reset()\n        self.bg_lpips_metric.reset()\n        self.pck_metric.reset()\n\n    @torch.no_grad()\n    def validate(self):\n        self.reset_metrics()\n        metric_imgs = self.validate_imgs() or {}\n        metric_kpts = self.validate_keypoints() or {}\n        return {**metric_imgs, **metric_kpts}\n\n    @torch.no_grad()\n    def validate_imgs(self):\n        guru.info(\"rendering validation images...\")\n        if self.val_img_loader is None:\n            return\n\n        for batch in tqdm(self.val_img_loader, desc=\"render val images\"):\n            batch = to_device(batch, self.device)\n            frame_name = batch[\"frame_names\"][0]\n            t = batch[\"ts\"][0]\n            # (1, 4, 4).\n            w2c = batch[\"w2cs\"]\n            # (1, 3, 3).\n            K = batch[\"Ks\"]\n            # (1, H, W, 3).\n            img = batch[\"imgs\"]\n            # (1, H, W).\n            valid_mask = batch.get(\n                \"valid_masks\", torch.ones_like(batch[\"imgs\"][..., 0])\n            )\n            # (1, H, W).\n            fg_mask = batch[\"masks\"]\n\n            # (H, W).\n            covisible_mask = batch.get(\n                \"covisible_masks\",\n                torch.ones_like(fg_mask)[None],\n            )\n            W, H = img_wh = img[0].shape[-2::-1]\n            rendered = self.model.render(t, w2c, K, img_wh, return_depth=True)\n\n            # Compute metrics.\n            valid_mask *= covisible_mask\n            fg_valid_mask = fg_mask * valid_mask\n            bg_valid_mask = (1 - fg_mask) * valid_mask\n            main_valid_mask = valid_mask if self.has_bg else fg_valid_mask\n\n            self.psnr_metric.update(rendered[\"img\"], img, main_valid_mask)\n            self.ssim_metric.update(rendered[\"img\"], img, main_valid_mask)\n            self.lpips_metric.update(rendered[\"img\"], img, main_valid_mask)\n\n            if self.has_bg:\n                self.fg_psnr_metric.update(rendered[\"img\"], img, fg_valid_mask)\n                self.fg_ssim_metric.update(rendered[\"img\"], img, fg_valid_mask)\n                self.fg_lpips_metric.update(rendered[\"img\"], img, fg_valid_mask)\n\n                self.bg_psnr_metric.update(rendered[\"img\"], img, bg_valid_mask)\n                self.bg_ssim_metric.update(rendered[\"img\"], img, bg_valid_mask)\n                self.bg_lpips_metric.update(rendered[\"img\"], img, bg_valid_mask)\n\n            # Dump results.\n            results_dir = osp.join(self.save_dir, \"results\", \"rgb\")\n            os.makedirs(results_dir, exist_ok=True)\n            iio.imwrite(\n                osp.join(results_dir, f\"{frame_name}.png\"),\n                (rendered[\"img\"][0].cpu().numpy() * 255).astype(np.uint8),\n            )\n\n        return {\n            \"val/psnr\": self.psnr_metric.compute(),\n            \"val/ssim\": self.ssim_metric.compute(),\n            \"val/lpips\": self.lpips_metric.compute(),\n            \"val/fg_psnr\": self.fg_psnr_metric.compute(),\n            \"val/fg_ssim\": self.fg_ssim_metric.compute(),\n            \"val/fg_lpips\": self.fg_lpips_metric.compute(),\n            \"val/bg_psnr\": self.bg_psnr_metric.compute(),\n            \"val/bg_ssim\": self.bg_ssim_metric.compute(),\n            \"val/bg_lpips\": self.bg_lpips_metric.compute(),\n        }\n\n    @torch.no_grad()\n    def validate_keypoints(self):\n        if self.val_kpt_loader is None:\n            return\n        pred_keypoints_3d_all = []\n        time_ids = self.val_kpt_loader.dataset.time_ids.tolist()\n        h, w = self.val_kpt_loader.dataset.dataset.imgs.shape[1:3]\n        pred_train_depths = np.zeros((len(time_ids), h, w))\n\n        for batch in tqdm(self.val_kpt_loader, desc=\"render val keypoints\"):\n            batch = to_device(batch, self.device)\n            # (2,).\n            ts = batch[\"ts\"][0]\n            # (2, 4, 4).\n            w2cs = batch[\"w2cs\"][0]\n            # (2, 3, 3).\n            Ks = batch[\"Ks\"][0]\n            # (2, H, W, 3).\n            imgs = batch[\"imgs\"][0]\n            # (2, P, 3).\n            keypoints = batch[\"keypoints\"][0]\n            # (P,)\n            keypoint_masks = (keypoints[..., -1] > 0.5).all(dim=0)\n            src_keypoints, target_keypoints = keypoints[:, keypoint_masks, :2]\n            W, H = img_wh = imgs.shape[-2:0:-1]\n            rendered = self.model.render(\n                ts[0].item(),\n                w2cs[:1],\n                Ks[:1],\n                img_wh,\n                target_ts=ts[1:],\n                target_w2cs=w2cs[1:],\n                return_depth=True,\n            )\n            pred_tracks_3d = rendered[\"tracks_3d\"][0, ..., 0, :]\n            pred_tracks_2d = torch.einsum(\"ij,hwj->hwi\", Ks[1], pred_tracks_3d)\n            pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp(\n                pred_tracks_2d[..., -1:], min=1e-6\n            )\n            pred_keypoints = F.grid_sample(\n                pred_tracks_2d[None].permute(0, 3, 1, 2),\n                normalize_coords(src_keypoints, H, W)[None, None],\n                align_corners=True,\n            ).permute(0, 2, 3, 1)[0, 0]\n\n            # Compute metrics.\n            self.pck_metric.update(pred_keypoints, target_keypoints, max(img_wh) * 0.05)\n\n            padded_keypoints_3d = torch.zeros_like(keypoints[0])\n            pred_keypoints_3d = F.grid_sample(\n                pred_tracks_3d[None].permute(0, 3, 1, 2),\n                normalize_coords(src_keypoints, H, W)[None, None],\n                align_corners=True,\n            ).permute(0, 2, 3, 1)[0, 0]\n            # Transform 3D keypoints back to world space.\n            pred_keypoints_3d = torch.einsum(\n                \"ij,pj->pi\",\n                torch.linalg.inv(w2cs[1])[:3],\n                F.pad(pred_keypoints_3d, (0, 1), value=1.0),\n            )\n            padded_keypoints_3d[keypoint_masks] = pred_keypoints_3d\n            # Cache predicted keypoints.\n            pred_keypoints_3d_all.append(padded_keypoints_3d.cpu().numpy())\n            pred_train_depths[time_ids.index(ts[0].item())] = (\n                rendered[\"depth\"][0, ..., 0].cpu().numpy()\n            )\n\n        # Dump unified results.\n        all_Ks = self.val_kpt_loader.dataset.dataset.Ks\n        all_w2cs = self.val_kpt_loader.dataset.dataset.w2cs\n\n        keypoint_result_dict = {\n            \"Ks\": all_Ks[time_ids].cpu().numpy(),\n            \"w2cs\": all_w2cs[time_ids].cpu().numpy(),\n            \"pred_keypoints_3d\": np.stack(pred_keypoints_3d_all, 0),\n            \"pred_train_depths\": pred_train_depths,\n        }\n\n        results_dir = osp.join(self.save_dir, \"results\")\n        os.makedirs(results_dir, exist_ok=True)\n        np.savez(\n            osp.join(results_dir, \"keypoints.npz\"),\n            **keypoint_result_dict,\n        )\n        guru.info(\n            f\"Dumped keypoint results to {results_dir=} {keypoint_result_dict['pred_keypoints_3d'].shape=}\"\n        )\n\n        return {\"val/pck\": self.pck_metric.compute()}\n\n    @torch.no_grad()\n    def save_train_videos(self, epoch: int):\n        if self.train_loader is None:\n            return\n        video_dir = osp.join(self.save_dir, \"videos\", f\"epoch_{epoch:04d}\")\n        os.makedirs(video_dir, exist_ok=True)\n        fps = getattr(self.train_loader.dataset.dataset, \"fps\", 15.0)\n        # Render video.\n        video = []\n        ref_pred_depths = []\n        masks = []\n        depth_min, depth_max = 1e6, 0\n        for batch_idx, batch in enumerate(\n            tqdm(self.train_loader, desc=\"Rendering video\", leave=False)\n        ):\n            batch = {\n                k: v.to(self.device) if isinstance(v, torch.Tensor) else v\n                for k, v in batch.items()\n            }\n            # ().\n            t = batch[\"ts\"][0]\n            # (4, 4).\n            w2c = batch[\"w2cs\"][0]\n            # (3, 3).\n            K = batch[\"Ks\"][0]\n            # (H, W, 3).\n            img = batch[\"imgs\"][0]\n            # (H, W).\n            depth = batch[\"depths\"][0]\n\n            img_wh = img.shape[-2::-1]\n            rendered = self.model.render(\n                t, w2c[None], K[None], img_wh, return_depth=True, return_mask=True\n            )\n            # Putting results onto CPU since it will consume unnecessarily\n            # large GPU memory for long sequence OW.\n            video.append(torch.cat([img, rendered[\"img\"][0]], dim=1).cpu())\n            ref_pred_depth = torch.cat(\n                (depth[..., None], rendered[\"depth\"][0]), dim=1\n            ).cpu()\n            ref_pred_depths.append(ref_pred_depth)\n            depth_min = min(depth_min, ref_pred_depth.min().item())\n            depth_max = max(depth_max, ref_pred_depth.quantile(0.99).item())\n            if rendered[\"mask\"] is not None:\n                masks.append(rendered[\"mask\"][0].cpu().squeeze(-1))\n\n        # rgb video\n        video = torch.stack(video, dim=0)\n        iio.mimwrite(\n            osp.join(video_dir, \"rgbs.mp4\"),\n            make_video_divisble((video.numpy() * 255).astype(np.uint8)),\n            fps=fps,\n        )\n        # depth video\n        depth_video = torch.stack(\n            [\n                apply_depth_colormap(\n                    ref_pred_depth, near_plane=depth_min, far_plane=depth_max\n                )\n                for ref_pred_depth in ref_pred_depths\n            ],\n            dim=0,\n        )\n        iio.mimwrite(\n            osp.join(video_dir, \"depths.mp4\"),\n            make_video_divisble((depth_video.numpy() * 255).astype(np.uint8)),\n            fps=fps,\n        )\n        if len(masks) > 0:\n            # mask video\n            mask_video = torch.stack(masks, dim=0)\n            iio.mimwrite(\n                osp.join(video_dir, \"masks.mp4\"),\n                make_video_divisble((mask_video.numpy() * 255).astype(np.uint8)),\n                fps=fps,\n            )\n\n        # Render 2D track video.\n        tracks_2d, target_imgs = [], []\n        sample_interval = 10\n        batch0 = {\n            k: v.to(self.device) if isinstance(v, torch.Tensor) else v\n            for k, v in self.train_loader.dataset[0].items()\n        }\n        # ().\n        t = batch0[\"ts\"]\n        # (4, 4).\n        w2c = batch0[\"w2cs\"]\n        # (3, 3).\n        K = batch0[\"Ks\"]\n        # (H, W, 3).\n        img = batch0[\"imgs\"]\n        # (H, W).\n        bool_mask = batch0[\"masks\"] > 0.5\n        img_wh = img.shape[-2::-1]\n        for batch in tqdm(\n            self.train_loader, desc=\"Rendering 2D track video\", leave=False\n        ):\n            batch = {\n                k: v.to(self.device) if isinstance(v, torch.Tensor) else v\n                for k, v in batch.items()\n            }\n            # Putting results onto CPU since it will consume unnecessarily\n            # large GPU memory for long sequence OW.\n            # (1, H, W, 3).\n            target_imgs.append(batch[\"imgs\"].cpu())\n            # (1,).\n            target_ts = batch[\"ts\"]\n            # (1, 4, 4).\n            target_w2cs = batch[\"w2cs\"]\n            # (1, 3, 3).\n            target_Ks = batch[\"Ks\"]\n            rendered = self.model.render(\n                t,\n                w2c[None],\n                K[None],\n                img_wh,\n                target_ts=target_ts,\n                target_w2cs=target_w2cs,\n            )\n            pred_tracks_3d = rendered[\"tracks_3d\"][0][\n                ::sample_interval, ::sample_interval\n            ][bool_mask[::sample_interval, ::sample_interval]].swapaxes(0, 1)\n            pred_tracks_2d = torch.einsum(\"bij,bpj->bpi\", target_Ks, pred_tracks_3d)\n            pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp(\n                pred_tracks_2d[..., 2:], min=1e-6\n            )\n            tracks_2d.append(pred_tracks_2d.cpu())\n        tracks_2d = torch.cat(tracks_2d, dim=0)\n        target_imgs = torch.cat(target_imgs, dim=0)\n        track_2d_video = plot_correspondences(\n            target_imgs.numpy(),\n            tracks_2d.numpy(),\n            query_id=cast(int, t),\n        )\n        iio.mimwrite(\n            osp.join(video_dir, \"tracks_2d.mp4\"),\n            make_video_divisble(np.stack(track_2d_video, 0)),\n            fps=fps,\n        )\n        # Render motion coefficient video.\n        with torch.random.fork_rng():\n            torch.random.manual_seed(0)\n            motion_coef_colors = torch.pca_lowrank(\n                self.model.fg.get_coefs()[None],\n                q=3,\n            )[0][0]\n        motion_coef_colors = (motion_coef_colors - motion_coef_colors.min(0)[0]) / (\n            motion_coef_colors.max(0)[0] - motion_coef_colors.min(0)[0]\n        )\n        motion_coef_colors = F.pad(\n            motion_coef_colors, (0, 0, 0, self.model.bg.num_gaussians), value=0.5\n        )\n        video = []\n        for batch in tqdm(\n            self.train_loader, desc=\"Rendering motion coefficient video\", leave=False\n        ):\n            batch = {\n                k: v.to(self.device) if isinstance(v, torch.Tensor) else v\n                for k, v in batch.items()\n            }\n            # ().\n            t = batch[\"ts\"][0]\n            # (4, 4).\n            w2c = batch[\"w2cs\"][0]\n            # (3, 3).\n            K = batch[\"Ks\"][0]\n            # (3, 3).\n            img = batch[\"imgs\"][0]\n            img_wh = img.shape[-2::-1]\n            rendered = self.model.render(\n                t, w2c[None], K[None], img_wh, colors_override=motion_coef_colors\n            )\n            # Putting results onto CPU since it will consume unnecessarily\n            # large GPU memory for long sequence OW.\n            video.append(torch.cat([img, rendered[\"img\"][0]], dim=1).cpu())\n        video = torch.stack(video, dim=0)\n        iio.mimwrite(\n            osp.join(video_dir, \"motion_coefs.mp4\"),\n            make_video_divisble((video.numpy() * 255).astype(np.uint8)),\n            fps=fps,\n        )\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/vis/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/vis/playback_panel.py",
    "content": "import threading\nimport time\n\nimport viser\n\n\ndef add_gui_playback_group(\n    server: viser.ViserServer,\n    num_frames: int,\n    min_fps: float = 1.0,\n    max_fps: float = 60.0,\n    fps_step: float = 0.1,\n    initial_fps: float = 10.0,\n):\n    gui_timestep = server.gui.add_slider(\n        \"Timestep\",\n        min=0,\n        max=num_frames - 1,\n        step=1,\n        initial_value=0,\n        disabled=True,\n    )\n    gui_next_frame = server.gui.add_button(\"Next Frame\")\n    gui_prev_frame = server.gui.add_button(\"Prev Frame\")\n    gui_playing_pause = server.gui.add_button(\"Pause\")\n    gui_playing_pause.visible = False\n    gui_playing_resume = server.gui.add_button(\"Resume\")\n    gui_framerate = server.gui.add_slider(\n        \"FPS\", min=min_fps, max=max_fps, step=fps_step, initial_value=initial_fps\n    )\n\n    # Frame step buttons.\n    @gui_next_frame.on_click\n    def _(_) -> None:\n        gui_timestep.value = (gui_timestep.value + 1) % num_frames\n\n    @gui_prev_frame.on_click\n    def _(_) -> None:\n        gui_timestep.value = (gui_timestep.value - 1) % num_frames\n\n    # Disable frame controls when we're playing.\n    def _toggle_gui_playing(_):\n        gui_playing_pause.visible = not gui_playing_pause.visible\n        gui_playing_resume.visible = not gui_playing_resume.visible\n        gui_timestep.disabled = gui_playing_pause.visible\n        gui_next_frame.disabled = gui_playing_pause.visible\n        gui_prev_frame.disabled = gui_playing_pause.visible\n\n    gui_playing_pause.on_click(_toggle_gui_playing)\n    gui_playing_resume.on_click(_toggle_gui_playing)\n\n    # Create a thread to update the timestep indefinitely.\n    def _update_timestep():\n        while True:\n            if gui_playing_pause.visible:\n                gui_timestep.value = (gui_timestep.value + 1) % num_frames\n            time.sleep(1 / gui_framerate.value)\n\n    threading.Thread(target=_update_timestep, daemon=True).start()\n\n    return (\n        gui_timestep,\n        gui_next_frame,\n        gui_prev_frame,\n        gui_playing_pause,\n        gui_playing_resume,\n        gui_framerate,\n    )\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/vis/render_panel.py",
    "content": "# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport colorsys\nimport dataclasses\nimport datetime\nimport json\nimport threading\nimport time\nfrom pathlib import Path\nfrom typing import Dict, List, Literal, Optional, Tuple\n\nimport numpy as np\nimport scipy\nimport splines\nimport splines.quaternion\nimport viser\nimport viser.transforms as tf\n\nVISER_SCALE_RATIO = 10.0\n\n\n@dataclasses.dataclass\nclass Keyframe:\n    time: float\n    position: np.ndarray\n    wxyz: np.ndarray\n    override_fov_enabled: bool\n    override_fov_rad: float\n    aspect: float\n    override_transition_enabled: bool\n    override_transition_sec: Optional[float]\n\n    @staticmethod\n    def from_camera(time: float, camera: viser.CameraHandle, aspect: float) -> Keyframe:\n        return Keyframe(\n            time,\n            camera.position,\n            camera.wxyz,\n            override_fov_enabled=False,\n            override_fov_rad=camera.fov,\n            aspect=aspect,\n            override_transition_enabled=False,\n            override_transition_sec=None,\n        )\n\n\nclass CameraPath:\n    def __init__(\n        self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float]\n    ):\n        self._server = server\n        self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {}\n        self._keyframe_counter: int = 0\n        self._spline_nodes: List[viser.SceneNodeHandle] = []\n        self._camera_edit_panel: Optional[viser.Gui3dContainerHandle] = None\n\n        self._orientation_spline: Optional[splines.quaternion.KochanekBartels] = None\n        self._position_spline: Optional[splines.KochanekBartels] = None\n        self._fov_spline: Optional[splines.KochanekBartels] = None\n        self._time_spline: Optional[splines.KochanekBartels] = None\n\n        self._keyframes_visible: bool = True\n\n        self._duration_element = duration_element\n\n        # These parameters should be overridden externally.\n        self.loop: bool = False\n        self.framerate: float = 30.0\n        self.tension: float = 0.5  # Tension / alpha term.\n        self.default_fov: float = 0.0\n        self.default_transition_sec: float = 0.0\n        self.show_spline: bool = True\n\n    def set_keyframes_visible(self, visible: bool) -> None:\n        self._keyframes_visible = visible\n        for keyframe in self._keyframes.values():\n            keyframe[1].visible = visible\n\n    def add_camera(\n        self, keyframe: Keyframe, keyframe_index: Optional[int] = None\n    ) -> None:\n        \"\"\"Add a new camera, or replace an old one if `keyframe_index` is passed in.\"\"\"\n        server = self._server\n\n        # Add a keyframe if we aren't replacing an existing one.\n        if keyframe_index is None:\n            keyframe_index = self._keyframe_counter\n            self._keyframe_counter += 1\n\n        print(\n            f\"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}\"\n        )\n        frustum_handle = server.scene.add_camera_frustum(\n            f\"/render_cameras/{keyframe_index}\",\n            fov=(\n                keyframe.override_fov_rad\n                if keyframe.override_fov_enabled\n                else self.default_fov\n            ),\n            aspect=keyframe.aspect,\n            scale=0.1,\n            color=(200, 10, 30),\n            wxyz=keyframe.wxyz,\n            position=keyframe.position,\n            visible=self._keyframes_visible,\n        )\n        self._server.scene.add_icosphere(\n            f\"/render_cameras/{keyframe_index}/sphere\",\n            radius=0.03,\n            color=(200, 10, 30),\n        )\n\n        @frustum_handle.on_click\n        def _(_) -> None:\n            if self._camera_edit_panel is not None:\n                self._camera_edit_panel.remove()\n                self._camera_edit_panel = None\n\n            with server.scene.add_3d_gui_container(\n                \"/camera_edit_panel\",\n                position=keyframe.position,\n            ) as camera_edit_panel:\n                self._camera_edit_panel = camera_edit_panel\n                override_fov = server.gui.add_checkbox(\n                    \"Override FOV\", initial_value=keyframe.override_fov_enabled\n                )\n                override_fov_degrees = server.gui.add_slider(\n                    \"Override FOV (degrees)\",\n                    5.0,\n                    175.0,\n                    step=0.1,\n                    initial_value=keyframe.override_fov_rad * 180.0 / np.pi,\n                    disabled=not keyframe.override_fov_enabled,\n                )\n                delete_button = server.gui.add_button(\n                    \"Delete\", color=\"red\", icon=viser.Icon.TRASH\n                )\n                go_to_button = server.gui.add_button(\"Go to\")\n                close_button = server.gui.add_button(\"Close\")\n\n            @override_fov.on_update\n            def _(_) -> None:\n                keyframe.override_fov_enabled = override_fov.value\n                override_fov_degrees.disabled = not override_fov.value\n                self.add_camera(keyframe, keyframe_index)\n\n            @override_fov_degrees.on_update\n            def _(_) -> None:\n                keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi\n                self.add_camera(keyframe, keyframe_index)\n\n            @delete_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                assert event.client is not None\n                with event.client.gui.add_modal(\"Confirm\") as modal:\n                    event.client.gui.add_markdown(\"Delete keyframe?\")\n                    confirm_button = event.client.gui.add_button(\n                        \"Yes\", color=\"red\", icon=viser.Icon.TRASH\n                    )\n                    exit_button = event.client.gui.add_button(\"Cancel\")\n\n                    @confirm_button.on_click\n                    def _(_) -> None:\n                        assert camera_edit_panel is not None\n\n                        keyframe_id = None\n                        for i, keyframe_tuple in self._keyframes.items():\n                            if keyframe_tuple[1] is frustum_handle:\n                                keyframe_id = i\n                                break\n                        assert keyframe_id is not None\n\n                        self._keyframes.pop(keyframe_id)\n                        frustum_handle.remove()\n                        camera_edit_panel.remove()\n                        self._camera_edit_panel = None\n                        modal.close()\n                        self.update_spline()\n\n                    @exit_button.on_click\n                    def _(_) -> None:\n                        modal.close()\n\n            @go_to_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                assert event.client is not None\n                client = event.client\n                T_world_current = tf.SE3.from_rotation_and_translation(\n                    tf.SO3(client.camera.wxyz), client.camera.position\n                )\n                T_world_target = tf.SE3.from_rotation_and_translation(\n                    tf.SO3(keyframe.wxyz), keyframe.position\n                ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5]))\n\n                T_current_target = T_world_current.inverse() @ T_world_target\n\n                for j in range(10):\n                    T_world_set = T_world_current @ tf.SE3.exp(\n                        T_current_target.log() * j / 9.0\n                    )\n\n                    # Important bit: we atomically set both the orientation and the position\n                    # of the camera.\n                    with client.atomic():\n                        client.camera.wxyz = T_world_set.rotation().wxyz\n                        client.camera.position = T_world_set.translation()\n                    time.sleep(1.0 / 30.0)\n\n            @close_button.on_click\n            def _(_) -> None:\n                assert camera_edit_panel is not None\n                camera_edit_panel.remove()\n                self._camera_edit_panel = None\n\n        self._keyframes[keyframe_index] = (keyframe, frustum_handle)\n\n    def update_aspect(self, aspect: float) -> None:\n        for keyframe_index, frame in self._keyframes.items():\n            frame = dataclasses.replace(frame[0], aspect=aspect)\n            self.add_camera(frame, keyframe_index=keyframe_index)\n\n    def get_aspect(self) -> float:\n        \"\"\"Get W/H aspect ratio, which is shared across all keyframes.\"\"\"\n        assert len(self._keyframes) > 0\n        return next(iter(self._keyframes.values()))[0].aspect\n\n    def reset(self) -> None:\n        for frame in self._keyframes.values():\n            print(f\"removing {frame[1]}\")\n            frame[1].remove()\n        self._keyframes.clear()\n        self.update_spline()\n        print(\"camera path reset\")\n\n    def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:\n        \"\"\"From a time value in seconds, compute a t value for our geometric\n        spline interpolation. An increment of 1 for the latter will move the\n        camera forward by one keyframe.\n\n        We use a PCHIP spline here to guarantee monotonicity.\n        \"\"\"\n        transition_times_cumsum = self.compute_transition_times_cumsum()\n        spline_indices = np.arange(transition_times_cumsum.shape[0])\n\n        if self.loop:\n            # In the case of a loop, we pad the spline to match the start/end\n            # slopes.\n            interpolator = scipy.interpolate.PchipInterpolator(\n                x=np.concatenate(\n                    [\n                        [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],\n                        transition_times_cumsum,\n                        transition_times_cumsum[-1:] + transition_times_cumsum[1:2],\n                    ],\n                    axis=0,\n                ),\n                y=np.concatenate(\n                    [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0\n                ),\n            )\n        else:\n            interpolator = scipy.interpolate.PchipInterpolator(\n                x=transition_times_cumsum, y=spline_indices\n            )\n\n        # Clip to account for floating point error.\n        return np.clip(interpolator(time), 0, spline_indices[-1])\n\n    def interpolate_pose_and_fov_rad(\n        self, normalized_t: float\n    ) -> Optional[Tuple[tf.SE3, float, float]]:\n        if len(self._keyframes) < 2:\n            return None\n\n        self._time_spline = splines.KochanekBartels(\n            [keyframe[0].time for keyframe in self._keyframes.values()],\n            tcb=(self.tension, 0.0, 0.0),\n            endconditions=\"closed\" if self.loop else \"natural\",\n        )\n\n        self._fov_spline = splines.KochanekBartels(\n            [\n                (\n                    keyframe[0].override_fov_rad\n                    if keyframe[0].override_fov_enabled\n                    else self.default_fov\n                )\n                for keyframe in self._keyframes.values()\n            ],\n            tcb=(self.tension, 0.0, 0.0),\n            endconditions=\"closed\" if self.loop else \"natural\",\n        )\n\n        assert self._orientation_spline is not None\n        assert self._position_spline is not None\n        assert self._fov_spline is not None\n        assert self._time_spline is not None\n\n        max_t = self.compute_duration()\n        t = max_t * normalized_t\n        spline_t = float(self.spline_t_from_t_sec(np.array(t)))\n\n        quat = self._orientation_spline.evaluate(spline_t)\n        assert isinstance(quat, splines.quaternion.UnitQuaternion)\n        return (\n            tf.SE3.from_rotation_and_translation(\n                tf.SO3(np.array([quat.scalar, *quat.vector])),\n                self._position_spline.evaluate(spline_t),\n            ),\n            float(self._fov_spline.evaluate(spline_t)),\n            float(self._time_spline.evaluate(spline_t)),\n        )\n\n    def update_spline(self) -> None:\n        num_frames = int(self.compute_duration() * self.framerate)\n        keyframes = list(self._keyframes.values())\n\n        if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:\n            for node in self._spline_nodes:\n                node.remove()\n            self._spline_nodes.clear()\n            return\n\n        transition_times_cumsum = self.compute_transition_times_cumsum()\n\n        self._orientation_spline = splines.quaternion.KochanekBartels(\n            [\n                splines.quaternion.UnitQuaternion.from_unit_xyzw(\n                    np.roll(keyframe[0].wxyz, shift=-1)\n                )\n                for keyframe in keyframes\n            ],\n            tcb=(self.tension, 0.0, 0.0),\n            endconditions=\"closed\" if self.loop else \"natural\",\n        )\n        self._position_spline = splines.KochanekBartels(\n            [keyframe[0].position for keyframe in keyframes],\n            tcb=(self.tension, 0.0, 0.0),\n            endconditions=\"closed\" if self.loop else \"natural\",\n        )\n\n        # Update visualized spline.\n        points_array = self._position_spline.evaluate(\n            self.spline_t_from_t_sec(\n                np.linspace(0, transition_times_cumsum[-1], num_frames)\n            )\n        )\n        colors_array = np.array(\n            [\n                colorsys.hls_to_rgb(h, 0.5, 1.0)\n                for h in np.linspace(0.0, 1.0, len(points_array))\n            ]\n        )\n\n        # Clear prior spline nodes.\n        for node in self._spline_nodes:\n            node.remove()\n        self._spline_nodes.clear()\n\n        self._spline_nodes.append(\n            self._server.scene.add_spline_catmull_rom(\n                \"/render_camera_spline\",\n                positions=points_array,\n                color=(220, 220, 220),\n                closed=self.loop,\n                line_width=1.0,\n                segments=points_array.shape[0] + 1,\n            )\n        )\n        self._spline_nodes.append(\n            self._server.scene.add_point_cloud(\n                \"/render_camera_spline/points\",\n                points=points_array,\n                colors=colors_array,\n                point_size=0.04,\n            )\n        )\n\n        def make_transition_handle(i: int) -> None:\n            assert self._position_spline is not None\n            transition_pos = self._position_spline.evaluate(\n                float(\n                    self.spline_t_from_t_sec(\n                        (transition_times_cumsum[i] + transition_times_cumsum[i + 1])\n                        / 2.0,\n                    )\n                )\n            )\n            transition_sphere = self._server.scene.add_icosphere(\n                f\"/render_camera_spline/transition_{i}\",\n                radius=0.04,\n                color=(255, 0, 0),\n                position=transition_pos,\n            )\n            self._spline_nodes.append(transition_sphere)\n\n            @transition_sphere.on_click\n            def _(_) -> None:\n                server = self._server\n\n                if self._camera_edit_panel is not None:\n                    self._camera_edit_panel.remove()\n                    self._camera_edit_panel = None\n\n                keyframe_index = (i + 1) % len(self._keyframes)\n                keyframe = keyframes[keyframe_index][0]\n\n                with server.scene.add_3d_gui_container(\n                    \"/camera_edit_panel\",\n                    position=transition_pos,\n                ) as camera_edit_panel:\n                    self._camera_edit_panel = camera_edit_panel\n                    override_transition_enabled = server.gui.add_checkbox(\n                        \"Override transition\",\n                        initial_value=keyframe.override_transition_enabled,\n                    )\n                    override_transition_sec = server.gui.add_number(\n                        \"Override transition (sec)\",\n                        initial_value=(\n                            keyframe.override_transition_sec\n                            if keyframe.override_transition_sec is not None\n                            else self.default_transition_sec\n                        ),\n                        min=0.001,\n                        max=30.0,\n                        step=0.001,\n                        disabled=not override_transition_enabled.value,\n                    )\n                    close_button = server.gui.add_button(\"Close\")\n\n                @override_transition_enabled.on_update\n                def _(_) -> None:\n                    keyframe.override_transition_enabled = (\n                        override_transition_enabled.value\n                    )\n                    override_transition_sec.disabled = (\n                        not override_transition_enabled.value\n                    )\n                    self._duration_element.value = self.compute_duration()\n\n                @override_transition_sec.on_update\n                def _(_) -> None:\n                    keyframe.override_transition_sec = override_transition_sec.value\n                    self._duration_element.value = self.compute_duration()\n\n                @close_button.on_click\n                def _(_) -> None:\n                    assert camera_edit_panel is not None\n                    camera_edit_panel.remove()\n                    self._camera_edit_panel = None\n\n        (num_transitions_plus_1,) = transition_times_cumsum.shape\n        for i in range(num_transitions_plus_1 - 1):\n            make_transition_handle(i)\n\n        # for i in range(transition_times.shape[0])\n\n    def compute_duration(self) -> float:\n        \"\"\"Compute the total duration of the trajectory.\"\"\"\n        total = 0.0\n        for i, (keyframe, frustum) in enumerate(self._keyframes.values()):\n            if i == 0 and not self.loop:\n                continue\n            del frustum\n            total += (\n                keyframe.override_transition_sec\n                if keyframe.override_transition_enabled\n                and keyframe.override_transition_sec is not None\n                else self.default_transition_sec\n            )\n        return total\n\n    def compute_transition_times_cumsum(self) -> np.ndarray:\n        \"\"\"Compute the total duration of the trajectory.\"\"\"\n        total = 0.0\n        out = [0.0]\n        for i, (keyframe, frustum) in enumerate(self._keyframes.values()):\n            if i == 0:\n                continue\n            del frustum\n            total += (\n                keyframe.override_transition_sec\n                if keyframe.override_transition_enabled\n                and keyframe.override_transition_sec is not None\n                else self.default_transition_sec\n            )\n            out.append(total)\n\n        if self.loop:\n            keyframe = next(iter(self._keyframes.values()))[0]\n            total += (\n                keyframe.override_transition_sec\n                if keyframe.override_transition_enabled\n                and keyframe.override_transition_sec is not None\n                else self.default_transition_sec\n            )\n            out.append(total)\n\n        return np.array(out)\n\n\n@dataclasses.dataclass\nclass RenderTabState:\n    \"\"\"Useful GUI handles exposed by the render tab.\"\"\"\n\n    preview_render: bool\n    preview_fov: float\n    preview_aspect: float\n    preview_camera_type: Literal[\"Perspective\", \"Fisheye\", \"Equirectangular\"]\n\n\ndef populate_render_tab(\n    server: viser.ViserServer,\n    datapath: Path,\n    gui_timestep_handle: viser.GuiInputHandle[int] | None,\n) -> RenderTabState:\n\n    render_tab_state = RenderTabState(\n        preview_render=False,\n        preview_fov=0.0,\n        preview_aspect=1.0,\n        preview_camera_type=\"Perspective\",\n    )\n\n    fov_degrees = server.gui.add_slider(\n        \"Default FOV\",\n        initial_value=75.0,\n        min=0.1,\n        max=175.0,\n        step=0.01,\n        hint=\"Field-of-view for rendering, which can also be overridden on a per-keyframe basis.\",\n    )\n\n    @fov_degrees.on_update\n    def _(_) -> None:\n        fov_radians = fov_degrees.value / 180.0 * np.pi\n        for client in server.get_clients().values():\n            client.camera.fov = fov_radians\n        camera_path.default_fov = fov_radians\n\n        # Updating the aspect ratio will also re-render the camera frustums.\n        # Could rethink this.\n        camera_path.update_aspect(resolution.value[0] / resolution.value[1])\n        compute_and_update_preview_camera_state()\n\n    resolution = server.gui.add_vector2(\n        \"Resolution\",\n        initial_value=(1920, 1080),\n        min=(50, 50),\n        max=(10_000, 10_000),\n        step=1,\n        hint=\"Render output resolution in pixels.\",\n    )\n\n    @resolution.on_update\n    def _(_) -> None:\n        camera_path.update_aspect(resolution.value[0] / resolution.value[1])\n        compute_and_update_preview_camera_state()\n\n    camera_type = server.gui.add_dropdown(\n        \"Camera type\",\n        (\"Perspective\", \"Fisheye\", \"Equirectangular\"),\n        initial_value=\"Perspective\",\n        hint=\"Camera model to render with. This is applied to all keyframes.\",\n    )\n    add_button = server.gui.add_button(\n        \"Add Keyframe\",\n        icon=viser.Icon.PLUS,\n        hint=\"Add a new keyframe at the current pose.\",\n    )\n\n    @add_button.on_click\n    def _(event: viser.GuiEvent) -> None:\n        assert event.client_id is not None\n        camera = server.get_clients()[event.client_id].camera\n        pose = tf.SE3.from_rotation_and_translation(\n            tf.SO3(camera.wxyz), camera.position\n        )\n        print(f\"client {event.client_id} at {camera.position} {camera.wxyz}\")\n        print(f\"camera pose {pose.as_matrix()}\")\n        if gui_timestep_handle is not None:\n            print(f\"timestep {gui_timestep_handle.value}\")\n\n        # Add this camera to the path.\n        time = 0\n        if gui_timestep_handle is not None:\n            time = gui_timestep_handle.value\n        camera_path.add_camera(\n            Keyframe.from_camera(\n                time,\n                camera,\n                aspect=resolution.value[0] / resolution.value[1],\n            ),\n        )\n        duration_number.value = camera_path.compute_duration()\n        camera_path.update_spline()\n\n    clear_keyframes_button = server.gui.add_button(\n        \"Clear Keyframes\",\n        icon=viser.Icon.TRASH,\n        hint=\"Remove all keyframes from the render path.\",\n    )\n\n    @clear_keyframes_button.on_click\n    def _(event: viser.GuiEvent) -> None:\n        assert event.client_id is not None\n        client = server.get_clients()[event.client_id]\n        with client.atomic(), client.gui.add_modal(\"Confirm\") as modal:\n            client.gui.add_markdown(\"Clear all keyframes?\")\n            confirm_button = client.gui.add_button(\n                \"Yes\", color=\"red\", icon=viser.Icon.TRASH\n            )\n            exit_button = client.gui.add_button(\"Cancel\")\n\n            @confirm_button.on_click\n            def _(_) -> None:\n                camera_path.reset()\n                modal.close()\n\n                duration_number.value = camera_path.compute_duration()\n\n                # Clear move handles.\n                if len(transform_controls) > 0:\n                    for t in transform_controls:\n                        t.remove()\n                    transform_controls.clear()\n                    return\n\n            @exit_button.on_click\n            def _(_) -> None:\n                modal.close()\n\n    loop = server.gui.add_checkbox(\n        \"Loop\", False, hint=\"Add a segment between the first and last keyframes.\"\n    )\n\n    @loop.on_update\n    def _(_) -> None:\n        camera_path.loop = loop.value\n        duration_number.value = camera_path.compute_duration()\n\n    tension_slider = server.gui.add_slider(\n        \"Spline tension\",\n        min=0.0,\n        max=1.0,\n        initial_value=0.0,\n        step=0.01,\n        hint=\"Tension parameter for adjusting smoothness of spline interpolation.\",\n    )\n\n    @tension_slider.on_update\n    def _(_) -> None:\n        camera_path.tension = tension_slider.value\n        camera_path.update_spline()\n\n    move_checkbox = server.gui.add_checkbox(\n        \"Move keyframes\",\n        initial_value=False,\n        hint=\"Toggle move handles for keyframes in the scene.\",\n    )\n\n    transform_controls: List[viser.SceneNodeHandle] = []\n\n    @move_checkbox.on_update\n    def _(event: viser.GuiEvent) -> None:\n        # Clear move handles when toggled off.\n        if move_checkbox.value is False:\n            for t in transform_controls:\n                t.remove()\n            transform_controls.clear()\n            return\n\n        def _make_transform_controls_callback(\n            keyframe: Tuple[Keyframe, viser.SceneNodeHandle],\n            controls: viser.TransformControlsHandle,\n        ) -> None:\n            @controls.on_update\n            def _(_) -> None:\n                keyframe[0].wxyz = controls.wxyz\n                keyframe[0].position = controls.position\n\n                keyframe[1].wxyz = controls.wxyz\n                keyframe[1].position = controls.position\n\n                camera_path.update_spline()\n\n        # Show move handles.\n        assert event.client is not None\n        for keyframe_index, keyframe in camera_path._keyframes.items():\n            controls = event.client.scene.add_transform_controls(\n                f\"/keyframe_move/{keyframe_index}\",\n                scale=0.4,\n                wxyz=keyframe[0].wxyz,\n                position=keyframe[0].position,\n            )\n            transform_controls.append(controls)\n            _make_transform_controls_callback(keyframe, controls)\n\n    show_keyframe_checkbox = server.gui.add_checkbox(\n        \"Show keyframes\",\n        initial_value=True,\n        hint=\"Show keyframes in the scene.\",\n    )\n\n    @show_keyframe_checkbox.on_update\n    def _(_: viser.GuiEvent) -> None:\n        camera_path.set_keyframes_visible(show_keyframe_checkbox.value)\n\n    show_spline_checkbox = server.gui.add_checkbox(\n        \"Show spline\",\n        initial_value=True,\n        hint=\"Show camera path spline in the scene.\",\n    )\n\n    @show_spline_checkbox.on_update\n    def _(_) -> None:\n        camera_path.show_spline = show_spline_checkbox.value\n        camera_path.update_spline()\n\n    playback_folder = server.gui.add_folder(\"Playback\")\n    with playback_folder:\n        play_button = server.gui.add_button(\"Play\", icon=viser.Icon.PLAYER_PLAY)\n        pause_button = server.gui.add_button(\n            \"Pause\", icon=viser.Icon.PLAYER_PAUSE, visible=False\n        )\n        preview_render_button = server.gui.add_button(\n            \"Preview Render\", hint=\"Show a preview of the render in the viewport.\"\n        )\n        preview_render_stop_button = server.gui.add_button(\n            \"Exit Render Preview\", color=\"red\", visible=False\n        )\n\n        transition_sec_number = server.gui.add_number(\n            \"Transition (sec)\",\n            min=0.001,\n            max=30.0,\n            step=0.001,\n            initial_value=2.0,\n            hint=\"Time in seconds between each keyframe, which can also be overridden on a per-transition basis.\",\n        )\n        framerate_number = server.gui.add_number(\n            \"FPS\", min=0.1, max=240.0, step=1e-2, initial_value=30.0\n        )\n        framerate_buttons = server.gui.add_button_group(\"\", (\"24\", \"30\", \"60\"))\n        duration_number = server.gui.add_number(\n            \"Duration (sec)\",\n            min=0.0,\n            max=1e8,\n            step=0.001,\n            initial_value=0.0,\n            disabled=True,\n        )\n\n        @framerate_buttons.on_click\n        def _(_) -> None:\n            framerate_number.value = float(framerate_buttons.value)\n\n    @transition_sec_number.on_update\n    def _(_) -> None:\n        camera_path.default_transition_sec = transition_sec_number.value\n        duration_number.value = camera_path.compute_duration()\n\n    def get_max_frame_index() -> int:\n        return max(1, int(framerate_number.value * duration_number.value) - 1)\n\n    preview_camera_handle: Optional[viser.SceneNodeHandle] = None\n\n    def remove_preview_camera() -> None:\n        nonlocal preview_camera_handle\n        if preview_camera_handle is not None:\n            preview_camera_handle.remove()\n            preview_camera_handle = None\n\n    def compute_and_update_preview_camera_state() -> (\n        Optional[Tuple[tf.SE3, float, float]]\n    ):\n        \"\"\"Update the render tab state with the current preview camera pose.\n        Returns current camera pose + FOV if available.\"\"\"\n\n        if preview_frame_slider is None:\n            return\n        maybe_pose_and_fov_rad_and_time = camera_path.interpolate_pose_and_fov_rad(\n            preview_frame_slider.value / get_max_frame_index()\n        )\n        if maybe_pose_and_fov_rad_and_time is None:\n            remove_preview_camera()\n            return\n        pose, fov_rad, time = maybe_pose_and_fov_rad_and_time\n        render_tab_state.preview_fov = fov_rad\n        render_tab_state.preview_aspect = camera_path.get_aspect()\n        render_tab_state.preview_camera_type = camera_type.value\n        if gui_timestep_handle is not None:\n            gui_timestep_handle.value = int(time)\n        return pose, fov_rad, time\n\n    def add_preview_frame_slider() -> Optional[viser.GuiInputHandle[int]]:\n        \"\"\"Helper for creating the current frame # slider. This is removed and\n        re-added anytime the `max` value changes.\"\"\"\n\n        with playback_folder:\n            preview_frame_slider = server.gui.add_slider(\n                \"Preview frame\",\n                min=0,\n                max=get_max_frame_index(),\n                step=1,\n                initial_value=0,\n                # Place right after the pause button.\n                order=preview_render_stop_button.order + 0.01,\n                disabled=get_max_frame_index() == 1,\n            )\n            play_button.disabled = preview_frame_slider.disabled\n            preview_render_button.disabled = preview_frame_slider.disabled\n\n        @preview_frame_slider.on_update\n        def _(_) -> None:\n            nonlocal preview_camera_handle\n            maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state()\n            if maybe_pose_and_fov_rad_and_time is None:\n                return\n            pose, fov_rad, time = maybe_pose_and_fov_rad_and_time\n\n            preview_camera_handle = server.scene.add_camera_frustum(\n                \"/preview_camera\",\n                fov=fov_rad,\n                aspect=resolution.value[0] / resolution.value[1],\n                scale=0.35,\n                wxyz=pose.rotation().wxyz,\n                position=pose.translation(),\n                color=(10, 200, 30),\n            )\n            if render_tab_state.preview_render:\n                for client in server.get_clients().values():\n                    client.camera.wxyz = pose.rotation().wxyz\n                    client.camera.position = pose.translation()\n                if gui_timestep_handle is not None:\n                    gui_timestep_handle.value = int(time)\n\n        return preview_frame_slider\n\n    # We back up the camera poses before and after we start previewing renders.\n    camera_pose_backup_from_id: Dict[int, tuple] = {}\n\n    @preview_render_button.on_click\n    def _(_) -> None:\n        render_tab_state.preview_render = True\n        preview_render_button.visible = False\n        preview_render_stop_button.visible = True\n\n        maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state()\n        if maybe_pose_and_fov_rad_and_time is None:\n            remove_preview_camera()\n            return\n        pose, fov, time = maybe_pose_and_fov_rad_and_time\n        del fov\n\n        # Hide all scene nodes when we're previewing the render.\n        server.scene.set_global_visibility(True)\n\n        # Back up and then set camera poses.\n        for client in server.get_clients().values():\n            camera_pose_backup_from_id[client.client_id] = (\n                client.camera.position,\n                client.camera.look_at,\n                client.camera.up_direction,\n            )\n            client.camera.wxyz = pose.rotation().wxyz\n            client.camera.position = pose.translation()\n        if gui_timestep_handle is not None:\n            gui_timestep_handle.value = int(time)\n\n    @preview_render_stop_button.on_click\n    def _(_) -> None:\n        render_tab_state.preview_render = False\n        preview_render_button.visible = True\n        preview_render_stop_button.visible = False\n\n        # Revert camera poses.\n        for client in server.get_clients().values():\n            if client.client_id not in camera_pose_backup_from_id:\n                continue\n            cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(\n                client.client_id\n            )\n            client.camera.position = cam_position\n            client.camera.look_at = cam_look_at\n            client.camera.up_direction = cam_up\n            client.flush()\n\n        # Un-hide scene nodes.\n        server.scene.set_global_visibility(True)\n\n    preview_frame_slider = add_preview_frame_slider()\n\n    # Update the # of frames.\n    @duration_number.on_update\n    @framerate_number.on_update\n    def _(_) -> None:\n        remove_preview_camera()  # Will be re-added when slider is updated.\n\n        nonlocal preview_frame_slider\n        old = preview_frame_slider\n        assert old is not None\n\n        preview_frame_slider = add_preview_frame_slider()\n        if preview_frame_slider is not None:\n            old.remove()\n        else:\n            preview_frame_slider = old\n\n        camera_path.framerate = framerate_number.value\n        camera_path.update_spline()\n\n    # Play the camera trajectory when the play button is pressed.\n    @play_button.on_click\n    def _(_) -> None:\n        play_button.visible = False\n        pause_button.visible = True\n\n        def play() -> None:\n            while not play_button.visible:\n                max_frame = int(framerate_number.value * duration_number.value)\n                if max_frame > 0:\n                    assert preview_frame_slider is not None\n                    preview_frame_slider.value = (\n                        preview_frame_slider.value + 1\n                    ) % max_frame\n                time.sleep(1.0 / framerate_number.value)\n\n        threading.Thread(target=play).start()\n\n    # Play the camera trajectory when the play button is pressed.\n    @pause_button.on_click\n    def _(_) -> None:\n        play_button.visible = True\n        pause_button.visible = False\n\n    # add button for loading existing path\n    load_camera_path_button = server.gui.add_button(\n        \"Load Path\", icon=viser.Icon.FOLDER_OPEN, hint=\"Load an existing camera path.\"\n    )\n\n    @load_camera_path_button.on_click\n    def _(event: viser.GuiEvent) -> None:\n        assert event.client is not None\n        camera_path_dir = datapath.parent\n        camera_path_dir.mkdir(parents=True, exist_ok=True)\n        preexisting_camera_paths = list(camera_path_dir.glob(\"*.json\"))\n        preexisting_camera_filenames = [p.name for p in preexisting_camera_paths]\n\n        with event.client.gui.add_modal(\"Load Path\") as modal:\n            if len(preexisting_camera_filenames) == 0:\n                event.client.gui.add_markdown(\"No existing paths found\")\n            else:\n                event.client.gui.add_markdown(\"Select existing camera path:\")\n                camera_path_dropdown = event.client.gui.add_dropdown(\n                    label=\"Camera Path\",\n                    options=[str(p) for p in preexisting_camera_filenames],\n                    initial_value=str(preexisting_camera_filenames[0]),\n                )\n                load_button = event.client.gui.add_button(\"Load\")\n\n                @load_button.on_click\n                def _(_) -> None:\n                    # load the json file\n                    json_path = datapath / camera_path_dropdown.value\n                    with open(json_path, \"r\") as f:\n                        json_data = json.load(f)\n\n                    keyframes = json_data[\"keyframes\"]\n                    camera_path.reset()\n                    for i in range(len(keyframes)):\n                        frame = keyframes[i]\n                        pose = tf.SE3.from_matrix(\n                            np.array(frame[\"matrix\"]).reshape(4, 4)\n                        )\n                        # apply the x rotation by 180 deg\n                        pose = tf.SE3.from_rotation_and_translation(\n                            pose.rotation() @ tf.SO3.from_x_radians(np.pi),\n                            pose.translation(),\n                        )\n\n                        camera_path.add_camera(\n                            Keyframe(\n                                frame[\"time\"],\n                                position=pose.translation(),\n                                wxyz=pose.rotation().wxyz,\n                                # There are some floating point conversions between degrees and radians, so the fov and\n                                # default_Fov values will not be exactly matched.\n                                override_fov_enabled=abs(\n                                    frame[\"fov\"] - json_data.get(\"default_fov\", 0.0)\n                                )\n                                > 1e-3,\n                                override_fov_rad=frame[\"fov\"] / 180.0 * np.pi,\n                                aspect=frame[\"aspect\"],\n                                override_transition_enabled=frame.get(\n                                    \"override_transition_enabled\", None\n                                ),\n                                override_transition_sec=frame.get(\n                                    \"override_transition_sec\", None\n                                ),\n                            )\n                        )\n\n                    transition_sec_number.value = json_data.get(\n                        \"default_transition_sec\", 0.5\n                    )\n\n                    # update the render name\n                    camera_path_name.value = json_path.stem\n                    camera_path.update_spline()\n                    modal.close()\n\n            cancel_button = event.client.gui.add_button(\"Cancel\")\n\n            @cancel_button.on_click\n            def _(_) -> None:\n                modal.close()\n\n    # set the initial value to the current date-time string\n    now = datetime.datetime.now()\n    camera_path_name = server.gui.add_text(\n        \"Camera path name\",\n        initial_value=now.strftime(\"%Y-%m-%d %H:%M:%S\"),\n        hint=\"Name of the render\",\n    )\n\n    save_path_button = server.gui.add_button(\n        \"Save Camera Path\",\n        color=\"green\",\n        icon=viser.Icon.FILE_EXPORT,\n        hint=\"Save the camera path to json.\",\n    )\n\n    reset_up_button = server.gui.add_button(\n        \"Reset Up Direction\",\n        icon=viser.Icon.ARROW_BIG_UP_LINES,\n        color=\"gray\",\n        hint=\"Set the up direction of the camera orbit controls to the camera's current up direction.\",\n    )\n\n    @reset_up_button.on_click\n    def _(event: viser.GuiEvent) -> None:\n        assert event.client is not None\n        event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array(\n            [0.0, -1.0, 0.0]\n        )\n\n    @save_path_button.on_click\n    def _(event: viser.GuiEvent) -> None:\n        assert event.client is not None\n        num_frames = int(framerate_number.value * duration_number.value)\n        json_data = {}\n        # json data has the properties:\n        # keyframes: list of keyframes with\n        #     matrix : flattened 4x4 matrix\n        #     fov: float in degrees\n        #     aspect: float\n        # camera_type: string of camera type\n        # render_height: int\n        # render_width: int\n        # fps: int\n        # seconds: float\n        # is_cycle: bool\n        # smoothness_value: float\n        # camera_path: list of frames with properties\n        # camera_to_world: flattened 4x4 matrix\n        # fov: float in degrees\n        # aspect: float\n        # first populate the keyframes:\n        keyframes = []\n        for keyframe, dummy in camera_path._keyframes.values():\n            pose = tf.SE3.from_rotation_and_translation(\n                tf.SO3(keyframe.wxyz), keyframe.position\n            )\n            keyframes.append(\n                {\n                    \"matrix\": pose.as_matrix().flatten().tolist(),\n                    \"fov\": (\n                        np.rad2deg(keyframe.override_fov_rad)\n                        if keyframe.override_fov_enabled\n                        else fov_degrees.value\n                    ),\n                    \"aspect\": keyframe.aspect,\n                    \"override_transition_enabled\": keyframe.override_transition_enabled,\n                    \"override_transition_sec\": keyframe.override_transition_sec,\n                }\n            )\n        json_data[\"default_fov\"] = fov_degrees.value\n        json_data[\"default_transition_sec\"] = transition_sec_number.value\n        json_data[\"keyframes\"] = keyframes\n        json_data[\"camera_type\"] = camera_type.value.lower()\n        json_data[\"render_height\"] = resolution.value[1]\n        json_data[\"render_width\"] = resolution.value[0]\n        json_data[\"fps\"] = framerate_number.value\n        json_data[\"seconds\"] = duration_number.value\n        json_data[\"is_cycle\"] = loop.value\n        json_data[\"smoothness_value\"] = tension_slider.value\n\n        def get_intrinsics(W, H, fov):\n            focal = 0.5 * H / np.tan(0.5 * fov)\n            return np.array(\n                [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]\n            )\n\n        # now populate the camera path:\n        camera_path_list = []\n        for i in range(num_frames):\n            maybe_pose_and_fov_and_time = camera_path.interpolate_pose_and_fov_rad(\n                i / num_frames\n            )\n            if maybe_pose_and_fov_and_time is None:\n                return\n            pose, fov, time = maybe_pose_and_fov_and_time\n            H = resolution.value[1]\n            W = resolution.value[0]\n            K = get_intrinsics(W, H, fov)\n            # rotate the axis of the camera 180 about x axis\n            w2c = pose.inverse().as_matrix()\n            camera_path_list.append(\n                {\n                    \"time\": time,\n                    \"w2c\": w2c.flatten().tolist(),\n                    \"K\": K.flatten().tolist(),\n                    \"img_wh\": (W, H),\n                }\n            )\n        json_data[\"camera_path\"] = camera_path_list\n\n        # now write the json file\n        out_name = camera_path_name.value\n        json_outfile = datapath / f\"{out_name}.json\"\n        datapath.mkdir(parents=True, exist_ok=True)\n        print(f\"writing to {json_outfile}\")\n        with open(json_outfile.absolute(), \"w\") as outfile:\n            json.dump(json_data, outfile)\n\n    camera_path = CameraPath(server, duration_number)\n    camera_path.default_fov = fov_degrees.value / 180.0 * np.pi\n    camera_path.default_transition_sec = transition_sec_number.value\n\n    return render_tab_state\n\n\nif __name__ == \"__main__\":\n    populate_render_tab(\n        server=viser.ViserServer(),\n        datapath=Path(\".\"),\n        gui_timestep_handle=None,\n    )\n    while True:\n        time.sleep(10.0)\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/vis/utils.py",
    "content": "import colorsys\nfrom typing import cast\n\nimport cv2\nimport numpy as np\n\nimport nvdiffrast.torch as dr\nimport torch\nimport torch.nn.functional as F\nfrom matplotlib import colormaps\nfrom viser import ViserServer\n\n\nclass Singleton(type):\n    _instances = {}\n\n    def __call__(cls, *args, **kwargs):\n        if cls not in cls._instances:\n            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)\n        return cls._instances[cls]\n\n\nclass VisManager(metaclass=Singleton):\n    _servers = {}\n\n\ndef get_server(port: int | None = None) -> ViserServer:\n    manager = VisManager()\n    if port is None:\n        avail_ports = list(manager._servers.keys())\n        port = avail_ports[0] if len(avail_ports) > 0 else 8890\n    if port not in manager._servers:\n        manager._servers[port] = ViserServer(port=port, verbose=False)\n    return manager._servers[port]\n\n\ndef project_2d_tracks(tracks_3d_w, Ks, T_cw, return_depth=False):\n    \"\"\"\n    :param tracks_3d_w (torch.Tensor): (T, N, 3)\n    :param Ks (torch.Tensor): (T, 3, 3)\n    :param T_cw (torch.Tensor): (T, 4, 4)\n    :returns tracks_2d (torch.Tensor): (T, N, 2)\n    \"\"\"\n    tracks_3d_c = torch.einsum(\n        \"tij,tnj->tni\", T_cw, F.pad(tracks_3d_w, (0, 1), value=1)\n    )[..., :3]\n    tracks_3d_v = torch.einsum(\"tij,tnj->tni\", Ks, tracks_3d_c)\n    if return_depth:\n        return (\n            tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5),\n            tracks_3d_v[..., 2],\n        )\n    return tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5)\n\n\ndef draw_keypoints_video(\n    imgs, kps, colors=None, occs=None, cmap: str = \"gist_rainbow\", radius: int = 3\n):\n    \"\"\"\n    :param imgs (np.ndarray): (T, H, W, 3) uint8 [0, 255]\n    :param kps (np.ndarray): (N, T, 2)\n    :param colors (np.ndarray): (N, 3) float [0, 1]\n    :param occ (np.ndarray): (N, T) bool\n    return out_frames (T, H, W, 3)\n    \"\"\"\n    if colors is None:\n        label = np.linspace(0, 1, kps.shape[0])\n        colors = np.asarray(colormaps.get_cmap(cmap)(label))[..., :3]\n    out_frames = []\n    for t in range(len(imgs)):\n        occ = occs[:, t] if occs is not None else None\n        vis = draw_keypoints_cv2(imgs[t], kps[:, t], colors, occ, radius=radius)\n        out_frames.append(vis)\n    return out_frames\n\n\ndef draw_keypoints_cv2(img, kps, colors=None, occs=None, radius=3):\n    \"\"\"\n    :param img (H, W, 3)\n    :param kps (N, 2)\n    :param occs (N)\n    :param colors (N, 3) from 0 to 1\n    \"\"\"\n    out_img = img.copy()\n    kps = kps.round().astype(\"int\").tolist()\n    if colors is not None:\n        colors = (255 * colors).astype(\"int\").tolist()\n    for n in range(len(kps)):\n        kp = kps[n]\n        color = colors[n] if colors is not None else (255, 0, 0)\n        thickness = -1 if occs is None or occs[n] == 0 else 1\n        out_img = cv2.circle(out_img, kp, radius, color, thickness, cv2.LINE_AA)\n    return out_img\n\n\ndef draw_tracks_2d(\n    img: torch.Tensor,\n    tracks_2d: torch.Tensor,\n    track_point_size: int = 2,\n    track_line_width: int = 1,\n    cmap_name: str = \"gist_rainbow\",\n):\n    cmap = colormaps.get_cmap(cmap_name)\n    # (H, W, 3).\n    img_np = (img.cpu().numpy() * 255.0).astype(np.uint8)\n    # (P, N, 2).\n    tracks_2d_np = tracks_2d.cpu().numpy()\n\n    num_tracks, num_frames = tracks_2d_np.shape[:2]\n\n    canvas = img_np.copy()\n    for i in range(num_frames - 1):\n        alpha = max(1 - 0.9 * ((num_frames - 1 - i) / (num_frames * 0.99)), 0.1)\n        img_curr = canvas.copy()\n        for j in range(num_tracks):\n            color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255)\n            color_alpha = 1\n            hsv = colorsys.rgb_to_hsv(color[0], color[1], color[2])\n            color = colorsys.hsv_to_rgb(hsv[0], hsv[1] * color_alpha, hsv[2])\n            pt1 = tracks_2d_np[j, i]\n            pt2 = tracks_2d_np[j, i + 1]\n            p1 = (int(round(pt1[0])), int(round(pt1[1])))\n            p2 = (int(round(pt2[0])), int(round(pt2[1])))\n            img_curr = cv2.line(\n                img_curr,\n                p1,\n                p2,\n                color,\n                thickness=track_line_width,\n                lineType=cv2.LINE_AA,\n            )\n        canvas = cv2.addWeighted(img_curr, alpha, canvas, 1 - alpha, 0)\n\n    for j in range(num_tracks):\n        color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255)\n        pt = tracks_2d_np[j, -1]\n        pt = (int(round(pt[0])), int(round(pt[1])))\n        canvas = cv2.circle(\n            canvas,\n            pt,\n            track_point_size,\n            color,\n            thickness=-1,\n            lineType=cv2.LINE_AA,\n        )\n\n    return canvas\n\n\ndef generate_line_verts_faces(starts, ends, line_width):\n    \"\"\"\n    Args:\n        starts: (P, N, 2).\n        ends: (P, N, 2).\n        line_width: int.\n\n    Returns:\n        verts: (P * N * 4, 2).\n        faces: (P * N * 2, 3).\n    \"\"\"\n    P, N, _ = starts.shape\n\n    directions = F.normalize(ends - starts, dim=-1)\n    deltas = (\n        torch.cat([-directions[..., 1:], directions[..., :1]], dim=-1)\n        * line_width\n        / 2.0\n    )\n    v0 = starts + deltas\n    v1 = starts - deltas\n    v2 = ends + deltas\n    v3 = ends - deltas\n    verts = torch.stack([v0, v1, v2, v3], dim=-2)\n    verts = verts.reshape(-1, 2)\n\n    faces = []\n    for p in range(P):\n        for n in range(N):\n            base_index = p * N * 4 + n * 4\n            # Two triangles per rectangle: (0, 1, 2) and (2, 1, 3)\n            faces.append([base_index, base_index + 1, base_index + 2])\n            faces.append([base_index + 2, base_index + 1, base_index + 3])\n    faces = torch.as_tensor(faces, device=starts.device)\n\n    return verts, faces\n\n\ndef generate_point_verts_faces(points, point_size, num_segments=10):\n    \"\"\"\n    Args:\n        points: (P, 2).\n        point_size: int.\n        num_segments: int.\n\n    Returns:\n        verts: (P * (num_segments + 1), 2).\n        faces: (P * num_segments, 3).\n    \"\"\"\n    P, _ = points.shape\n\n    angles = torch.linspace(0, 2 * torch.pi, num_segments + 1, device=points.device)[\n        ..., :-1\n    ]\n    unit_circle = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)\n    scaled_circles = (point_size / 2.0) * unit_circle\n    scaled_circles = scaled_circles[None].repeat(P, 1, 1)\n    verts = points[:, None] + scaled_circles\n    verts = torch.cat([verts, points[:, None]], dim=1)\n    verts = verts.reshape(-1, 2)\n\n    faces = F.pad(\n        torch.as_tensor(\n            [[i, (i + 1) % num_segments] for i in range(num_segments)],\n            device=points.device,\n        ),\n        (0, 1),\n        value=num_segments,\n    )\n    faces = faces[None, :] + torch.arange(P, device=points.device)[:, None, None] * (\n        num_segments + 1\n    )\n    faces = faces.reshape(-1, 3)\n\n    return verts, faces\n\n\ndef pixel_to_verts_clip(pixels, img_wh, z: float | torch.Tensor = 0.0, w=1.0):\n    verts_clip = pixels / pixels.new_tensor(img_wh) * 2.0 - 1.0\n    w = torch.full_like(verts_clip[..., :1], w)\n    verts_clip = torch.cat([verts_clip, z * w, w], dim=-1)\n    return verts_clip\n\n\ndef draw_tracks_2d_th(\n    img: torch.Tensor,\n    tracks_2d: torch.Tensor,\n    track_point_size: int = 5,\n    track_point_segments: int = 16,\n    track_line_width: int = 2,\n    cmap_name: str = \"gist_rainbow\",\n):\n    cmap = colormaps.get_cmap(cmap_name)\n    CTX = dr.RasterizeCudaContext()\n\n    W, H = img.shape[1], img.shape[0]\n    if W % 8 != 0 or H % 8 != 0:\n        # Make sure img is divisible by 8.\n        img = F.pad(\n            img,\n            (\n                0,\n                0,\n                0,\n                8 - W % 8 if W % 8 != 0 else 0,\n                0,\n                8 - H % 8 if H % 8 != 0 else 0,\n            ),\n            value=0.0,\n        )\n    num_tracks, num_frames = tracks_2d.shape[:2]\n\n    track_colors = torch.tensor(\n        [cmap(j / max(1, float(num_tracks - 1)))[:3] for j in range(num_tracks)],\n        device=img.device,\n    ).float()\n\n    # Generate line verts.\n    verts_l, faces_l = generate_line_verts_faces(\n        tracks_2d[:, :-1], tracks_2d[:, 1:], track_line_width\n    )\n    # Generate point verts.\n    verts_p, faces_p = generate_point_verts_faces(\n        tracks_2d[:, -1], track_point_size, track_point_segments\n    )\n\n    verts = torch.cat([verts_l, verts_p], dim=0)\n    faces = torch.cat([faces_l, faces_p + len(verts_l)], dim=0)\n    vert_colors = torch.cat(\n        [\n            (\n                track_colors[:, None]\n                .repeat_interleave(4 * (num_frames - 1), dim=1)\n                .reshape(-1, 3)\n            ),\n            (\n                track_colors[:, None]\n                .repeat_interleave(track_point_segments + 1, dim=1)\n                .reshape(-1, 3)\n            ),\n        ],\n        dim=0,\n    )\n    track_zs = torch.linspace(0.0, 1.0, num_tracks, device=img.device)[:, None]\n    vert_zs = torch.cat(\n        [\n            (\n                track_zs[:, None]\n                .repeat_interleave(4 * (num_frames - 1), dim=1)\n                .reshape(-1, 1)\n            ),\n            (\n                track_zs[:, None]\n                .repeat_interleave(track_point_segments + 1, dim=1)\n                .reshape(-1, 1)\n            ),\n        ],\n        dim=0,\n    )\n    track_alphas = torch.linspace(\n        max(0.1, 1.0 - (num_frames - 1) * 0.1), 1.0, num_frames, device=img.device\n    )\n    vert_alphas = torch.cat(\n        [\n            (\n                track_alphas[None, :-1, None]\n                .repeat_interleave(num_tracks, dim=0)\n                .repeat_interleave(4, dim=-2)\n                .reshape(-1, 1)\n            ),\n            (\n                track_alphas[None, -1:, None]\n                .repeat_interleave(num_tracks, dim=0)\n                .repeat_interleave(track_point_segments + 1, dim=-2)\n                .reshape(-1, 1)\n            ),\n        ],\n        dim=0,\n    )\n\n    # Small trick to always render one track in front of the other.\n    verts_clip = pixel_to_verts_clip(verts, (img.shape[1], img.shape[0]), vert_zs)\n    faces_int32 = faces.to(torch.int32)\n\n    rast, _ = cast(\n        tuple,\n        dr.rasterize(CTX, verts_clip[None], faces_int32, (img.shape[0], img.shape[1])),\n    )\n    rgba = cast(\n        torch.Tensor,\n        dr.interpolate(\n            torch.cat([vert_colors, vert_alphas], dim=-1).contiguous(),\n            rast,\n            faces_int32,\n        ),\n    )[0]\n    rgba = cast(torch.Tensor, dr.antialias(rgba, rast, verts_clip, faces_int32))[\n        0\n    ].clamp(0, 1)\n    # Compose.\n    color = rgba[..., :-1] * rgba[..., -1:] + (1.0 - rgba[..., -1:]) * img\n\n    # Unpad.\n    color = color[:H, :W]\n\n    return (color.cpu().numpy() * 255.0).astype(np.uint8)\n\n\ndef make_video_divisble(\n    video: torch.Tensor | np.ndarray, block_size=16\n) -> torch.Tensor | np.ndarray:\n    H, W = video.shape[1:3]\n    H_new = H - H % block_size\n    W_new = W - W % block_size\n    return video[:, :H_new, :W_new]\n\n\ndef apply_float_colormap(img: torch.Tensor, colormap: str = \"turbo\") -> torch.Tensor:\n    \"\"\"Convert single channel to a color img.\n\n    Args:\n        img (torch.Tensor): (..., 1) float32 single channel image.\n        colormap (str): Colormap for img.\n\n    Returns:\n        (..., 3) colored img with colors in [0, 1].\n    \"\"\"\n    img = torch.nan_to_num(img, 0)\n    if colormap == \"gray\":\n        return img.repeat(1, 1, 3)\n    img_long = (img * 255).long()\n    img_long_min = torch.min(img_long)\n    img_long_max = torch.max(img_long)\n    assert img_long_min >= 0, f\"the min value is {img_long_min}\"\n    assert img_long_max <= 255, f\"the max value is {img_long_max}\"\n    return torch.tensor(\n        colormaps[colormap].colors,  # type: ignore\n        device=img.device,\n    )[img_long[..., 0]]\n\n\ndef apply_depth_colormap(\n    depth: torch.Tensor,\n    acc: torch.Tensor | None = None,\n    near_plane: float | None = None,\n    far_plane: float | None = None,\n) -> torch.Tensor:\n    \"\"\"Converts a depth image to color for easier analysis.\n\n    Args:\n        depth (torch.Tensor): (..., 1) float32 depth.\n        acc (torch.Tensor | None): (..., 1) optional accumulation mask.\n        near_plane: Closest depth to consider. If None, use min image value.\n        far_plane: Furthest depth to consider. If None, use max image value.\n\n    Returns:\n        (..., 3) colored depth image with colors in [0, 1].\n    \"\"\"\n    near_plane = near_plane or float(torch.min(depth))\n    far_plane = far_plane or float(torch.max(depth))\n    depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)\n    depth = torch.clip(depth, 0.0, 1.0)\n    img = apply_float_colormap(depth, colormap=\"turbo\")\n    if acc is not None:\n        img = img * acc + (1.0 - acc)\n    return img\n\n\ndef float2uint8(x):\n    return (255.0 * x).astype(np.uint8)\n\n\ndef uint82float(img):\n    return np.ascontiguousarray(img) / 255.0\n\n\ndef drawMatches(\n    img1,\n    img2,\n    kp1,\n    kp2,\n    num_vis=200,\n    center=None,\n    idx_vis=None,\n    radius=2,\n    seed=1234,\n    mask=None,\n):\n    num_pts = len(kp1)\n    if idx_vis is None:\n        if num_vis < num_pts:\n            rng = np.random.RandomState(seed)\n            idx_vis = rng.choice(num_pts, num_vis, replace=False)\n        else:\n            idx_vis = np.arange(num_pts)\n\n    kp1_vis = kp1[idx_vis]\n    kp2_vis = kp2[idx_vis]\n\n    h1, w1 = img1.shape[:2]\n    h2, w2 = img2.shape[:2]\n\n    kp1_vis[:, 0] = np.clip(kp1_vis[:, 0], a_min=0, a_max=w1 - 1)\n    kp1_vis[:, 1] = np.clip(kp1_vis[:, 1], a_min=0, a_max=h1 - 1)\n\n    kp2_vis[:, 0] = np.clip(kp2_vis[:, 0], a_min=0, a_max=w2 - 1)\n    kp2_vis[:, 1] = np.clip(kp2_vis[:, 1], a_min=0, a_max=h2 - 1)\n\n    img1 = float2uint8(img1)\n    img2 = float2uint8(img2)\n\n    if center is None:\n        center = np.median(kp1, axis=0)\n\n    set_max = range(128)\n    colors = {m: i for i, m in enumerate(set_max)}\n    hsv = colormaps.get_cmap(\"hsv\")\n    colors = {\n        m: (255 * np.array(hsv(i / float(len(colors))))[:3][::-1]).astype(np.int32)\n        for m, i in colors.items()\n    }\n\n    if mask is not None:\n        ind = np.argsort(mask)[::-1]\n        kp1_vis = kp1_vis[ind]\n        kp2_vis = kp2_vis[ind]\n        mask = mask[ind]\n\n    for i, (pt1, pt2) in enumerate(zip(kp1_vis, kp2_vis)):\n        # random_color = tuple(np.random.randint(low=0, high=255, size=(3,)).tolist())\n        coord_angle = np.arctan2(pt1[1] - center[1], pt1[0] - center[0])\n        corr_color = np.int32(64 * coord_angle / np.pi) % 128\n        color = tuple(colors[corr_color].tolist())\n\n        if (\n            (pt1[0] <= w1 - 1)\n            and (pt1[0] >= 0)\n            and (pt1[1] <= h1 - 1)\n            and (pt1[1] >= 0)\n        ):\n            img1 = cv2.circle(\n                img1, (int(pt1[0]), int(pt1[1])), radius, color, -1, cv2.LINE_AA\n            )\n        if (\n            (pt2[0] <= w2 - 1)\n            and (pt2[0] >= 0)\n            and (pt2[1] <= h2 - 1)\n            and (pt2[1] >= 0)\n        ):\n            if mask is not None and mask[i]:\n                continue\n                # img2 = cv2.drawMarker(img2, (int(pt2[0]), int(pt2[1])), color, markerType=cv2.MARKER_CROSS,\n                #                       markerSize=int(5*radius), thickness=int(radius/2), line_type=cv2.LINE_AA)\n            else:\n                img2 = cv2.circle(\n                    img2, (int(pt2[0]), int(pt2[1])), radius, color, -1, cv2.LINE_AA\n                )\n\n    out = np.concatenate([img1, img2], axis=1)\n    return out\n\n\ndef plot_correspondences(\n    rgbs, kpts, query_id=0, masks=None, num_vis=1000000, radius=3, seed=1234\n):\n    num_rgbs = len(rgbs)\n    rng = np.random.RandomState(seed)\n    permutation = rng.permutation(kpts.shape[1])\n    kpts = kpts[:, permutation, :][:, :num_vis]\n    if masks is not None:\n        masks = masks[:, permutation][:, :num_vis]\n\n    rgbq = rgbs[query_id]  # [h, w, 3]\n    kptsq = kpts[query_id]  # [n, 2]\n\n    frames = []\n    for i in range(num_rgbs):\n        rgbi = rgbs[i]\n        kptsi = kpts[i]\n        if masks is not None:\n            maski = masks[i]\n        else:\n            maski = None\n        frame = drawMatches(\n            rgbq,\n            rgbi,\n            kptsq,\n            kptsi,\n            mask=maski,\n            num_vis=num_vis,\n            radius=radius,\n            seed=seed,\n        )\n        frames.append(frame)\n    return frames\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/flow3d/vis/viewer.py",
    "content": "from pathlib import Path\nfrom typing import Callable, Literal, Optional, Tuple, Union\n\nimport numpy as np\nfrom jaxtyping import Float32, UInt8\nfrom nerfview import CameraState, Viewer\nfrom viser import Icon, ViserServer\n\nfrom flow3d.vis.playback_panel import add_gui_playback_group\nfrom flow3d.vis.render_panel import populate_render_tab\n\n\nclass DynamicViewer(Viewer):\n    def __init__(\n        self,\n        server: ViserServer,\n        render_fn: Callable[\n            [CameraState, Tuple[int, int]],\n            Union[\n                UInt8[np.ndarray, \"H W 3\"],\n                Tuple[UInt8[np.ndarray, \"H W 3\"], Optional[Float32[np.ndarray, \"H W\"]]],\n            ],\n        ],\n        num_frames: int,\n        work_dir: str,\n        mode: Literal[\"rendering\", \"training\"] = \"rendering\",\n    ):\n        self.num_frames = num_frames\n        self.work_dir = Path(work_dir)\n        super().__init__(server, render_fn, mode)\n\n    def _define_guis(self):\n        super()._define_guis()\n        server = self.server\n        self._time_folder = server.gui.add_folder(\"Time\")\n        with self._time_folder:\n            self._playback_guis = add_gui_playback_group(\n                server,\n                num_frames=self.num_frames,\n                initial_fps=15.0,\n            )\n            self._playback_guis[0].on_update(self.rerender)\n            self._canonical_checkbox = server.gui.add_checkbox(\"Canonical\", False)\n            self._canonical_checkbox.on_update(self.rerender)\n\n            _cached_playback_disabled = []\n\n            def _toggle_gui_playing(event):\n                if event.target.value:\n                    nonlocal _cached_playback_disabled\n                    _cached_playback_disabled = [\n                        gui.disabled for gui in self._playback_guis\n                    ]\n                    target_disabled = [True] * len(self._playback_guis)\n                else:\n                    target_disabled = _cached_playback_disabled\n                for gui, disabled in zip(self._playback_guis, target_disabled):\n                    gui.disabled = disabled\n\n            self._canonical_checkbox.on_update(_toggle_gui_playing)\n\n        self._render_track_checkbox = server.gui.add_checkbox(\"Render tracks\", False)\n        self._render_track_checkbox.on_update(self.rerender)\n\n        tabs = server.gui.add_tab_group()\n        with tabs.add_tab(\"Render\", Icon.CAMERA):\n            self.render_tab_state = populate_render_tab(\n                server, Path(self.work_dir) / \"camera_paths\", self._playback_guis[0]\n            )\n"
  },
  {
    "path": "mvtracker/models/core/shape-of-motion/launch_davis.py",
    "content": "import os\nimport subprocess\nfrom concurrent.futures import ProcessPoolExecutor\nimport tyro\n\n\ndef main(\n    devices: list[int],\n    seqs: list[str] | None,\n    work_root: str,\n    davis_root: str = \"/shared/vye/datasets/DAVIS\",\n    image_name: str = \"JPEGImages\",\n    res: str = \"480p\",\n    depth_type: str = \"aligned_depth_anything\",\n):\n    img_dir = f\"{davis_root}/{image_name}/{res}\"\n    if seqs is None:\n        seqs = sorted(os.listdir(img_dir))\n    with ProcessPoolExecutor() as exc:\n        for i, seq_name in enumerate(seqs):\n            device = devices[i % len(devices)]\n            cmd = (\n                f\"CUDA_VISIBLE_DEVICES={device} python run_training.py \"\n                f\"--work-dir {work_root}/{seq_name} data:davis \"\n                f\"--data.seq_name {seq_name} --data.root_dir {davis_root} \"\n                f\"--data.res {res} --data.depth_type {depth_type}\"\n            )\n            print(cmd)\n            exc.submit(subprocess.call, cmd, shell=True)\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "mvtracker/models/core/spatracker/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/spatracker/blocks.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport collections\nfrom itertools import repeat\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\nto_2tuple = _ntuple(2)\n\n\nclass Mlp(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n            self,\n            in_features,\n            hidden_features=None,\n            out_features=None,\n            act_layer=nn.GELU,\n            bias=True,\n            drop=0.0,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn=\"group\", stride=1):\n        super(ResidualBlock, self).__init__()\n\n        self.conv1 = nn.Conv2d(\n            in_planes,\n            planes,\n            kernel_size=3,\n            padding=1,\n            stride=stride,\n            padding_mode=\"zeros\",\n        )\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode=\"zeros\")\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n\n        elif norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n\n        elif norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n\n        else:\n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3\n            )\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x + y)\n\n\nclass BasicEncoder(nn.Module):\n    def __init__(\n            self, input_dim=3, output_dim=128, stride=8, norm_fn=\"batch\", dropout=0.0,\n            Embed3D=False\n    ):\n        super(BasicEncoder, self).__init__()\n        self.stride = stride\n        self.norm_fn = norm_fn\n        self.in_planes = 64\n\n        if self.norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)\n            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)\n\n        elif self.norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(self.in_planes)\n            self.norm2 = nn.BatchNorm2d(output_dim * 2)\n\n        elif self.norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(self.in_planes)\n            self.norm2 = nn.InstanceNorm2d(output_dim * 2)\n\n        elif self.norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(\n            input_dim,\n            self.in_planes,\n            kernel_size=7,\n            stride=2,\n            padding=3,\n            padding_mode=\"zeros\",\n        )\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.shallow = False\n        if self.shallow:\n            self.layer1 = self._make_layer(64, stride=1)\n            self.layer2 = self._make_layer(96, stride=2)\n            self.layer3 = self._make_layer(128, stride=2)\n            self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)\n        else:\n            if Embed3D:\n                self.conv_fuse = nn.Conv2d(64 + 63,\n                                           self.in_planes, kernel_size=3, padding=1)\n            self.layer1 = self._make_layer(64, stride=1)\n            self.layer2 = self._make_layer(96, stride=2)\n            self.layer3 = self._make_layer(128, stride=2)\n            self.layer4 = self._make_layer(128, stride=2)\n            # TODO: Add 2 layers.\n            # self.layer5 = self._make_layer(128, stride=1)\n            # self.layer6 = self._make_layer(128, stride=1)\n            self.conv2 = nn.Conv2d(\n                128 + 128 + 96 + 64,\n                output_dim * 2,\n                kernel_size=3,\n                padding=1,\n                padding_mode=\"zeros\",\n            )\n            self.relu2 = nn.ReLU(inplace=True)\n            self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\",\n                                        nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n\n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n    def forward(self, x, feat_PE=None):\n        _, _, H, W = x.shape\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        if self.shallow:\n            a = self.layer1(x)\n            b = self.layer2(a)\n            c = self.layer3(b)\n            a = F.interpolate(\n                a,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            b = F.interpolate(\n                b,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            c = F.interpolate(\n                c,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            x = self.conv2(torch.cat([a, b, c], dim=1))\n        else:\n            if feat_PE is not None:\n                x = self.conv_fuse(torch.cat([x, feat_PE], dim=1))\n                a = self.layer1(x)\n            else:\n                a = self.layer1(x)\n            b = self.layer2(a)\n            c = self.layer3(b)\n            d = self.layer4(c)\n            a = F.interpolate(\n                a,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            b = F.interpolate(\n                b,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            c = F.interpolate(\n                c,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            d = F.interpolate(\n                d,\n                (H // self.stride, W // self.stride),\n                mode=\"bilinear\",\n                align_corners=True,\n            )\n            x = self.conv2(torch.cat([a, b, c, d], dim=1))\n            x = self.norm2(x)\n            x = self.relu2(x)\n            x = self.conv3(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n        return x\n\n\nclass DeeperBasicEncoder(nn.Module):\n    def __init__(\n            self, input_dim=3, output_dim=128, stride=8, norm_fn=\"batch\", dropout=0.0\n    ):\n        super(DeeperBasicEncoder, self).__init__()\n        self.stride = stride\n        self.norm_fn = norm_fn\n        self.in_planes = 64\n\n        if self.norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)\n            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)\n\n        elif self.norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(self.in_planes)\n            self.norm2 = nn.BatchNorm2d(output_dim * 2)\n\n        elif self.norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(self.in_planes)\n            self.norm2 = nn.InstanceNorm2d(output_dim * 2)\n\n        elif self.norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(\n            input_dim,\n            self.in_planes,\n            kernel_size=7,\n            stride=2,\n            padding=3,\n            padding_mode=\"zeros\",\n        )\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.layer1 = self._make_layer(64, stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n        self.layer4 = self._make_layer(128, stride=2)\n        self.layer5 = self._make_layer(128, stride=1)\n        self.layer6 = self._make_layer(64, stride=2)\n\n        self.conv2 = nn.Conv2d(\n            64 + 128 + 128 + 128 + 96 + 64,\n            output_dim * 2,\n            kernel_size=3,\n            padding=1,\n            padding_mode=\"zeros\",\n        )\n        self.relu2 = nn.ReLU(inplace=True)\n        self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\",\n                                        nonlinearity=\"relu\")\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n\n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n    def forward(self, x, feat_PE=None):\n        _, _, H, W = x.shape\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        if feat_PE is not None:\n            x = self.conv_fuse(torch.cat([x, feat_PE], dim=1))\n            a = self.layer1(x)\n        else:\n            a = self.layer1(x)\n        b = self.layer2(a)\n        c = self.layer3(b)\n        d = self.layer4(c)\n        e = self.layer5(d)\n        f = self.layer6(e)\n        a = F.interpolate(\n            a,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        b = F.interpolate(\n            b,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        c = F.interpolate(\n            c,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        d = F.interpolate(\n            d,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        e = F.interpolate(\n            e,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        f = F.interpolate(\n            f,\n            (H // self.stride, W // self.stride),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        x = self.conv2(torch.cat([a, b, c, d, e, f], dim=1))\n        x = self.norm2(x)\n        x = self.relu2(x)\n        x = self.conv3(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n        return x\n\n\nclass CorrBlock:\n    def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):\n        B, S, C, H_prev, W_prev = fmaps.shape\n        self.S, self.C, self.H, self.W = S, C, H_prev, W_prev\n\n        self.num_levels = num_levels\n        self.radius = radius\n        self.fmaps_pyramid = []\n        self.depth_pyramid = []\n        self.fmaps_pyramid.append(fmaps)\n        if depths_dnG is not None:\n            self.depth_pyramid.append(depths_dnG)\n        for i in range(self.num_levels - 1):\n            if depths_dnG is not None:\n                depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)\n                depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)\n                _, _, H, W = depths_dnG_.shape\n                depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)\n                self.depth_pyramid.append(depths_dnG)\n            fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)\n            fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)\n            _, _, H, W = fmaps_.shape\n            fmaps = fmaps_.reshape(B, S, C, H, W)\n            H_prev = H\n            W_prev = W\n            self.fmaps_pyramid.append(fmaps)\n\n    def sample(self, coords):\n        r = self.radius\n        B, S, N, D = coords.shape\n        assert D == 2\n\n        H, W = self.H, self.W\n        out_pyramid = []\n        for i in range(self.num_levels):\n            corrs = self.corrs_pyramid[i]  # B, S, N, H, W\n            _, _, _, H, W = corrs.shape\n\n            dx = torch.linspace(-r, r, 2 * r + 1)\n            dy = torch.linspace(-r, r, 2 * r + 1)\n            delta = torch.stack(torch.meshgrid(dy, dx, indexing=\"ij\"), axis=-1).to(\n                coords.device\n            )\n            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i\n            delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)\n            coords_lvl = centroid_lvl + delta_lvl\n            corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)\n            corrs = corrs.view(B, S, N, -1)\n            out_pyramid.append(corrs)\n\n        out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2\n        return out.contiguous().float()\n\n    def corr(self, targets):\n        B, S, N, C = targets.shape\n        assert C == self.C\n        assert S == self.S\n\n        fmap1 = targets\n\n        self.corrs_pyramid = []\n        for fmaps in self.fmaps_pyramid:\n            _, _, _, H, W = fmaps.shape\n            fmap2s = fmaps.view(B, S, C, H * W)\n            corrs = torch.matmul(fmap1, fmap2s)\n            corrs = corrs.view(B, S, N, H, W)\n            corrs = corrs / torch.sqrt(torch.tensor(C).float())\n            self.corrs_pyramid.append(corrs)\n\n    def corr_sample(self, targets, coords, coords_dp=None):\n        B, S, N, C = targets.shape\n        r = self.radius\n        Dim_c = (2 * r + 1) ** 2\n        assert C == self.C\n        assert S == self.S\n\n        out_pyramid = []\n        out_pyramid_dp = []\n        for i in range(self.num_levels):\n            dx = torch.linspace(-r, r, 2 * r + 1)\n            dy = torch.linspace(-r, r, 2 * r + 1)\n            delta = torch.stack(torch.meshgrid(dy, dx, indexing=\"ij\"), axis=-1).to(\n                coords.device\n            )\n            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i\n            delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)\n            coords_lvl = centroid_lvl + delta_lvl\n            fmaps = self.fmaps_pyramid[i]\n            _, _, _, H, W = fmaps.shape\n            fmap2s = fmaps.view(B * S, C, H, W)\n            if len(self.depth_pyramid) > 0:\n                depths_dnG_i = self.depth_pyramid[i]\n                depths_dnG_i = depths_dnG_i.view(B * S, 1, H, W)\n                dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B * S, 1, N * Dim_c, 2))\n                dp_corrs = (dnG_sample.view(B * S, N, -1) - coords_dp[0]).abs() / coords_dp[0]\n                out_pyramid_dp.append(dp_corrs)\n            fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B * S, 1, N * Dim_c, 2))\n            fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2)  # B*S, N*Dim_c, C, -1\n            corrs = torch.matmul(targets.reshape(B * S * N, 1, -1),\n                                 fmap2s_sample.reshape(B * S * N, Dim_c, -1).permute(0, 2, 1))\n            corrs = corrs / torch.sqrt(torch.tensor(C).float())\n            corrs = corrs.view(B, S, N, -1)\n            out_pyramid.append(corrs)\n\n        out = torch.cat(out_pyramid, dim=-1)  # B, S, N, LRR*2\n        if len(self.depth_pyramid) > 0:\n            out_dp = torch.cat(out_pyramid_dp, dim=-1)\n            self.fcorrD = out_dp.contiguous().float()\n        else:\n            self.fcorrD = torch.zeros_like(out).contiguous().float()\n        return out.contiguous().float()\n\n\nclass Attention(nn.Module):\n    def __init__(self, query_dim, num_heads=8, dim_head=48, qkv_bias=False, flash=False):\n        super().__init__()\n        inner_dim = self.inner_dim = dim_head * num_heads\n        self.scale = dim_head ** -0.5\n        self.heads = num_heads\n        self.flash = flash\n\n        self.qkv = nn.Linear(query_dim, inner_dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(inner_dim, query_dim)\n\n    def forward(self, x, attn_bias=None):\n        B, N1, _ = x.shape\n        C = self.inner_dim\n        h = self.heads\n\n        qkv = self.qkv(x).reshape(B, N1, 3, h, C // h)\n        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]\n        N2 = x.shape[1]\n\n        k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)\n        v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)\n        q = q.reshape(B, N1, h, C // h).permute(0, 2, 1, 3)\n        if self.flash == False:\n            sim = (q @ k.transpose(-2, -1)) * self.scale\n            if attn_bias is not None:\n                sim = sim + attn_bias\n            attn = sim.softmax(dim=-1)\n            x = (attn @ v).transpose(1, 2).reshape(B, N1, C)\n        else:\n            input_args = [x.half().contiguous() for x in [q, k, v]]\n            x = F.scaled_dot_product_attention(*input_args).permute(0, 2, 1, 3).reshape(B, N1, -1)  # type: ignore\n\n        return self.proj(x.float())\n\n\nclass AttnBlock(nn.Module):\n    \"\"\"\n    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.\n    \"\"\"\n\n    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,\n                 flash=False, **block_kwargs):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.flash = flash\n\n        self.attn = Attention(\n            hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,\n            **block_kwargs\n        )\n\n        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        approx_gelu = lambda: nn.GELU(approximate=\"tanh\")\n        self.mlp = Mlp(\n            in_features=hidden_size,\n            hidden_features=mlp_hidden_dim,\n            act_layer=approx_gelu,\n            drop=0,\n        )\n\n    def forward(self, x):\n        x = x + self.attn(self.norm1(x))\n        x = x + self.mlp(self.norm2(x))\n        return x\n\n\ndef bilinear_sampler(img, coords, mode=\"bilinear\", mask=False):\n    \"\"\"Wrapper for grid_sample, uses pixel coordinates\"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1, 1], dim=-1)\n    # go to 0,1 then 0,2 then -1,1\n    xgrid = 2 * xgrid / (W - 1) - 1\n    ygrid = 2 * ygrid / (H - 1) - 1\n\n    grid = torch.cat([xgrid, ygrid], dim=-1)\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\nclass EUpdateFormer(nn.Module):\n    \"\"\"\n    Transformer model that updates track estimates.\n    \"\"\"\n\n    def __init__(\n            self,\n            space_depth=12,\n            time_depth=12,\n            input_dim=320,\n            hidden_size=384,\n            num_heads=8,\n            output_dim=130,\n            mlp_ratio=4.0,\n            vq_depth=3,\n            add_space_attn=True,\n            add_time_attn=True,\n            flash=True\n    ):\n        super().__init__()\n        self.out_channels = 2\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.add_space_attn = add_space_attn\n        self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)\n        self.flash = flash\n        self.flow_head = nn.Sequential(\n            nn.Linear(hidden_size, output_dim, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Linear(output_dim, output_dim, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Linear(output_dim, output_dim, bias=True)\n        )\n\n        cross_attn_kwargs = {\n            \"d_model\": self.hidden_size,\n            \"nhead\": 4,\n            \"layer_names\": ['self', 'cross'] * 3,\n        }\n        from mvtracker.models.core.loftr import LocalFeatureTransformer\n        self.gnn = LocalFeatureTransformer(cross_attn_kwargs)\n\n        # Attention Modules in the temporal dimension         \n        self.time_blocks = nn.ModuleList(\n            [\n                AttnBlock(\n                    hidden_size,\n                    num_heads,\n                    mlp_ratio=mlp_ratio,\n                    flash=flash,\n                ) if add_time_attn else nn.Identity()\n                for _ in range(time_depth)\n            ]\n        )\n\n        if add_space_attn:\n            self.space_blocks = nn.ModuleList(\n                [\n                    AttnBlock(\n                        hidden_size,\n                        num_heads,\n                        mlp_ratio=mlp_ratio,\n                        flash=flash,\n                    )\n                    for _ in range(space_depth)\n                ]\n            )\n            assert len(self.time_blocks) >= len(self.space_blocks)\n\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        def _basic_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n\n        self.apply(_basic_init)\n\n    def forward(self, input_tensor, se3_feature):\n        \"\"\" Updating with Transformer\n\n        Args:\n            input_tensor: B, N, T, C\n            arap_embed: B, N, T, C\n        \"\"\"\n        B, N, T, C = input_tensor.shape\n        x = self.input_transform(input_tensor)\n        tokens = x\n        K = 0\n        j = 0\n        for i in range(len(self.time_blocks)):\n            tokens_time = rearrange(tokens, \"b n t c -> (b n) t c\", b=B, t=T, n=N + K)\n            tokens_time = self.time_blocks[i](tokens_time)\n            tokens = rearrange(tokens_time, \"(b n) t c -> b n t c \", b=B, t=T, n=N + K)\n            if self.add_space_attn and (\n                    i % (len(self.time_blocks) // len(self.space_blocks)) == 0\n            ):\n                tokens_space = rearrange(tokens, \"b n t c -> (b t) n c \", b=B, t=T, n=N)\n                tokens_space = self.space_blocks[j](tokens_space)\n                tokens = rearrange(tokens_space, \"(b t) n c -> b n t c  \", b=B, t=T, n=N)\n                j += 1\n\n        B, N, S, _ = tokens.shape\n\n        feat0, feat1 = self.gnn(tokens.view(B * N * S, -1)[None, ...], se3_feature[None, ...])\n        flow = self.flow_head(feat0.view(B, N, S, -1))\n\n        return flow, feat1\n\n\ndef pix2cam(coords,\n            intr):\n    \"\"\"\n    Args:\n        coords: [B, T, N, 3]\n        intr: [B, T, 3, 3]\n    \"\"\"\n    B, S, N, _, = coords.shape\n    assert coords.shape == (B, S, N, 3)\n    assert intr.shape == (B, S, 3, 3)\n\n    coords = coords.detach()\n    xy_src = coords.reshape(B * S * N, 3)\n    intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)\n    xy_src = torch.cat([xy_src[..., :2], torch.ones_like(xy_src[..., :1])], dim=-1)\n    xyz_src = (torch.inverse(intr) @ xy_src[..., None])[..., 0]\n    dp_pred = coords[..., 2]\n    xyz_src_ = (xyz_src * (dp_pred.reshape(B * S * N, 1)))\n    xyz_src_ = xyz_src_.reshape(B, S, N, 3)\n    return xyz_src_\n\n\ndef cam2pix(coords,\n            intr):\n    \"\"\"\n    Args:\n        coords: [B, T, N, 3]\n        intr: [B, T, 3, 3]\n    \"\"\"\n    coords = coords.detach()\n    B, S, N, _, = coords.shape\n    xy_src = coords.reshape(B * S * N, 3).clone()\n    intr = intr[:, :, None, ...].repeat(1, 1, N, 1, 1).reshape(B * S * N, 3, 3)\n    xy_src = xy_src / (xy_src[..., 2:] + 1e-5)\n    xyz_src = (intr @ xy_src[..., None])[..., 0]\n    dp_pred = coords[..., 2]\n    xyz_src[..., 2] *= dp_pred.reshape(S * N)\n    xyz_src = xyz_src.reshape(B, S, N, 3)\n    return xyz_src\n"
  },
  {
    "path": "mvtracker/models/core/spatracker/softsplat.py",
    "content": "#!/usr/bin/env python\n\n\"\"\"The code of softsplat function is modified from:\nhttps://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py\n\n\"\"\"\n\nimport collections\nimport os\nimport re\nimport typing\n\nimport cupy\nimport torch\n\nobjCudacache = {}\n\n\ndef cuda_int32(intIn: int):\n    return cupy.int32(intIn)\n\n\ndef cuda_float32(fltIn: float):\n    return cupy.float32(fltIn)\n\n\ndef cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):\n    if 'device' not in objCudacache:\n        objCudacache['device'] = torch.cuda.get_device_name()\n\n    strKey = strFunction\n\n    for strVariable in objVariables:\n        objValue = objVariables[strVariable]\n\n        strKey += strVariable\n\n        if objValue is None:\n            continue\n\n        elif type(objValue) == int:\n            strKey += str(objValue)\n\n        elif type(objValue) == float:\n            strKey += str(objValue)\n\n        elif type(objValue) == bool:\n            strKey += str(objValue)\n\n        elif type(objValue) == str:\n            strKey += objValue\n\n        elif type(objValue) == torch.Tensor:\n            strKey += str(objValue.dtype)\n            strKey += str(objValue.shape)\n            strKey += str(objValue.stride())\n\n        elif True:\n            print(strVariable, type(objValue))\n            assert (False)\n\n    strKey += objCudacache['device']\n\n    if strKey not in objCudacache:\n        for strVariable in objVariables:\n            objValue = objVariables[strVariable]\n\n            if objValue is None:\n                continue\n\n            elif type(objValue) == int:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == float:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == bool:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == str:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:\n                strKernel = strKernel.replace('{{type}}', 'unsigned char')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:\n                strKernel = strKernel.replace('{{type}}', 'half')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:\n                strKernel = strKernel.replace('{{type}}', 'float')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:\n                strKernel = strKernel.replace('{{type}}', 'double')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:\n                strKernel = strKernel.replace('{{type}}', 'int')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:\n                strKernel = strKernel.replace('{{type}}', 'long')\n\n            elif type(objValue) == torch.Tensor:\n                print(strVariable, objValue.dtype)\n                assert (False)\n\n            elif True:\n                print(strVariable, type(objValue))\n                assert (False)\n\n        while True:\n            objMatch = re.search('(SIZE_)([0-4])(\\()([^\\)]*)(\\))', strKernel)\n\n            if objMatch is None:\n                break\n\n            intArg = int(objMatch.group(2))\n\n            strTensor = objMatch.group(4)\n            intSizes = objVariables[strTensor].size()\n\n            strKernel = strKernel.replace(objMatch.group(), str(\n                intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))\n\n        while True:\n            objMatch = re.search('(OFFSET_)([0-4])(\\()', strKernel)\n\n            if objMatch is None:\n                break\n\n            intStart = objMatch.span()[1]\n            intStop = objMatch.span()[1]\n            intParentheses = 1\n\n            while True:\n                intParentheses += 1 if strKernel[intStop] == '(' else 0\n                intParentheses -= 1 if strKernel[intStop] == ')' else 0\n\n                if intParentheses == 0:\n                    break\n\n                intStop += 1\n\n            intArgs = int(objMatch.group(2))\n            strArgs = strKernel[intStart:intStop].split(',')\n\n            assert (intArgs == len(strArgs) - 1)\n\n            strTensor = strArgs[0]\n            intStrides = objVariables[strTensor].stride()\n\n            strIndex = []\n\n            for intArg in range(intArgs):\n                strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(\n                    intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[\n                        intArg].item()) + ')')\n\n            strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')',\n                                          '(' + str.join('+', strIndex) + ')')\n\n        while True:\n            objMatch = re.search('(VALUE_)([0-4])(\\()', strKernel)\n\n            if objMatch is None:\n                break\n\n            intStart = objMatch.span()[1]\n            intStop = objMatch.span()[1]\n            intParentheses = 1\n\n            while True:\n                intParentheses += 1 if strKernel[intStop] == '(' else 0\n                intParentheses -= 1 if strKernel[intStop] == ')' else 0\n\n                if intParentheses == 0:\n                    break\n\n                intStop += 1\n\n            intArgs = int(objMatch.group(2))\n            strArgs = strKernel[intStart:intStop].split(',')\n\n            assert (intArgs == len(strArgs) - 1)\n\n            strTensor = strArgs[0]\n            intStrides = objVariables[strTensor].stride()\n\n            strIndex = []\n\n            for intArg in range(intArgs):\n                strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(\n                    intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[\n                        intArg].item()) + ')')\n\n            strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')',\n                                          strTensor + '[' + str.join('+', strIndex) + ']')\n\n        objCudacache[strKey] = {\n            'strFunction': strFunction,\n            'strKernel': strKernel\n        }\n\n    return strKey\n\n\n@cupy.memoize(for_each_device=True)\ndef cuda_launch(strKey: str):\n    if 'CUDA_HOME' not in os.environ:\n        os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()\n\n    return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(\n        ['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(\n        objCudacache[strKey]['strFunction'])\n\n\n##########################################################\n\n\ndef softsplat(\n        tenIn: torch.Tensor,\n        tenFlow: torch.Tensor,\n        tenMetric: typing.Optional[torch.Tensor],\n        strMode: str,\n        tenoutH=None,\n        tenoutW=None,\n        use_pointcloud_splatting=False,\n        return_normalization_tensor=False,\n):\n    assert (strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])\n\n    if strMode == 'sum': assert (tenMetric is None)\n    if strMode == 'avg': assert (tenMetric is None)\n    if strMode.split('-')[0] == 'linear': assert (tenMetric is not None)\n    if strMode.split('-')[0] == 'soft': assert (tenMetric is not None)\n\n    if strMode == 'avg':\n        tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)\n\n    elif strMode.split('-')[0] == 'linear':\n        tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)\n\n    elif strMode.split('-')[0] == 'soft':\n        tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)\n\n    # If tenIn only contains a HW grid where each position in the grid will be\n    # taken into account for splatting as (grid_x + flow_x, grid_y + flow_y),\n    # then we use the original softsplat function which was designed for this.\n    # Otherwise, we assume the positions of the points in the grid do not matter\n    # and only the flow should be taken into account as (flow_x, flow_y)\n    # to determine the splatted position.\n    if use_pointcloud_splatting:\n        tenOut = softsplat_pointcloud_func.apply(tenIn, tenFlow, tenoutH, tenoutW)\n    else:\n        tenOut = softsplat_func.apply(tenIn, tenFlow, tenoutH, tenoutW)\n\n    if strMode.split('-')[0] in ['avg', 'linear', 'soft']:\n        tenNormalize = tenOut[:, -1:, :, :]\n\n        if len(strMode.split('-')) == 1:\n            tenNormalize = tenNormalize + 0.0001\n\n        elif strMode.split('-')[1] == 'addeps':\n            tenNormalize = tenNormalize + 0.0001\n\n        elif strMode.split('-')[1] == 'zeroeps':\n            tenNormalize[tenNormalize == 0.0] = 1.0\n\n        elif strMode.split('-')[1] == 'clipeps':\n            tenNormalize = tenNormalize.clip(0.0001, None)\n\n        tenOut = tenOut[:, :-1, :, :] / tenNormalize\n\n    if return_normalization_tensor:\n        return tenOut, tenNormalize\n    else:\n        return tenOut\n\n\nclass softsplat_func(torch.autograd.Function):\n    @staticmethod\n    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n    def forward(self, tenIn, tenFlow, H=None, W=None):\n        if H is None:\n            tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])\n        else:\n            tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W])\n\n        if tenIn.is_cuda == True:\n            cuda_launch(cuda_kernel('softsplat_out', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_out(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    {{type}}* __restrict__ tenOut\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn);\n                    const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn)                  ) % SIZE_1(tenIn);\n                    const int intY = ( intIndex / SIZE_3(tenIn)                                   ) % SIZE_2(tenIn);\n                    const int intX = ( intIndex                                                    ) % SIZE_3(tenIn);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);\n                    {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);\n                    {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));\n                    {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);\n                    }\n\n                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);\n                    }\n\n                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);\n                    }\n\n                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);\n                    }\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOut': tenOut\n            }))(\n                grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int32(tenIn.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n\n        elif tenIn.is_cuda != True:\n            assert (False)\n\n        self.save_for_backward(tenIn, tenFlow)\n\n        return tenOut\n\n    @staticmethod\n    @torch.cuda.amp.custom_bwd\n    def backward(self, tenOutgrad):\n        tenIn, tenFlow = self.saved_tensors\n\n        tenOutgrad = tenOutgrad.contiguous();\n        assert (tenOutgrad.is_cuda == True)\n\n        tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if \\\n            self.needs_input_grad[0] == True else None\n        tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if \\\n            self.needs_input_grad[1] == True else None\n        Hgrad = None\n        Wgrad = None\n\n        if tenIngrad is not None:\n            cuda_launch(cuda_kernel('softsplat_ingrad', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_ingrad(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    const {{type}}* __restrict__ tenOutgrad,\n                    {{type}}* __restrict__ tenIngrad,\n                    {{type}}* __restrict__ tenFlowgrad\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);\n                    const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad)                     ) % SIZE_1(tenIngrad);\n                    const int intY = ( intIndex / SIZE_3(tenIngrad)                                         ) % SIZE_2(tenIngrad);\n                    const int intX = ( intIndex                                                             ) % SIZE_3(tenIngrad);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltIngrad = 0.0f;\n\n                    {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);\n                    {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);\n                    {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));\n                    {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;\n                    }\n\n                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;\n                    }\n\n                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;\n                    }\n\n                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;\n                    }\n\n                    tenIngrad[intIndex] = fltIngrad;\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOutgrad': tenOutgrad,\n                'tenIngrad': tenIngrad,\n                'tenFlowgrad': tenFlowgrad\n            }))(\n                grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(),\n                      tenIngrad.data_ptr(), None],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n\n        if tenFlowgrad is not None:\n            cuda_launch(cuda_kernel('softsplat_flowgrad', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_flowgrad(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    const {{type}}* __restrict__ tenOutgrad,\n                    {{type}}* __restrict__ tenIngrad,\n                    {{type}}* __restrict__ tenFlowgrad\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);\n                    const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad)                       ) % SIZE_1(tenFlowgrad);\n                    const int intY = ( intIndex / SIZE_3(tenFlowgrad)                                             ) % SIZE_2(tenFlowgrad);\n                    const int intX = ( intIndex                                                                   ) % SIZE_3(tenFlowgrad);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltFlowgrad = 0.0f;\n\n                    {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = 0.0f;\n                    {{type}} fltNortheast = 0.0f;\n                    {{type}} fltSouthwest = 0.0f;\n                    {{type}} fltSoutheast = 0.0f;\n\n                    if (intC == 0) {\n                        fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);\n                        fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);\n                        fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));\n                        fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    } else if (intC == 1) {\n                        fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));\n                        fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));\n                        fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));\n                        fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));\n\n                    }\n\n                    for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {\n                        {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);\n\n                        if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;\n                        }\n\n                        if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;\n                        }\n\n                        if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;\n                        }\n\n                        if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;\n                        }\n                    }\n\n                    tenFlowgrad[intIndex] = fltFlowgrad;\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOutgrad': tenOutgrad,\n                'tenIngrad': tenIngrad,\n                'tenFlowgrad': tenFlowgrad\n            }))(\n                grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(),\n                      None, tenFlowgrad.data_ptr()],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n        return tenIngrad, tenFlowgrad, Hgrad, Wgrad\n\n\ndef cuda_int64(intIn: int):\n    return cupy.int64(intIn)\n\n\ndef cuda_kernel_longlong(strFunction: str, strKernel: str, objVariables: typing.Dict):\n    if 'device' not in objCudacache:\n        objCudacache['device'] = torch.cuda.get_device_name()\n\n    strKey = strFunction\n\n    for strVariable in objVariables:\n        objValue = objVariables[strVariable]\n\n        strKey += strVariable\n\n        if objValue is None:\n            continue\n\n        elif type(objValue) == int:\n            strKey += str(objValue)\n\n        elif type(objValue) == float:\n            strKey += str(objValue)\n\n        elif type(objValue) == bool:\n            strKey += str(objValue)\n\n        elif type(objValue) == str:\n            strKey += objValue\n\n        elif type(objValue) == torch.Tensor:\n            strKey += str(objValue.dtype)\n            strKey += str(objValue.shape)\n            strKey += str(objValue.stride())\n\n        elif True:\n            print(strVariable, type(objValue))\n            assert (False)\n\n    strKey += objCudacache['device']\n\n    if strKey not in objCudacache:\n        for strVariable in objVariables:\n            objValue = objVariables[strVariable]\n\n            if objValue is None:\n                continue\n\n            elif type(objValue) == int:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == float:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == bool:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))\n\n            elif type(objValue) == str:\n                strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:\n                strKernel = strKernel.replace('{{type}}', 'unsigned char')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:\n                strKernel = strKernel.replace('{{type}}', 'half')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:\n                strKernel = strKernel.replace('{{type}}', 'float')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:\n                strKernel = strKernel.replace('{{type}}', 'double')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:\n                strKernel = strKernel.replace('{{type}}', 'int')\n\n            elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:\n                strKernel = strKernel.replace('{{type}}', 'long')\n\n            elif type(objValue) == torch.Tensor:\n                print(strVariable, objValue.dtype)\n                assert (False)\n\n            elif True:\n                print(strVariable, type(objValue))\n                assert (False)\n\n        while True:\n            objMatch = re.search('(SIZE_)([0-4])(\\()([^\\)]*)(\\))', strKernel)\n\n            if objMatch is None:\n                break\n\n            intArg = int(objMatch.group(2))\n\n            strTensor = objMatch.group(4)\n            intSizes = objVariables[strTensor].size()\n\n            strKernel = strKernel.replace(objMatch.group(), str(\n                intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))\n\n        while True:\n            objMatch = re.search('(OFFSET_)([0-4])(\\()', strKernel)\n\n            if objMatch is None:\n                break\n\n            intStart = objMatch.span()[1]\n            intStop = objMatch.span()[1]\n            intParentheses = 1\n\n            while True:\n                intParentheses += 1 if strKernel[intStop] == '(' else 0\n                intParentheses -= 1 if strKernel[intStop] == ')' else 0\n\n                if intParentheses == 0:\n                    break\n\n                intStop += 1\n\n            intArgs = int(objMatch.group(2))\n            strArgs = strKernel[intStart:intStop].split(',')\n\n            assert (intArgs == len(strArgs) - 1)\n\n            strTensor = strArgs[0]\n            intStrides = objVariables[strTensor].stride()\n\n            strIndex = []\n\n            for intArg in range(intArgs):\n                idx_expr = strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip()\n                stride_val = (\n                    intStrides[intArg] if not torch.is_tensor(intStrides[intArg]) else intStrides[intArg].item()\n                )\n                strIndex.append(\n                    '(static_cast<long long int>(' + idx_expr + ') * ' + str(stride_val) + ')'\n                )\n\n            strKernel = strKernel.replace(\n                'OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')',\n                '(' + ' + '.join(strIndex) + ')'\n            )\n\n        while True:\n            objMatch = re.search('(VALUE_)([0-4])(\\()', strKernel)\n\n            if objMatch is None:\n                break\n\n            intStart = objMatch.span()[1]\n            intStop = objMatch.span()[1]\n            intParentheses = 1\n\n            while True:\n                intParentheses += 1 if strKernel[intStop] == '(' else 0\n                intParentheses -= 1 if strKernel[intStop] == ')' else 0\n\n                if intParentheses == 0:\n                    break\n\n                intStop += 1\n\n            intArgs = int(objMatch.group(2))\n            strArgs = strKernel[intStart:intStop].split(',')\n\n            assert (intArgs == len(strArgs) - 1)\n\n            strTensor = strArgs[0]\n            intStrides = objVariables[strTensor].stride()\n\n            strIndex = []\n\n            for intArg in range(intArgs):\n                idx_expr = strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip()\n                stride_val = (\n                    intStrides[intArg] if not torch.is_tensor(intStrides[intArg]) else intStrides[intArg].item()\n                )\n                strIndex.append(\n                    '(static_cast<long long int>(' + idx_expr + ') * ' + str(stride_val) + ')'\n                )\n\n            strKernel = strKernel.replace(\n                'VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')',\n                strTensor + '[' + ' + '.join(strIndex) + ']'\n            )\n\n        objCudacache[strKey] = {\n            'strFunction': strFunction,\n            'strKernel': strKernel\n        }\n\n    return strKey\n\n\nclass softsplat_pointcloud_func(torch.autograd.Function):\n    @staticmethod\n    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n    def forward(self, tenIn, tenFlow, H=None, W=None):\n        if H is None:\n            tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])\n        else:\n            tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], H, W])\n\n        if tenIn.is_cuda == True:\n            cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_out', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_pointcloud_out(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    {{type}}* __restrict__ tenOut\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn) / SIZE_1(tenIn) ) % SIZE_0(tenIn);\n                    const int intC = ( intIndex / SIZE_3(tenIn) / SIZE_2(tenIn)                  ) % SIZE_1(tenIn);\n                    const int intY = ( intIndex / SIZE_3(tenIn)                                   ) % SIZE_2(tenIn);\n                    const int intX = ( intIndex                                                    ) % SIZE_3(tenIn);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);\n                    {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);\n                    {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));\n                    {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);\n                    }\n\n                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);\n                    }\n\n                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);\n                    }\n\n                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {\n                        atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);\n                    }\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOut': tenOut\n            }))(\n                grid=tuple([int((tenIn.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int64(tenIn.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n\n        elif tenIn.is_cuda != True:\n            assert (False)\n\n        self.save_for_backward(tenIn, tenFlow)\n\n        return tenOut\n\n    @staticmethod\n    @torch.cuda.amp.custom_bwd\n    def backward(self, tenOutgrad):\n        tenIn, tenFlow = self.saved_tensors\n\n        tenOutgrad = tenOutgrad.contiguous();\n        assert (tenOutgrad.is_cuda == True)\n\n        tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if \\\n            self.needs_input_grad[0] == True else None\n        tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if \\\n            self.needs_input_grad[1] == True else None\n        Hgrad = None\n        Wgrad = None\n\n        if tenIngrad is not None:\n            cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_ingrad', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_pointcloud_ingrad(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    const {{type}}* __restrict__ tenOutgrad,\n                    {{type}}* __restrict__ tenIngrad,\n                    {{type}}* __restrict__ tenFlowgrad\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);\n                    const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad)                     ) % SIZE_1(tenIngrad);\n                    const int intY = ( intIndex / SIZE_3(tenIngrad)                                         ) % SIZE_2(tenIngrad);\n                    const int intX = ( intIndex                                                             ) % SIZE_3(tenIngrad);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltIngrad = 0.0f;\n\n                    {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);\n                    {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);\n                    {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));\n                    {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;\n                    }\n\n                    if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;\n                    }\n\n                    if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;\n                    }\n\n                    if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {\n                        fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;\n                    }\n\n                    tenIngrad[intIndex] = fltIngrad;\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOutgrad': tenOutgrad,\n                'tenIngrad': tenIngrad,\n                'tenFlowgrad': tenFlowgrad\n            }))(\n                grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int64(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(),\n                      tenIngrad.data_ptr(), None],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n\n        if tenFlowgrad is not None:\n            cuda_launch(cuda_kernel_longlong('softsplat_pointcloud_flowgrad', '''\n                extern \"C\" __global__ void __launch_bounds__(512) softsplat_flowgrad(\n                    const long long int n,\n                    const {{type}}* __restrict__ tenIn,\n                    const {{type}}* __restrict__ tenFlow,\n                    const {{type}}* __restrict__ tenOutgrad,\n                    {{type}}* __restrict__ tenIngrad,\n                    {{type}}* __restrict__ tenFlowgrad\n                ) { for (long long int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {\n                    const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);\n                    const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad)                       ) % SIZE_1(tenFlowgrad);\n                    const int intY = ( intIndex / SIZE_3(tenFlowgrad)                                             ) % SIZE_2(tenFlowgrad);\n                    const int intX = ( intIndex                                                                   ) % SIZE_3(tenFlowgrad);\n\n                    assert(SIZE_1(tenFlow) == 2);\n\n                    {{type}} fltFlowgrad = 0.0f;\n\n                    {{type}} fltX = ({{type}}) VALUE_4(tenFlow, intN, 0, intY, intX);\n                    {{type}} fltY = ({{type}}) VALUE_4(tenFlow, intN, 1, intY, intX);\n\n                    if (isfinite(fltX) == false) { return; }\n                    if (isfinite(fltY) == false) { return; }\n\n                    int intNorthwestX = (int) (floor(fltX));\n                    int intNorthwestY = (int) (floor(fltY));\n                    int intNortheastX = intNorthwestX + 1;\n                    int intNortheastY = intNorthwestY;\n                    int intSouthwestX = intNorthwestX;\n                    int intSouthwestY = intNorthwestY + 1;\n                    int intSoutheastX = intNorthwestX + 1;\n                    int intSoutheastY = intNorthwestY + 1;\n\n                    {{type}} fltNorthwest = 0.0f;\n                    {{type}} fltNortheast = 0.0f;\n                    {{type}} fltSouthwest = 0.0f;\n                    {{type}} fltSoutheast = 0.0f;\n\n                    if (intC == 0) {\n                        fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);\n                        fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);\n                        fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));\n                        fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));\n\n                    } else if (intC == 1) {\n                        fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));\n                        fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));\n                        fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));\n                        fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));\n\n                    }\n\n                    for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {\n                        {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);\n\n                        if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;\n                        }\n\n                        if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;\n                        }\n\n                        if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;\n                        }\n\n                        if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {\n                            fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;\n                        }\n                    }\n\n                    tenFlowgrad[intIndex] = fltFlowgrad;\n                } }\n            ''', {\n                'tenIn': tenIn,\n                'tenFlow': tenFlow,\n                'tenOutgrad': tenOutgrad,\n                'tenIngrad': tenIngrad,\n                'tenFlowgrad': tenFlowgrad\n            }))(\n                grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),\n                block=tuple([512, 1, 1]),\n                args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(),\n                      None, tenFlowgrad.data_ptr()],\n                stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)\n            )\n        return tenIngrad, tenFlowgrad, Hgrad, Wgrad\n"
  },
  {
    "path": "mvtracker/models/core/spatracker/spatracker_monocular.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\nimport logging\nimport warnings\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import nn as nn\n\nfrom mvtracker.models.core.embeddings import (\n    get_3d_embedding,\n    get_1d_sincos_pos_embed_from_grid,\n    get_2d_sincos_pos_embed,\n    get_3d_sincos_pos_embed_from_grid,\n    Embedder_Fourier,\n)\nfrom mvtracker.models.core.model_utils import (\n    bilinear_sample2d, smart_cat, sample_features5d, pixel_xy_and_camera_z_to_world_space\n)\nfrom mvtracker.models.core.spatracker.blocks import (\n    BasicEncoder,\n    CorrBlock,\n    EUpdateFormer,\n    pix2cam,\n    cam2pix\n)\nfrom mvtracker.models.core.spatracker.softsplat import softsplat\n\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\ndef sample_pos_embed(grid_size, embed_dim, coords):\n    if coords.shape[-1] == 2:\n        pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim,\n                                            grid_size=grid_size)\n        pos_embed = (\n            torch.from_numpy(pos_embed)\n            .reshape(grid_size[0], grid_size[1], embed_dim)\n            .float()\n            .unsqueeze(0)\n            .to(coords.device)\n        )\n        sampled_pos_embed = bilinear_sample2d(\n            pos_embed.permute(0, 3, 1, 2),\n            coords[:, 0, :, 0], coords[:, 0, :, 1]\n        )\n    elif coords.shape[-1] == 3:\n        sampled_pos_embed = get_3d_sincos_pos_embed_from_grid(\n            embed_dim, coords[:, :1, ...]\n        ).float()[:, 0, ...].permute(0, 2, 1)\n\n    return sampled_pos_embed\n\n\nclass SpaTracker(nn.Module):\n    def __init__(\n            self,\n            sliding_window_len=8,\n            stride=8,\n            add_space_attn=True,\n            num_heads=8,\n            hidden_size=384,\n            space_depth=12,\n            time_depth=12,\n            triplane_zres=128,\n    ):\n        super(SpaTracker, self).__init__()\n\n        self.S = sliding_window_len\n        self.stride = stride\n        self.hidden_dim = 256\n        self.latent_dim = latent_dim = 128\n        self.b_latent_dim = self.latent_dim // 3\n        self.corr_levels = 4\n        self.corr_radius = 3\n        self.add_space_attn = add_space_attn\n        self.triplane_zres = triplane_zres\n\n        # @Encoder\n        self.fnet = BasicEncoder(input_dim=3,\n                                 output_dim=self.latent_dim, norm_fn=\"instance\", dropout=0,\n                                 stride=stride, Embed3D=False\n                                 )\n\n        # conv head for the tri-plane features\n        self.headyz = nn.Sequential(\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))\n\n        self.headxz = nn.Sequential(\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1))\n\n        # @UpdateFormer\n        self.updateformer = EUpdateFormer(\n            space_depth=space_depth,\n            time_depth=time_depth,\n            input_dim=456,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            output_dim=latent_dim + 3,\n            mlp_ratio=4.0,\n            add_space_attn=add_space_attn,\n            flash=True\n        )\n        self.support_features = torch.zeros(100, 384).to(\"cuda\") + 0.1\n\n        self.norm = nn.GroupNorm(1, self.latent_dim)\n\n        self.ffeat_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n        self.ffeatyz_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n        self.ffeatxz_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n\n        # TODO @NeuralArap: optimize the arap\n        self.embed_traj = Embedder_Fourier(\n            input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True\n        )\n        self.embed3d = Embedder_Fourier(\n            input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True\n        )\n        self.embedConv = nn.Conv2d(self.latent_dim + 63,\n                                   self.latent_dim, 3, padding=1)\n\n        # @Vis_predictor\n        self.vis_predictor = nn.Sequential(\n            nn.Linear(128, 1),\n        )\n\n        self.embedProj = nn.Linear(63, 456)\n        self.zeroMLPflow = nn.Linear(195, 130)\n\n    def prepare_track(self, rgbds, queries):\n        \"\"\"\n        NOTE:\n        Normalized the rgbs and sorted the queries via their first appeared time\n        Args:\n            rgbds: the input rgbd images (B T 4 H W)\n            queries: the input queries (B N 4)\n        Return:\n            rgbds: the normalized rgbds (B T 4 H W)\n            queries: the sorted queries (B N 4)\n            track_mask:\n        \"\"\"\n        assert (rgbds.shape[2] == 4) and (queries.shape[2] == 4)\n        # Step1: normalize the rgbs input\n        device = rgbds.device\n        rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0\n        B, T, C, H, W = rgbds.shape\n        B, N, __ = queries.shape\n        self.traj_e = torch.zeros((B, T, N, 3), device=device)\n        self.vis_e = torch.zeros((B, T, N), device=device)\n\n        # Step2: sort the points via their first appeared time\n        first_positive_inds = queries[0, :, 0].long()\n        __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False)\n        inv_sort_inds = torch.argsort(sort_inds, dim=0)\n        first_positive_sorted_inds = first_positive_inds[sort_inds]\n        # check if can be inverse\n        assert torch.allclose(\n            first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds]\n        )\n\n        # filter those points never appear points during 1 - T\n        ind_array = torch.arange(T, device=device)\n        ind_array = ind_array[None, :, None].repeat(B, 1, N)\n        track_mask = (ind_array >=\n                      first_positive_inds[None, None, :]).unsqueeze(-1)\n\n        # scale the coords_init\n        coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat(\n            1, self.S, 1, 1\n        )\n        coords_init[..., :2] /= float(self.stride)\n\n        # Step3: initial the regular grid\n        gridx = torch.linspace(0, W // self.stride - 1, W // self.stride)\n        gridy = torch.linspace(0, H // self.stride - 1, H // self.stride)\n        gridx, gridy = torch.meshgrid(gridx, gridy, indexing=\"ij\")\n        gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute(\n            2, 1, 0\n        )\n        vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10\n\n        # Step4: initial traj for neural arap\n        T_series = torch.linspace(0, 5, T).reshape(1, T, 1, 1).cuda()  # 1 T 1 1\n        T_series = T_series.repeat(B, 1, N, 1)\n        # get the 3d traj in the camera coordinates\n        intr_init = self.intrs[:, queries[0, :, 0].long()]\n        Traj_series = pix2cam(queries[:, :, None, 1:].double(), intr_init.double())\n        # torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1\n        Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float()\n        Traj_series = torch.cat([T_series, Traj_series], dim=-1)\n        # get the indicator for the neural arap\n        Traj_mask = -1e2 * torch.ones_like(T_series)\n        Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1)\n\n        return (\n            rgbds,\n            first_positive_inds,\n            first_positive_sorted_inds,\n            sort_inds, inv_sort_inds,\n            track_mask, gridxy, coords_init[..., sort_inds, :].clone(),\n            vis_init, Traj_series[..., sort_inds, :].clone()\n        )\n\n    def sample_trifeat(self, t,\n                       coords,\n                       featMapxy,\n                       featMapyz,\n                       featMapxz):\n        \"\"\"\n        Sample the features from the 5D triplane feature map 3*(B S C H W)\n        Args:\n            t: the time index\n            coords: the coordinates of the points B S N 3\n            featMapxy: the feature map B S C Hx Wy\n            featMapyz: the feature map B S C Hy Wz\n            featMapxz: the feature map B S C Hx Wz\n        \"\"\"\n        # get xy_t yz_t xz_t\n        queried_t = t.reshape(1, 1, -1, 1)\n        xy_t = torch.cat(\n            [queried_t, coords[..., [0, 1]]],\n            dim=-1\n        )\n        yz_t = torch.cat(\n            [queried_t, coords[..., [1, 2]]],\n            dim=-1\n        )\n        xz_t = torch.cat(\n            [queried_t, coords[..., [0, 2]]],\n            dim=-1\n        )\n        featxy_init = sample_features5d(featMapxy, xy_t)\n\n        featyz_init = sample_features5d(featMapyz, yz_t)\n        featxz_init = sample_features5d(featMapxz, xz_t)\n\n        featxy_init = featxy_init.repeat(1, self.S, 1, 1)\n        featyz_init = featyz_init.repeat(1, self.S, 1, 1)\n        featxz_init = featxz_init.repeat(1, self.S, 1, 1)\n\n        return featxy_init, featyz_init, featxz_init\n\n    def neural_arap(self, coords, Traj_arap, intrs_S, T_mark):\n        \"\"\" calculate the ARAP embedding and offset\n        Args:\n            coords: the coordinates of the current points   1 S N' 3\n            Traj_arap: the trajectory of the points   1 T N' 5\n            intrs_S: the camera intrinsics B S 3 3\n\n        \"\"\"\n        coords_out = coords.clone()\n        coords_out[..., :2] *= float(self.stride)\n        coords_out[..., 2] = coords_out[..., 2] / self.Dz\n        coords_out[..., 2] = coords_out[..., 2] * (self.d_far - self.d_near) + self.d_near\n        intrs_S = intrs_S[:, :, None, ...].repeat(1, 1, coords_out.shape[2], 1, 1)\n        B, S, N, D = coords_out.shape\n        if S != intrs_S.shape[1]:\n            intrs_S = torch.cat(\n                [intrs_S, intrs_S[:, -1:].repeat(1, S - intrs_S.shape[1], 1, 1, 1)], dim=1\n            )\n            T_mark = torch.cat(\n                [T_mark, T_mark[:, -1:].repeat(1, S - T_mark.shape[1], 1)], dim=1\n            )\n        xyz_ = pix2cam(coords_out.double(), intrs_S.double()[:, :, 0])\n        xyz_ = xyz_.float()\n        xyz_embed = torch.cat([T_mark[..., None], xyz_,\n                               torch.zeros_like(T_mark[..., None])], dim=-1)\n\n        xyz_embed = self.embed_traj(xyz_embed)\n        Traj_arap_embed = self.embed_traj(Traj_arap)\n        d_xyz, traj_feat = self.arapFormer(xyz_embed, Traj_arap_embed)\n        # update in camera coordinate\n        xyz_ = xyz_ + d_xyz.clamp(-5, 5)\n        # project back to the image plane\n        coords_out = cam2pix(xyz_.double(), intrs_S[:, :, 0].double()).float()\n        # resize back\n        coords_out[..., :2] /= float(self.stride)\n        coords_out[..., 2] = (coords_out[..., 2] - self.d_near) / (self.d_far - self.d_near)\n        coords_out[..., 2] *= self.Dz\n\n        return xyz_, coords_out, traj_feat\n\n    def gradient_arap(self, coords, aff_avg=None, aff_std=None, aff_f_sg=None,\n                      iter=0, iter_num=4, neigh_idx=None, intr=None, msk_track=None):\n        with torch.enable_grad():\n            coords.requires_grad_(True)\n            y = self.ARAP_ln(coords, aff_f_sg=aff_f_sg, neigh_idx=neigh_idx,\n                             iter=iter, iter_num=iter_num, intr=intr, msk_track=msk_track)\n            d_output = torch.ones_like(y, requires_grad=False, device=y.device)\n            gradients = torch.autograd.grad(\n                outputs=y,\n                inputs=coords,\n                grad_outputs=d_output,\n                create_graph=True,\n                retain_graph=True,\n                only_inputs=True, allow_unused=True)[0]\n\n        return gradients.detach()\n\n    def forward_iteration(\n            self,\n            fmapXY,\n            fmapYZ,\n            fmapXZ,\n            coords_init,\n            feat_init=None,\n            vis_init=None,\n            track_mask=None,\n            iters=4,\n            intrs_S=None,\n    ):\n        B, S_init, N, D = coords_init.shape\n        assert D == 3\n        assert B == 1\n        B, S, __, H8, W8 = fmapXY.shape\n        device = fmapXY.device\n\n        if S_init < S:\n            coords = torch.cat(\n                [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)],\n                dim=1\n            )\n            vis_init = torch.cat(\n                [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1\n            )\n            intrs_S = torch.cat(\n                [intrs_S, intrs_S[:, -1].repeat(1, S - S_init, 1, 1)], dim=1\n            )\n        else:\n            coords = coords_init.clone()\n\n        fcorr_fnXY = CorrBlock(\n            fmapXY, num_levels=self.corr_levels, radius=self.corr_radius\n        )\n        fcorr_fnYZ = CorrBlock(\n            fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius\n        )\n        fcorr_fnXZ = CorrBlock(\n            fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius\n        )\n\n        ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1)\n        ffeats = [f.squeeze(-1) for f in ffeats]\n\n        times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)\n        pos_embed = sample_pos_embed(\n            grid_size=(H8, W8),\n            embed_dim=456,\n            coords=coords[..., :2],\n        )\n        pos_embed = rearrange(pos_embed, \"b e n -> (b n) e\").unsqueeze(1)\n\n        times_embed = (\n            torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]\n            .repeat(B, 1, 1)\n            .float()\n            .to(device)\n        )\n        coord_predictions = []\n        attn_predictions = []\n        Rot_ln = 0\n        support_feat = self.support_features\n\n        for __ in range(iters):\n            coords = coords.detach()\n            # if self.args.if_ARAP == True:\n            #     # refine the track with arap\n            #     xyz_pred, coords, flows_cat0 = self.neural_arap(coords.detach(),\n            #                                                    Traj_arap.detach(),\n            #                                                    intrs_S, T_mark)\n            fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2])\n            fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1, 2]])\n            fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0, 2]])\n            # fcorrs = fcorrsXY\n            fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ\n            LRR = fcorrs.shape[3]\n            fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)\n\n            flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)\n            flows_cat = get_3d_embedding(flows_, 64, cat_coords=True)\n            flows_cat = self.zeroMLPflow(flows_cat)\n\n            ffeats_xy = ffeats[0].permute(0,\n                                          2, 1, 3).reshape(B * N, S, self.latent_dim)\n            ffeats_yz = ffeats[1].permute(0,\n                                          2, 1, 3).reshape(B * N, S, self.latent_dim)\n            ffeats_xz = ffeats[2].permute(0,\n                                          2, 1, 3).reshape(B * N, S, self.latent_dim)\n            ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz\n\n            if track_mask.shape[1] < vis_init.shape[1]:\n                track_mask = torch.cat(\n                    [\n                        track_mask,\n                        torch.zeros_like(track_mask[:, 0]).repeat(\n                            1, vis_init.shape[1] - track_mask.shape[1], 1, 1\n                        ),\n                    ],\n                    dim=1,\n                )\n            concat = (\n                torch.cat([track_mask, vis_init], dim=2)\n                .permute(0, 2, 1, 3)\n                .reshape(B * N, S, 2)\n            )\n\n            transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)\n\n            if transformer_input.shape[-1] < pos_embed.shape[-1]:\n                # padding the transformer_input to the same dimension as pos_embed\n                transformer_input = F.pad(\n                    transformer_input, (0, pos_embed.shape[-1] - transformer_input.shape[-1]),\n                    \"constant\", 0\n                )\n\n            x = transformer_input + pos_embed + times_embed\n            x = rearrange(x, \"(b n) t d -> b n t d\", b=B)\n\n            delta, delta_se3F = self.updateformer(x, support_feat)\n            support_feat = support_feat + delta_se3F[0] / 100\n            delta = rearrange(delta, \" b n t d -> (b n) t d\")\n            d_coord = delta[:, :, :3]\n            d_feats = delta[:, :, 3:]\n\n            ffeats_xy = self.ffeat_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xy.reshape(-1,\n                                                                                                             self.latent_dim)\n            ffeats_yz = self.ffeatyz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_yz.reshape(-1,\n                                                                                                               self.latent_dim)\n            ffeats_xz = self.ffeatxz_updater(self.norm(d_feats.view(-1, self.latent_dim))) + ffeats_xz.reshape(-1,\n                                                                                                               self.latent_dim)\n            ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(\n                0, 2, 1, 3\n            )  # B,S,N,C\n            ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(\n                0, 2, 1, 3\n            )  # B,S,N,C\n            ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(\n                0, 2, 1, 3\n            )  # B,S,N,C\n            coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3)\n            if torch.isnan(coords).any():\n                # import ipdb;\n                # ipdb.set_trace()\n                logging.error(\"nan in coords\")\n\n            coords_out = coords.clone()\n            coords_out[..., :2] *= float(self.stride)\n\n            coords_out[..., 2] = coords_out[..., 2] / self.Dz\n            coords_out[..., 2] = coords_out[..., 2] * (self.d_far - self.d_near) + self.d_near\n\n            coord_predictions.append(coords_out)\n\n        ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]\n        vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(\n            B, S, N\n        )\n        self.support_features = support_feat.detach()\n        return coord_predictions, attn_predictions, vis_e, feat_init, Rot_ln\n\n    def forward(self, rgbds, queries, iters=4, feat_init=None, is_train=False, intrs=None):\n        self.support_features = torch.zeros(100, 384).to(\"cuda\") + 0.1\n        self.is_train = is_train\n        B, T, C, H, W = rgbds.shape\n        # set the intrinsic or simply initialized\n        if intrs is None:\n            intrs = torch.from_numpy(np.array([[W, 0.0, W // 2],\n                                               [0.0, W, H // 2],\n                                               [0.0, 0.0, 1.0]]))\n            intrs = intrs[None,\n            None, ...].repeat(B, T, 1, 1).float().to(rgbds.device)\n        self.intrs = intrs\n\n        # prepare the input for tracking\n        (\n            rgbds,\n            first_positive_inds,\n            first_positive_sorted_inds, sort_inds,\n            inv_sort_inds, timestep_should_be_estimated_mask, gridxy,\n            coords_init, vis_init, Traj_arap\n        ) = self.prepare_track(rgbds.clone(), queries)\n        coords_init_ = coords_init.clone()\n        vis_init_ = vis_init[:, :, sort_inds].clone()\n\n        depth_all = rgbds[:, :, 3, ...]\n        d_near = self.d_near = depth_all[depth_all > 0.01].min().item()\n        d_far = self.d_far = depth_all[depth_all > 0.01].max().item()\n\n        B, N, __ = queries.shape\n        self.Dz = Dz = self.triplane_zres\n        w_idx_start = 0\n        p_idx_end = 0\n        p_idx_start = 0\n        fmaps_ = None\n        vis_predictions = []\n        coord_predictions = []\n        attn_predictions = []\n        p_idx_end_list = []\n        Rigid_ln_total = 0\n        while w_idx_start < T - self.S // 2:\n            curr_wind_points = torch.nonzero(\n                first_positive_sorted_inds < w_idx_start + self.S)\n            if curr_wind_points.shape[0] == 0:\n                w_idx_start = w_idx_start + self.S // 2\n                logging.info(f\"No points in window {w_idx_start}-{w_idx_start + self.S}; adding empty results to list\")\n                p_idx_end_list.append(torch.zeros((1,), dtype=torch.int64, device=first_positive_sorted_inds.device))\n                if is_train:\n                    vis_predictions.append(torch.zeros((B, self.S, 0), device=rgbds.device))\n                    coord_predictions.append(\n                        [torch.zeros((B, self.S, 0, 3), device=rgbds.device) for _ in range(iters)])\n                    attn_predictions.append([-1 for _ in range(iters)])\n                continue\n            p_idx_end = curr_wind_points[-1] + 1\n            p_idx_end_list.append(p_idx_end)\n            # the T may not be divided by self.S\n            rgbds_seq = rgbds[:, w_idx_start:w_idx_start + self.S].clone()\n            S = S_local = rgbds_seq.shape[1]\n            if S < self.S:\n                rgbds_seq = torch.cat(\n                    [rgbds_seq,\n                     rgbds_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],\n                    dim=1,\n                )\n                S = rgbds_seq.shape[1]\n\n            rgbs_ = rgbds_seq.reshape(B * S, C, H, W)[:, :3]\n            depths = rgbds_seq.reshape(B * S, C, H, W)[:, 3:].clone()\n            # open the mask\n            # Traj_arap[:, w_idx_start:w_idx_start + self.S, :p_idx_end, -1] = 0\n            # step1: normalize the depth map\n\n            depths = (depths - d_near) / (d_far - d_near)\n            depths_dn = nn.functional.interpolate(\n                depths, scale_factor=1.0 / self.stride, mode=\"nearest\")\n            depths_dnG = depths_dn * Dz\n\n            # step2: normalize the coordinate\n            coords_init_[:, :, p_idx_start:p_idx_end, 2] = (\n                                                                   coords_init[:, :, p_idx_start:p_idx_end, 2] - d_near\n                                                           ) / (d_far - d_near)\n            coords_init_[:, :, p_idx_start:p_idx_end, 2] *= Dz\n\n            # efficient triplane splatting\n            gridxyz = torch.cat([gridxy[None, ...].repeat(\n                depths_dn.shape[0], 1, 1, 1), depths_dnG], dim=1)\n            Fxy2yz = gridxyz[:, [1, 2], ...] - gridxyz[:, :2]\n            Fxy2xz = gridxyz[:, [0, 2], ...] - gridxyz[:, :2]\n            gridxyz_nm = gridxyz.clone()\n            gridxyz_nm[:, 0, ...] = (gridxyz_nm[:, 0, ...] - gridxyz_nm[:, 0, ...].min()) / (\n                    gridxyz_nm[:, 0, ...].max() - gridxyz_nm[:, 0, ...].min())\n            gridxyz_nm[:, 1, ...] = (gridxyz_nm[:, 1, ...] - gridxyz_nm[:, 1, ...].min()) / (\n                    gridxyz_nm[:, 1, ...].max() - gridxyz_nm[:, 1, ...].min())\n            gridxyz_nm[:, 2, ...] = (gridxyz_nm[:, 2, ...] - gridxyz_nm[:, 2, ...].min()) / (\n                    gridxyz_nm[:, 2, ...].max() - gridxyz_nm[:, 2, ...].min())\n            gridxyz_nm = 2 * (gridxyz_nm - 0.5)\n            _, _, h4, w4 = gridxyz_nm.shape\n            gridxyz_nm = gridxyz_nm.permute(0, 2, 3, 1).reshape(S * h4 * w4, 3)\n            featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0, 3, 1, 2)\n            if fmaps_ is None:\n                fmaps_ = torch.cat([self.fnet(rgbs_), featPE], dim=1)\n                fmaps_ = self.embedConv(fmaps_)\n            else:\n                fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2:]), featPE[self.S // 2:]], dim=1)\n                fmaps_new = self.embedConv(fmaps_new)\n                fmaps_ = torch.cat(\n                    [fmaps_[self.S // 2:], fmaps_new], dim=0\n                )\n\n            fmapXY = fmaps_[:, :self.latent_dim].reshape(\n                B, S, self.latent_dim, H // self.stride, W // self.stride\n            )\n\n            fmapYZ = softsplat(fmapXY[0], Fxy2yz, None,\n                               strMode=\"avg\", tenoutH=self.Dz, tenoutW=H // self.stride)\n            fmapXZ = softsplat(fmapXY[0], Fxy2xz, None,\n                               strMode=\"avg\", tenoutH=self.Dz, tenoutW=W // self.stride)\n\n            fmapYZ = self.headyz(fmapYZ)[None, ...]\n            fmapXZ = self.headxz(fmapXZ)[None, ...]\n\n            if p_idx_end - p_idx_start > 0:\n                queried_t = (first_positive_sorted_inds[p_idx_start:p_idx_end]\n                             - w_idx_start)\n                (featxy_init,\n                 featyz_init,\n                 featxz_init) = self.sample_trifeat(\n                    t=queried_t, featMapxy=fmapXY,\n                    featMapyz=fmapYZ, featMapxz=fmapXZ,\n                    coords=coords_init_[:, :1, p_idx_start:p_idx_end]\n                )\n                # T, S, N, C, 3\n                feat_init_curr = torch.stack([featxy_init,\n                                              featyz_init, featxz_init], dim=-1)\n                feat_init = smart_cat(feat_init, feat_init_curr, dim=2)\n\n            if p_idx_start > 0:\n                # preprocess the coordinates of last windows\n                last_coords = coords[-1][:, self.S // 2:].clone()\n                last_coords[..., :2] /= float(self.stride)\n                last_coords[..., 2:] = (last_coords[..., 2:] - d_near) / (d_far - d_near)\n                last_coords[..., 2:] = last_coords[..., 2:] * Dz\n\n                coords_init_[:, : self.S // 2, :p_idx_start] = last_coords\n                coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[\n                                                              :, -1\n                                                              ].repeat(1, self.S // 2, 1, 1)\n\n                last_vis = vis[:, self.S // 2:].unsqueeze(-1)\n                vis_init_[:, : self.S // 2, :p_idx_start] = last_vis\n                vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat(\n                    1, self.S // 2, 1, 1\n                )\n\n            coords, attns, vis, __, Rigid_ln = self.forward_iteration(\n                fmapXY=fmapXY,\n                fmapYZ=fmapYZ,\n                fmapXZ=fmapXZ,\n                coords_init=coords_init_[:, :, :p_idx_end],\n                feat_init=feat_init[:, :, :p_idx_end],\n                vis_init=vis_init_[:, :, :p_idx_end],\n                track_mask=timestep_should_be_estimated_mask[:, w_idx_start: w_idx_start + self.S, :p_idx_end],\n                iters=iters,\n                intrs_S=self.intrs[:, w_idx_start: w_idx_start + self.S],\n            )\n\n            Rigid_ln_total += Rigid_ln\n\n            if is_train:\n                vis_predictions.append(vis[:, :S_local])\n                coord_predictions.append([coord[:, :S_local] for coord in coords])\n                attn_predictions.append(attns)\n\n            self.traj_e[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords[-1][:, :S_local]\n            self.vis_e[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = vis[:, :S_local]\n\n            timestep_should_be_estimated_mask[:, : w_idx_start + self.S, :p_idx_end] = 0.0\n            w_idx_start = w_idx_start + self.S // 2\n\n            p_idx_start = p_idx_end\n\n        self.traj_e = self.traj_e[:, :, inv_sort_inds]\n        self.vis_e = self.vis_e[:, :, inv_sort_inds]\n\n        self.vis_e = torch.sigmoid(self.vis_e)\n        train_data = (\n            (vis_predictions, coord_predictions, attn_predictions,\n             p_idx_end_list, sort_inds, Rigid_ln_total)\n        )\n        if self.is_train:\n            return self.traj_e, feat_init, self.vis_e, train_data\n        else:\n            return self.traj_e, feat_init, self.vis_e\n\n\nclass SpaTrackerMultiViewAdapter(nn.Module):\n    def __init__(self, **kwargs):\n        super(SpaTrackerMultiViewAdapter, self).__init__()\n        self.spatracker = SpaTracker(**kwargs)\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points,\n            intrs,\n            extrs,\n            iters=4,\n            feat_init=None,\n            is_train=False,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            query_points_view=None,\n            **kwargs,\n    ):\n        batch_size, num_views, num_frames, _, height, width = rgbs.shape\n        _, num_points, _ = query_points.shape\n\n        depths = depths.clamp(max=36.0)\n\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n        assert query_points.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        if feat_init is not None:\n            raise NotImplementedError(\"feat_init is not supported yet\")\n\n        # Project the queries to each view\n        query_points_t = query_points[:, :, :1].long()\n        query_points_xyz_worldspace = query_points[:, :, 1:]\n\n        query_points_xy_pixelspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 2))\n        query_points_z_cameraspace_per_view = query_points.new_zeros((batch_size, num_views, num_points, 1))\n        for batch_idx in range(batch_size):\n            for t in query_points_t[batch_idx].unique():\n                query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t\n                point_3d_world = query_points_xyz_worldspace[batch_idx][query_points_t_mask]\n\n                # World to camera space\n                point_4d_world_homo = torch.cat(\n                    [point_3d_world, point_3d_world.new_ones(point_3d_world[..., :1].shape)], -1)\n                point_3d_camera = torch.einsum('Aij,Bj->ABi', extrs[batch_idx, :, t, :, :], point_4d_world_homo[:, :])\n\n                # Camera to pixel space\n                point_2d_pixel_homo = torch.einsum('Aij,ABj->ABi', intrs[batch_idx, :, t, :, :], point_3d_camera[:, :])\n                point_2d_pixel = point_2d_pixel_homo[..., :2] / point_2d_pixel_homo[..., 2:]\n\n                query_points_xy_pixelspace_per_view[batch_idx, :, query_points_t_mask] = point_2d_pixel\n                query_points_z_cameraspace_per_view[batch_idx, :, query_points_t_mask] = point_3d_camera[..., -1:]\n\n        # Estimate occlusion mask in each view based on depth maps\n        query_points_depth_in_view = query_points.new_zeros((batch_size, num_views, num_points, 1))\n        for batch_idx in range(batch_size):\n            for view_idx in range(num_views):\n                for t in query_points_t[batch_idx].unique():\n                    query_points_t_mask = query_points_t[batch_idx].squeeze(-1) == t\n                    interpolated_depth = bilinear_sample2d(\n                        im=depths[batch_idx, view_idx, t][None],\n                        x=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 0][None],\n                        y=query_points_xy_pixelspace_per_view[batch_idx, view_idx, query_points_t_mask, 1][None],\n                    )[0].permute(1, 0).type(query_points.dtype)\n                    query_points_depth_in_view[batch_idx, view_idx, query_points_t_mask] = interpolated_depth\n\n        query_points_depth_in_view_masked = query_points_depth_in_view.clone()\n        query_points_outside_of_view_box = (\n                (query_points_xy_pixelspace_per_view[..., 0] < 0) |\n                (query_points_xy_pixelspace_per_view[..., 0] >= width) |\n                (query_points_xy_pixelspace_per_view[..., 1] < 0) |\n                (query_points_xy_pixelspace_per_view[..., 1] >= height) |\n                (query_points_z_cameraspace_per_view[..., 0] < 0)\n        )\n        if query_points_outside_of_view_box.all(1).any():\n            warnings.warn(f\"There are some query points that are outside of the frame of every view: \"\n                          f\"{query_points_xy_pixelspace_per_view[query_points_outside_of_view_box.all(1)[:, None, :].repeat(1, num_views, 1)].reshape(num_views, -1, 2).permute(1, 0, 2)}\")\n        query_points_depth_in_view_masked[query_points_outside_of_view_box] = -1e4\n        # query_points_occluded_by_depthmap = (query_points_depth_in_view * 1.1 < query_points_z_cameraspace_per_view)\n        # query_points_depth_in_view_masked[query_points_occluded_by_depthmap] = -1e3\n\n        query_points_best_visibility_view = (\n                query_points_depth_in_view_masked - query_points_z_cameraspace_per_view).argmax(1)\n        query_points_best_visibility_view = query_points_best_visibility_view.squeeze(-1)\n\n        if query_points_view is not None:\n            query_points_best_visibility_view = query_points_view\n            logging.info(f\"Using the provided query_points_view instead of the estimated one\")\n\n        assert batch_size == 1, \"Batch size > 1 is not supported yet\"\n        batch_idx = 0\n\n        results = {}\n\n        # Call the SpaTracker for each view\n        traj_e_per_view = {}\n        feat_init_per_view = {}\n        vis_e_per_view = {}\n        train_data_per_view = {}\n        for view_idx in range(num_views):\n            track_mask = query_points_best_visibility_view[batch_idx] == view_idx\n            if track_mask.sum() == 0:\n                continue\n\n            view_query_points = torch.concat([\n                query_points_t[batch_idx, :, :][track_mask],\n                query_points_xy_pixelspace_per_view[batch_idx, view_idx, :, :][track_mask],\n                query_points_z_cameraspace_per_view[batch_idx, view_idx, :, :][track_mask],\n            ], dim=-1)\n\n            view_rgbds = torch.concat([rgbs[batch_idx, view_idx], depths[batch_idx, view_idx]], dim=1)\n            view_intrs = intrs[batch_idx, view_idx]\n            view_extrs = extrs[batch_idx, view_idx]\n\n            output_tuple = self.spatracker(\n                rgbds=view_rgbds[None],\n                queries=view_query_points[None],\n                intrs=view_intrs[None],\n                iters=iters,\n                feat_init=None,\n                is_train=is_train,\n            )\n            if is_train:\n                view_traj_e, view_feat_init, view_vis_e, view_train_data = output_tuple\n            else:\n                view_traj_e, view_feat_init, view_vis_e = output_tuple\n\n            # Project points to the world space\n            intrs_inv = torch.inverse(view_intrs.float())\n            view_extrs_square = torch.eye(4).to(view_extrs.device)[None].repeat(num_frames, 1, 1)\n            view_extrs_square[:, :3, :] = view_extrs\n            extrs_inv = torch.inverse(view_extrs_square.float())\n            view_traj_e = pixel_xy_and_camera_z_to_world_space(\n                pixel_xy=view_traj_e[0, ..., :-1].float(),\n                camera_z=view_traj_e[0, ..., -1:].float(),\n                intrs_inv=intrs_inv,\n                extrs_inv=extrs_inv,\n            )[None]\n            if is_train:\n                num_windows = len(view_train_data[1])\n                num_iterations = len(view_train_data[1][0])\n                coord_predictions = view_train_data[1]\n                window_start_t = 0\n                while window_start_t < num_frames - self.spatracker.S // 2:\n                    window_idx = window_start_t // (self.spatracker.S // 2)\n                    for iteration_idx in range(num_iterations):\n                        coord_predictions[window_idx][iteration_idx] = pixel_xy_and_camera_z_to_world_space(\n                            pixel_xy=coord_predictions[window_idx][iteration_idx][0, ..., :-1].float(),\n                            camera_z=coord_predictions[window_idx][iteration_idx][0, ..., -1:].float(),\n                            intrs_inv=intrs_inv[window_start_t:window_start_t + self.spatracker.S],\n                            extrs_inv=extrs_inv[window_start_t:window_start_t + self.spatracker.S],\n                        )[None]\n                    window_start_t = window_start_t + (self.spatracker.S // 2)\n                assert window_idx == num_windows - 1, \"The last window should be the last one\"\n                assert view_train_data[1] == coord_predictions, \"The view_train_data[1] should be updated in-place\"\n\n            # Set the trajectory to (0,0,0) for the timesteps before the query timestep\n            for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)[track_mask]):\n                view_traj_e[0, :t, point_idx, :] = 0.0\n\n            traj_e_per_view[view_idx] = view_traj_e\n            feat_init_per_view[view_idx] = view_feat_init\n            vis_e_per_view[view_idx] = view_vis_e\n            if is_train:\n                train_data_per_view[view_idx] = view_train_data\n\n        # Merging the results from all views\n        views_to_keep = list(traj_e_per_view.keys())\n        traj_e = torch.cat([traj_e_per_view[view_idx] for view_idx in views_to_keep], dim=2)\n        vis_e = torch.cat([vis_e_per_view[view_idx] for view_idx in views_to_keep], dim=2)\n        feat_init = torch.cat([feat_init_per_view[view_idx] for view_idx in views_to_keep], dim=2)\n\n        # Sort the traj_e and vis_e based on the original indices, since concatenating the results from all views\n        # will first put the results from the first view, then the results from the second view, and so on.\n        # But we want to keep the trajectories order to match the original query points order.\n        sort_inds = []\n        for view_idx in views_to_keep:\n            track_mask = query_points_best_visibility_view[batch_idx] == view_idx\n            if track_mask.sum() == 0:\n                continue\n            global_indices = torch.nonzero(track_mask).squeeze(-1)\n            sort_inds += [global_indices]\n        sort_inds = torch.cat(sort_inds, dim=0)\n        inv_sort_inds = torch.argsort(sort_inds, dim=0)\n\n        # Use the inv_sort_inds to sort the traj_e and vis_e\n        traj_e = traj_e[:, :, inv_sort_inds]\n        vis_e = vis_e[:, :, inv_sort_inds]\n        feat_init = None  # Not supported yet, correct sorting needs to be implemented\n\n        # Delete the intermediate variables to avoid confusion with the later variables\n        del sort_inds, inv_sort_inds\n\n        # # Sanity check that the sorted traj_e have about similar values for the query points\n        # # The forward pass is expected to tweak the values a bit, but they would probably stay close\n        # pred_xyz_for_query = traj_e[0][query_points_t[batch_idx].squeeze(-1), torch.arange(num_points)]\n        # pred_xyz_for_query = pred_xyz_for_query.type(query_points_xyz_worldspace.dtype)\n        # assert torch.allclose(pred_xyz_for_query, query_points_xyz_worldspace[batch_idx], atol=1)\n        # # But, an untrained model might not be able to predict the query points exactly\n\n        # # Also check that the query points are visible\n        # pred_visibility_for_query = vis_e[0][query_points_t[batch_idx].squeeze(-1), torch.arange(num_points)]\n        # assert torch.all(pred_visibility_for_query > 0.5)\n        # # But, for some points the model might predict the query points to be occluded\n\n        if not is_train:\n            if torch.isnan(traj_e).any():\n                warnings.warn(\n                    f\"Found {torch.isnan(traj_e).sum()}/{traj_e.numel()} NaN values in traj_e. Setting them to 0.\")\n                traj_e[traj_e.isnan()] = 0\n            if torch.isnan(vis_e).any():\n                warnings.warn(\n                    f\"Found {torch.isnan(vis_e).sum()}/{vis_e.numel()} NaN values in visibilities. Setting them to 1.\")\n                vis_e[vis_e.isnan()] = 1\n\n        # Save to results\n        results[\"traj_e\"] = traj_e\n        results[\"feat_init\"] = feat_init\n        results[\"vis_e\"] = vis_e\n\n        # If training mode, we need to merge the results from all views.\n        # Those merged results are used in the backward pass to compute the loss.\n        # train_data is a tuple of (vis_pred, coord_pred, attn_pred, p_idx_end_list, sort_inds, Rigid_ln_total)\n        if is_train:\n            # SpaTracker is using sliding windows, and for each window, it is using multiple iterations.\n            num_windows = len(train_data_per_view[views_to_keep[0]][0])\n            num_iterations = len(train_data_per_view[views_to_keep[0]][1][0])\n\n            sort_inds = []\n            vis_predictions = [[] for _ in range(num_windows)]\n            coord_predictions = [[[] for _ in range(num_iterations)] for _ in range(num_windows)]\n            for window_idx in range(num_windows):\n                for view_idx in views_to_keep:\n                    # What points will be tracked in this view\n                    track_mask = query_points_best_visibility_view[batch_idx] == view_idx\n                    if track_mask.sum() == 0:\n                        # This view does not track any points at all\n                        continue\n\n                    # Get the indices of points that appeared in this window (from the points tracked in this view)\n                    try:\n                        start_idx = 0 if window_idx == 0 else train_data_per_view[view_idx][3][window_idx - 1].item()\n                        end_idx = train_data_per_view[view_idx][3][window_idx].item()\n                        if end_idx == 0:\n                            # No points from this view were tracked in this window\n                            continue\n                    except Exception as e:\n                        logging.error(f\"Error: {e}\")\n                        logging.error(f\"view_idx: {view_idx}, window_idx: {window_idx}\")\n                        logging.error(f\"train_data_per_view[view_idx][3]: {train_data_per_view[view_idx][3]}\")\n                        raise e\n\n                    # Convert the view-specific sorted indices to \"global\" indices\n                    # that say which trajectory/query the point originally belonged to\n                    indices_in_view = train_data_per_view[view_idx][4][start_idx:end_idx]\n                    global_indices = torch.nonzero(track_mask).squeeze(-1)[indices_in_view]\n\n                    # Sorted indices are saying how the original trajectories were reordered/sorted\n                    # in the return results. This is because in the forward passes, we want to group\n                    # the points that will appear in the same window together. The points that haven't\n                    # appeared in a window will not be used in the forward pass for that window.\n                    # For each new window, points can only be added, not removed, and they will be added\n                    # if they have just appeared in that window. Since we are merging the results from\n                    # all views, we will first take all the points that appeared in the first window from\n                    # all views, then all the points that appeared in the second window from all views,\n                    # and so on. This is why we do a for loop over the windows first, then over the views\n                    # and merge the indices in the next line:\n                    sort_inds.append(global_indices)\n                    # The indices are now sorted in the order that they will appear in the merged results.\n                    # This can be illustrated as follows:\n                    #   Final sorted indices for the merged results: [\n                    #     view 1 new points from window 1\n                    #     view 2 new points from window 1\n                    #     view ... new points from window 1\n                    #     view 1 new points from window 2\n                    #     view 2 new points from window 2\n                    #     view ... new points from window 2\n                    #     ...\n                    #   ]\n\n                    # This also means that the results from each view need to be carefully merged to match\n                    # the expected ordering/sorting. To illustrate this, the merged results for the vis_predictions\n                    # and coord_predictions will look like this:\n                    #   Window 1 results: [\n                    #     view 1 new points from window 1\n                    #     view 2 new points from window 1\n                    #     view ... new points from window 1\n                    #   ]\n                    #   Window 2 results: [\n                    #     view 1 new points from window 1\n                    #     view 2 new points from window 1\n                    #     view ... new points from window 1\n                    #     view 1 new points from window 2\n                    #     view 2 new points from window 2\n                    #     view ... new points from window 2\n                    #   ]\n                    #   Window ...\n\n                    # Below we will merge the results from all views for each window as illustrated above\n                    for window_idx_inner in range(num_windows):\n                        vis_predictions[window_idx_inner].append(\n                            train_data_per_view[view_idx][0][window_idx_inner][:, :, start_idx:end_idx]\n                        )\n                        for iteration_idx in range(num_iterations):\n                            coord_predictions[window_idx_inner][iteration_idx].append(\n                                train_data_per_view[view_idx][1][window_idx_inner][iteration_idx][\n                                :, :, start_idx:end_idx, :]\n                            )\n\n            # Concatenate the merged results correctly\n            sort_inds = torch.cat(sort_inds, dim=0)\n            vis_predictions = [\n                torch.cat(vis_predictions[window_idx], dim=2)\n                for window_idx in range(num_windows)\n            ]\n            coord_predictions = [\n                [\n                    torch.cat(coord_predictions[window_idx][iteration_idx], dim=2)\n                    for iteration_idx in range(num_iterations)\n                ]\n                for window_idx in range(num_windows)\n            ]\n\n            # Compute the p_idx_end_list for each window, it is the sum of the number of points\n            # that appeared in each view for that window as this is the way we have merged the results.\n            p_idx_end_list = [\n                torch.stack([\n                    train_data_per_view[view_idx][3][window_idx]\n                    for view_idx in views_to_keep\n                ], dim=1).sum(dim=1)\n                for window_idx in range(num_windows)\n            ]\n\n            # Compute the attn_predictions and Rigid_ln_total\n            attn_predictions = None  # Not supported yet\n            Rigid_ln_total = None  # Not supported yet\n\n            # Sanity check that using the computed sort_inds gives the same results as the merged traj_e and vis_e\n            traj_e_reproduced = traj_e.new_zeros(traj_e.shape)\n            vis_e_reproduced = vis_e.new_zeros(vis_e.shape)\n            window_start_t = 0\n            while window_start_t < num_frames - self.spatracker.S // 2:\n                window_idx = window_start_t // (self.spatracker.S // 2)\n                p_idx_end = p_idx_end_list[window_idx]\n                if p_idx_end == 0:\n                    continue\n                wind_coords = coord_predictions[window_idx][-1]\n                wind_vis = vis_predictions[window_idx]\n                traj_e_reproduced[:, window_start_t:window_start_t + self.spatracker.S, :p_idx_end] = wind_coords\n                vis_e_reproduced[:, window_start_t:window_start_t + self.spatracker.S, :p_idx_end] = wind_vis\n                window_start_t = window_start_t + (self.spatracker.S // 2)\n            inv_sort_inds = torch.argsort(sort_inds, dim=0)\n            traj_e_reproduced = traj_e_reproduced[:, :, inv_sort_inds]\n            vis_e_reproduced = torch.sigmoid(vis_e_reproduced[:, :, inv_sort_inds])\n\n            # Set the trajectory to (0,0,0) for the timesteps before the query timestep\n            for point_idx, t in enumerate(query_points_t[batch_idx, :, :].squeeze(-1)):\n                traj_e_reproduced[0, :t, point_idx, :] = 0.0\n\n            assert torch.allclose(traj_e, traj_e_reproduced, atol=1e-3)\n            assert torch.allclose(vis_e, vis_e_reproduced, atol=1e-3)\n\n            # Save to results\n            results[\"train_data\"] = {\n                \"vis_predictions\": vis_predictions,\n                \"coord_predictions\": coord_predictions,\n                \"attn_predictions\": attn_predictions,\n                \"p_idx_end_list\": p_idx_end_list,\n                \"sort_inds\": sort_inds,\n                \"Rigid_ln_total\": Rigid_ln_total,\n            }\n\n        return results\n"
  },
  {
    "path": "mvtracker/models/core/spatracker/spatracker_multiview.py",
    "content": "import logging\nimport os\nimport warnings\n\nimport cv2\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom matplotlib import pyplot as plt\nfrom torch import nn as nn\n\nfrom mvtracker.models.core.embeddings import Embedder_Fourier, get_3d_sincos_pos_embed_from_grid, \\\n    get_1d_sincos_pos_embed_from_grid, get_3d_embedding\nfrom mvtracker.models.core.model_utils import sample_features5d, smart_cat\nfrom mvtracker.models.core.spatracker.blocks import BasicEncoder, EUpdateFormer, CorrBlock\nfrom mvtracker.models.core.spatracker.softsplat import softsplat\nfrom mvtracker.models.core.spatracker.spatracker_monocular import sample_pos_embed\nfrom mvtracker.utils.basic import to_homogeneous, from_homogeneous, time_now\n\n\nclass MultiViewSpaTracker(nn.Module):\n    \"\"\"\n    Multi-view Spatial Tracker: A 3D Multi-View Tracker with\n    Transformer-based Iterative Flow Updates. This version computes\n    local correlation in a global triplane space that is aligned with\n    the world coordinate planes. However, this leaves most of the triplane\n    space empty since it is difficult to create one plane that covers all the\n    relevant areas of interest.\n    \"\"\"\n\n    def __init__(\n            self,\n            sliding_window_len=8,\n            stride=8,\n            add_space_attn=True,\n            use_3d_pos_embed=True,\n            remove_zeromlpflow=True,\n            concat_triplane_features=True,\n            num_heads=8,\n            hidden_size=384,\n            space_depth=12,\n            time_depth=12,\n            fmaps_dim=128,\n            triplane_xres=128,\n            triplane_yres=128,\n            triplane_zres=128,\n    ):\n        super(MultiViewSpaTracker, self).__init__()\n\n        self.S = sliding_window_len\n        self.stride = stride\n        self.hidden_dim = 256\n        self.latent_dim = fmaps_dim\n        self.flow_embed_dim = 64\n        self.b_latent_dim = self.latent_dim // 3\n        self.corr_levels = 4\n        self.corr_radius = 3\n        self.add_space_attn = add_space_attn\n        self.use_3d_pos_embed = use_3d_pos_embed\n        self.remove_zeromlpflow = remove_zeromlpflow\n        self.concat_triplane_features = concat_triplane_features\n        self.updateformer_input_dim = (\n\n            # The positional encoding of the 3D flow from t=i to t=0\n                + (self.flow_embed_dim + 1) * (3 if self.remove_zeromlpflow else 2)\n\n                # The correlation features (LRR) for the three planes (xy, yz, xz), concatenated\n                + 196 * (3 if self.concat_triplane_features else 1)\n\n                # The features of the tracked points, one for each of the three planes\n                + self.latent_dim * (3 if self.concat_triplane_features else 1)\n\n                # The visibility mask\n                + 1\n\n                # The whether-the-point-is-tracked mask\n                + 1\n\n        )\n        self.triplane_xres = triplane_xres\n        self.triplane_yres = triplane_yres\n        self.triplane_zres = triplane_zres\n\n        # Feature encoder\n        self.fnet = BasicEncoder(\n            input_dim=3,\n            output_dim=self.latent_dim,\n            norm_fn=\"instance\",\n            dropout=0,\n            stride=stride,\n            Embed3D=False,\n        )\n\n        # Convolutional heads for the tri-plane features\n        self.headxy = nn.Sequential(\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n        )\n        self.headyz = nn.Sequential(\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n        )\n        self.headxz = nn.Sequential(\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1),\n        )\n\n        # Transformer for the iterative flow updates\n        self.support_features = torch.zeros(100, 384).to(\"cuda\") + 0.1\n        self.updateformer = EUpdateFormer(\n            space_depth=space_depth,\n            time_depth=time_depth,\n            input_dim=self.updateformer_input_dim,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            output_dim=3 + self.latent_dim * 3,\n            mlp_ratio=4.0,\n            add_space_attn=add_space_attn,\n            flash=True,\n        )\n\n        # Updater of the features of the tracked points\n        self.norm_xy = nn.GroupNorm(1, self.latent_dim)\n        self.norm_yz = nn.GroupNorm(1, self.latent_dim)\n        self.norm_xz = nn.GroupNorm(1, self.latent_dim)\n        self.ffeatxy_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n        self.ffeatyz_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n        self.ffeatxz_updater = nn.Sequential(\n            nn.Linear(self.latent_dim, self.latent_dim),\n            nn.GELU(),\n        )\n\n        # Embedders\n        self.embed_traj = Embedder_Fourier(input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True)\n        self.embed3d = Embedder_Fourier(input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True)\n        self.embedConv = nn.Conv2d(self.latent_dim + 63, self.latent_dim, 3, padding=1)\n\n        # Predictor of the visibility of the tracked points\n        self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim * (3 if self.concat_triplane_features else 1), 1))\n        self.zeroMLPflow = nn.Linear(195, 130)\n\n    def sample_trifeat(self, t, coords, featMapxy, featMapyz, featMapxz):\n        \"\"\"\n        Sample the features from the 5D triplane feature map 3*(B S C H W)\n        Args:\n            t: the time index\n            coords: the coordinates of the points B S N 3\n            featMapxy: the feature map B S C Hx Wy\n            featMapyz: the feature map B S C Hy Wz\n            featMapxz: the feature map B S C Hx Wz\n        \"\"\"\n        # get xy_t yz_t xz_t\n        queried_t = t.reshape(1, 1, -1, 1)\n        xy_t = torch.cat(\n            [queried_t, coords[..., [0, 1]]],\n            dim=-1\n        )\n        yz_t = torch.cat(\n            [queried_t, coords[..., [1, 2]]],\n            dim=-1\n        )\n        xz_t = torch.cat(\n            [queried_t, coords[..., [0, 2]]],\n            dim=-1\n        )\n        featxy_init = sample_features5d(featMapxy, xy_t)\n\n        featyz_init = sample_features5d(featMapyz, yz_t)\n        featxz_init = sample_features5d(featMapxz, xz_t)\n\n        featxy_init = featxy_init.repeat(1, self.S, 1, 1)\n        featyz_init = featyz_init.repeat(1, self.S, 1, 1)\n        featxz_init = featxz_init.repeat(1, self.S, 1, 1)\n\n        return featxy_init, featyz_init, featxz_init\n\n    def forward_iteration(\n            self,\n            fmapXY,\n            fmapYZ,\n            fmapXZ,\n            coords_init,\n            vis_init,\n            track_mask,\n            iters=4,\n            feat_init=None,\n    ):\n        N = coords_init.shape[2]\n        B, S, fmap_dim, triplane_H, triplane_W = fmapXY.shape\n        triplane_D = fmapXZ.shape[-2]\n        device = fmapXY.device\n\n        if coords_init.shape[1] < S:\n            coords = torch.cat([coords_init, coords_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1)\n            vis_init = torch.cat([vis_init, vis_init[:, -1].repeat(1, S - coords_init.shape[1], 1, 1)], dim=1)\n        else:\n            coords = coords_init.clone()\n\n        assert B == 1\n        assert fmapXY.shape == (B, S, fmap_dim, triplane_H, triplane_W)\n        assert fmapYZ.shape == (B, S, fmap_dim, triplane_D, triplane_H)\n        assert fmapXZ.shape == (B, S, fmap_dim, triplane_D, triplane_W)\n        assert coords.shape == (B, S, N, 3)\n        assert vis_init.shape == (B, S, N, 1)\n        assert track_mask.shape == (B, S, N, 1)\n        assert feat_init is None or feat_init.shape == (B, S, N, self.latent_dim, 3)\n\n        fcorr_fnXY = CorrBlock(fmapXY, num_levels=self.corr_levels, radius=self.corr_radius)\n        fcorr_fnYZ = CorrBlock(fmapYZ, num_levels=self.corr_levels, radius=self.corr_radius)\n        fcorr_fnXZ = CorrBlock(fmapXZ, num_levels=self.corr_levels, radius=self.corr_radius)\n\n        ffeats = torch.split(feat_init.clone(), dim=-1, split_size_or_sections=1)\n        ffeats = [f.squeeze(-1) for f in ffeats]\n\n        grid_size = coords.new_tensor([triplane_H, triplane_W, triplane_D])\n        # @Single-view-difference:\n        #     Instead of computing 2D positional embeddings in the XY plane of the single-view triplane\n        #     (which is aligned with the monocular view used in the single-view SpatialTracker), I will\n        #     compute 3D positional embeddings in the 3D grid of the triplane. This could allow the model\n        #     to more easily learn the 3D spatial relationships between the points in the triplane.\n        # pos_embed = sample_pos_embed(\n        #     grid_size=(H8, W8),\n        #     embed_dim=456,\n        #     coords=coords[..., :2],\n        # )\n        embed_dim = self.updateformer_input_dim\n        if self.use_3d_pos_embed:\n            # Ours\n            if embed_dim % 3 != 0:\n                # Make sure that the embed_dim is divisible by 3\n                embed_dim += 3 - (embed_dim % 3)\n            pos_embed = get_3d_sincos_pos_embed_from_grid(\n                embed_dim=embed_dim,\n                # Normalize the coordinates so that the grid ranges over [-128,128]\n                grid=((coords[:, :1, ...] / grid_size) * 2 - 1) * 128,\n            ).float()[:, 0, ...].permute(0, 2, 1)\n        else:\n            # Original\n            if embed_dim % 4 != 0:\n                # Make sure that the embed_dim is divisible by 4\n                embed_dim += 4 - (embed_dim % 4)\n            pos_embed = sample_pos_embed(\n                grid_size=(triplane_H, triplane_W),\n                embed_dim=embed_dim,\n                coords=coords[..., :2],\n            )\n        if embed_dim > self.updateformer_input_dim:\n            # If the embed_dim was increased for divisibility, then remove the extra dimensions\n            pos_embed = pos_embed[:, :self.updateformer_input_dim, :]\n        pos_embed = rearrange(pos_embed, \"b e n -> (b n) e\").unsqueeze(1)\n\n        times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)\n        embed_dim = self.updateformer_input_dim\n        if embed_dim % 2 != 0:\n            # Make sure that the embed_dim is divisible by 2\n            embed_dim += 2 - (embed_dim % 2)\n        times_embed = (\n            torch.from_numpy(get_1d_sincos_pos_embed_from_grid(embed_dim, times_[0]))[None]\n            .repeat(B, 1, 1)\n            .float()\n            .to(device)\n        )\n        if embed_dim > self.updateformer_input_dim:\n            # If the embed_dim was increased to be divisible by 2, then remove the extra dimensions\n            times_embed = times_embed[:, :, :self.updateformer_input_dim]\n\n        coord_predictions = []\n        support_feat = self.support_features\n\n        for _ in range(iters):\n            coords = coords.detach()\n            fcorrsXY = fcorr_fnXY.corr_sample(ffeats[0], coords[..., :2])\n            fcorrsYZ = fcorr_fnYZ.corr_sample(ffeats[1], coords[..., [1, 2]])\n            fcorrsXZ = fcorr_fnXZ.corr_sample(ffeats[2], coords[..., [0, 2]])\n            # @Single-view-difference:\n            #     Instead of summing the correlations for different planes, I will concatenate them so that the model\n            #     can learn to differentiate between the correlations of different planes. Summing the correlations up\n            #     can make it very difficult for the model to differentiate between the correlations of different\n            #     planes unless, e.g., it learns to create the feature maps in a way that they are orthogonal\n            #     to each other. But rather than relying on the model to learn this, I believe that it is better\n            #     to provide the model with the information that the correlations are from different planes explicitly.\n            #     Note that this change will increase the dimension of the correlation features that are given to the\n            #     transformer: 196 * 3 = 588, instead of 196.\n            # fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ\n            if self.concat_triplane_features:\n                # Ours\n                fcorrs = torch.cat([fcorrsXY, fcorrsYZ, fcorrsXZ], dim=-1)\n            else:\n                # Original\n                fcorrs = fcorrsXY + fcorrsYZ + fcorrsXZ\n            LRR = fcorrs.shape[3]\n            fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)\n\n            flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 3)\n            flows_cat = get_3d_embedding(flows_, self.flow_embed_dim, cat_coords=True)\n            # @Single-view-difference:\n            #     I have removed the zeroMLPflow linear layer which was added to project the flow embedding\n            #     from a 195-dimensional vector to a 130-dimensional to have a cleaner architecture.\n            #     I believe that the authors have added this layer just to match the 130 that the original\n            #     CoTracker implementation had used, but this can introduce confusion in the architecture's design.\n            # flows_cat = self.zeroMLPflow(flows_cat)\n            if self.remove_zeromlpflow:\n                # Ours\n                pass\n            else:\n                # Original\n                flows_cat = self.zeroMLPflow(flows_cat)\n\n            ffeats_xy = ffeats[0].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)\n            ffeats_yz = ffeats[1].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)\n            ffeats_xz = ffeats[2].permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)\n            # @Single-view-difference:\n            #     Instead of summing the features for different planes, I will concatenate them so that the model\n            #     can learn to differentiate between the features of different planes. Summing the features up\n            #     can make it very difficult for the model to differentiate between the features of different\n            #     planes. I believe that it is better to provide the model with the information that the features\n            #     are from different planes explicitly. Note that this change will increase the dimension of the\n            #     feature embeddings that are given to the transformer: 128 * 3 = 384, instead of 128.\n            # ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz\n            if self.concat_triplane_features:\n                # Ours\n                ffeats_ = torch.cat([ffeats_xy, ffeats_yz, ffeats_xz], dim=-1)\n            else:\n                # Original\n                ffeats_ = ffeats_xy + ffeats_yz + ffeats_xz\n\n            if track_mask.shape[1] < vis_init.shape[1]:\n                track_mask = torch.cat([\n                    track_mask,\n                    torch.zeros_like(track_mask[:, 0]).repeat(1, vis_init.shape[1] - track_mask.shape[1], 1, 1),\n                ], dim=1)\n            track_mask_and_vis = torch.cat([track_mask, vis_init], dim=2).permute(0, 2, 1, 3).reshape(B * N, S, 2)\n\n            transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, track_mask_and_vis], dim=2)\n            assert transformer_input.shape[-1] == pos_embed.shape[-1]\n\n            x = transformer_input + pos_embed + times_embed\n            x = rearrange(x, \"(b n) t d -> b n t d\", b=B)\n\n            delta, delta_se3F = self.updateformer(x, support_feat)\n            support_feat = support_feat + delta_se3F[0] / 100\n            delta = rearrange(delta, \" b n t d -> (b n) t d\")\n            d_coord = delta[:, :, :3]\n            d_feats_xy = delta[:, :, 3:self.latent_dim + 3]\n            d_feats_yz = delta[:, :, self.latent_dim + 3:self.latent_dim * 2 + 3]\n            d_feats_xz = delta[:, :, self.latent_dim * 2 + 3:]\n            d_feats_xy_norm = self.norm_xy(d_feats_xy.view(-1, self.latent_dim))\n            d_feats_yz_norm = self.norm_yz(d_feats_yz.view(-1, self.latent_dim))\n            d_feats_xz_norm = self.norm_xz(d_feats_xz.view(-1, self.latent_dim))\n            ffeats_xy = ffeats_xy.reshape(-1, self.latent_dim) + self.ffeatxy_updater(d_feats_xy_norm)\n            ffeats_yz = ffeats_yz.reshape(-1, self.latent_dim) + self.ffeatyz_updater(d_feats_yz_norm)\n            ffeats_xz = ffeats_xz.reshape(-1, self.latent_dim) + self.ffeatxz_updater(d_feats_xz_norm)\n            ffeats[0] = ffeats_xy.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3)\n            ffeats[1] = ffeats_yz.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3)\n            ffeats[2] = ffeats_xz.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3)\n            coords = coords + d_coord.reshape(B, N, S, 3).permute(0, 2, 1, 3)\n            if torch.isnan(coords).any():\n                logging.error(\"Got NaN values in coords, perhaps the training exploded\")\n                import ipdb;\n                ipdb.set_trace()\n\n            coord_predictions.append(coords.clone())\n\n        # @Single-view-difference:\n        #     Instead of summing the features for different planes,\n        #     I will concatenate before inputting them to the shallow visibility predictor.\n        # ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]\n        if self.concat_triplane_features:\n            ffeats_f = torch.cat(ffeats, dim=-1)\n            vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim * 3)).reshape(B, S, N)\n        else:\n            ffeats_f = ffeats[0] + ffeats[1] + ffeats[2]\n            vis_e = self.vis_predictor(ffeats_f.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)\n\n        self.support_features = support_feat.detach()\n\n        return coord_predictions, vis_e, feat_init\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points,\n            intrs,\n            extrs,\n            iters=4,\n            feat_init=None,\n            is_train=False,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            **kwargs,\n    ):\n        batch_size, num_views, num_frames, _, height, width = rgbs.shape\n        _, num_points, _ = query_points.shape\n\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height, width)\n        assert query_points.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        if feat_init is not None:\n            raise NotImplementedError(\"feat_init is not supported yet\")\n\n        if save_debug_logs:\n            os.makedirs(debug_logs_path, exist_ok=True)\n            if kwargs:\n                warnings.warn(f\"Received unexpected kwargs: {kwargs.keys()}\")\n\n        self.support_features = torch.zeros(100, 384).to(\"cuda\") + 0.1\n        self.is_train = is_train\n\n        # Unpack the query points\n        query_points_t = query_points[:, :, :1].long()\n        query_points_xyz_worldspace = query_points[:, :, 1:]\n\n        # Invert intrinsics and extrinsics\n        intrs_inv = torch.inverse(intrs.float())\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n        extrs_square[:, :, :, :3, :] = extrs\n        extrs_inv = torch.inverse(extrs_square.float())\n\n        # Interpolate the rgbs and depthmaps to the stride of the SpaTracker\n        strided_height = height // self.stride\n        strided_width = width // self.stride\n\n        strided_depths = nn.functional.interpolate(\n            input=depths.reshape(-1, 1, height, width),\n            scale_factor=1.0 / self.stride,\n            mode=\"nearest\",\n        ).reshape(batch_size, num_views, num_frames, 1, strided_height, strided_width)\n\n        strided_rgbs = nn.functional.interpolate(\n            input=rgbs.reshape(-1, 3, height, width),\n            scale_factor=1.0 / self.stride,\n            mode=\"bilinear\",\n        ).reshape(batch_size, num_views, num_frames, 3, strided_height, strided_width)\n\n        # Un-project strided depthmaps back to world coordinates\n        pixel_xy = torch.stack(torch.meshgrid(\n            (torch.arange(0, height / self.stride) + 0.5) * self.stride - 0.5,\n            (torch.arange(0, width / self.stride) + 0.5) * self.stride - 0.5,\n            indexing=\"ij\",\n        )[::-1], dim=-1)\n        pixel_xy = pixel_xy.to(device=rgbs.device, dtype=rgbs.dtype)\n        pixel_xy_homo = to_homogeneous(pixel_xy)\n        depthmap_camera_xyz = torch.einsum('BVTij,HWj->BVTHWi', intrs_inv, pixel_xy_homo)\n        depthmap_camera_xyz = depthmap_camera_xyz * strided_depths[..., 0, :, :, None]\n        depthmap_camera_xyz_homo = to_homogeneous(depthmap_camera_xyz)\n        depthmap_world_xyz_homo = torch.einsum('BVTij,BVTHWj->BVTHWi', extrs_inv, depthmap_camera_xyz_homo)\n        depthmap_world_xyz = from_homogeneous(depthmap_world_xyz_homo)\n\n        if save_debug_logs:\n            t = 0\n            n_skip = 4\n            xyz = depthmap_world_xyz[0, :, t, ::n_skip, ::n_skip, :].reshape(-1, 3).cpu().numpy()\n            c = strided_rgbs.permute(0, 1, 2, 4, 5, 3)[0, :, t, ::n_skip, ::n_skip].reshape(-1, 3).cpu().numpy() / 255\n            filename = time_now() + \"__rgbd_with_queries\"\n            qp = query_points_xyz_worldspace[0].cpu().numpy()\n            qc = np.array([[1, 0, 0]] * query_points_xyz_worldspace.shape[1])\n            self._plot_pointcloud(debug_logs_path, filename, xyz, c, qp, qc, show=False)\n\n        # Put the three planes along the YX, ZX, and ZY axes\n        # TODO: Hardcode the xyz ranges for the triplanes,\n        #       as taking the whole range would make the\n        #       central object of interest very tiny and\n        #       the grid would be wasted in representing\n        #       wast background.\n        x_range = [-14, 14]\n        y_range = [-14, 14]\n        z_range = [-1, 10]\n        query_points_outside_of_triplane_range = (\n                (query_points_xyz_worldspace[..., 0].flatten() < x_range[0]) |\n                (query_points_xyz_worldspace[..., 0].flatten() > x_range[1]) |\n                (query_points_xyz_worldspace[..., 1].flatten() < y_range[0]) |\n                (query_points_xyz_worldspace[..., 1].flatten() > y_range[1]) |\n                (query_points_xyz_worldspace[..., 2].flatten() < z_range[0]) |\n                (query_points_xyz_worldspace[..., 2].flatten() > z_range[1])\n        )\n        if query_points_outside_of_triplane_range.any():\n            warnings.warn(f\"Some Query points are outside of the triplane range. \"\n                          f\"x_range={x_range}, y_range={y_range}, z_range={z_range}. \"\n                          f\"query_points_xyz_worldspace={query_points_xyz_worldspace[:, query_points_outside_of_triplane_range]}\")\n\n        kwargs = {\"device\": depthmap_world_xyz.device, \"dtype\": depthmap_world_xyz.dtype}\n        triplane_xyz_min = torch.tensor([x_range[0], y_range[0], z_range[0]], **kwargs)\n        triplane_xyz_max = torch.tensor([x_range[1], y_range[1], z_range[1]], **kwargs)\n        triplane_grid_dims = torch.tensor([self.triplane_xres, self.triplane_yres, self.triplane_zres], **kwargs)\n\n        if save_debug_logs:\n            t = 0\n            n_skip = 1\n            xyz = depthmap_world_xyz[0, :, t, ::n_skip, ::n_skip, :].reshape(-1, 3).cpu().numpy()\n            c = strided_rgbs.permute(0, 1, 2, 4, 5, 3)[0, :, t, ::n_skip, ::n_skip, :].reshape(-1,\n                                                                                               3).cpu().numpy() / 255\n            mask = (\n                    (xyz[:, 0] >= x_range[0]) & (xyz[:, 0] <= x_range[1]) &\n                    (xyz[:, 1] >= y_range[0]) & (xyz[:, 1] <= y_range[1]) &\n                    (xyz[:, 2] >= z_range[0]) & (xyz[:, 2] <= z_range[1])\n            )\n            xyz_in_range = xyz[mask]\n            c_in_range = c[mask]\n\n            qp = query_points_xyz_worldspace[0].cpu().numpy()\n            qc = np.array([[1, 0, 0]] * query_points_xyz_worldspace.shape[1])\n            mask = (\n                    (qp[:, 0] >= x_range[0]) & (qp[:, 0] <= x_range[1]) &\n                    (qp[:, 1] >= y_range[0]) & (qp[:, 1] <= y_range[1]) &\n                    (qp[:, 2] >= z_range[0]) & (qp[:, 2] <= z_range[1])\n            )\n            qp_in_range = qp[mask]\n            qc_in_range = qc[mask]\n\n            filename = time_now() + \"__rgbd_with_queries_within_triplane_range\"\n            self._plot_pointcloud(debug_logs_path, filename, xyz_in_range, c_in_range, qp_in_range, qc_in_range,\n                                  show=False)\n\n        # Pre-compute the per-view feature maps\n        rgbs_normalized = 2 * (rgbs / 255.0) - 1.0\n        fnet_fmaps = self.fnet(rgbs_normalized.reshape(-1, 3, height, width))\n        fnet_fmaps = fnet_fmaps.reshape(\n            batch_size, num_views, num_frames, self.latent_dim, strided_height, strided_width,\n        )\n\n        # Add Positional 3D Embeddings/Encodings\n        def world_to_triplane(points, inverse=False):\n            assert points.shape[-1] == 3\n            if inverse:\n                return points * (triplane_xyz_max - triplane_xyz_min) / (triplane_grid_dims - 1) + triplane_xyz_min\n            else:\n                return (points - triplane_xyz_min) / (triplane_xyz_max - triplane_xyz_min) * (triplane_grid_dims - 1)\n\n        depthmap_world_xyz_normalized = (depthmap_world_xyz - triplane_xyz_min) / (triplane_xyz_max - triplane_xyz_min)\n        positional_encoding_3d = self.embed3d(2 * depthmap_world_xyz_normalized.reshape(-1, 3) - 1)\n        positional_encoding_3d = (\n            positional_encoding_3d\n            .reshape(batch_size, num_views, num_frames, strided_height, strided_width, -1)\n            .permute(0, 1, 2, 5, 3, 4)  # HWC --> CHW\n        )\n        fmaps = torch.cat([fnet_fmaps, positional_encoding_3d], dim=-3)\n        fmaps = fmaps.reshape(-1, self.latent_dim + self.embed3d.out_dim, strided_height, strided_width)\n        fmaps = self.embedConv(fmaps)\n        fmaps = fmaps.reshape(batch_size, num_views, num_frames, self.latent_dim, strided_height, strided_width)\n\n        # Compute the flows from each depthmap to the triplane\n        # The flows are needed to splat the features from the depthmap to the triplane\n        # The flow defines how one 2D plane is transformed to another 2D plane\n        # In our case, the first plane will be of ... TODO describe the planes more\n        depthmap_world_xyz_normalized_to_triplane_grid = depthmap_world_xyz_normalized * (triplane_grid_dims - 1)\n        depthmap_world_xyz_reproduced = world_to_triplane(\n            points=depthmap_world_xyz_normalized_to_triplane_grid,\n            inverse=True,\n        )\n        if not depthmap_world_xyz_reproduced.allclose(depthmap_world_xyz, atol=0.72):\n            logging.info(\"depthmap_world_xyz_reproduced\", depthmap_world_xyz_reproduced)\n            logging.info(\"depthmap_world_xyz\", depthmap_world_xyz)\n            warnings.warn(f\"Applying the inverse of world_to_triplane did not reproduce depthmap_world_xyz... \"\n                          f\"The maximum difference is {torch.max(torch.abs(depthmap_world_xyz_reproduced - depthmap_world_xyz))}\")\n\n        flow_pointcloud_to_xy = depthmap_world_xyz_normalized_to_triplane_grid[..., [0, 1]]\n        flow_pointcloud_to_yz = depthmap_world_xyz_normalized_to_triplane_grid[..., [1, 2]]\n        flow_pointcloud_to_xz = depthmap_world_xyz_normalized_to_triplane_grid[..., [0, 2]]\n        flow_pointcloud_to_xy = (\n            flow_pointcloud_to_xy\n            .permute(0, 2, 5, 3, 1, 4)\n            .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width)\n        )\n        flow_pointcloud_to_yz = (\n            flow_pointcloud_to_yz\n            .permute(0, 2, 5, 3, 1, 4)\n            .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width)\n        )\n        flow_pointcloud_to_xz = (\n            flow_pointcloud_to_xz\n            .permute(0, 2, 5, 3, 1, 4)\n            .reshape(batch_size * num_frames, 2, strided_height, num_views * strided_width)\n        )\n\n        # Compute the triplane features by splatting the per-view features following the flows\n        def splat_fmaps(fmaps, flow_xy, flow_yz, flow_xz, out_shape):\n            dtype = fmaps.dtype\n            fmaps = fmaps.float()\n            flow_xy = flow_xy.float()\n            flow_yz = flow_yz.float()\n            flow_xz = flow_xz.float()\n            fmap_xy, fmap_xy_norm = softsplat(\n                tenIn=fmaps,\n                tenFlow=flow_xy,\n                tenMetric=None,\n                strMode=\"avg\",\n                tenoutH=out_shape[1],\n                tenoutW=out_shape[0],\n                use_pointcloud_splatting=True,\n                return_normalization_tensor=True,\n            )\n            fmap_yz, fmap_yz_norm = softsplat(\n                tenIn=fmaps,\n                tenFlow=flow_yz,\n                tenMetric=None,\n                strMode=\"avg\",\n                tenoutH=out_shape[2],\n                tenoutW=out_shape[1],\n                use_pointcloud_splatting=True,\n                return_normalization_tensor=True,\n            )\n            fmap_xz, fmap_xz_norm = softsplat(\n                tenIn=fmaps,\n                tenFlow=flow_xz,\n                tenMetric=None,\n                strMode=\"avg\",\n                tenoutH=out_shape[2],\n                tenoutW=out_shape[0],\n                use_pointcloud_splatting=True,\n                return_normalization_tensor=True,\n            )\n            if dtype != fmaps.dtype:\n                fmap_xy = fmap_xy.to(dtype)\n                fmap_yz = fmap_yz.to(dtype)\n                fmap_xz = fmap_xz.to(dtype)\n                fmap_xy_norm = fmap_xy_norm.to(dtype)\n                fmap_yz_norm = fmap_yz_norm.to(dtype)\n                fmap_xz_norm = fmap_xz_norm.to(dtype)\n            return fmap_xy, fmap_yz, fmap_xz, fmap_xy_norm, fmap_yz_norm, fmap_xz_norm\n\n        fmaps = (\n            fmaps\n            .permute(0, 2, 3, 4, 1, 5)\n            .reshape(batch_size * num_frames, self.latent_dim, strided_height, num_views * strided_width)\n        )\n        fmap_xy, fmap_yz, fmap_xz, fmap_xy_norm, fmap_yz_norm, fmap_xz_norm = splat_fmaps(\n            fmaps=fmaps,\n            flow_xy=flow_pointcloud_to_xy,\n            flow_yz=flow_pointcloud_to_yz,\n            flow_xz=flow_pointcloud_to_xz,\n            out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres),\n        )\n\n        if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n            # Visualize how the splatting would look like if the strided_rgbs would be directly splatted instead of feature maps\n            rgbs_fmaps = (\n                strided_rgbs\n                .permute(0, 2, 3, 4, 1, 5)\n                .reshape(batch_size * num_frames, 3, strided_height, num_views * strided_width)\n            )\n            rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz, rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm = splat_fmaps(\n                fmaps=rgbs_fmaps,\n                flow_xy=flow_pointcloud_to_xy,\n                flow_yz=flow_pointcloud_to_yz,\n                flow_xz=flow_pointcloud_to_xz,\n                out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres),\n            )\n            rgbs_fmap_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz], -1)\n            rgbs_fmap_norm_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm], -1)\n            self._plot_featuremaps(\n                logs_path=debug_logs_path,\n                filename=time_now() + \"__splatted_rgbs\",\n                fmaps_before_splatting=rgbs_fmaps,\n                splatted_fmaps=rgbs_fmap_xy_yz_xz_concat,\n                splat_normalization=rgbs_fmap_norm_xy_yz_xz_concat,\n                chosen_channels=(0, 1, 2),\n            )\n\n        if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n            # Also splat only the first view RGBs to see how the splatting would look like\n            rgbs_fmaps = strided_rgbs[0, 0]\n            rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz, rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm = splat_fmaps(\n                fmaps=rgbs_fmaps,\n                flow_xy=flow_pointcloud_to_xy[:, :, :, :strided_width],\n                flow_yz=flow_pointcloud_to_yz[:, :, :, :strided_width],\n                flow_xz=flow_pointcloud_to_xz[:, :, :, :strided_width],\n                out_shape=(self.triplane_xres, self.triplane_yres, self.triplane_zres),\n            )\n            rgbs_fmap_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy, rgbs_fmap_yz, rgbs_fmap_xz], -1)\n            rgbs_fmap_norm_xy_yz_xz_concat = torch.concat([rgbs_fmap_xy_norm, rgbs_fmap_yz_norm, rgbs_fmap_xz_norm],\n                                                          -1)\n            self._plot_featuremaps(\n                logs_path=debug_logs_path,\n                filename=time_now() + \"__splatted_rgbs_first_view_only\",\n                fmaps_before_splatting=rgbs_fmaps,\n                splatted_fmaps=rgbs_fmap_xy_yz_xz_concat,\n                splat_normalization=rgbs_fmap_norm_xy_yz_xz_concat,\n                chosen_channels=(0, 1, 2),\n            )\n            xyz = to_homogeneous(\n                flow_pointcloud_to_xy[0, :, :, :strided_width].permute(1, 2, 0).reshape(-1, 2)).cpu().numpy()\n            c = strided_rgbs[0, 0, 0, :, :].permute(1, 2, 0).reshape(-1, 3).cpu().numpy() / 255\n            self._plot_pointcloud(debug_logs_path, time_now() + \"__flow_xy_debug\", xyz, c, show=False)\n\n        if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n            if not (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n                raise NotImplementedError(\"Current implementation assumed these, otherwise needs some padding/interp.\")\n            fmap_xy_yz_xz_concat = torch.concat([fmap_xy, fmap_yz, fmap_xz], dim=-1)\n            fmap_norm_xy_yz_xz_concat = torch.concat([fmap_xy_norm, fmap_yz_norm, fmap_xz_norm], dim=-1)\n            self._plot_featuremaps(\n                logs_path=debug_logs_path,\n                filename=time_now() + \"__fmaps\",\n                fmaps_before_splatting=fmaps,\n                splatted_fmaps=fmap_xy_yz_xz_concat,\n                splat_normalization=fmap_norm_xy_yz_xz_concat,\n                chosen_channels=(0, 1, 2),\n            )\n\n        fmap_xy = self.headxy(fmap_xy)\n        fmap_yz = self.headyz(fmap_yz)\n        fmap_xz = self.headxz(fmap_xz)\n\n        fmap_xy = fmap_xy.reshape(batch_size, num_frames, self.latent_dim, self.triplane_yres, self.triplane_xres)\n        fmap_yz = fmap_yz.reshape(batch_size, num_frames, self.latent_dim, self.triplane_zres, self.triplane_yres)\n        fmap_xz = fmap_xz.reshape(batch_size, num_frames, self.latent_dim, self.triplane_zres, self.triplane_xres)\n\n        if save_debug_logs and (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n            if not (self.triplane_xres == self.triplane_yres == self.triplane_zres):\n                raise NotImplementedError(\"Current implementation assumed these, otherwise needs some padding/interp.\")\n            fmap_xy_yz_xz_concat = torch.concat([fmap_xy[0], fmap_yz[0], fmap_xz[0]], dim=-1)\n            fmap_norm_xy_yz_xz_concat = torch.concat([fmap_xy_norm, fmap_yz_norm, fmap_xz_norm], dim=-1)\n            self._plot_featuremaps(\n                logs_path=debug_logs_path,\n                filename=time_now() + \"__fmaps_after_head\",\n                fmaps_before_splatting=fmaps,\n                splatted_fmaps=fmap_xy_yz_xz_concat,\n                splat_normalization=fmap_norm_xy_yz_xz_concat,\n                chosen_channels=(-3, -2, -1),\n            )\n\n        # Filter the points that never appear during 1 - T\n        assert batch_size == 1, \"Batch size > 1 is not supported yet\"\n        query_points_t = query_points_t.squeeze(0).squeeze(-1)  # BN1 --> N\n        ind_array = torch.arange(num_frames, device=query_points.device)\n        ind_array = ind_array[None, :, None].repeat(batch_size, 1, num_points)\n        track_mask = (ind_array >= query_points_t[None, None, :]).unsqueeze(-1)\n\n        # Prepare the initial coordinates and visibility\n        coords_init = query_points_xyz_worldspace.unsqueeze(1).repeat(1, self.S, 1, 1)\n        coords_init = world_to_triplane(coords_init)\n        vis_init = query_points.new_ones((batch_size, self.S, num_points, 1)) * 10\n\n        # Sort the queries via their first appeared time\n        _, sort_inds = torch.sort(query_points_t, dim=0, descending=False)\n        inv_sort_inds = torch.argsort(sort_inds, dim=0)\n        assert torch.allclose(query_points_t, query_points_t[sort_inds][inv_sort_inds])\n\n        query_points_t_ = query_points_t[sort_inds]\n        coords_init_ = coords_init[..., sort_inds, :].clone()\n        vis_init_ = vis_init[:, :, sort_inds].clone()\n        track_mask_ = track_mask[:, :, sort_inds].clone()\n\n        # Placeholders for the results (for the sorted points)\n        traj_e_ = query_points.new_zeros((batch_size, num_frames, num_points, 3))\n        vis_e_ = query_points.new_zeros((batch_size, num_frames, num_points))\n\n        # Perform the iterative forward pass of the SpaTracker as usual,\n        # but make sure to use the pre-computed triplane features\n        w_idx_start = 0\n        p_idx_start = 0\n        vis_predictions = []\n        coord_predictions = []\n        p_idx_end_list = []\n        while w_idx_start < num_frames - self.S // 2:\n            curr_wind_points = torch.nonzero(query_points_t_ < w_idx_start + self.S)\n            if curr_wind_points.shape[0] == 0:\n                w_idx_start = w_idx_start + self.S // 2\n                continue\n            p_idx_end = curr_wind_points[-1] + 1\n            p_idx_end_list.append(p_idx_end)\n\n            # TODO: Is cloning necessary here – I don't think so?\n            fmap_xy_seq = fmap_xy[:, w_idx_start:w_idx_start + self.S].clone()\n            fmap_yz_seq = fmap_yz[:, w_idx_start:w_idx_start + self.S].clone()\n            fmap_xz_seq = fmap_xz[:, w_idx_start:w_idx_start + self.S].clone()\n\n            # the number of frames may not be divisible by self.S\n            S_local = fmap_xy_seq.shape[1]\n            if S_local < self.S:\n                fmap_xy_seq = torch.cat([fmap_xy_seq, fmap_xy_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1)\n                fmap_yz_seq = torch.cat([fmap_yz_seq, fmap_yz_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1)\n                fmap_xz_seq = torch.cat([fmap_xz_seq, fmap_xz_seq[:, -1, None].repeat(1, self.S - S_local, 1, 1, 1)], 1)\n\n            if p_idx_end - p_idx_start > 0:\n                queried_t = (query_points_t_[p_idx_start:p_idx_end] - w_idx_start)\n                featxy_init, featyz_init, featxz_init = self.sample_trifeat(\n                    t=queried_t,\n                    featMapxy=fmap_xy_seq,\n                    featMapyz=fmap_yz_seq,\n                    featMapxz=fmap_xz_seq,\n                    coords=coords_init_[:, :1, p_idx_start:p_idx_end],\n                )\n                feat_init_curr = torch.stack([featxy_init, featyz_init, featxz_init], dim=-1)\n                feat_init = smart_cat(feat_init, feat_init_curr, dim=2)\n\n            # Update the initial coordinates and visibility for non-first windows\n            if p_idx_start > 0:\n                last_coords = coords[-1][:, self.S // 2:].clone()  # Take the predicted coords from the last window\n                coords_init_[:, : self.S // 2, :p_idx_start] = last_coords\n                coords_init_[:, self.S // 2:, :p_idx_start] = last_coords[:, -1].repeat(1, self.S // 2, 1, 1)\n\n                last_vis = vis[:, self.S // 2:][..., None]\n                vis_init_[:, : self.S // 2, :p_idx_start] = last_vis\n                vis_init_[:, self.S // 2:, :p_idx_start] = last_vis[:, -1].repeat(1, self.S // 2, 1, 1)\n\n            track_mask_current = track_mask_[:, w_idx_start: w_idx_start + self.S, :p_idx_end]\n            if S_local < self.S:\n                track_mask_current = torch.cat([\n                    track_mask_current,\n                    track_mask_current[:, -1:].repeat(1, self.S - S_local, 1, 1),\n                ], 1)\n\n            coords, vis, _ = self.forward_iteration(\n                fmapXY=fmap_xy_seq,\n                fmapYZ=fmap_yz_seq,\n                fmapXZ=fmap_xz_seq,\n                coords_init=coords_init_[:, :, :p_idx_end],\n                feat_init=feat_init[:, :, :p_idx_end],\n                vis_init=vis_init_[:, :, :p_idx_end],\n                track_mask=track_mask_current,\n                iters=iters,\n            )\n\n            coords_in_worldspace = [world_to_triplane(coord, inverse=True) for coord in coords]\n\n            if is_train:\n                coord_predictions.append([coord[:, :S_local] for coord in coords_in_worldspace])\n                vis_predictions.append(vis[:, :S_local])\n\n            traj_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = coords_in_worldspace[-1][:, :S_local]\n            vis_e_[:, w_idx_start:w_idx_start + self.S, :p_idx_end] = torch.sigmoid(vis[:, :S_local])\n\n            track_mask_[:, : w_idx_start + self.S, :p_idx_end] = 0.0\n            w_idx_start = w_idx_start + self.S // 2\n\n            p_idx_start = p_idx_end\n\n        traj_e = traj_e_[:, :, inv_sort_inds]\n        vis_e = vis_e_[:, :, inv_sort_inds]\n\n        results = {\n            \"traj_e\": traj_e,\n            \"feat_init\": feat_init,\n            \"vis_e\": vis_e,\n        }\n        if self.is_train:\n            results[\"train_data\"] = {\n                \"vis_predictions\": vis_predictions,\n                \"coord_predictions\": coord_predictions,\n                \"attn_predictions\": None,\n                \"p_idx_end_list\": p_idx_end_list,\n                \"sort_inds\": sort_inds,\n                \"Rigid_ln_total\": None,\n            }\n        return results\n\n    @staticmethod\n    def _plot_pointcloud(logs_path, filename, xyz, c, q_xyz=None, q_c=None,\n                         elevations=(0, 30, 90), azimuths=(0, 45, 90), show=False):\n        fig = plt.figure(figsize=(len(azimuths) * 4.8, len(elevations) * 4.8))\n        fig.suptitle(filename)\n        for i, elev_ in enumerate(elevations):\n            for j, azim in enumerate(azimuths):\n                ax = fig.add_subplot(len(elevations), len(azimuths), i * len(azimuths) + j + 1, projection='3d')\n                ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], c=c, s=1, marker=\".\", label=\"RGBD pointcloud\")\n                if q_xyz is not None:\n                    ax.scatter(q_xyz[:, 0], q_xyz[:, 1], q_xyz[:, 2], c=q_c, s=3, marker=\"^\", label=\"Query Points\")\n                ax.set_xlabel('x')\n                ax.set_ylabel('y')\n                ax.set_zlabel('z')\n                ax.legend()\n                ax.view_init(elev=elev_, azim=azim)\n        plt.tight_layout(pad=0)\n        plt.savefig(os.path.join(logs_path, f\"{filename}.png\"))\n        if show:\n            plt.show()\n        plt.close()\n\n    @staticmethod\n    def _plot_featuremaps(\n            logs_path,\n            filename,\n            fmaps_before_splatting,\n            splatted_fmaps,\n            splat_normalization,\n            chosen_channels=(-3, -2, -1),\n    ):\n        num_frames, n_channels, height_before, width_before = fmaps_before_splatting.shape\n        _, _, height_after, width_after = splatted_fmaps.shape\n\n        assert fmaps_before_splatting.shape == (num_frames, n_channels, height_before, width_before)\n        assert splatted_fmaps.shape == (num_frames, n_channels, height_after, width_after)\n        assert splat_normalization.shape == (num_frames, 1, height_after, width_after)\n\n        fmaps_before_splatting = fmaps_before_splatting.detach().cpu().float().numpy()\n        splatted_fmaps = splatted_fmaps.detach().cpu().float().numpy()\n        splat_normalization = splat_normalization.detach().cpu().float().numpy()\n\n        # Extract the chosen channels and normalize them\n        fmaps_before_splatting = fmaps_before_splatting[:, chosen_channels, :, :]\n        splatted_fmaps = splatted_fmaps[:, chosen_channels, :, :]\n\n        ch_min = fmaps_before_splatting.min(axis=(0, 2, 3), keepdims=True)\n        ch_max = fmaps_before_splatting.max(axis=(0, 2, 3), keepdims=True)\n        fmaps_before_splatting = (fmaps_before_splatting - ch_min) / (ch_max - ch_min)\n        splatted_fmaps = (splatted_fmaps - ch_min) / (ch_max - ch_min)\n\n        # Normalize the normalization as well ( ͡° ͜ʖ ͡°)\n        splat_normalization = splat_normalization / splat_normalization.max()\n\n        # Pad the shorter side to match the longer side\n        if width_before != width_after:\n            if width_after > width_before:\n                fmaps_before_splatting = np.pad(\n                    fmaps_before_splatting,\n                    ((0, 0), (0, 0), (0, 0), (0, width_after - width_before)),\n                    mode='constant',\n                    constant_values=0\n                )\n            else:\n                splatted_fmaps = np.pad(\n                    splatted_fmaps,\n                    ((0, 0), (0, 0), (0, 0), (0, width_before - width_after)),\n                    mode='constant',\n                    constant_values=0\n                )\n                splat_normalization = np.pad(\n                    splat_normalization,\n                    ((0, 0), (0, 0), (0, 0), (0, width_before - width_after)),\n                    mode='constant',\n                    constant_values=0\n                )\n\n        # Concatenate images along the height dimension\n        splat_normalization = np.repeat(splat_normalization, 3, axis=1)\n        imgs = [\n            np.concatenate([\n                fmaps_before_splatting[t],\n                splatted_fmaps[t],\n                splat_normalization[t]\n            ], axis=1).transpose(1, 2, 0)[..., ::-1]\n            for t in range(num_frames)\n        ]\n\n        video = cv2.VideoWriter(\n            os.path.join(logs_path, f\"{filename}.mp4\"),\n            cv2.VideoWriter_fourcc(*\"mp4v\"),\n            12,\n            (imgs[0].shape[1], imgs[0].shape[0]),\n        )\n\n        for img in imgs:\n            video.write((img * 255).astype(np.uint8))\n\n        video.release()\n        logging.info(f\"Saved the featuremap video to {os.path.abspath(os.path.join(logs_path, f'{filename}.mp4'))}\")\n"
  },
  {
    "path": "mvtracker/models/core/vggt/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/vggt/heads/camera_head.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport math\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ..layers import Mlp\nfrom ..layers.block import Block\nfrom ..heads.head_act import activate_pose\n\n\nclass CameraHead(nn.Module):\n    \"\"\"\n    CameraHead predicts camera parameters from token representations using iterative refinement.\n\n    It applies a series of transformer blocks (the \"trunk\") to dedicated camera tokens.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim_in: int = 2048,\n        trunk_depth: int = 4,\n        pose_encoding_type: str = \"absT_quaR_FoV\",\n        num_heads: int = 16,\n        mlp_ratio: int = 4,\n        init_values: float = 0.01,\n        trans_act: str = \"linear\",\n        quat_act: str = \"linear\",\n        fl_act: str = \"relu\",  # Field of view activations: ensures FOV values are positive.\n    ):\n        super().__init__()\n\n        if pose_encoding_type == \"absT_quaR_FoV\":\n            self.target_dim = 9\n        else:\n            raise ValueError(f\"Unsupported camera encoding type: {pose_encoding_type}\")\n\n        self.trans_act = trans_act\n        self.quat_act = quat_act\n        self.fl_act = fl_act\n        self.trunk_depth = trunk_depth\n\n        # Build the trunk using a sequence of transformer blocks.\n        self.trunk = nn.Sequential(\n            *[\n                Block(\n                    dim=dim_in,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    init_values=init_values,\n                )\n                for _ in range(trunk_depth)\n            ]\n        )\n\n        # Normalizations for camera token and trunk output.\n        self.token_norm = nn.LayerNorm(dim_in)\n        self.trunk_norm = nn.LayerNorm(dim_in)\n\n        # Learnable empty camera pose token.\n        self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))\n        self.embed_pose = nn.Linear(self.target_dim, dim_in)\n\n        # Module for producing modulation parameters: shift, scale, and a gate.\n        self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))\n\n        # Adaptive layer normalization without affine parameters.\n        self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)\n        self.pose_branch = Mlp(\n            in_features=dim_in,\n            hidden_features=dim_in // 2,\n            out_features=self.target_dim,\n            drop=0,\n        )\n\n    def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:\n        \"\"\"\n        Forward pass to predict camera parameters.\n\n        Args:\n            aggregated_tokens_list (list): List of token tensors from the network;\n                the last tensor is used for prediction.\n            num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.\n\n        Returns:\n            list: A list of predicted camera encodings (post-activation) from each iteration.\n        \"\"\"\n        # Use tokens from the last block for camera prediction.\n        tokens = aggregated_tokens_list[-1]\n\n        # Extract the camera tokens\n        pose_tokens = tokens[:, :, 0]\n        pose_tokens = self.token_norm(pose_tokens)\n\n        pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)\n        return pred_pose_enc_list\n\n    def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:\n        \"\"\"\n        Iteratively refine camera pose predictions.\n\n        Args:\n            pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].\n            num_iterations (int): Number of refinement iterations.\n\n        Returns:\n            list: List of activated camera encodings from each iteration.\n        \"\"\"\n        B, S, C = pose_tokens.shape  # S is expected to be 1.\n        pred_pose_enc = None\n        pred_pose_enc_list = []\n\n        for _ in range(num_iterations):\n            # Use a learned empty pose for the first iteration.\n            if pred_pose_enc is None:\n                module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))\n            else:\n                # Detach the previous prediction to avoid backprop through time.\n                pred_pose_enc = pred_pose_enc.detach()\n                module_input = self.embed_pose(pred_pose_enc)\n\n            # Generate modulation parameters and split them into shift, scale, and gate components.\n            shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)\n\n            # Adaptive layer normalization and modulation.\n            pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)\n            pose_tokens_modulated = pose_tokens_modulated + pose_tokens\n\n            pose_tokens_modulated = self.trunk(pose_tokens_modulated)\n            # Compute the delta update for the pose encoding.\n            pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))\n\n            if pred_pose_enc is None:\n                pred_pose_enc = pred_pose_enc_delta\n            else:\n                pred_pose_enc = pred_pose_enc + pred_pose_enc_delta\n\n            # Apply final activation functions for translation, quaternion, and field-of-view.\n            activated_pose = activate_pose(\n                pred_pose_enc,\n                trans_act=self.trans_act,\n                quat_act=self.quat_act,\n                fl_act=self.fl_act,\n            )\n            pred_pose_enc_list.append(activated_pose)\n\n        return pred_pose_enc_list\n\n\ndef modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Modulate the input tensor using scaling and shifting parameters.\n    \"\"\"\n    # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19\n    return x * (1 + scale) + shift\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/dpt_head.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\n# Inspired by https://github.com/DepthAnything/Depth-Anything-V2\n\n\nimport os\nfrom typing import List, Dict, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .head_act import activate_head\nfrom .utils import create_uv_grid, position_grid_to_embed\n\n\nclass DPTHead(nn.Module):\n    \"\"\"\n    DPT  Head for dense prediction tasks.\n\n    This implementation follows the architecture described in \"Vision Transformers for Dense Prediction\"\n    (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer\n    backbone and produces dense predictions by fusing multi-scale features.\n\n    Args:\n        dim_in (int): Input dimension (channels).\n        patch_size (int, optional): Patch size. Default is 14.\n        output_dim (int, optional): Number of output channels. Default is 4.\n        activation (str, optional): Activation type. Default is \"inv_log\".\n        conf_activation (str, optional): Confidence activation type. Default is \"expp1\".\n        features (int, optional): Feature channels for intermediate representations. Default is 256.\n        out_channels (List[int], optional): Output channels for each intermediate layer.\n        intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.\n        pos_embed (bool, optional): Whether to use positional embedding. Default is True.\n        feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.\n        down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim_in: int,\n        patch_size: int = 14,\n        output_dim: int = 4,\n        activation: str = \"inv_log\",\n        conf_activation: str = \"expp1\",\n        features: int = 256,\n        out_channels: List[int] = [256, 512, 1024, 1024],\n        intermediate_layer_idx: List[int] = [4, 11, 17, 23],\n        pos_embed: bool = True,\n        feature_only: bool = False,\n        down_ratio: int = 1,\n    ) -> None:\n        super(DPTHead, self).__init__()\n        self.patch_size = patch_size\n        self.activation = activation\n        self.conf_activation = conf_activation\n        self.pos_embed = pos_embed\n        self.feature_only = feature_only\n        self.down_ratio = down_ratio\n        self.intermediate_layer_idx = intermediate_layer_idx\n\n        self.norm = nn.LayerNorm(dim_in)\n\n        # Projection layers for each output channel from tokens.\n        self.projects = nn.ModuleList(\n            [\n                nn.Conv2d(\n                    in_channels=dim_in,\n                    out_channels=oc,\n                    kernel_size=1,\n                    stride=1,\n                    padding=0,\n                )\n                for oc in out_channels\n            ]\n        )\n\n        # Resize layers for upsampling feature maps.\n        self.resize_layers = nn.ModuleList(\n            [\n                nn.ConvTranspose2d(\n                    in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0\n                ),\n                nn.ConvTranspose2d(\n                    in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0\n                ),\n                nn.Identity(),\n                nn.Conv2d(\n                    in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1\n                ),\n            ]\n        )\n\n        self.scratch = _make_scratch(\n            out_channels,\n            features,\n            expand=False,\n        )\n\n        # Attach additional modules to scratch.\n        self.scratch.stem_transpose = None\n        self.scratch.refinenet1 = _make_fusion_block(features)\n        self.scratch.refinenet2 = _make_fusion_block(features)\n        self.scratch.refinenet3 = _make_fusion_block(features)\n        self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)\n\n        head_features_1 = features\n        head_features_2 = 32\n\n        if feature_only:\n            self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)\n        else:\n            self.scratch.output_conv1 = nn.Conv2d(\n                head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1\n            )\n            conv2_in_channels = head_features_1 // 2\n\n            self.scratch.output_conv2 = nn.Sequential(\n                nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),\n            )\n\n    def forward(\n        self,\n        aggregated_tokens_list: List[torch.Tensor],\n        images: torch.Tensor,\n        patch_start_idx: int,\n        frames_chunk_size: int = 8,\n        inference_feature_only: bool = False,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Forward pass through the DPT head, supports processing by chunking frames.\n        Args:\n            aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.\n            images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].\n            patch_start_idx (int): Starting index for patch tokens in the token sequence.\n                Used to separate patch tokens from other tokens (e.g., camera or register tokens).\n            frames_chunk_size (int, optional): Number of frames to process in each chunk.\n                If None or larger than S, all frames are processed at once. Default: 8.\n\n        Returns:\n            Tensor or Tuple[Tensor, Tensor]:\n                - If feature_only=True: Feature maps with shape [B, S, C, H, W]\n                - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]\n        \"\"\"\n        B, S, _, H, W = images.shape\n        # If frames_chunk_size is not specified or greater than S, process all frames at once\n        if frames_chunk_size is None or frames_chunk_size >= S:\n            return self._forward_impl(aggregated_tokens_list, images, patch_start_idx, inference_feature_only = inference_feature_only)\n\n        # Otherwise, process frames in chunks to manage memory usage\n        assert frames_chunk_size > 0\n\n        # Process frames in batches\n        all_preds = []\n        all_conf = []\n\n        for frames_start_idx in range(0, S, frames_chunk_size):\n            frames_end_idx = min(frames_start_idx + frames_chunk_size, S)\n\n            # Process batch of frames\n            # if self.feature_only or inference_feature_only:\n            #     chunk_output = self._forward_impl(\n            #         aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only\n            #     )\n            #     all_preds.append(chunk_output)\n            # else:\n            #     chunk_preds, chunk_conf = self._forward_impl(\n            #         aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only\n            #     )\n            #     all_preds.append(chunk_preds)\n            #     all_conf.append(chunk_conf)\n            chunk_preds, chunk_conf = self._forward_impl(\n                aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx, inference_feature_only = inference_feature_only\n            )\n            all_preds.append(chunk_preds)\n            all_conf.append(chunk_conf)\n\n        # Concatenate results along the sequence dimension\n        # if self.feature_only or inference_feature_only:\n        #     return torch.cat(all_preds, dim=1)\n        # else:\n        #     return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)\n        return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)\n\n    def _forward_impl(\n        self,\n        aggregated_tokens_list: List[torch.Tensor],\n        images: torch.Tensor,\n        patch_start_idx: int,\n        frames_start_idx: int = None,\n        frames_end_idx: int = None,\n        inference_feature_only: bool = False,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Implementation of the forward pass through the DPT head.\n\n        This method processes a specific chunk of frames from the sequence.\n\n        Args:\n            aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.\n            images (Tensor): Input images with shape [B, S, 3, H, W].\n            patch_start_idx (int): Starting index for patch tokens.\n            frames_start_idx (int, optional): Starting index for frames to process.\n            frames_end_idx (int, optional): Ending index for frames to process.\n\n        Returns:\n            Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).\n        \"\"\"\n        if frames_start_idx is not None and frames_end_idx is not None:\n            images = images[:, frames_start_idx:frames_end_idx].contiguous()\n\n        B, S, _, H, W = images.shape\n\n        patch_h, patch_w = H // self.patch_size, W // self.patch_size\n\n        out = []\n        dpt_idx = 0\n\n        for layer_idx in self.intermediate_layer_idx:\n            x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]\n\n            # Select frames if processing a chunk\n            if frames_start_idx is not None and frames_end_idx is not None:\n                x = x[:, frames_start_idx:frames_end_idx]\n\n            x = x.view(B * S, -1, x.shape[-1])\n\n            x = self.norm(x)\n\n            x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))\n\n            x = self.projects[dpt_idx](x)\n            if self.pos_embed:\n                x = self._apply_pos_embed(x, W, H)\n            x = self.resize_layers[dpt_idx](x)\n\n            out.append(x)\n            dpt_idx += 1\n\n        # Fuse features from multiple layers.\n        out = self.scratch_forward(out)\n        # Interpolate fused output to match target image resolution.\n        out = custom_interpolate(\n            out,\n            (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n\n        if self.pos_embed:\n            out = self._apply_pos_embed(out, W, H)\n        if self.feature_only or inference_feature_only:\n            feature_output = out.view(B, S, *out.shape[1:])\n            # return out.view(B, S, *out.shape[1:])\n\n        out = self.scratch.output_conv2(out)\n        preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)\n\n        preds = preds.view(B, S, *preds.shape[1:])\n        conf = conf.view(B, S, *conf.shape[1:])\n\n        if self.feature_only or inference_feature_only:\n            return feature_output, conf\n        else:\n            return preds, conf\n\n    def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:\n        \"\"\"\n        Apply positional embedding to tensor x.\n        \"\"\"\n        patch_w = x.shape[-1]\n        patch_h = x.shape[-2]\n        pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)\n        pos_embed = position_grid_to_embed(pos_embed, x.shape[1])\n        pos_embed = pos_embed * ratio\n        pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)\n        return x + pos_embed\n\n    def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:\n        \"\"\"\n        Forward pass through the fusion blocks.\n\n        Args:\n            features (List[Tensor]): List of feature maps from different layers.\n\n        Returns:\n            Tensor: Fused feature map.\n        \"\"\"\n        layer_1, layer_2, layer_3, layer_4 = features\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])\n        del layer_4_rn, layer_4\n\n        out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])\n        del layer_3_rn, layer_3\n\n        out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])\n        del layer_2_rn, layer_2\n\n        out = self.scratch.refinenet1(out, layer_1_rn)\n        del layer_1_rn, layer_1\n\n        out = self.scratch.output_conv1(out)\n        return out\n\n\n################################################################################\n# Modules\n################################################################################\n\n\ndef _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:\n    return FeatureFusionBlock(\n        features,\n        nn.ReLU(inplace=True),\n        deconv=False,\n        bn=False,\n        expand=False,\n        align_corners=True,\n        size=size,\n        has_residual=has_residual,\n        groups=groups,\n    )\n\n\ndef _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:\n    scratch = nn.Module()\n    out_shape1 = out_shape\n    out_shape2 = out_shape\n    out_shape3 = out_shape\n    if len(in_shape) >= 4:\n        out_shape4 = out_shape\n\n    if expand:\n        out_shape1 = out_shape\n        out_shape2 = out_shape * 2\n        out_shape3 = out_shape * 4\n        if len(in_shape) >= 4:\n            out_shape4 = out_shape * 8\n\n    scratch.layer1_rn = nn.Conv2d(\n        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer2_rn = nn.Conv2d(\n        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer3_rn = nn.Conv2d(\n        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    if len(in_shape) >= 4:\n        scratch.layer4_rn = nn.Conv2d(\n            in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n        )\n    return scratch\n\n\nclass ResidualConvUnit(nn.Module):\n    \"\"\"Residual convolution module.\"\"\"\n\n    def __init__(self, features, activation, bn, groups=1):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.bn = bn\n        self.groups = groups\n        self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)\n        self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)\n\n        self.norm1 = None\n        self.norm2 = None\n\n        self.activation = activation\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n\n        out = self.activation(x)\n        out = self.conv1(out)\n        if self.norm1 is not None:\n            out = self.norm1(out)\n\n        out = self.activation(out)\n        out = self.conv2(out)\n        if self.norm2 is not None:\n            out = self.norm2(out)\n\n        return self.skip_add.add(out, x)\n\n\nclass FeatureFusionBlock(nn.Module):\n    \"\"\"Feature fusion block.\"\"\"\n\n    def __init__(\n        self,\n        features,\n        activation,\n        deconv=False,\n        bn=False,\n        expand=False,\n        align_corners=True,\n        size=None,\n        has_residual=True,\n        groups=1,\n    ):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock, self).__init__()\n\n        self.deconv = deconv\n        self.align_corners = align_corners\n        self.groups = groups\n        self.expand = expand\n        out_features = features\n        if self.expand == True:\n            out_features = features // 2\n\n        self.out_conv = nn.Conv2d(\n            features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups\n        )\n\n        if has_residual:\n            self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)\n\n        self.has_residual = has_residual\n        self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)\n\n        self.skip_add = nn.quantized.FloatFunctional()\n        self.size = size\n\n    def forward(self, *xs, size=None):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if self.has_residual:\n            res = self.resConfUnit1(xs[1])\n            output = self.skip_add.add(output, res)\n\n        output = self.resConfUnit2(output)\n\n        if (size is None) and (self.size is None):\n            modifier = {\"scale_factor\": 2}\n        elif size is None:\n            modifier = {\"size\": self.size}\n        else:\n            modifier = {\"size\": size}\n\n        output = custom_interpolate(output, **modifier, mode=\"bilinear\", align_corners=self.align_corners)\n        output = self.out_conv(output)\n\n        return output\n\n\ndef custom_interpolate(\n    x: torch.Tensor,\n    size: Tuple[int, int] = None,\n    scale_factor: float = None,\n    mode: str = \"bilinear\",\n    align_corners: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.\n    \"\"\"\n    if size is None:\n        size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))\n\n    INT_MAX = 1610612736\n\n    input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]\n\n    if input_elements > INT_MAX:\n        chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)\n        interpolated_chunks = [\n            nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks\n        ]\n        x = torch.cat(interpolated_chunks, dim=0)\n        return x.contiguous()\n    else:\n        return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/head_act.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef activate_pose(pred_pose_enc, trans_act=\"linear\", quat_act=\"linear\", fl_act=\"linear\"):\n    \"\"\"\n    Activate pose parameters with specified activation functions.\n\n    Args:\n        pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]\n        trans_act: Activation type for translation component\n        quat_act: Activation type for quaternion component\n        fl_act: Activation type for focal length component\n\n    Returns:\n        Activated pose parameters tensor\n    \"\"\"\n    T = pred_pose_enc[..., :3]\n    quat = pred_pose_enc[..., 3:7]\n    fl = pred_pose_enc[..., 7:]  # or fov\n\n    T = base_pose_act(T, trans_act)\n    quat = base_pose_act(quat, quat_act)\n    fl = base_pose_act(fl, fl_act)  # or fov\n\n    pred_pose_enc = torch.cat([T, quat, fl], dim=-1)\n\n    return pred_pose_enc\n\n\ndef base_pose_act(pose_enc, act_type=\"linear\"):\n    \"\"\"\n    Apply basic activation function to pose parameters.\n\n    Args:\n        pose_enc: Tensor containing encoded pose parameters\n        act_type: Activation type (\"linear\", \"inv_log\", \"exp\", \"relu\")\n\n    Returns:\n        Activated pose parameters\n    \"\"\"\n    if act_type == \"linear\":\n        return pose_enc\n    elif act_type == \"inv_log\":\n        return inverse_log_transform(pose_enc)\n    elif act_type == \"exp\":\n        return torch.exp(pose_enc)\n    elif act_type == \"relu\":\n        return F.relu(pose_enc)\n    else:\n        raise ValueError(f\"Unknown act_type: {act_type}\")\n\n\ndef activate_head(out, activation=\"norm_exp\", conf_activation=\"expp1\"):\n    \"\"\"\n    Process network output to extract 3D points and confidence values.\n\n    Args:\n        out: Network output tensor (B, C, H, W)\n        activation: Activation type for 3D points\n        conf_activation: Activation type for confidence values\n\n    Returns:\n        Tuple of (3D points tensor, confidence tensor)\n    \"\"\"\n    # Move channels from last dim to the 4th dimension => (B, H, W, C)\n    fmap = out.permute(0, 2, 3, 1)  # B,H,W,C expected\n\n    # Split into xyz (first C-1 channels) and confidence (last channel)\n    xyz = fmap[:, :, :, :-1]\n    conf = fmap[:, :, :, -1]\n\n    if activation == \"norm_exp\":\n        d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)\n        xyz_normed = xyz / d\n        pts3d = xyz_normed * torch.expm1(d)\n    elif activation == \"norm\":\n        pts3d = xyz / xyz.norm(dim=-1, keepdim=True)\n    elif activation == \"exp\":\n        pts3d = torch.exp(xyz)\n    elif activation == \"relu\":\n        pts3d = F.relu(xyz)\n    elif activation == \"inv_log\":\n        pts3d = inverse_log_transform(xyz)\n    elif activation == \"xy_inv_log\":\n        xy, z = xyz.split([2, 1], dim=-1)\n        z = inverse_log_transform(z)\n        pts3d = torch.cat([xy * z, z], dim=-1)\n    elif activation == \"sigmoid\":\n        pts3d = torch.sigmoid(xyz)\n    elif activation == \"linear\":\n        pts3d = xyz\n    else:\n        raise ValueError(f\"Unknown activation: {activation}\")\n\n    if conf_activation == \"expp1\":\n        conf_out = 1 + conf.exp()\n    elif conf_activation == \"expp0\":\n        conf_out = conf.exp()\n    elif conf_activation == \"sigmoid\":\n        conf_out = torch.sigmoid(conf)\n    else:\n        raise ValueError(f\"Unknown conf_activation: {conf_activation}\")\n\n    return pts3d, conf_out\n\n\ndef inverse_log_transform(y):\n    \"\"\"\n    Apply inverse log transform: sign(y) * (exp(|y|) - 1)\n\n    Args:\n        y: Input tensor\n\n    Returns:\n        Transformed tensor\n    \"\"\"\n    return torch.sign(y) * (torch.expm1(torch.abs(y)))\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_head.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch.nn as nn\nfrom .dpt_head import DPTHead\nfrom .track_modules.base_track_predictor import BaseTrackerPredictor\n\n\nclass TrackHead(nn.Module):\n    \"\"\"\n    Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.\n    The tracking is performed iteratively, refining predictions over multiple iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim_in,\n        patch_size=14,\n        features=128,\n        iters=4,\n        predict_conf=True,\n        stride=2,\n        corr_levels=7,\n        corr_radius=4,\n        hidden_size=384,\n    ):\n        \"\"\"\n        Initialize the TrackHead module.\n\n        Args:\n            dim_in (int): Input dimension of tokens from the backbone.\n            patch_size (int): Size of image patches used in the vision transformer.\n            features (int): Number of feature channels in the feature extractor output.\n            iters (int): Number of refinement iterations for tracking predictions.\n            predict_conf (bool): Whether to predict confidence scores for tracked points.\n            stride (int): Stride value for the tracker predictor.\n            corr_levels (int): Number of correlation pyramid levels\n            corr_radius (int): Radius for correlation computation, controlling the search area.\n            hidden_size (int): Size of hidden layers in the tracker network.\n        \"\"\"\n        super().__init__()\n\n        self.patch_size = patch_size\n\n        # Feature extractor based on DPT architecture\n        # Processes tokens into feature maps for tracking\n        self.feature_extractor = DPTHead(\n            dim_in=dim_in,\n            patch_size=patch_size,\n            features=features,\n            feature_only=True,  # Only output features, no activation\n            down_ratio=2,  # Reduces spatial dimensions by factor of 2\n            pos_embed=False,\n        )\n\n        # Tracker module that predicts point trajectories\n        # Takes feature maps and predicts coordinates and visibility\n        self.tracker = BaseTrackerPredictor(\n            latent_dim=features,  # Match the output_dim of feature extractor\n            predict_conf=predict_conf,\n            stride=stride,\n            corr_levels=corr_levels,\n            corr_radius=corr_radius,\n            hidden_size=hidden_size,\n        )\n\n        self.iters = iters\n\n    def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):\n        \"\"\"\n        Forward pass of the TrackHead.\n\n        Args:\n            aggregated_tokens_list (list): List of aggregated tokens from the backbone.\n            images (torch.Tensor): Input images of shape (B, S, C, H, W) where:\n                                   B = batch size, S = sequence length.\n            patch_start_idx (int): Starting index for patch tokens.\n            query_points (torch.Tensor, optional): Initial query points to track.\n                                                  If None, points are initialized by the tracker.\n            iters (int, optional): Number of refinement iterations. If None, uses self.iters.\n\n        Returns:\n            tuple:\n                - coord_preds (torch.Tensor): Predicted coordinates for tracked points.\n                - vis_scores (torch.Tensor): Visibility scores for tracked points.\n                - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).\n        \"\"\"\n        B, S, _, H, W = images.shape\n\n        # Extract features from tokens\n        # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2\n        feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)\n\n        # Use default iterations if not specified\n        if iters is None:\n            iters = self.iters\n\n        # Perform tracking using the extracted features\n        coord_preds, vis_scores, conf_scores = self.tracker(\n            query_points=query_points,\n            fmaps=feature_maps,\n            iters=iters,\n        )\n\n        return coord_preds, vis_scores, conf_scores\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_modules/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_modules/base_track_predictor.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\n\n\nfrom .blocks import EfficientUpdateFormer, CorrBlock\nfrom .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed\nfrom .modules import Mlp\n\n\nclass BaseTrackerPredictor(nn.Module):\n    def __init__(\n        self,\n        stride=1,\n        corr_levels=5,\n        corr_radius=4,\n        latent_dim=128,\n        hidden_size=384,\n        use_spaceatt=True,\n        depth=6,\n        max_scale=518,\n        predict_conf=True,\n    ):\n        super(BaseTrackerPredictor, self).__init__()\n        \"\"\"\n        The base template to create a track predictor\n        \n        Modified from https://github.com/facebookresearch/co-tracker/\n        and https://github.com/facebookresearch/vggsfm\n        \"\"\"\n\n        self.stride = stride\n        self.latent_dim = latent_dim\n        self.corr_levels = corr_levels\n        self.corr_radius = corr_radius\n        self.hidden_size = hidden_size\n        self.max_scale = max_scale\n        self.predict_conf = predict_conf\n\n        self.flows_emb_dim = latent_dim // 2\n\n        self.corr_mlp = Mlp(\n            in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,\n            hidden_features=self.hidden_size,\n            out_features=self.latent_dim,\n        )\n\n        self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4\n\n        self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))\n\n        space_depth = depth if use_spaceatt else 0\n        time_depth = depth\n\n        self.updateformer = EfficientUpdateFormer(\n            space_depth=space_depth,\n            time_depth=time_depth,\n            input_dim=self.transformer_dim,\n            hidden_size=self.hidden_size,\n            output_dim=self.latent_dim + 2,\n            mlp_ratio=4.0,\n            add_space_attn=use_spaceatt,\n        )\n\n        self.fmap_norm = nn.LayerNorm(self.latent_dim)\n        self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)\n\n        # A linear layer to update track feats at each iteration\n        self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())\n\n        self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))\n\n        if predict_conf:\n            self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))\n\n    def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):\n        \"\"\"\n        query_points: B x N x 2, the number of batches, tracks, and xy\n        fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.\n                note HH and WW is the size of feature maps instead of original images\n        \"\"\"\n        B, N, D = query_points.shape\n        B, S, C, HH, WW = fmaps.shape\n\n        assert D == 2, \"Input points must be 2D coordinates\"\n\n        # apply a layernorm to fmaps here\n        fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))\n        fmaps = fmaps.permute(0, 1, 4, 2, 3)\n\n        # Scale the input query_points because we may downsample the images\n        # by down_ratio or self.stride\n        # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map\n        # its query_points should be query_points/4\n        if down_ratio > 1:\n            query_points = query_points / float(down_ratio)\n\n        query_points = query_points / float(self.stride)\n\n        # Init with coords as the query points\n        # It means the search will start from the position of query points at the reference frames\n        coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)\n\n        # Sample/extract the features of the query points in the query frame\n        query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])\n\n        # init track feats by query feats\n        track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1)  # B, S, N, C\n        # back up the init coords\n        coords_backup = coords.clone()\n\n        fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)\n\n        coord_preds = []\n\n        # Iterative Refinement\n        for _ in range(iters):\n            # Detach the gradients from the last iteration\n            # (in my experience, not very important for performance)\n            coords = coords.detach()\n\n            fcorrs = fcorr_fn.corr_sample(track_feats, coords)\n\n            corr_dim = fcorrs.shape[3]\n            fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)\n            fcorrs_ = self.corr_mlp(fcorrs_)\n\n            # Movement of current coords relative to query points\n            flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)\n\n            flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)\n\n            # (In my trials, it is also okay to just add the flows_emb instead of concat)\n            flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)\n\n            track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)\n\n            # Concatenate them as the input for the transformers\n            transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)\n\n            # 2D positional embed\n            # TODO: this can be much simplified\n            pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)\n            sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])\n\n            sampled_pos_emb = rearrange(sampled_pos_emb, \"b n c -> (b n) c\").unsqueeze(1)\n\n            x = transformer_input + sampled_pos_emb\n\n            # Add the query ref token to the track feats\n            query_ref_token = torch.cat(\n                [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1\n            )\n            x = x + query_ref_token.to(x.device).to(x.dtype)\n\n            # B, N, S, C\n            x = rearrange(x, \"(b n) s d -> b n s d\", b=B)\n\n            # Compute the delta coordinates and delta track features\n            delta, _ = self.updateformer(x)\n\n            # BN, S, C\n            delta = rearrange(delta, \" b n s d -> (b n) s d\", b=B)\n            delta_coords_ = delta[:, :, :2]\n            delta_feats_ = delta[:, :, 2:]\n\n            track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)\n            delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)\n\n            # Update the track features\n            track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_\n\n            track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3)  # BxSxNxC\n\n            # B x S x N x 2\n            coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)\n\n            # Force coord0 as query\n            # because we assume the query points should not be changed\n            coords[:, 0] = coords_backup[:, 0]\n\n            # The predicted tracks are in the original image scale\n            if down_ratio > 1:\n                coord_preds.append(coords * self.stride * down_ratio)\n            else:\n                coord_preds.append(coords * self.stride)\n\n        # B, S, N\n        vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)\n        if apply_sigmoid:\n            vis_e = torch.sigmoid(vis_e)\n\n        if self.predict_conf:\n            conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)\n            if apply_sigmoid:\n                conf_e = torch.sigmoid(conf_e)\n        else:\n            conf_e = None\n\n        if return_feat:\n            return coord_preds, vis_e, track_feats, query_track_feat, conf_e\n        else:\n            return coord_preds, vis_e, conf_e\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_modules/blocks.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\n# Modified from https://github.com/facebookresearch/co-tracker/\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .utils import bilinear_sampler\nfrom .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock\n\n\nclass EfficientUpdateFormer(nn.Module):\n    \"\"\"\n    Transformer model that updates track estimates.\n    \"\"\"\n\n    def __init__(\n        self,\n        space_depth=6,\n        time_depth=6,\n        input_dim=320,\n        hidden_size=384,\n        num_heads=8,\n        output_dim=130,\n        mlp_ratio=4.0,\n        add_space_attn=True,\n        num_virtual_tracks=64,\n    ):\n        super().__init__()\n\n        self.out_channels = 2\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.add_space_attn = add_space_attn\n\n        # Add input LayerNorm before linear projection\n        self.input_norm = nn.LayerNorm(input_dim)\n        self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)\n\n        # Add output LayerNorm before final projection\n        self.output_norm = nn.LayerNorm(hidden_size)\n        self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)\n        self.num_virtual_tracks = num_virtual_tracks\n\n        if self.add_space_attn:\n            self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))\n        else:\n            self.virual_tracks = None\n\n        self.time_blocks = nn.ModuleList(\n            [\n                AttnBlock(\n                    hidden_size,\n                    num_heads,\n                    mlp_ratio=mlp_ratio,\n                    attn_class=nn.MultiheadAttention,\n                )\n                for _ in range(time_depth)\n            ]\n        )\n\n        if add_space_attn:\n            self.space_virtual_blocks = nn.ModuleList(\n                [\n                    AttnBlock(\n                        hidden_size,\n                        num_heads,\n                        mlp_ratio=mlp_ratio,\n                        attn_class=nn.MultiheadAttention,\n                    )\n                    for _ in range(space_depth)\n                ]\n            )\n            self.space_point2virtual_blocks = nn.ModuleList(\n                [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]\n            )\n            self.space_virtual2point_blocks = nn.ModuleList(\n                [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]\n            )\n            assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        def _basic_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n            torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)\n\n        self.apply(_basic_init)\n\n    def forward(self, input_tensor, mask=None):\n        # Apply input LayerNorm\n        input_tensor = self.input_norm(input_tensor)\n        tokens = self.input_transform(input_tensor)\n\n        init_tokens = tokens\n\n        B, _, T, _ = tokens.shape\n\n        if self.add_space_attn:\n            virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)\n            tokens = torch.cat([tokens, virtual_tokens], dim=1)\n\n        _, N, _, _ = tokens.shape\n\n        j = 0\n        for i in range(len(self.time_blocks)):\n            time_tokens = tokens.contiguous().view(B * N, T, -1)  # B N T C -> (B N) T C\n\n            time_tokens = self.time_blocks[i](time_tokens)\n\n            tokens = time_tokens.view(B, N, T, -1)  # (B N) T C -> B N T C\n            if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):\n                space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)  # B N T C -> (B T) N C\n                point_tokens = space_tokens[:, : N - self.num_virtual_tracks]\n                virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]\n\n                virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)\n                virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)\n                point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)\n\n                space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)\n                tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3)  # (B T) N C -> B N T C\n                j += 1\n\n        if self.add_space_attn:\n            tokens = tokens[:, : N - self.num_virtual_tracks]\n\n        tokens = tokens + init_tokens\n\n        # Apply output LayerNorm before final projection\n        tokens = self.output_norm(tokens)\n        flow = self.flow_head(tokens)\n\n        return flow, None\n\n\nclass CorrBlock:\n    def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode=\"zeros\"):\n        \"\"\"\n        Build a pyramid of feature maps from the input.\n\n        fmaps: Tensor (B, S, C, H, W)\n        num_levels: number of pyramid levels (each downsampled by factor 2)\n        radius: search radius for sampling correlation\n        multiple_track_feats: if True, split the target features per pyramid level\n        padding_mode: passed to grid_sample / bilinear_sampler\n        \"\"\"\n        B, S, C, H, W = fmaps.shape\n        self.S, self.C, self.H, self.W = S, C, H, W\n        self.num_levels = num_levels\n        self.radius = radius\n        self.padding_mode = padding_mode\n        self.multiple_track_feats = multiple_track_feats\n\n        # Build pyramid: each level is half the spatial resolution of the previous\n        self.fmaps_pyramid = [fmaps]  # level 0 is full resolution\n        current_fmaps = fmaps\n        for i in range(num_levels - 1):\n            B, S, C, H, W = current_fmaps.shape\n            # Merge batch & sequence dimensions\n            current_fmaps = current_fmaps.reshape(B * S, C, H, W)\n            # Avg pool down by factor 2\n            current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)\n            _, _, H_new, W_new = current_fmaps.shape\n            current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)\n            self.fmaps_pyramid.append(current_fmaps)\n\n        # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.\n        # This grid is added to the (scaled) coordinate centroids.\n        r = self.radius\n        dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)\n        dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)\n        # delta: for every (dy,dx) displacement (i.e. Δx, Δy)\n        self.delta = torch.stack(torch.meshgrid(dy, dx, indexing=\"ij\"), dim=-1)  # shape: (2r+1, 2r+1, 2)\n\n    def corr_sample(self, targets, coords):\n        \"\"\"\n        Instead of storing the entire correlation pyramid, we compute each level's correlation\n        volume, sample it immediately, then discard it. This saves GPU memory.\n\n        Args:\n          targets: Tensor (B, S, N, C) — features for the current targets.\n          coords: Tensor (B, S, N, 2) — coordinates at full resolution.\n\n        Returns:\n          Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)\n        \"\"\"\n        B, S, N, C = targets.shape\n\n        # If you have multiple track features, split them per level.\n        if self.multiple_track_feats:\n            targets_split = torch.split(targets, C // self.num_levels, dim=-1)\n\n        out_pyramid = []\n        for i, fmaps in enumerate(self.fmaps_pyramid):\n            # Get current spatial resolution H, W for this pyramid level.\n            B, S, C, H, W = fmaps.shape\n            # Reshape feature maps for correlation computation:\n            # fmap2s: (B, S, C, H*W)\n            fmap2s = fmaps.view(B, S, C, H * W)\n            # Choose appropriate target features.\n            fmap1 = targets_split[i] if self.multiple_track_feats else targets  # shape: (B, S, N, C)\n\n            # Compute correlation directly\n            corrs = compute_corr_level(fmap1, fmap2s, C)\n            corrs = corrs.view(B, S, N, H, W)\n\n            # Prepare sampling grid:\n            # Scale down the coordinates for the current level.\n            centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)\n            # Make sure our precomputed delta grid is on the same device/dtype.\n            delta_lvl = self.delta.to(coords.device).to(coords.dtype)\n            # Now the grid for grid_sample is:\n            # coords_lvl = centroid_lvl + delta_lvl   (broadcasted over grid)\n            coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)\n\n            # Sample from the correlation volume using bilinear interpolation.\n            # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.\n            corrs_sampled = bilinear_sampler(\n                corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode\n            )\n            # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.\n            corrs_sampled = corrs_sampled.view(B, S, N, -1)  # Now shape: (B, S, N, (2r+1)^2)\n            out_pyramid.append(corrs_sampled)\n\n        # Concatenate all levels along the last dimension.\n        out = torch.cat(out_pyramid, dim=-1).contiguous()\n        return out\n\n\ndef compute_corr_level(fmap1, fmap2s, C):\n    # fmap1: (B, S, N, C)\n    # fmap2s: (B, S, C, H*W)\n    corrs = torch.matmul(fmap1, fmap2s)  # (B, S, N, H*W)\n    corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1)  # (B, S, N, H*W)\n    return corrs / math.sqrt(C)\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_modules/modules.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom functools import partial\nfrom typing import Callable\nimport collections\nfrom torch import Tensor\nfrom itertools import repeat\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n\n    return parse\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    return val if exists(val) else d\n\n\nto_2tuple = _ntuple(2)\n\n\nclass ResidualBlock(nn.Module):\n    \"\"\"\n    ResidualBlock: construct a block of two conv layers with residual connections\n    \"\"\"\n\n    def __init__(self, in_planes, planes, norm_fn=\"group\", stride=1, kernel_size=3):\n        super(ResidualBlock, self).__init__()\n\n        self.conv1 = nn.Conv2d(\n            in_planes,\n            planes,\n            kernel_size=kernel_size,\n            padding=1,\n            stride=stride,\n            padding_mode=\"zeros\",\n        )\n        self.conv2 = nn.Conv2d(\n            planes,\n            planes,\n            kernel_size=kernel_size,\n            padding=1,\n            padding_mode=\"zeros\",\n        )\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == \"group\":\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n\n        elif norm_fn == \"batch\":\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n\n        elif norm_fn == \"instance\":\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == \"none\":\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n        else:\n            raise NotImplementedError\n\n        if stride == 1:\n            self.downsample = None\n        else:\n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),\n                self.norm3,\n            )\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x + y)\n\n\nclass Mlp(nn.Module):\n    \"\"\"MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        act_layer=nn.GELU,\n        norm_layer=None,\n        bias=True,\n        drop=0.0,\n        use_conv=False,\n    ):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        bias = to_2tuple(bias)\n        drop_probs = to_2tuple(drop)\n        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear\n\n        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])\n        self.act = act_layer()\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass AttnBlock(nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        num_heads,\n        attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,\n        mlp_ratio=4.0,\n        **block_kwargs\n    ):\n        \"\"\"\n        Self attention block\n        \"\"\"\n        super().__init__()\n\n        self.norm1 = nn.LayerNorm(hidden_size)\n        self.norm2 = nn.LayerNorm(hidden_size)\n\n        self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)\n\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n\n        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)\n\n    def forward(self, x, mask=None):\n        # Prepare the mask for PyTorch's attention (it expects a different format)\n        # attn_mask = mask if mask is not None else None\n        # Normalize before attention\n        x = self.norm1(x)\n\n        # PyTorch's MultiheadAttention returns attn_output, attn_output_weights\n        # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)\n\n        attn_output, _ = self.attn(x, x, x)\n\n        # Add & Norm\n        x = x + attn_output\n        x = x + self.mlp(self.norm2(x))\n        return x\n\n\nclass CrossAttnBlock(nn.Module):\n    def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):\n        \"\"\"\n        Cross attention block\n        \"\"\"\n        super().__init__()\n\n        self.norm1 = nn.LayerNorm(hidden_size)\n        self.norm_context = nn.LayerNorm(hidden_size)\n        self.norm2 = nn.LayerNorm(hidden_size)\n\n        self.cross_attn = nn.MultiheadAttention(\n            embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs\n        )\n\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n\n        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)\n\n    def forward(self, x, context, mask=None):\n        # Normalize inputs\n        x = self.norm1(x)\n        context = self.norm_context(context)\n\n        # Apply cross attention\n        # Note: nn.MultiheadAttention returns attn_output, attn_output_weights\n        attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)\n\n        # Add & Norm\n        x = x + attn_output\n        x = x + self.mlp(self.norm2(x))\n        return x\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/track_modules/utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Modified from https://github.com/facebookresearch/vggsfm\n# and https://github.com/facebookresearch/co-tracker/tree/main\n\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom typing import Optional, Tuple, Union\n\n\ndef get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:\n    \"\"\"\n    This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.\n    It is a wrapper of get_2d_sincos_pos_embed_from_grid.\n    Args:\n    - embed_dim: The embedding dimension.\n    - grid_size: The grid size.\n    Returns:\n    - pos_embed: The generated 2D positional embedding.\n    \"\"\"\n    if isinstance(grid_size, tuple):\n        grid_size_h, grid_size_w = grid_size\n    else:\n        grid_size_h = grid_size_w = grid_size\n    grid_h = torch.arange(grid_size_h, dtype=torch.float)\n    grid_w = torch.arange(grid_size_w, dtype=torch.float)\n    grid = torch.meshgrid(grid_w, grid_h, indexing=\"xy\")\n    grid = torch.stack(grid, dim=0)\n    grid = grid.reshape([2, 1, grid_size_h, grid_size_w])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if return_grid:\n        return (\n            pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),\n            grid,\n        )\n    return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    This function generates a 2D positional embedding from a given grid using sine and cosine functions.\n\n    Args:\n    - embed_dim: The embedding dimension.\n    - grid: The grid to generate the embedding from.\n\n    Returns:\n    - emb: The generated 2D positional embedding.\n    \"\"\"\n    assert embed_dim % 2 == 0\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = torch.cat([emb_h, emb_w], dim=2)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    This function generates a 1D positional embedding from a given grid using sine and cosine functions.\n\n    Args:\n    - embed_dim: The embedding dimension.\n    - pos: The position to generate the embedding from.\n\n    Returns:\n    - emb: The generated 1D positional embedding.\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = torch.arange(embed_dim // 2, dtype=torch.double)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = torch.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = torch.sin(out)  # (M, D/2)\n    emb_cos = torch.cos(out)  # (M, D/2)\n\n    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)\n    return emb[None].float()\n\n\ndef get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:\n    \"\"\"\n    This function generates a 2D positional embedding from given coordinates using sine and cosine functions.\n\n    Args:\n    - xy: The coordinates to generate the embedding from.\n    - C: The size of the embedding.\n    - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.\n\n    Returns:\n    - pe: The generated 2D positional embedding.\n    \"\"\"\n    B, N, D = xy.shape\n    assert D == 2\n\n    x = xy[:, :, 0:1]\n    y = xy[:, :, 1:2]\n    div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))\n\n    pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)\n    pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)\n\n    pe_x[:, :, 0::2] = torch.sin(x * div_term)\n    pe_x[:, :, 1::2] = torch.cos(x * div_term)\n\n    pe_y[:, :, 0::2] = torch.sin(y * div_term)\n    pe_y[:, :, 1::2] = torch.cos(y * div_term)\n\n    pe = torch.cat([pe_x, pe_y], dim=2)  # (B, N, C*3)\n    if cat_coords:\n        pe = torch.cat([xy, pe], dim=2)  # (B, N, C*3+3)\n    return pe\n\n\ndef bilinear_sampler(input, coords, align_corners=True, padding_mode=\"border\"):\n    r\"\"\"Sample a tensor using bilinear interpolation\n\n    `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at\n    coordinates :attr:`coords` using bilinear interpolation. It is the same\n    as `torch.nn.functional.grid_sample()` but with a different coordinate\n    convention.\n\n    The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where\n    :math:`B` is the batch size, :math:`C` is the number of channels,\n    :math:`H` is the height of the image, and :math:`W` is the width of the\n    image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is\n    interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.\n\n    Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,\n    in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note\n    that in this case the order of the components is slightly different\n    from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.\n\n    If `align_corners` is `True`, the coordinate :math:`x` is assumed to be\n    in the range :math:`[0,W-1]`, with 0 corresponding to the center of the\n    left-most image pixel :math:`W-1` to the center of the right-most\n    pixel.\n\n    If `align_corners` is `False`, the coordinate :math:`x` is assumed to\n    be in the range :math:`[0,W]`, with 0 corresponding to the left edge of\n    the left-most pixel :math:`W` to the right edge of the right-most\n    pixel.\n\n    Similar conventions apply to the :math:`y` for the range\n    :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range\n    :math:`[0,T-1]` and :math:`[0,T]`.\n\n    Args:\n        input (Tensor): batch of input images.\n        coords (Tensor): batch of coordinates.\n        align_corners (bool, optional): Coordinate convention. Defaults to `True`.\n        padding_mode (str, optional): Padding mode. Defaults to `\"border\"`.\n\n    Returns:\n        Tensor: sampled points.\n    \"\"\"\n    coords = coords.detach().clone()\n    ############################################################\n    # IMPORTANT:\n    coords = coords.to(input.device).to(input.dtype)\n    ############################################################\n\n    sizes = input.shape[2:]\n\n    assert len(sizes) in [2, 3]\n\n    if len(sizes) == 3:\n        # t x y -> x y t to match dimensions T H W in grid_sample\n        coords = coords[..., [1, 2, 0]]\n\n    if align_corners:\n        scale = torch.tensor(\n            [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype\n        )\n    else:\n        scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)\n\n    coords.mul_(scale)  # coords = coords * scale\n    coords.sub_(1)  # coords = coords - 1\n\n    return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)\n\n\ndef sample_features4d(input, coords):\n    r\"\"\"Sample spatial features\n\n    `sample_features4d(input, coords)` samples the spatial features\n    :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.\n\n    The field is sampled at coordinates :attr:`coords` using bilinear\n    interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,\n    2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the\n    same convention as :func:`bilinear_sampler` with `align_corners=True`.\n\n    The output tensor has one feature per point, and has shape :math:`(B,\n    R, C)`.\n\n    Args:\n        input (Tensor): spatial features.\n        coords (Tensor): points.\n\n    Returns:\n        Tensor: sampled features.\n    \"\"\"\n\n    B, _, _, _ = input.shape\n\n    # B R 2 -> B R 1 2\n    coords = coords.unsqueeze(2)\n\n    # B C R 1\n    feats = bilinear_sampler(input, coords)\n\n    return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3])  # B C R 1 -> B R C\n"
  },
  {
    "path": "mvtracker/models/core/vggt/heads/utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\n\n\ndef position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:\n    \"\"\"\n    Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)\n\n    Args:\n        pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates\n        embed_dim: Output channel dimension for embeddings\n\n    Returns:\n        Tensor of shape (H, W, embed_dim) with positional embeddings\n    \"\"\"\n    H, W, grid_dim = pos_grid.shape\n    assert grid_dim == 2\n    pos_flat = pos_grid.reshape(-1, grid_dim)  # Flatten to (H*W, 2)\n\n    # Process x and y coordinates separately\n    emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)  # [1, H*W, D/2]\n    emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)  # [1, H*W, D/2]\n\n    # Combine and reshape\n    emb = torch.cat([emb_x, emb_y], dim=-1)  # [1, H*W, D]\n\n    return emb.view(H, W, embed_dim)  # [H, W, D]\n\n\ndef make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:\n    \"\"\"\n    This function generates a 1D positional embedding from a given grid using sine and cosine functions.\n\n    Args:\n    - embed_dim: The embedding dimension.\n    - pos: The position to generate the embedding from.\n\n    Returns:\n    - emb: The generated 1D positional embedding.\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / omega_0**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = torch.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = torch.sin(out)  # (M, D/2)\n    emb_cos = torch.cos(out)  # (M, D/2)\n\n    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)\n    return emb.float()\n\n\n# Inspired by https://github.com/microsoft/moge\n\n\ndef create_uv_grid(\n    width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None\n) -> torch.Tensor:\n    \"\"\"\n    Create a normalized UV grid of shape (width, height, 2).\n\n    The grid spans horizontally and vertically according to an aspect ratio,\n    ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right\n    corner is at (x_span, y_span), normalized by the diagonal of the plane.\n\n    Args:\n        width (int): Number of points horizontally.\n        height (int): Number of points vertically.\n        aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.\n        dtype (torch.dtype, optional): Data type of the resulting tensor.\n        device (torch.device, optional): Device on which the tensor is created.\n\n    Returns:\n        torch.Tensor: A (width, height, 2) tensor of UV coordinates.\n    \"\"\"\n    # Derive aspect ratio if not explicitly provided\n    if aspect_ratio is None:\n        aspect_ratio = float(width) / float(height)\n\n    # Compute normalized spans for X and Y\n    diag_factor = (aspect_ratio**2 + 1.0) ** 0.5\n    span_x = aspect_ratio / diag_factor\n    span_y = 1.0 / diag_factor\n\n    # Establish the linspace boundaries\n    left_x = -span_x * (width - 1) / width\n    right_x = span_x * (width - 1) / width\n    top_y = -span_y * (height - 1) / height\n    bottom_y = span_y * (height - 1) / height\n\n    # Generate 1D coordinates\n    x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)\n    y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)\n\n    # Create 2D meshgrid (width x height) and stack into UV\n    uu, vv = torch.meshgrid(x_coords, y_coords, indexing=\"xy\")\n    uv_grid = torch.stack((uu, vv), dim=-1)\n\n    return uv_grid\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom .mlp import Mlp\nfrom .patch_embed import PatchEmbed\nfrom .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused\nfrom .block import NestedTensorBlock\nfrom .attention import MemEffAttention\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/attention.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nimport logging\nimport os\nimport warnings\n\nfrom torch import Tensor\nfrom torch import nn\nimport torch.nn.functional as F\n\nXFORMERS_AVAILABLE = False\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = True,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n        norm_layer: nn.Module = nn.LayerNorm,\n        qk_norm: bool = False,\n        fused_attn: bool = True,  # use F.scaled_dot_product_attention or not\n        rope=None,\n    ) -> None:\n        super().__init__()\n        assert dim % num_heads == 0, \"dim should be divisible by num_heads\"\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim**-0.5\n        self.fused_attn = fused_attn\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.rope = rope\n\n    def forward(self, x: Tensor, pos=None) -> Tensor:\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n\n        if self.rope is not None:\n            q = self.rope(q, pos)\n            k = self.rope(k, pos)\n\n        if self.fused_attn:\n            x = F.scaled_dot_product_attention(\n                q,\n                k,\n                v,\n                dropout_p=self.attn_drop.p if self.training else 0.0,\n            )\n        else:\n            q = q * self.scale\n            attn = q @ k.transpose(-2, -1)\n            attn = attn.softmax(dim=-1)\n            attn = self.attn_drop(attn)\n            x = attn @ v\n\n        x = x.transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:\n        assert pos is None\n        if not XFORMERS_AVAILABLE:\n            if attn_bias is not None:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/block.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nimport logging\nimport os\nfrom typing import Callable, List, Any, Tuple, Dict\nimport warnings\n\nimport torch\nfrom torch import nn, Tensor\n\nfrom .attention import Attention\nfrom .drop_path import DropPath\nfrom .layer_scale import LayerScale\nfrom .mlp import Mlp\n\n\nXFORMERS_AVAILABLE = False\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = True,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n        qk_norm: bool = False,\n        fused_attn: bool = True,  # use F.scaled_dot_product_attention or not\n        rope=None,\n    ) -> None:\n        super().__init__()\n\n        self.norm1 = norm_layer(dim)\n\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n            qk_norm=qk_norm,\n            fused_attn=fused_attn,\n            rope=rope,\n        )\n\n        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x: Tensor, pos=None) -> Tensor:\n        def attn_residual_func(x: Tensor, pos=None) -> Tensor:\n            return self.ls1(self.attn(self.norm1(x), pos=pos))\n\n        def ffn_residual_func(x: Tensor) -> Tensor:\n            return self.ls2(self.mlp(self.norm2(x)))\n\n        if self.training and self.sample_drop_ratio > 0.1:\n            # the overhead is compensated only for a drop path rate larger than 0.1\n            x = drop_add_residual_stochastic_depth(\n                x,\n                pos=pos,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n        elif self.training and self.sample_drop_ratio > 0.0:\n            x = x + self.drop_path1(attn_residual_func(x, pos=pos))\n            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2\n        else:\n            x = x + attn_residual_func(x, pos=pos)\n            x = x + ffn_residual_func(x)\n        return x\n\n\ndef drop_add_residual_stochastic_depth(\n    x: Tensor,\n    residual_func: Callable[[Tensor], Tensor],\n    sample_drop_ratio: float = 0.0,\n    pos=None,\n) -> Tensor:\n    # 1) extract subset using permutation\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    x_subset = x[brange]\n\n    # 2) apply residual_func to get residual\n    if pos is not None:\n        # if necessary, apply rope to the subset\n        pos = pos[brange]\n        residual = residual_func(x_subset, pos=pos)\n    else:\n        residual = residual_func(x_subset)\n\n    x_flat = x.flatten(1)\n    residual = residual.flatten(1)\n\n    residual_scale_factor = b / sample_subset_size\n\n    # 3) add the residual\n    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    return x_plus_residual.view_as(x)\n\n\ndef get_branges_scales(x, sample_drop_ratio=0.0):\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    residual_scale_factor = b / sample_subset_size\n    return brange, residual_scale_factor\n\n\ndef add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):\n    if scaling_vector is None:\n        x_flat = x.flatten(1)\n        residual = residual.flatten(1)\n        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    else:\n        x_plus_residual = scaled_index_add(\n            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor\n        )\n    return x_plus_residual\n\n\nattn_bias_cache: Dict[Tuple, Any] = {}\n\n\ndef get_attn_bias_and_cat(x_list, branges=None):\n    \"\"\"\n    this will perform the index select, cat the tensors, and provide the attn_bias from cache\n    \"\"\"\n    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]\n    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))\n    if all_shapes not in attn_bias_cache.keys():\n        seqlens = []\n        for b, x in zip(batch_sizes, x_list):\n            for _ in range(b):\n                seqlens.append(x.shape[1])\n        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)\n        attn_bias._batch_sizes = batch_sizes\n        attn_bias_cache[all_shapes] = attn_bias\n\n    if branges is not None:\n        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])\n    else:\n        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)\n        cat_tensors = torch.cat(tensors_bs1, dim=1)\n\n    return attn_bias_cache[all_shapes], cat_tensors\n\n\ndef drop_add_residual_stochastic_depth_list(\n    x_list: List[Tensor],\n    residual_func: Callable[[Tensor, Any], Tensor],\n    sample_drop_ratio: float = 0.0,\n    scaling_vector=None,\n) -> Tensor:\n    # 1) generate random set of indices for dropping samples in the batch\n    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]\n    branges = [s[0] for s in branges_scales]\n    residual_scale_factors = [s[1] for s in branges_scales]\n\n    # 2) get attention bias and index+concat the tensors\n    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)\n\n    # 3) apply residual_func to get residual, and split the result\n    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore\n\n    outputs = []\n    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):\n        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))\n    return outputs\n\n\nclass NestedTensorBlock(Block):\n    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:\n        \"\"\"\n        x_list contains a list of tensors to nest together and run\n        \"\"\"\n        assert isinstance(self.attn, MemEffAttention)\n\n        if self.training and self.sample_drop_ratio > 0.0:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.attn(self.norm1(x), attn_bias=attn_bias)\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.mlp(self.norm2(x))\n\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            return x_list\n        else:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls2(self.mlp(self.norm2(x)))\n\n            attn_bias, x = get_attn_bias_and_cat(x_list)\n            x = x + attn_residual_func(x, attn_bias=attn_bias)\n            x = x + ffn_residual_func(x)\n            return attn_bias.split(x)\n\n    def forward(self, x_or_x_list):\n        if isinstance(x_or_x_list, Tensor):\n            return super().forward(x_or_x_list)\n        elif isinstance(x_or_x_list, list):\n            if not XFORMERS_AVAILABLE:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return self.forward_nested(x_or_x_list)\n        else:\n            raise AssertionError\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/drop_path.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py\n\n\nfrom torch import nn\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0:\n        random_tensor.div_(keep_prob)\n    output = x * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/layer_scale.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110\n\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\n\n\nclass LayerScale(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        init_values: Union[float, Tensor] = 1e-5,\n        inplace: bool = False,\n    ) -> None:\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x: Tensor) -> Tensor:\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/mlp.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py\n\n\nfrom typing import Callable, Optional\n\nfrom torch import Tensor, nn\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/patch_embed.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom torch import Tensor\nimport torch.nn as nn\n\n\ndef make_2tuple(x):\n    if isinstance(x, tuple):\n        assert len(x) == 2\n        return x\n\n    assert isinstance(x, int)\n    return (x, x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    2D image to patch embedding: (B,C,H,W) -> (B,N,D)\n\n    Args:\n        img_size: Image size.\n        patch_size: Patch token size.\n        in_chans: Number of input image channels.\n        embed_dim: Number of linear projection output channels.\n        norm_layer: Normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Union[int, Tuple[int, int]] = 224,\n        patch_size: Union[int, Tuple[int, int]] = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        norm_layer: Optional[Callable] = None,\n        flatten_embedding: bool = True,\n    ) -> None:\n        super().__init__()\n\n        image_HW = make_2tuple(img_size)\n        patch_HW = make_2tuple(patch_size)\n        patch_grid_size = (\n            image_HW[0] // patch_HW[0],\n            image_HW[1] // patch_HW[1],\n        )\n\n        self.img_size = image_HW\n        self.patch_size = patch_HW\n        self.patches_resolution = patch_grid_size\n        self.num_patches = patch_grid_size[0] * patch_grid_size[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.flatten_embedding = flatten_embedding\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x: Tensor) -> Tensor:\n        _, _, H, W = x.shape\n        patch_H, patch_W = self.patch_size\n\n        assert H % patch_H == 0, f\"Input image height {H} is not a multiple of patch height {patch_H}\"\n        assert W % patch_W == 0, f\"Input image width {W} is not a multiple of patch width: {patch_W}\"\n\n        x = self.proj(x)  # B C H W\n        H, W = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)  # B HW C\n        x = self.norm(x)\n        if not self.flatten_embedding:\n            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C\n        return x\n\n    def flops(self) -> float:\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/rope.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n\n# Implementation of 2D Rotary Position Embeddings (RoPE).\n\n# This module provides a clean implementation of 2D Rotary Position Embeddings,\n# which extends the original RoPE concept to handle 2D spatial positions.\n\n# Inspired by:\n#         https://github.com/meta-llama/codellama/blob/main/llama/model.py\n#         https://github.com/naver-ai/rope-vit\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Dict, Tuple\n\n\nclass PositionGetter:\n    \"\"\"Generates and caches 2D spatial positions for patches in a grid.\n\n    This class efficiently manages the generation of spatial coordinates for patches\n    in a 2D grid, caching results to avoid redundant computations.\n\n    Attributes:\n        position_cache: Dictionary storing precomputed position tensors for different\n            grid dimensions.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Initializes the position generator with an empty cache.\"\"\"\n        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}\n\n    def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:\n        \"\"\"Generates spatial positions for a batch of patches.\n\n        Args:\n            batch_size: Number of samples in the batch.\n            height: Height of the grid in patches.\n            width: Width of the grid in patches.\n            device: Target device for the position tensor.\n\n        Returns:\n            Tensor of shape (batch_size, height*width, 2) containing y,x coordinates\n            for each position in the grid, repeated for each batch item.\n        \"\"\"\n        if (height, width) not in self.position_cache:\n            y_coords = torch.arange(height, device=device)\n            x_coords = torch.arange(width, device=device)\n            positions = torch.cartesian_prod(y_coords, x_coords)\n            self.position_cache[height, width] = positions\n\n        cached_positions = self.position_cache[height, width]\n        return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()\n\n\nclass RotaryPositionEmbedding2D(nn.Module):\n    \"\"\"2D Rotary Position Embedding implementation.\n\n    This module applies rotary position embeddings to input tokens based on their\n    2D spatial positions. It handles the position-dependent rotation of features\n    separately for vertical and horizontal dimensions.\n\n    Args:\n        frequency: Base frequency for the position embeddings. Default: 100.0\n        scaling_factor: Scaling factor for frequency computation. Default: 1.0\n\n    Attributes:\n        base_frequency: Base frequency for computing position embeddings.\n        scaling_factor: Factor to scale the computed frequencies.\n        frequency_cache: Cache for storing precomputed frequency components.\n    \"\"\"\n\n    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):\n        \"\"\"Initializes the 2D RoPE module.\"\"\"\n        super().__init__()\n        self.base_frequency = frequency\n        self.scaling_factor = scaling_factor\n        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}\n\n    def _compute_frequency_components(\n        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Computes frequency components for rotary embeddings.\n\n        Args:\n            dim: Feature dimension (must be even).\n            seq_len: Maximum sequence length.\n            device: Target device for computations.\n            dtype: Data type for the computed tensors.\n\n        Returns:\n            Tuple of (cosine, sine) tensors for frequency components.\n        \"\"\"\n        cache_key = (dim, seq_len, device, dtype)\n        if cache_key not in self.frequency_cache:\n            # Compute frequency bands\n            exponents = torch.arange(0, dim, 2, device=device).float() / dim\n            inv_freq = 1.0 / (self.base_frequency**exponents)\n\n            # Generate position-dependent frequencies\n            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)\n            angles = torch.einsum(\"i,j->ij\", positions, inv_freq)\n\n            # Compute and cache frequency components\n            angles = angles.to(dtype)\n            angles = torch.cat((angles, angles), dim=-1)\n            cos_components = angles.cos().to(dtype)\n            sin_components = angles.sin().to(dtype)\n            self.frequency_cache[cache_key] = (cos_components, sin_components)\n\n        return self.frequency_cache[cache_key]\n\n    @staticmethod\n    def _rotate_features(x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Performs feature rotation by splitting and recombining feature dimensions.\n\n        Args:\n            x: Input tensor to rotate.\n\n        Returns:\n            Rotated feature tensor.\n        \"\"\"\n        feature_dim = x.shape[-1]\n        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]\n        return torch.cat((-x2, x1), dim=-1)\n\n    def _apply_1d_rope(\n        self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"Applies 1D rotary position embeddings along one dimension.\n\n        Args:\n            tokens: Input token features.\n            positions: Position indices.\n            cos_comp: Cosine components for rotation.\n            sin_comp: Sine components for rotation.\n\n        Returns:\n            Tokens with applied rotary position embeddings.\n        \"\"\"\n        # Embed positions with frequency components\n        cos = F.embedding(positions, cos_comp)[:, None, :, :]\n        sin = F.embedding(positions, sin_comp)[:, None, :, :]\n\n        # Apply rotation\n        return (tokens * cos) + (self._rotate_features(tokens) * sin)\n\n    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:\n        \"\"\"Applies 2D rotary position embeddings to input tokens.\n\n        Args:\n            tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).\n                   The feature dimension (dim) must be divisible by 4.\n            positions: Position tensor of shape (batch_size, n_tokens, 2) containing\n                      the y and x coordinates for each token.\n\n        Returns:\n            Tensor of same shape as input with applied 2D rotary position embeddings.\n\n        Raises:\n            AssertionError: If input dimensions are invalid or positions are malformed.\n        \"\"\"\n        # Validate inputs\n        assert tokens.size(-1) % 2 == 0, \"Feature dimension must be even\"\n        assert positions.ndim == 3 and positions.shape[-1] == 2, \"Positions must have shape (batch_size, n_tokens, 2)\"\n\n        # Compute feature dimension for each spatial direction\n        feature_dim = tokens.size(-1) // 2\n\n        # Get frequency components\n        max_position = int(positions.max()) + 1\n        cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)\n\n        # Split features for vertical and horizontal processing\n        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)\n\n        # Apply RoPE separately for each dimension\n        vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)\n        horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)\n\n        # Combine processed features\n        return torch.cat((vertical_features, horizontal_features), dim=-1)\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/swiglu_ffn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\nimport os\nfrom typing import Callable, Optional\nimport warnings\n\nfrom torch import Tensor, nn\nimport torch.nn.functional as F\n\n\nclass SwiGLUFFN(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)\n        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x12 = self.w12(x)\n        x1, x2 = x12.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n        return self.w3(hidden)\n\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\n# try:\n#     if XFORMERS_ENABLED:\n#         from xformers.ops import SwiGLU\n\n#         XFORMERS_AVAILABLE = True\n#         warnings.warn(\"xFormers is available (SwiGLU)\")\n#     else:\n#         warnings.warn(\"xFormers is disabled (SwiGLU)\")\n#         raise ImportError\n# except ImportError:\nSwiGLU = SwiGLUFFN\nXFORMERS_AVAILABLE = False\n\n# warnings.warn(\"xFormers is not available (SwiGLU)\")\n\n\nclass SwiGLUFFNFused(SwiGLU):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8\n        super().__init__(\n            in_features=in_features,\n            hidden_features=hidden_features,\n            out_features=out_features,\n            bias=bias,\n        )\n"
  },
  {
    "path": "mvtracker/models/core/vggt/layers/vision_transformer.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nfrom functools import partial\nimport math\nimport logging\nfrom typing import Sequence, Tuple, Union, Callable\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.checkpoint import checkpoint\nfrom torch.nn.init import trunc_normal_\nfrom . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ndef named_apply(fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass BlockChunk(nn.ModuleList):\n    def forward(self, x):\n        for b in self:\n            x = b(x)\n        return x\n\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        ffn_bias=True,\n        proj_bias=True,\n        drop_path_rate=0.0,\n        drop_path_uniform=False,\n        init_values=None,  # for layerscale: None or 0 => no layerscale\n        embed_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        block_fn=Block,\n        ffn_layer=\"mlp\",\n        block_chunks=1,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1,\n        qk_norm=False,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            proj_bias (bool): enable bias for proj in attn if True\n            ffn_bias (bool): enable bias for ffn if True\n            drop_path_rate (float): stochastic depth rate\n            drop_path_uniform (bool): apply uniform drop rate across blocks\n            weight_init (str): weight init scheme\n            init_values (float): layer-scale init values\n            embed_layer (nn.Module): patch embedding layer\n            act_layer (nn.Module): MLP activation layer\n            block_fn (nn.Module): transformer block class\n            ffn_layer (str): \"mlp\", \"swiglu\", \"swiglufused\" or \"identity\"\n            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap\n            num_register_tokens: (int) number of extra cls tokens (so-called \"registers\")\n            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings\n            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings\n        \"\"\"\n        super().__init__()\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n\n        # tricky but makes it work\n        self.use_checkpoint = False\n        #\n\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 1\n        self.n_blocks = depth\n        self.num_heads = num_heads\n        self.patch_size = patch_size\n        self.num_register_tokens = num_register_tokens\n        self.interpolate_antialias = interpolate_antialias\n        self.interpolate_offset = interpolate_offset\n\n        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))\n        assert num_register_tokens >= 0\n        self.register_tokens = (\n            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None\n        )\n\n        if drop_path_uniform is True:\n            dpr = [drop_path_rate] * depth\n        else:\n            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n\n        if ffn_layer == \"mlp\":\n            logger.info(\"using MLP layer as FFN\")\n            ffn_layer = Mlp\n        elif ffn_layer == \"swiglufused\" or ffn_layer == \"swiglu\":\n            logger.info(\"using SwiGLU layer as FFN\")\n            ffn_layer = SwiGLUFFNFused\n        elif ffn_layer == \"identity\":\n            logger.info(\"using Identity layer as FFN\")\n\n            def f(*args, **kwargs):\n                return nn.Identity()\n\n            ffn_layer = f\n        else:\n            raise NotImplementedError\n\n        blocks_list = [\n            block_fn(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                proj_bias=proj_bias,\n                ffn_bias=ffn_bias,\n                drop_path=dpr[i],\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                ffn_layer=ffn_layer,\n                init_values=init_values,\n                qk_norm=qk_norm,\n            )\n            for i in range(depth)\n        ]\n        if block_chunks > 0:\n            self.chunked_blocks = True\n            chunked_blocks = []\n            chunksize = depth // block_chunks\n            for i in range(0, depth, chunksize):\n                # this is to keep the block index consistent if we chunk the block list\n                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])\n            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])\n        else:\n            self.chunked_blocks = False\n            self.blocks = nn.ModuleList(blocks_list)\n\n        self.norm = norm_layer(embed_dim)\n        self.head = nn.Identity()\n\n        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))\n\n        self.init_weights()\n\n    def init_weights(self):\n        trunc_normal_(self.pos_embed, std=0.02)\n        nn.init.normal_(self.cls_token, std=1e-6)\n        if self.register_tokens is not None:\n            nn.init.normal_(self.register_tokens, std=1e-6)\n        named_apply(init_weights_vit_timm, self)\n\n    def interpolate_pos_encoding(self, x, w, h):\n        previous_dtype = x.dtype\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        pos_embed = self.pos_embed.float()\n        class_pos_embed = pos_embed[:, 0]\n        patch_pos_embed = pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        M = int(math.sqrt(N))  # Recover the number of patches in each dimension\n        assert N == M * M\n        kwargs = {}\n        if self.interpolate_offset:\n            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8\n            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors\n            sx = float(w0 + self.interpolate_offset) / M\n            sy = float(h0 + self.interpolate_offset) / M\n            kwargs[\"scale_factor\"] = (sx, sy)\n        else:\n            # Simply specify an output size instead of a scale factor\n            kwargs[\"size\"] = (w0, h0)\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),\n            mode=\"bicubic\",\n            antialias=self.interpolate_antialias,\n            **kwargs,\n        )\n        assert (w0, h0) == patch_pos_embed.shape[-2:]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)\n\n    def prepare_tokens_with_masks(self, x, masks=None):\n        B, nc, w, h = x.shape\n        x = self.patch_embed(x)\n        if masks is not None:\n            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)\n\n        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        if self.register_tokens is not None:\n            x = torch.cat(\n                (\n                    x[:, :1],\n                    self.register_tokens.expand(x.shape[0], -1, -1),\n                    x[:, 1:],\n                ),\n                dim=1,\n            )\n\n        return x\n\n    def forward_features_list(self, x_list, masks_list):\n        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]\n\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint(blk, x, use_reentrant=self.use_reentrant)\n            else:\n                x = blk(x)\n\n        all_x = x\n        output = []\n        for x, masks in zip(all_x, masks_list):\n            x_norm = self.norm(x)\n            output.append(\n                {\n                    \"x_norm_clstoken\": x_norm[:, 0],\n                    \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n                    \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n                    \"x_prenorm\": x,\n                    \"masks\": masks,\n                }\n            )\n        return output\n\n    def forward_features(self, x, masks=None):\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n\n        for blk in self.blocks:\n            if self.use_checkpoint:\n                x = checkpoint(blk, x, use_reentrant=self.use_reentrant)\n            else:\n                x = blk(x)\n\n        x_norm = self.norm(x)\n        return {\n            \"x_norm_clstoken\": x_norm[:, 0],\n            \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n            \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n            \"x_prenorm\": x,\n            \"masks\": masks,\n        }\n\n    def _get_intermediate_layers_not_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        # If n is an int, take the n last blocks. If it's a list, take them\n        output, total_block_len = [], len(self.blocks)\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in blocks_to_take:\n                output.append(x)\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def _get_intermediate_layers_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        output, i, total_block_len = [], 0, len(self.blocks[-1])\n        # If n is an int, take the n last blocks. If it's a list, take them\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for block_chunk in self.blocks:\n            for blk in block_chunk[i:]:  # Passing the nn.Identity()\n                x = blk(x)\n                if i in blocks_to_take:\n                    output.append(x)\n                i += 1\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def get_intermediate_layers(\n        self,\n        x: torch.Tensor,\n        n: Union[int, Sequence] = 1,  # Layers or n last layers to take\n        reshape: bool = False,\n        return_class_token: bool = False,\n        norm=True,\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:\n        if self.chunked_blocks:\n            outputs = self._get_intermediate_layers_chunked(x, n)\n        else:\n            outputs = self._get_intermediate_layers_not_chunked(x, n)\n        if norm:\n            outputs = [self.norm(out) for out in outputs]\n        class_tokens = [out[:, 0] for out in outputs]\n        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]\n        if reshape:\n            B, _, w, h = x.shape\n            outputs = [\n                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()\n                for out in outputs\n            ]\n        if return_class_token:\n            return tuple(zip(outputs, class_tokens))\n        return tuple(outputs)\n\n    def forward(self, *args, is_training=True, **kwargs):\n        ret = self.forward_features(*args, **kwargs)\n        if is_training:\n            return ret\n        else:\n            return self.head(ret[\"x_norm_clstoken\"])\n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = \"\"):\n    \"\"\"ViT weight initialization, original timm impl (for reproducibility)\"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=0.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n\n\ndef vit_small(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_base(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_large(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):\n    \"\"\"\n    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64\n    \"\"\"\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1536,\n        depth=40,\n        num_heads=24,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n"
  },
  {
    "path": "mvtracker/models/core/vggt/models/aggregator.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport logging\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Optional, Tuple, Union, List, Dict, Any\n\nfrom ..layers import PatchEmbed\nfrom ..layers.block import Block\nfrom ..layers.rope import RotaryPositionEmbedding2D, PositionGetter\nfrom ..layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2\n\nlogger = logging.getLogger(__name__)\n\n_RESNET_MEAN = [0.485, 0.456, 0.406]\n_RESNET_STD = [0.229, 0.224, 0.225]\n\n\nclass Aggregator(nn.Module):\n    \"\"\"\n    The Aggregator applies alternating-attention over input frames,\n    as described in VGGT: Visual Geometry Grounded Transformer.\n\n\n    Args:\n        img_size (int): Image size in pixels.\n        patch_size (int): Size of each patch for PatchEmbed.\n        embed_dim (int): Dimension of the token embeddings.\n        depth (int): Number of blocks.\n        num_heads (int): Number of attention heads.\n        mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.\n        num_register_tokens (int): Number of register tokens.\n        block_fn (nn.Module): The block type used for attention (Block by default).\n        qkv_bias (bool): Whether to include bias in QKV projections.\n        proj_bias (bool): Whether to include bias in the output projection.\n        ffn_bias (bool): Whether to include bias in MLP layers.\n        patch_embed (str): Type of patch embed. e.g., \"conv\" or \"dinov2_vitl14_reg\".\n        aa_order (list[str]): The order of alternating attention, e.g. [\"frame\", \"global\"].\n        aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.\n        qk_norm (bool): Whether to apply QK normalization.\n        rope_freq (int): Base frequency for rotary embedding. -1 to disable.\n        init_values (float): Init scale for layer scale.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=518,\n        patch_size=14,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4.0,\n        num_register_tokens=4,\n        block_fn=Block,\n        qkv_bias=True,\n        proj_bias=True,\n        ffn_bias=True,\n        patch_embed=\"dinov2_vitl14_reg\",\n        aa_order=[\"frame\", \"global\"],\n        aa_block_size=1,\n        qk_norm=True,\n        rope_freq=100,\n        init_values=0.01,\n    ):\n        super().__init__()\n\n        self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)\n\n        # Initialize rotary position embedding if frequency > 0\n        self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None\n        self.position_getter = PositionGetter() if self.rope is not None else None\n\n        self.frame_blocks = nn.ModuleList(\n            [\n                block_fn(\n                    dim=embed_dim,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    proj_bias=proj_bias,\n                    ffn_bias=ffn_bias,\n                    init_values=init_values,\n                    qk_norm=qk_norm,\n                    rope=self.rope,\n                )\n                for _ in range(depth)\n            ]\n        )\n\n        self.global_blocks = nn.ModuleList(\n            [\n                block_fn(\n                    dim=embed_dim,\n                    num_heads=num_heads,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    proj_bias=proj_bias,\n                    ffn_bias=ffn_bias,\n                    init_values=init_values,\n                    qk_norm=qk_norm,\n                    rope=self.rope,\n                )\n                for _ in range(depth)\n            ]\n        )\n\n        self.depth = depth\n        self.aa_order = aa_order\n        self.patch_size = patch_size\n        self.aa_block_size = aa_block_size\n\n        # Validate that depth is divisible by aa_block_size\n        if self.depth % self.aa_block_size != 0:\n            raise ValueError(f\"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})\")\n\n        self.aa_block_num = self.depth // self.aa_block_size\n\n        # Note: We have two camera tokens, one for the first frame and one for the rest\n        # The same applies for register tokens\n        self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))\n        self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))\n\n        # The patch tokens start after the camera and register tokens\n        self.patch_start_idx = 1 + num_register_tokens\n\n        # Initialize parameters with small values\n        nn.init.normal_(self.camera_token, std=1e-6)\n        nn.init.normal_(self.register_token, std=1e-6)\n\n        # Register normalization constants as buffers\n        for name, value in (\n            (\"_resnet_mean\", _RESNET_MEAN),\n            (\"_resnet_std\", _RESNET_STD),\n        ):\n            self.register_buffer(\n                name,\n                torch.FloatTensor(value).view(1, 1, 3, 1, 1),\n                persistent=False,\n            )\n\n    def __build_patch_embed__(\n        self,\n        patch_embed,\n        img_size,\n        patch_size,\n        num_register_tokens,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        block_chunks=0,\n        init_values=1.0,\n        embed_dim=1024,\n    ):\n        \"\"\"\n        Build the patch embed layer. If 'conv', we use a\n        simple PatchEmbed conv layer. Otherwise, we use a vision transformer.\n        \"\"\"\n\n        if \"conv\" in patch_embed:\n            self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)\n        else:\n            vit_models = {\n                \"dinov2_vitl14_reg\": vit_large,\n                \"dinov2_vitb14_reg\": vit_base,\n                \"dinov2_vits14_reg\": vit_small,\n                \"dinov2_vitg2_reg\": vit_giant2,\n            }\n\n            self.patch_embed = vit_models[patch_embed](\n                img_size=img_size,\n                patch_size=patch_size,\n                num_register_tokens=num_register_tokens,\n                interpolate_antialias=interpolate_antialias,\n                interpolate_offset=interpolate_offset,\n                block_chunks=block_chunks,\n                init_values=init_values,\n            )\n\n            # Disable gradient updates for mask token\n            if hasattr(self.patch_embed, \"mask_token\"):\n                self.patch_embed.mask_token.requires_grad_(False)\n\n    def forward(\n        self,\n        images: torch.Tensor,\n    ) -> Tuple[List[torch.Tensor], int]:\n        \"\"\"\n        Args:\n            images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].\n                B: batch size, S: sequence length, 3: RGB channels, H: height, W: width\n\n        Returns:\n            (list[torch.Tensor], int):\n                The list of outputs from the attention blocks,\n                and the patch_start_idx indicating where patch tokens begin.\n        \"\"\"\n        B, S, C_in, H, W = images.shape\n\n        if C_in != 3:\n            raise ValueError(f\"Expected 3 input channels, got {C_in}\")\n\n        # Normalize images and reshape for patch embed\n        images = (images - self._resnet_mean) / self._resnet_std\n\n        # Reshape to [B*S, C, H, W] for patch embedding\n        images = images.view(B * S, C_in, H, W)\n        patch_tokens = self.patch_embed(images)\n\n        if isinstance(patch_tokens, dict):\n            patch_tokens = patch_tokens[\"x_norm_patchtokens\"]\n\n        _, P, C = patch_tokens.shape\n\n        # Expand camera and register tokens to match batch size and sequence length\n        camera_token = slice_expand_and_flatten(self.camera_token, B, S)\n        register_token = slice_expand_and_flatten(self.register_token, B, S)\n\n        # Concatenate special tokens with patch tokens\n        tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)\n\n        pos = None\n        if self.rope is not None:\n            pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)\n\n        if self.patch_start_idx > 0:\n            # do not use position embedding for special tokens (camera and register tokens)\n            # so set pos to 0 for the special tokens\n            pos = pos + 1\n            pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)\n            pos = torch.cat([pos_special, pos], dim=1)\n\n        # update P because we added special tokens\n        _, P, C = tokens.shape\n\n        frame_idx = 0\n        global_idx = 0\n        output_list = []\n\n        for _ in range(self.aa_block_num):\n            for attn_type in self.aa_order:\n                if attn_type == \"frame\":\n                    tokens, frame_idx, frame_intermediates = self._process_frame_attention(\n                        tokens, B, S, P, C, frame_idx, pos=pos\n                    )\n                elif attn_type == \"global\":\n                    tokens, global_idx, global_intermediates = self._process_global_attention(\n                        tokens, B, S, P, C, global_idx, pos=pos\n                    )\n                else:\n                    raise ValueError(f\"Unknown attention type: {attn_type}\")\n\n            for i in range(len(frame_intermediates)):\n                # concat frame and global intermediates, [B x S x P x 2C]\n                concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)\n                output_list.append(concat_inter)\n\n        del concat_inter\n        del frame_intermediates\n        del global_intermediates\n        return output_list, self.patch_start_idx\n\n    def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):\n        \"\"\"\n        Process frame attention blocks. We keep tokens in shape (B*S, P, C).\n        \"\"\"\n        # If needed, reshape tokens or positions:\n        if tokens.shape != (B * S, P, C):\n            tokens = tokens.view(B, S, P, C).view(B * S, P, C)\n\n        if pos is not None and pos.shape != (B * S, P, 2):\n            pos = pos.view(B, S, P, 2).view(B * S, P, 2)\n\n        intermediates = []\n\n        # by default, self.aa_block_size=1, which processes one block at a time\n        for _ in range(self.aa_block_size):\n            tokens = self.frame_blocks[frame_idx](tokens, pos=pos)\n            frame_idx += 1\n            intermediates.append(tokens.view(B, S, P, C))\n\n        return tokens, frame_idx, intermediates\n\n    def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):\n        \"\"\"\n        Process global attention blocks. We keep tokens in shape (B, S*P, C).\n        \"\"\"\n        if tokens.shape != (B, S * P, C):\n            tokens = tokens.view(B, S, P, C).view(B, S * P, C)\n\n        if pos is not None and pos.shape != (B, S * P, 2):\n            pos = pos.view(B, S, P, 2).view(B, S * P, 2)\n\n        intermediates = []\n\n        # by default, self.aa_block_size=1, which processes one block at a time\n        for _ in range(self.aa_block_size):\n            tokens = self.global_blocks[global_idx](tokens, pos=pos)\n            global_idx += 1\n            intermediates.append(tokens.view(B, S, P, C))\n\n        return tokens, global_idx, intermediates\n\n\ndef slice_expand_and_flatten(token_tensor, B, S):\n    \"\"\"\n    Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:\n    1) Uses the first position (index=0) for the first frame only\n    2) Uses the second position (index=1) for all remaining frames (S-1 frames)\n    3) Expands both to match batch size B\n    4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token\n       followed by (S-1) second-position tokens\n    5) Flattens to (B*S, X, C) for processing\n\n    Returns:\n        torch.Tensor: Processed tokens with shape (B*S, X, C)\n    \"\"\"\n\n    # Slice out the \"query\" tokens => shape (1, 1, ...)\n    query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])\n    # Slice out the \"other\" tokens => shape (1, S-1, ...)\n    others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])\n    # Concatenate => shape (B, S, ...)\n    combined = torch.cat([query, others], dim=1)\n\n    # Finally flatten => shape (B*S, ...)\n    combined = combined.view(B * S, *combined.shape[2:])\n    return combined\n"
  },
  {
    "path": "mvtracker/models/core/vggt/models/vggt.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nfrom huggingface_hub import PyTorchModelHubMixin  # used for model hub\n\nfrom ..models.aggregator import Aggregator\nfrom ..heads.camera_head import CameraHead\nfrom ..heads.dpt_head import DPTHead\nfrom ..heads.track_head import TrackHead\n\n\nclass VGGT(nn.Module, PyTorchModelHubMixin):\n    def __init__(self, img_size=518, patch_size=14, embed_dim=1024):\n        super().__init__()\n\n        self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)\n        self.camera_head = CameraHead(dim_in=2 * embed_dim)\n        self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation=\"inv_log\", conf_activation=\"expp1\")\n        self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation=\"exp\", conf_activation=\"expp1\")\n        self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)\n\n    def forward(\n        self,\n        images: torch.Tensor,\n        query_points: torch.Tensor = None,\n    ):\n        \"\"\"\n        Forward pass of the VGGT model.\n\n        Args:\n            images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].\n                B: batch size, S: sequence length, 3: RGB channels, H: height, W: width\n            query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.\n                Shape: [N, 2] or [B, N, 2], where N is the number of query points.\n                Default: None\n\n        Returns:\n            dict: A dictionary containing the following predictions:\n                - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)\n                - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]\n                - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]\n                - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]\n                - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]\n                - images (torch.Tensor): Original input images, preserved for visualization\n\n                If query_points is provided, also includes:\n                - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates\n                - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]\n                - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]\n        \"\"\"\n\n        # If without batch dimension, add it\n        if len(images.shape) == 4:\n            images = images.unsqueeze(0)\n        if query_points is not None and len(query_points.shape) == 2:\n            query_points = query_points.unsqueeze(0)\n\n        aggregated_tokens_list, patch_start_idx = self.aggregator(images)\n\n        predictions = {}\n\n        with torch.cuda.amp.autocast(enabled=False):\n            if self.camera_head is not None:\n                pose_enc_list = self.camera_head(aggregated_tokens_list)\n                predictions[\"pose_enc\"] = pose_enc_list[-1]  # pose encoding of the last iteration\n\n            if self.depth_head is not None:\n                depth, depth_conf = self.depth_head(\n                    aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx\n                )\n                predictions[\"depth\"] = depth\n                predictions[\"depth_conf\"] = depth_conf\n\n            if self.point_head is not None:\n                pts3d, pts3d_conf = self.point_head(\n                    aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx\n                )\n                predictions[\"world_points\"] = pts3d\n                predictions[\"world_points_conf\"] = pts3d_conf\n\n        if self.track_head is not None and query_points is not None:\n            track_list, vis, conf = self.track_head(\n                aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points\n            )\n            predictions[\"track\"] = track_list[-1]  # track of the last iteration\n            predictions[\"vis\"] = vis\n            predictions[\"conf\"] = conf\n\n        predictions[\"images\"] = images\n\n        return predictions\n"
  },
  {
    "path": "mvtracker/models/core/vggt/utils/geometry.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport os\nimport torch\nimport numpy as np\n\n\ndef unproject_depth_map_to_point_map(\n    depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray\n) -> np.ndarray:\n    \"\"\"\n    Unproject a batch of depth maps to 3D world coordinates.\n\n    Args:\n        depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)\n        extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)\n        intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)\n\n    Returns:\n        np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)\n    \"\"\"\n    if isinstance(depth_map, torch.Tensor):\n        depth_map = depth_map.cpu().numpy()\n    if isinstance(extrinsics_cam, torch.Tensor):\n        extrinsics_cam = extrinsics_cam.cpu().numpy()\n    if isinstance(intrinsics_cam, torch.Tensor):\n        intrinsics_cam = intrinsics_cam.cpu().numpy()\n\n    world_points_list = []\n    for frame_idx in range(depth_map.shape[0]):\n        cur_world_points, _, _ = depth_to_world_coords_points(\n            depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]\n        )\n        world_points_list.append(cur_world_points)\n    world_points_array = np.stack(world_points_list, axis=0)\n\n    return world_points_array\n\n\ndef depth_to_world_coords_points(\n    depth_map: np.ndarray,\n    extrinsic: np.ndarray,\n    intrinsic: np.ndarray,\n    eps=1e-8,\n) -> tuple[np.ndarray, np.ndarray, np.ndarray]:\n    \"\"\"\n    Convert a depth map to world coordinates.\n\n    Args:\n        depth_map (np.ndarray): Depth map of shape (H, W).\n        intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).\n        extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.\n\n    Returns:\n        tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).\n    \"\"\"\n    if depth_map is None:\n        return None, None, None\n\n    # Valid depth mask\n    point_mask = depth_map > eps\n\n    # Convert depth map to camera coordinates\n    cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)\n\n    # Multiply with the inverse of extrinsic matrix to transform to world coordinates\n    # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))\n    cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]\n\n    R_cam_to_world = cam_to_world_extrinsic[:3, :3]\n    t_cam_to_world = cam_to_world_extrinsic[:3, 3]\n\n    # Apply the rotation and translation to the camera coordinates\n    world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world  # HxWx3, 3x3 -> HxWx3\n    # world_coords_points = np.einsum(\"ij,hwj->hwi\", R_cam_to_world, cam_coords_points) + t_cam_to_world\n\n    return world_coords_points, cam_coords_points, point_mask\n\n\ndef depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:\n    \"\"\"\n    Convert a depth map to camera coordinates.\n\n    Args:\n        depth_map (np.ndarray): Depth map of shape (H, W).\n        intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).\n\n    Returns:\n        tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)\n    \"\"\"\n    H, W = depth_map.shape\n    assert intrinsic.shape == (3, 3), \"Intrinsic matrix must be 3x3\"\n    assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, \"Intrinsic matrix must have zero skew\"\n\n    # Intrinsic parameters\n    fu, fv = intrinsic[0, 0], intrinsic[1, 1]\n    cu, cv = intrinsic[0, 2], intrinsic[1, 2]\n\n    # Generate grid of pixel coordinates\n    u, v = np.meshgrid(np.arange(W), np.arange(H))\n\n    # Unproject to camera coordinates\n    x_cam = (u - cu) * depth_map / fu\n    y_cam = (v - cv) * depth_map / fv\n    z_cam = depth_map\n\n    # Stack to form camera coordinates\n    cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)\n\n    return cam_coords\n\n\ndef closed_form_inverse_se3(se3, R=None, T=None):\n    \"\"\"\n    Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.\n\n    If `R` and `T` are provided, they must correspond to the rotation and translation\n    components of `se3`. Otherwise, they will be extracted from `se3`.\n\n    Args:\n        se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.\n        R (optional): Nx3x3 array or tensor of rotation matrices.\n        T (optional): Nx3x1 array or tensor of translation vectors.\n\n    Returns:\n        Inverted SE3 matrices with the same type and device as `se3`.\n\n    Shapes:\n        se3: (N, 4, 4)\n        R: (N, 3, 3)\n        T: (N, 3, 1)\n    \"\"\"\n    # Check if se3 is a numpy array or a torch tensor\n    is_numpy = isinstance(se3, np.ndarray)\n\n    # Validate shapes\n    if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):\n        raise ValueError(f\"se3 must be of shape (N,4,4), got {se3.shape}.\")\n\n    # Extract R and T if not provided\n    if R is None:\n        R = se3[:, :3, :3]  # (N,3,3)\n    if T is None:\n        T = se3[:, :3, 3:]  # (N,3,1)\n\n    # Transpose R\n    if is_numpy:\n        # Compute the transpose of the rotation for NumPy\n        R_transposed = np.transpose(R, (0, 2, 1))\n        # -R^T t for NumPy\n        top_right = -np.matmul(R_transposed, T)\n        inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))\n    else:\n        R_transposed = R.transpose(1, 2)  # (N,3,3)\n        top_right = -torch.bmm(R_transposed, T)  # (N,3,1)\n        inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)\n        inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)\n\n    inverted_matrix[:, :3, :3] = R_transposed\n    inverted_matrix[:, :3, 3:] = top_right\n\n    return inverted_matrix\n"
  },
  {
    "path": "mvtracker/models/core/vggt/utils/load_fn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nfrom PIL import Image\nfrom torchvision import transforms as TF\n\n\ndef load_and_preprocess_images(image_path_list, mode=\"crop\"):\n    \"\"\"\n    A quick start function to load and preprocess images for model input.\n    This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.\n\n    Args:\n        image_path_list (list): List of paths to image files\n        mode (str, optional): Preprocessing mode, either \"crop\" or \"pad\".\n                             - \"crop\" (default): Sets width to 518px and center crops height if needed.\n                             - \"pad\": Preserves all pixels by making the largest dimension 518px\n                               and padding the smaller dimension to reach a square shape.\n\n    Returns:\n        torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)\n\n    Raises:\n        ValueError: If the input list is empty or if mode is invalid\n\n    Notes:\n        - Images with different dimensions will be padded with white (value=1.0)\n        - A warning is printed when images have different shapes\n        - When mode=\"crop\": The function ensures width=518px while maintaining aspect ratio\n          and height is center-cropped if larger than 518px\n        - When mode=\"pad\": The function ensures the largest dimension is 518px while maintaining aspect ratio\n          and the smaller dimension is padded to reach a square shape (518x518)\n        - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements\n    \"\"\"\n    # Check for empty list\n    if len(image_path_list) == 0:\n        raise ValueError(\"At least 1 image is required\")\n    \n    # Validate mode\n    if mode not in [\"crop\", \"pad\"]:\n        raise ValueError(\"Mode must be either 'crop' or 'pad'\")\n\n    images = []\n    shapes = set()\n    to_tensor = TF.ToTensor()\n    target_size = 518\n\n    # First process all images and collect their shapes\n    for image_path in image_path_list:\n\n        # Open image\n        img = Image.open(image_path)\n\n        # If there's an alpha channel, blend onto white background:\n        if img.mode == \"RGBA\":\n            # Create white background\n            background = Image.new(\"RGBA\", img.size, (255, 255, 255, 255))\n            # Alpha composite onto the white background\n            img = Image.alpha_composite(background, img)\n\n        # Now convert to \"RGB\" (this step assigns white for transparent areas)\n        img = img.convert(\"RGB\")\n\n        width, height = img.size\n        \n        if mode == \"pad\":\n            # Make the largest dimension 518px while maintaining aspect ratio\n            if width >= height:\n                new_width = target_size\n                new_height = round(height * (new_width / width) / 14) * 14  # Make divisible by 14\n            else:\n                new_height = target_size\n                new_width = round(width * (new_height / height) / 14) * 14  # Make divisible by 14\n        else:  # mode == \"crop\"\n            # Original behavior: set width to 518px\n            new_width = target_size\n            # Calculate height maintaining aspect ratio, divisible by 14\n            new_height = round(height * (new_width / width) / 14) * 14\n\n        # Resize with new dimensions (width, height)\n        img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)\n        img = to_tensor(img)  # Convert to tensor (0, 1)\n\n        # Center crop height if it's larger than 518 (only in crop mode)\n        if mode == \"crop\" and new_height > target_size:\n            start_y = (new_height - target_size) // 2\n            img = img[:, start_y : start_y + target_size, :]\n        \n        # For pad mode, pad to make a square of target_size x target_size\n        if mode == \"pad\":\n            h_padding = target_size - img.shape[1]\n            w_padding = target_size - img.shape[2]\n            \n            if h_padding > 0 or w_padding > 0:\n                pad_top = h_padding // 2\n                pad_bottom = h_padding - pad_top\n                pad_left = w_padding // 2\n                pad_right = w_padding - pad_left\n                \n                # Pad with white (value=1.0)\n                img = torch.nn.functional.pad(\n                    img, (pad_left, pad_right, pad_top, pad_bottom), mode=\"constant\", value=1.0\n                )\n\n        shapes.add((img.shape[1], img.shape[2]))\n        images.append(img)\n\n    # Check if we have different shapes\n    # In theory our model can also work well with different shapes\n    if len(shapes) > 1:\n        print(f\"Warning: Found images with different shapes: {shapes}\")\n        # Find maximum dimensions\n        max_height = max(shape[0] for shape in shapes)\n        max_width = max(shape[1] for shape in shapes)\n\n        # Pad images if necessary\n        padded_images = []\n        for img in images:\n            h_padding = max_height - img.shape[1]\n            w_padding = max_width - img.shape[2]\n\n            if h_padding > 0 or w_padding > 0:\n                pad_top = h_padding // 2\n                pad_bottom = h_padding - pad_top\n                pad_left = w_padding // 2\n                pad_right = w_padding - pad_left\n\n                img = torch.nn.functional.pad(\n                    img, (pad_left, pad_right, pad_top, pad_bottom), mode=\"constant\", value=1.0\n                )\n            padded_images.append(img)\n        images = padded_images\n\n    images = torch.stack(images)  # concatenate images\n\n    # Ensure correct shape when single image\n    if len(image_path_list) == 1:\n        # Verify shape is (1, C, H, W)\n        if images.dim() == 3:\n            images = images.unsqueeze(0)\n\n    return images\n"
  },
  {
    "path": "mvtracker/models/core/vggt/utils/pose_enc.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport torch\nfrom .rotation import quat_to_mat, mat_to_quat\n\n\ndef extri_intri_to_pose_encoding(\n    extrinsics,\n    intrinsics,\n    image_size_hw=None,  # e.g., (256, 512)\n    pose_encoding_type=\"absT_quaR_FoV\",\n):\n    \"\"\"Convert camera extrinsics and intrinsics to a compact pose encoding.\n\n    This function transforms camera parameters into a unified pose encoding format,\n    which can be used for various downstream tasks like pose prediction or representation.\n\n    Args:\n        extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,\n            where B is batch size and S is sequence length.\n            In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.\n            The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.\n        intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.\n            Defined in pixels, with format:\n            [[fx, 0, cx],\n             [0, fy, cy],\n             [0,  0,  1]]\n            where fx, fy are focal lengths and (cx, cy) is the principal point\n        image_size_hw (tuple): Tuple of (height, width) of the image in pixels.\n            Required for computing field of view values. For example: (256, 512).\n        pose_encoding_type (str): Type of pose encoding to use. Currently only\n            supports \"absT_quaR_FoV\" (absolute translation, quaternion rotation, field of view).\n\n    Returns:\n        torch.Tensor: Encoded camera pose parameters with shape BxSx9.\n            For \"absT_quaR_FoV\" type, the 9 dimensions are:\n            - [:3] = absolute translation vector T (3D)\n            - [3:7] = rotation as quaternion quat (4D)\n            - [7:] = field of view (2D)\n    \"\"\"\n\n    # extrinsics: BxSx3x4\n    # intrinsics: BxSx3x3\n\n    if pose_encoding_type == \"absT_quaR_FoV\":\n        R = extrinsics[:, :, :3, :3]  # BxSx3x3\n        T = extrinsics[:, :, :3, 3]  # BxSx3\n\n        quat = mat_to_quat(R)\n        # Note the order of h and w here\n        H, W = image_size_hw\n        fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])\n        fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])\n        pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()\n    else:\n        raise NotImplementedError\n\n    return pose_encoding\n\n\ndef pose_encoding_to_extri_intri(\n    pose_encoding,\n    image_size_hw=None,  # e.g., (256, 512)\n    pose_encoding_type=\"absT_quaR_FoV\",\n    build_intrinsics=True,\n):\n    \"\"\"Convert a pose encoding back to camera extrinsics and intrinsics.\n\n    This function performs the inverse operation of extri_intri_to_pose_encoding,\n    reconstructing the full camera parameters from the compact encoding.\n\n    Args:\n        pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,\n            where B is batch size and S is sequence length.\n            For \"absT_quaR_FoV\" type, the 9 dimensions are:\n            - [:3] = absolute translation vector T (3D)\n            - [3:7] = rotation as quaternion quat (4D)\n            - [7:] = field of view (2D)\n        image_size_hw (tuple): Tuple of (height, width) of the image in pixels.\n            Required for reconstructing intrinsics from field of view values.\n            For example: (256, 512).\n        pose_encoding_type (str): Type of pose encoding used. Currently only\n            supports \"absT_quaR_FoV\" (absolute translation, quaternion rotation, field of view).\n        build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.\n            If False, only extrinsics are returned and intrinsics will be None.\n\n    Returns:\n        tuple: (extrinsics, intrinsics)\n            - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.\n              In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world\n              transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is\n              a 3x1 translation vector.\n            - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,\n              or None if build_intrinsics is False. Defined in pixels, with format:\n              [[fx, 0, cx],\n               [0, fy, cy],\n               [0,  0,  1]]\n              where fx, fy are focal lengths and (cx, cy) is the principal point,\n              assumed to be at the center of the image (W/2, H/2).\n    \"\"\"\n\n    intrinsics = None\n\n    if pose_encoding_type == \"absT_quaR_FoV\":\n        T = pose_encoding[..., :3]\n        quat = pose_encoding[..., 3:7]\n        fov_h = pose_encoding[..., 7]\n        fov_w = pose_encoding[..., 8]\n\n        R = quat_to_mat(quat)\n        extrinsics = torch.cat([R, T[..., None]], dim=-1)\n\n        if build_intrinsics:\n            H, W = image_size_hw\n            fy = (H / 2.0) / torch.tan(fov_h / 2.0)\n            fx = (W / 2.0) / torch.tan(fov_w / 2.0)\n            intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)\n            intrinsics[..., 0, 0] = fx\n            intrinsics[..., 1, 1] = fy\n            intrinsics[..., 0, 2] = W / 2\n            intrinsics[..., 1, 2] = H / 2\n            intrinsics[..., 2, 2] = 1.0  # Set the homogeneous coordinate to 1\n    else:\n        raise NotImplementedError\n\n    return extrinsics, intrinsics\n"
  },
  {
    "path": "mvtracker/models/core/vggt/utils/rotation.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d\n\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\n\n\ndef quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Quaternion Order: XYZW or say ijkr, scalar-last\n\n    Convert rotations given as quaternions to rotation matrices.\n    Args:\n        quaternions: quaternions with real part last,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    i, j, k, r = torch.unbind(quaternions, -1)\n    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n\n\ndef mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n\n    Returns:\n        quaternions with real part last, as tensor of shape (..., 4).\n        Quaternion Order: XYZW or say ijkr, scalar-last\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    # we produce the desired quaternion multiplied by each of r, i, j, k\n    quat_by_rijk = torch.stack(\n        [\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and\n            #  `int`.\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n    # the candidate won't be picked.\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n    # forall i; we pick the best-conditioned one (with the largest denominator)\n    out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))\n\n    # Convert from rijk to ijkr\n    out = out[..., [1, 2, 3, 0]]\n\n    out = standardize_quaternion(out)\n\n    return out\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns torch.sqrt(torch.max(0, x))\n    but with a zero subgradient where x is 0.\n    \"\"\"\n    ret = torch.zeros_like(x)\n    positive_mask = x > 0\n    if torch.is_grad_enabled():\n        ret[positive_mask] = torch.sqrt(x[positive_mask])\n    else:\n        ret = torch.where(positive_mask, torch.sqrt(x), ret)\n    return ret\n\n\ndef standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert a unit quaternion to a standard form: one in which the real\n    part is non negative.\n\n    Args:\n        quaternions: Quaternions with real part last,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Standardized quaternions as tensor of shape (..., 4).\n    \"\"\"\n    return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)\n"
  },
  {
    "path": "mvtracker/models/core/vggt/utils/visual_track.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport cv2\nimport torch\nimport numpy as np\nimport os\n\n\ndef color_from_xy(x, y, W, H, cmap_name=\"hsv\"):\n    \"\"\"\n    Map (x, y) -> color in (R, G, B).\n    1) Normalize x,y to [0,1].\n    2) Combine them into a single scalar c in [0,1].\n    3) Use matplotlib's colormap to convert c -> (R,G,B).\n\n    You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).\n    \"\"\"\n    import matplotlib.cm\n    import matplotlib.colors\n\n    x_norm = x / max(W - 1, 1)\n    y_norm = y / max(H - 1, 1)\n    # Simple combination:\n    c = (x_norm + y_norm) / 2.0\n\n    cmap = matplotlib.cm.get_cmap(cmap_name)\n    # cmap(c) -> (r,g,b,a) in [0,1]\n    rgba = cmap(c)\n    r, g, b = rgba[0], rgba[1], rgba[2]\n    return (r, g, b)  # in [0,1], RGB order\n\n\ndef get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name=\"hsv\"):\n    \"\"\"\n    Given all tracks in one sample (b), compute a (N,3) array of RGB color values\n    in [0,255]. The color is determined by the (x,y) position in the first\n    visible frame for each track.\n\n    Args:\n        tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.\n        vis_mask_b: (S, N) boolean mask; if None, assume all are visible.\n        image_width, image_height: used for normalizing (x, y).\n        cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').\n\n    Returns:\n        track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].\n    \"\"\"\n    S, N, _ = tracks_b.shape\n    track_colors = np.zeros((N, 3), dtype=np.uint8)\n\n    if vis_mask_b is None:\n        # treat all as visible\n        vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)\n\n    for i in range(N):\n        # Find first visible frame for track i\n        visible_frames = torch.where(vis_mask_b[:, i])[0]\n        if len(visible_frames) == 0:\n            # track is never visible; just assign black or something\n            track_colors[i] = (0, 0, 0)\n            continue\n\n        first_s = int(visible_frames[0].item())\n        # use that frame's (x,y)\n        x, y = tracks_b[first_s, i].tolist()\n\n        # map (x,y) -> (R,G,B) in [0,1]\n        r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)\n        # scale to [0,255]\n        r, g, b = int(r * 255), int(g * 255), int(b * 255)\n        track_colors[i] = (r, g, b)\n\n    return track_colors\n\n\ndef visualize_tracks_on_images(\n    images,\n    tracks,\n    track_vis_mask=None,\n    out_dir=\"track_visuals_concat_by_xy\",\n    image_format=\"CHW\",  # \"CHW\" or \"HWC\"\n    normalize_mode=\"[0,1]\",\n    cmap_name=\"hsv\",  # e.g. \"hsv\", \"rainbow\", \"jet\"\n    frames_per_row=4,  # New parameter for grid layout\n    save_grid=True,  # Flag to control whether to save the grid image\n):\n    \"\"\"\n    Visualizes frames in a grid layout with specified frames per row.\n    Each track's color is determined by its (x,y) position\n    in the first visible frame (or frame 0 if always visible).\n    Finally convert the BGR result to RGB before saving.\n    Also saves each individual frame as a separate PNG file.\n\n    Args:\n        images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.\n        tracks: torch.Tensor (S, N, 2), last dim = (x, y).\n        track_vis_mask: torch.Tensor (S, N) or None.\n        out_dir: folder to save visualizations.\n        image_format: \"CHW\" or \"HWC\".\n        normalize_mode: \"[0,1]\", \"[-1,1]\", or None for direct raw -> 0..255\n        cmap_name: a matplotlib colormap name for color_from_xy.\n        frames_per_row: number of frames to display in each row of the grid.\n        save_grid: whether to save all frames in one grid image.\n\n    Returns:\n        None (saves images in out_dir).\n    \"\"\"\n\n    if len(tracks.shape) == 4:\n        tracks = tracks.squeeze(0)\n        images = images.squeeze(0)\n        if track_vis_mask is not None:\n            track_vis_mask = track_vis_mask.squeeze(0)\n\n    import matplotlib\n\n    matplotlib.use(\"Agg\")  # for non-interactive (optional)\n\n    os.makedirs(out_dir, exist_ok=True)\n\n    S = images.shape[0]\n    _, N, _ = tracks.shape  # (S, N, 2)\n\n    # Move to CPU\n    images = images.cpu().clone()\n    tracks = tracks.cpu().clone()\n    if track_vis_mask is not None:\n        track_vis_mask = track_vis_mask.cpu().clone()\n\n    # Infer H, W from images shape\n    if image_format == \"CHW\":\n        # e.g. images[s].shape = (3, H, W)\n        H, W = images.shape[2], images.shape[3]\n    else:\n        # e.g. images[s].shape = (H, W, 3)\n        H, W = images.shape[1], images.shape[2]\n\n    # Pre-compute the color for each track i based on first visible position\n    track_colors_rgb = get_track_colors_by_position(\n        tracks,  # shape (S, N, 2)\n        vis_mask_b=track_vis_mask if track_vis_mask is not None else None,\n        image_width=W,\n        image_height=H,\n        cmap_name=cmap_name,\n    )\n\n    # We'll accumulate each frame's drawn image in a list\n    frame_images = []\n\n    for s in range(S):\n        # shape => either (3, H, W) or (H, W, 3)\n        img = images[s]\n\n        # Convert to (H, W, 3)\n        if image_format == \"CHW\":\n            img = img.permute(1, 2, 0)  # (H, W, 3)\n        # else \"HWC\", do nothing\n\n        img = img.numpy().astype(np.float32)\n\n        # Scale to [0,255] if needed\n        if normalize_mode == \"[0,1]\":\n            img = np.clip(img, 0, 1) * 255.0\n        elif normalize_mode == \"[-1,1]\":\n            img = (img + 1.0) * 0.5 * 255.0\n            img = np.clip(img, 0, 255.0)\n        # else no normalization\n\n        # Convert to uint8\n        img = img.astype(np.uint8)\n\n        # For drawing in OpenCV, convert to BGR\n        img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n\n        # Draw each visible track\n        cur_tracks = tracks[s]  # shape (N, 2)\n        if track_vis_mask is not None:\n            valid_indices = torch.where(track_vis_mask[s])[0]\n        else:\n            valid_indices = range(N)\n\n        cur_tracks_np = cur_tracks.numpy()\n        for i in valid_indices:\n            x, y = cur_tracks_np[i]\n            pt = (int(round(x)), int(round(y)))\n\n            # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR\n            R, G, B = track_colors_rgb[i]\n            color_bgr = (int(B), int(G), int(R))\n            cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)\n\n        # Convert back to RGB for consistent final saving:\n        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)\n\n        # Save individual frame\n        frame_path = os.path.join(out_dir, f\"frame_{s:04d}.png\")\n        # Convert to BGR for OpenCV imwrite\n        frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)\n        cv2.imwrite(frame_path, frame_bgr)\n\n        frame_images.append(img_rgb)\n\n    # Only create and save the grid image if save_grid is True\n    if save_grid:\n        # Calculate grid dimensions\n        num_rows = (S + frames_per_row - 1) // frames_per_row  # Ceiling division\n\n        # Create a grid of images\n        grid_img = None\n        for row in range(num_rows):\n            start_idx = row * frames_per_row\n            end_idx = min(start_idx + frames_per_row, S)\n\n            # Concatenate this row horizontally\n            row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)\n\n            # If this row has fewer than frames_per_row images, pad with black\n            if end_idx - start_idx < frames_per_row:\n                padding_width = (frames_per_row - (end_idx - start_idx)) * W\n                padding = np.zeros((H, padding_width, 3), dtype=np.uint8)\n                row_img = np.concatenate([row_img, padding], axis=1)\n\n            # Add this row to the grid\n            if grid_img is None:\n                grid_img = row_img\n            else:\n                grid_img = np.concatenate([grid_img, row_img], axis=0)\n\n        out_path = os.path.join(out_dir, \"tracks_grid.png\")\n        # Convert back to BGR for OpenCV imwrite\n        grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)\n        cv2.imwrite(out_path, grid_img_bgr)\n        print(f\"[INFO] Saved color-by-XY track visualization grid -> {out_path}\")\n\n    print(f\"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png\")\n"
  },
  {
    "path": "mvtracker/models/core/vit/__init__.py",
    "content": ""
  },
  {
    "path": "mvtracker/models/core/vit/common.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom typing import Type\n\nimport torch\nimport torch.nn as nn\n\n\nclass MLPBlock(nn.Module):\n    def __init__(\n            self,\n            embedding_dim: int,\n            mlp_dim: int,\n            act: Type[nn.Module] = nn.GELU,\n    ) -> None:\n        super().__init__()\n        self.lin1 = nn.Linear(embedding_dim, mlp_dim)\n        self.lin2 = nn.Linear(mlp_dim, embedding_dim)\n        self.act = act()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.lin2(self.act(self.lin1(x)))\n\n\n# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa\n# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa\nclass LayerNorm2d(nn.Module):\n    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(num_channels))\n        self.bias = nn.Parameter(torch.zeros(num_channels))\n        self.eps = eps\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        u = x.mean(1, keepdim=True)\n        s = (x - u).pow(2).mean(1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.eps)\n        x = self.weight[:, None, None] * x + self.bias[:, None, None]\n        return x\n"
  },
  {
    "path": "mvtracker/models/core/vit/encoder.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom typing import Optional, Tuple, Type\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mvtracker.models.core.vit.common import (\n    LayerNorm2d, MLPBlock\n)\n\n\n# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa\nclass ImageEncoderViT(nn.Module):\n    def __init__(\n            self,\n            img_size: int = 1024,\n            patch_size: int = 16,\n            in_chans: int = 3,\n            embed_dim: int = 768,\n            depth: int = 12,\n            num_heads: int = 12,\n            mlp_ratio: float = 4.0,\n            out_chans: int = 256,\n            qkv_bias: bool = True,\n            norm_layer: Type[nn.Module] = nn.LayerNorm,\n            act_layer: Type[nn.Module] = nn.GELU,\n            use_abs_pos: bool = True,\n            use_rel_pos: bool = False,\n            rel_pos_zero_init: bool = True,\n            window_size: int = 0,\n            global_attn_indexes: Tuple[int, ...] = (),\n    ) -> None:\n        \"\"\"\n        Args:\n            img_size (int): Input image size.\n            patch_size (int): Patch size.\n            in_chans (int): Number of input image channels.\n            embed_dim (int): Patch embedding dimension.\n            depth (int): Depth of ViT.\n            num_heads (int): Number of attention heads in each ViT block.\n            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n            qkv_bias (bool): If True, add a learnable bias to query, key, value.\n            norm_layer (nn.Module): Normalization layer.\n            act_layer (nn.Module): Activation layer.\n            use_abs_pos (bool): If True, use absolute positional embeddings.\n            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            window_size (int): Window size for window attention blocks.\n            global_attn_indexes (list): Indexes for blocks using global attention.\n        \"\"\"\n        super().__init__()\n        self.img_size = img_size\n\n        self.patch_embed = PatchEmbed(\n            kernel_size=(patch_size, patch_size),\n            stride=(patch_size, patch_size),\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n        )\n\n        self.pos_embed: Optional[nn.Parameter] = None\n        if use_abs_pos:\n            # Initialize absolute positional embedding with pretrain image size.\n            self.pos_embed = nn.Parameter(\n                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)\n            )\n\n        self.blocks = nn.ModuleList()\n        for i in range(depth):\n            block = Block(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                use_rel_pos=use_rel_pos,\n                rel_pos_zero_init=rel_pos_zero_init,\n                window_size=window_size if i not in global_attn_indexes else 0,\n                input_size=(img_size // patch_size, img_size // patch_size),\n            )\n            self.blocks.append(block)\n\n        self.neck = nn.Sequential(\n            nn.Conv2d(\n                embed_dim,\n                out_chans,\n                kernel_size=1,\n                bias=False,\n            ),\n            LayerNorm2d(out_chans),\n            nn.Conv2d(\n                out_chans,\n                out_chans,\n                kernel_size=3,\n                padding=1,\n                bias=False,\n            ),\n            LayerNorm2d(out_chans),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n\n        x = self.patch_embed(x)\n        if self.pos_embed is not None:\n            x = x + self.pos_embed\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x = self.neck(x.permute(0, 3, 1, 2))\n\n        return x\n\n\nclass Block(nn.Module):\n    \"\"\"Transformer blocks with support of window attention and residual propagation blocks\"\"\"\n\n    def __init__(\n            self,\n            dim: int,\n            num_heads: int,\n            mlp_ratio: float = 4.0,\n            qkv_bias: bool = True,\n            norm_layer: Type[nn.Module] = nn.LayerNorm,\n            act_layer: Type[nn.Module] = nn.GELU,\n            use_rel_pos: bool = False,\n            rel_pos_zero_init: bool = True,\n            window_size: int = 0,\n            input_size: Optional[Tuple[int, int]] = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim (int): Number of input channels.\n            num_heads (int): Number of attention heads in each ViT block.\n            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n            qkv_bias (bool): If True, add a learnable bias to query, key, value.\n            norm_layer (nn.Module): Normalization layer.\n            act_layer (nn.Module): Activation layer.\n            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            window_size (int): Window size for window attention blocks. If it equals 0, then\n                use global attention.\n            input_size (tuple(int, int) or None): Input resolution for calculating the relative\n                positional parameter size.\n        \"\"\"\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            use_rel_pos=use_rel_pos,\n            rel_pos_zero_init=rel_pos_zero_init,\n            input_size=input_size if window_size == 0 else (window_size, window_size),\n        )\n\n        self.norm2 = norm_layer(dim)\n        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)\n\n        self.window_size = window_size\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        shortcut = x\n        x = self.norm1(x)\n        # Window partition\n        if self.window_size > 0:\n            H, W = x.shape[1], x.shape[2]\n            x, pad_hw = window_partition(x, self.window_size)\n\n        x = self.attn(x)\n        # Reverse window partition\n        if self.window_size > 0:\n            x = window_unpartition(x, self.window_size, pad_hw, (H, W))\n\n        x = shortcut + x\n        x = x + self.mlp(self.norm2(x))\n\n        return x\n\n\nclass Attention(nn.Module):\n    \"\"\"Multi-head Attention block with relative position embeddings.\"\"\"\n\n    def __init__(\n            self,\n            dim: int,\n            num_heads: int = 8,\n            qkv_bias: bool = True,\n            use_rel_pos: bool = False,\n            rel_pos_zero_init: bool = True,\n            input_size: Optional[Tuple[int, int]] = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            dim (int): Number of input channels.\n            num_heads (int): Number of attention heads.\n            qkv_bias (bool):  If True, add a learnable bias to query, key, value.\n            rel_pos (bool): If True, add relative positional embeddings to the attention map.\n            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.\n            input_size (tuple(int, int) or None): Input resolution for calculating the relative\n                positional parameter size.\n        \"\"\"\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.proj = nn.Linear(dim, dim)\n\n        self.use_rel_pos = use_rel_pos\n        if self.use_rel_pos:\n            assert (\n                    input_size is not None\n            ), \"Input size must be provided if using relative positional encoding.\"\n            # initialize relative positional embeddings\n            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))\n            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, H, W, _ = x.shape\n        # qkv with shape (3, B, nHead, H * W, C)\n        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)\n        # q, k, v with shape (B * nHead, H * W, C)\n        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n\n        if self.use_rel_pos:\n            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))\n\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)\n        x = self.proj(x)\n\n        return x\n\n\ndef window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:\n    \"\"\"\n    Partition into non-overlapping windows with padding if needed.\n    Args:\n        x (tensor): input tokens with [B, H, W, C].\n        window_size (int): window size.\n\n    Returns:\n        windows: windows after partition with [B * num_windows, window_size, window_size, C].\n        (Hp, Wp): padded height and width before partition\n    \"\"\"\n    B, H, W, C = x.shape\n\n    pad_h = (window_size - H % window_size) % window_size\n    pad_w = (window_size - W % window_size) % window_size\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))\n    Hp, Wp = H + pad_h, W + pad_w\n\n    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)\n    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n    return windows, (Hp, Wp)\n\n\ndef window_unpartition(\n        windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]\n) -> torch.Tensor:\n    \"\"\"\n    Window unpartition into original sequences and removing padding.\n    Args:\n        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].\n        window_size (int): window size.\n        pad_hw (Tuple): padded height and width (Hp, Wp).\n        hw (Tuple): original height and width (H, W) before padding.\n\n    Returns:\n        x: unpartitioned sequences with [B, H, W, C].\n    \"\"\"\n    Hp, Wp = pad_hw\n    H, W = hw\n    B = windows.shape[0] // (Hp * Wp // window_size // window_size)\n    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)\n    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)\n\n    if Hp > H or Wp > W:\n        x = x[:, :H, :W, :].contiguous()\n    return x\n\n\ndef get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Get relative positional embeddings according to the relative positions of\n        query and key sizes.\n    Args:\n        q_size (int): size of query q.\n        k_size (int): size of key k.\n        rel_pos (Tensor): relative position embeddings (L, C).\n\n    Returns:\n        Extracted positional embeddings according to relative positions.\n    \"\"\"\n    max_rel_dist = int(2 * max(q_size, k_size) - 1)\n    # Interpolate rel pos if needed.\n    if rel_pos.shape[0] != max_rel_dist:\n        # Interpolate rel pos.\n        rel_pos_resized = F.interpolate(\n            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),\n            size=max_rel_dist,\n            mode=\"linear\",\n        )\n        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)\n    else:\n        rel_pos_resized = rel_pos\n\n    # Scale the coords with short length if shapes for q and k are different.\n    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)\n    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)\n    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)\n\n    return rel_pos_resized[relative_coords.long()]\n\n\ndef add_decomposed_rel_pos(\n        attn: torch.Tensor,\n        q: torch.Tensor,\n        rel_pos_h: torch.Tensor,\n        rel_pos_w: torch.Tensor,\n        q_size: Tuple[int, int],\n        k_size: Tuple[int, int],\n) -> torch.Tensor:\n    \"\"\"\n    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.\n    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950\n    Args:\n        attn (Tensor): attention map.\n        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).\n        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.\n        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.\n        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).\n        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).\n\n    Returns:\n        attn (Tensor): attention map with added relative positional embeddings.\n    \"\"\"\n    q_h, q_w = q_size\n    k_h, k_w = k_size\n    Rh = get_rel_pos(q_h, k_h, rel_pos_h)\n    Rw = get_rel_pos(q_w, k_w, rel_pos_w)\n\n    B, _, dim = q.shape\n    r_q = q.reshape(B, q_h, q_w, dim)\n    rel_h = torch.einsum(\"bhwc,hkc->bhwk\", r_q, Rh)\n    rel_w = torch.einsum(\"bhwc,wkc->bhwk\", r_q, Rw)\n\n    attn = (\n            attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]\n    ).view(B, q_h * q_w, k_h * k_w)\n\n    return attn\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Image to Patch Embedding.\n    \"\"\"\n\n    def __init__(\n            self,\n            kernel_size: Tuple[int, int] = (16, 16),\n            stride: Tuple[int, int] = (16, 16),\n            padding: Tuple[int, int] = (0, 0),\n            in_chans: int = 3,\n            embed_dim: int = 768,\n    ) -> None:\n        \"\"\"\n        Args:\n            kernel_size (Tuple): kernel size of the projection layer.\n            stride (Tuple): stride of the projection layer.\n            padding (Tuple): padding size of the projection layer.\n            in_chans (int): Number of input image channels.\n            embed_dim (int): Patch embedding dimension.\n        \"\"\"\n        super().__init__()\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.proj(x)\n        # B C H W -> B H W C\n        x = x.permute(0, 2, 3, 1)\n        return x\n"
  },
  {
    "path": "mvtracker/models/evaluation_predictor_3dpt.py",
    "content": "import os\nimport random\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom mvtracker.models.core.model_utils import bilinear_sample2d, get_points_on_a_grid\nfrom mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z\nfrom mvtracker.models.core.mvtracker.mvtracker import save_pointcloud_to_ply\nfrom mvtracker.utils.basic import to_homogeneous, from_homogeneous, time_now\nfrom mvtracker.utils.visualizer_mp4 import MultiViewVisualizer\n\n\nclass EvaluationPredictor(torch.nn.Module):\n    def __init__(\n            self,\n            multiview_model: torch.nn.Module,\n            interp_shape: Optional[Tuple[int, int]] = (384, 512),\n            visibility_threshold=0.5,\n            grid_size: int = 5,\n            n_grids_per_view: int = 1,\n            local_grid_size: int = 8,\n            local_extent: int = 50,\n            single_point: bool = False,\n            sift_size: int = 0,\n            num_uniformly_sampled_pts: int = 0,\n            n_iters: int = 6,\n    ) -> None:\n        super(EvaluationPredictor, self).__init__()\n        self.model = multiview_model\n        self.interp_shape = interp_shape\n        self.visibility_threshold = visibility_threshold\n        self.grid_size = grid_size\n        self.n_grids_per_view = n_grids_per_view\n        self.local_grid_size = local_grid_size\n        self.local_extent = local_extent\n        self.single_point = single_point\n        self.sift_size = sift_size\n        self.num_uniformly_sampled_pts = num_uniformly_sampled_pts\n        self.n_iters = n_iters\n\n        self.model.eval()\n\n    def forward(\n            self,\n            rgbs,\n            depths,\n            query_points_3d,\n            intrs,\n            extrs,\n            save_debug_logs=False,\n            debug_logs_path=\"\",\n            query_points_view=None,\n            **kwargs,\n    ):\n        batch_size, num_views, num_frames, _, height_raw, width_raw = rgbs.shape\n        _, num_points, _ = query_points_3d.shape\n\n        assert rgbs.shape == (batch_size, num_views, num_frames, 3, height_raw, width_raw)\n        assert depths.shape == (batch_size, num_views, num_frames, 1, height_raw, width_raw)\n        assert query_points_3d.shape == (batch_size, num_points, 4)\n        assert intrs.shape == (batch_size, num_views, num_frames, 3, 3)\n        assert extrs.shape == (batch_size, num_views, num_frames, 3, 4)\n\n        if batch_size != 1:\n            raise NotImplementedError\n\n        # Interpolate the inputs to the desired resolution, if needed\n        if self.interp_shape is None:\n            height, width = height_raw, width_raw\n        else:\n            height, width = self.interp_shape\n            rgbs = rgbs.reshape(-1, 3, height_raw, width_raw)\n            rgbs = F.interpolate(rgbs, (height, width), mode=\"nearest\")\n            rgbs = rgbs.reshape(batch_size, num_views, num_frames, 3, height, width)\n            depths = depths.reshape(-1, 1, height_raw, width_raw)\n            depths = F.interpolate(depths, (height, width), mode=\"nearest\")\n            depths = depths.reshape(batch_size, num_views, num_frames, 1, height, width)\n            intrs_resize_transform = torch.tensor([\n                [width / width_raw, 0, 0],\n                [0, height / height_raw, 0],\n                [0, 0, 1],\n            ], device=intrs.device, dtype=intrs.dtype)\n            intrs = torch.einsum(\"ij,BVTjk->BVTik\", intrs_resize_transform, intrs)\n\n        # Unpack the query points\n        query_points_t = query_points_3d[:, :, :1].long()\n        query_points_xyz_worldspace = query_points_3d[:, :, 1:]\n\n        # Invert intrinsics and extrinsics\n        intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n        extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n        extrs_square[:, :, :, :3, :] = extrs\n        extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n\n        support_points = torch.zeros((batch_size, 0, 4), device=rgbs.device)\n\n        grid_points = []\n        if self.grid_size > 0:\n            pixel_xy = get_points_on_a_grid(self.grid_size, (height, width), device=rgbs.device)\n            pixel_xy_homo = to_homogeneous(pixel_xy)\n            for t in range(0, num_frames, max(1, num_frames // self.n_grids_per_view)):\n                for view_idx in range(num_views):\n                    camera_z = bilinear_sample2d(\n                        depths[0, view_idx, t][None],\n                        pixel_xy[..., 0],\n                        pixel_xy[..., 1],\n                    ).permute(0, 2, 1)\n                    camera_xyz = torch.einsum('Bij,BNj->BNi', intrs_inv[:, view_idx, t, :, :], pixel_xy_homo)\n                    camera_xyz = camera_xyz * camera_z\n                    camera_xyz_homo = to_homogeneous(camera_xyz)\n                    world_xyz_homo = torch.einsum('Bij,BNj->BNi', extrs_inv[:, view_idx, t, :, :], camera_xyz_homo)\n                    world_xyz = from_homogeneous(world_xyz_homo)\n                    grid_points_i = torch.cat([torch.ones_like(world_xyz[:, :, :1]) * t, world_xyz], dim=2)\n                    grid_points.append(grid_points_i)\n            grid_points = torch.cat(grid_points, dim=1)\n            support_points = torch.concat([support_points, grid_points], dim=1)\n\n            if save_debug_logs:\n                os.makedirs(debug_logs_path, exist_ok=True)\n                save_pointcloud_to_ply(\n                    filename=os.path.join(debug_logs_path, time_now() + \"__predictor__query_points.ply\"),\n                    points=query_points_xyz_worldspace[0].cpu().numpy(),\n                    colors=np.ones_like(query_points_xyz_worldspace[0].cpu().numpy(), dtype=int) * np.array(\n                        [255, 30, 60]),\n                )\n                save_pointcloud_to_ply(\n                    filename=os.path.join(debug_logs_path, time_now() + \"__predictor__support_grid_points.ply\"),\n                    points=grid_points[0, :, 1:].cpu().numpy(),\n                    colors=np.ones_like(grid_points[0, :, 1:].cpu().numpy(), dtype=int) * np.array([45, 255, 60]),\n                )\n\n        sift_points = []\n        if self.sift_size > 0:\n            raise NotImplementedError\n            # xy = get_sift_sampled_pts(video, sift_size, T, [H, W], device=device)\n            # if xy.shape[1] == sift_size:\n            #     queries = torch.cat([queries, xy], dim=1)  #\n            # else:\n            #     sift_size = 0\n            sift_points = torch.cat(sift_points, dim=1)\n            support_points = torch.concat([support_points, sift_points], dim=1)\n\n        support_uniform_pts = []\n        if self.num_uniformly_sampled_pts > 0:\n            sampled_pts = get_uniformly_sampled_pts(\n                self.num_uniformly_sampled_pts,\n                num_frames,\n                (height, width),\n                device=rgbs.device,\n            )[0]  # shape: (N, 3) where each row is (t, y, x)\n\n            t_samples = sampled_pts[:, 0].long()\n            y_samples = sampled_pts[:, 1].float()\n            x_samples = sampled_pts[:, 2].float()\n\n            pixel_xy = torch.stack([x_samples, y_samples], dim=-1)[None]  # (1, N, 2)\n            pixel_xy_homo = to_homogeneous(pixel_xy)\n\n            for idx in range(sampled_pts.shape[0]):\n                t = t_samples[idx].item()\n                x = x_samples[idx].item()\n                y = y_samples[idx].item()\n\n                for view_idx in range(num_views):\n                    depth_val = bilinear_sample2d(\n                        depths[0, view_idx, t][None],  # shape (1, 1, H, W)\n                        torch.tensor([[x]], device=rgbs.device),\n                        torch.tensor([[y]], device=rgbs.device),\n                    ).item()\n\n                    cam_xy_h = torch.tensor([[x, y, 1.0]], device=rgbs.device).T\n                    K_inv = intrs_inv[0, view_idx, t]\n                    extr_inv = extrs_inv[0, view_idx, t]\n\n                    cam_xyz = (K_inv @ cam_xy_h).squeeze() * depth_val\n                    cam_xyz_h = to_homogeneous(cam_xyz[None])[0]\n                    world_xyz_h = extr_inv @ cam_xyz_h\n                    world_xyz = from_homogeneous(world_xyz_h[None])[0]\n\n                    support_point = torch.cat([torch.tensor([t], device=rgbs.device), world_xyz])\n                    support_uniform_pts.append(support_point)\n\n            if support_uniform_pts:\n                support_uniform_pts = torch.stack(support_uniform_pts, dim=0)[None]  # (1, N, 4)\n                support_points = torch.cat([support_points, support_uniform_pts], dim=1)\n\n        if self.single_point:\n            # Project the queries to each view\n            # This will be needed if adding local grid points\n            query_points_xyz_worldspace_homo = to_homogeneous(query_points_xyz_worldspace)\n            query_points_perview_camera_xyz = torch.einsum('BVTij,BNj->BVTNi', extrs, query_points_xyz_worldspace_homo)\n            query_points_perview_pixel_xy_homo = torch.einsum('BVTij,BVTNj->BVTNi', intrs,\n                                                              query_points_perview_camera_xyz)\n            query_points_perview_pixel_xy = from_homogeneous(query_points_perview_pixel_xy_homo)\n            query_points_perview_camera_xyz = query_points_perview_camera_xyz[\n                # Extract at the correct per-query timestep\n                torch.arange(batch_size)[:, None, None],\n                torch.arange(num_views)[None, :, None],\n                query_points_t[:, None, :, 0],\n                torch.arange(num_points)[None, None, :],\n            ]\n            query_points_perview_pixel_xy = query_points_perview_pixel_xy[  # Extract at the correct per-query timestep\n                torch.arange(batch_size)[:, None, None],\n                torch.arange(num_views)[None, :, None],\n                query_points_t[:, None, :, 0],\n                torch.arange(num_points)[None, None, :],\n            ]\n            query_points_perview_camera_z = query_points_perview_camera_xyz[..., -1:]\n\n            traj_e = torch.zeros((batch_size, num_frames, num_points, 3), device=rgbs.device)\n            vis_e = torch.zeros((batch_size, num_frames, num_points), device=rgbs.device)\n            for point_idx in tqdm(range(num_points), desc=\"Single point evaluation\"):\n                # Support points for this query point\n                support_points_i = torch.zeros((batch_size, 0, 4), device=rgbs.device)\n\n                # Add the local support points\n                if self.local_grid_size > 0:\n                    t = query_points_t[0, point_idx, 0].item()\n                    local_grid_points = torch.zeros((batch_size, 0, 4), device=rgbs.device)\n                    for view_idx in range(num_views):\n                        pixel_xy = get_points_on_a_grid(\n                            size=self.local_grid_size,\n                            extent=(self.local_extent, self.local_extent),\n                            center=(query_points_perview_pixel_xy[0, view_idx, point_idx, 1].item(),\n                                    query_points_perview_pixel_xy[0, view_idx, point_idx, 0].item()),\n                            device=rgbs.device,\n                        )\n                        inside_frame = ((pixel_xy[0, :, 0] >= 0)\n                                        & (pixel_xy[0, :, 0] < width)\n                                        & (pixel_xy[0, :, 1] >= 0)\n                                        & (pixel_xy[0, :, 1] < height))\n                        if not inside_frame.any():\n                            continue\n                        pixel_xy = pixel_xy[:, inside_frame, :]\n                        pixel_xy_homo = to_homogeneous(pixel_xy)\n                        camera_z = bilinear_sample2d(\n                            depths[0, view_idx, t][None],\n                            pixel_xy[..., 0],\n                            pixel_xy[..., 1],\n                        ).permute(0, 2, 1)\n                        camera_xyz = torch.einsum('Bij,BNj->BNi', intrs_inv[:, view_idx, t, :, :], pixel_xy_homo)\n                        camera_xyz = camera_xyz * camera_z\n                        camera_xyz_homo = to_homogeneous(camera_xyz)\n                        world_xyz_homo = torch.einsum('Bij,BNj->BNi', extrs_inv[:, view_idx, t, :, :], camera_xyz_homo)\n                        world_xyz = from_homogeneous(world_xyz_homo)\n                        local_grid_points_i = torch.cat([torch.ones_like(world_xyz[:, :, :1]) * t, world_xyz], dim=2)\n                        local_grid_points = torch.cat([local_grid_points, local_grid_points_i], dim=1)\n                    support_points_i = torch.cat([support_points_i, local_grid_points], dim=1)\n\n                # Add the global support points\n                support_points_i = torch.cat([support_points_i, support_points], dim=1)\n\n                # Forward pass for this query point\n                query_points_i = torch.cat([query_points_3d[:, point_idx: point_idx + 1, :], support_points_i], dim=1)\n                if query_points_view is not None:\n                    query_points_view = torch.cat([\n                        query_points_view[:, point_idx: point_idx + 1],\n                        query_points_view.new_zeros(support_points_i[:, :, 0].shape),\n                    ], dim=1)\n                results_i = self.model(\n                    rgbs,\n                    depths=depths,\n                    query_points=query_points_i,\n                    intrs=intrs,\n                    extrs=extrs,\n                    iters=self.n_iters,\n                    save_debug_logs=save_debug_logs and point_idx == 0,\n                    debug_logs_path=debug_logs_path,\n                    query_points_view=query_points_view,\n                    **kwargs,\n                )\n                traj_e[:, :, point_idx: point_idx + 1] = results_i[\"traj_e\"][:, :, :1]\n                vis_e[:, :, point_idx: point_idx + 1] = results_i[\"vis_e\"][:, :, :1]\n\n                if save_debug_logs and (point_idx in [0, 1, 2, 3, 4] or point_idx % 100 == 0):\n                    visualizer = MultiViewVisualizer(\n                        save_dir=debug_logs_path,\n                        pad_value=16,\n                        fps=12,\n                        show_first_frame=0,\n                        tracks_leave_trace=0,\n                    )\n\n                    # filename, pred_trajectories, pred_visibilities, qps\n                    tuples_to_process = []\n                    tuples_to_process += [(\n                        f\"predictor__pidx={point_idx}__viz_A_pred\",\n                        results_i[\"traj_e\"][:, :, :1],\n                        results_i[\"vis_e\"][:, :, :1],\n                        query_points_i[:, :1, :],\n                    )]\n                    tuples_to_process += [(\n                        f\"predictor__pidx={point_idx}__viz_B_pred_w_support\",\n                        results_i[\"traj_e\"],\n                        results_i[\"vis_e\"],\n                        query_points_i[:, :, :],\n                    )]\n                    if self.local_grid_size > 0 and local_grid_points.shape[1] > 0:\n                        num_local_support_points = local_grid_points.shape[1]\n                        tuples_to_process += [(\n                            f\"predictor__pidx={point_idx}__viz_C_local_support_grid\",\n                            results_i[\"traj_e\"][:, :, 1:1 + num_local_support_points, :],\n                            results_i[\"vis_e\"][:, :, 1:1 + num_local_support_points],\n                            query_points_i[:, 1:1 + num_local_support_points, :],\n                        )]\n                    if self.grid_size > 0:\n                        num_global_support_points = support_points.shape[1]\n                        tuples_to_process += [(\n                            f\"predictor__pidx={point_idx}__viz_D_global_support_grid\",\n                            results_i[\"traj_e\"][:, :, -num_global_support_points:, :],\n                            results_i[\"vis_e\"][:, :, -num_global_support_points:],\n                            query_points_i[:, -num_global_support_points:, :],\n                        )]\n                    for filename, pred_trajectories, pred_visibilities, qps in tuples_to_process:\n                        filename = time_now() + \"__\" + filename\n                        # Project the predictions to pixel space for visualization\n                        pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([\n                            torch.cat(world_space_to_pixel_xy_and_camera_z(\n                                world_xyz=pred_trajectories[0],\n                                intrs=intrs[0, view_idx],\n                                extrs=extrs[0, view_idx],\n                            ), dim=-1)\n                            for view_idx in range(num_views)\n                        ], dim=0)[None]\n                        pred_viz, _ = visualizer.visualize(\n                            video=rgbs,\n                            video_depth=depths,\n                            tracks=pred_trajectories_pixel_xy_camera_z_per_view,\n                            visibility=pred_visibilities > 0.5,\n                            query_frame=qps[..., 0].long().clone(),\n                            filename=filename,\n                            writer=None,\n                            step=0,\n                            save_video=True,\n                        )\n\n        else:\n            query_points_3d = torch.cat([query_points_3d, support_points], dim=1)\n            if query_points_view is not None:\n                query_points_view = torch.cat([\n                    query_points_view, query_points_view.new_zeros(support_points[:, :, 0].shape)\n                ], dim=1)\n            results = self.model(\n                rgbs,\n                depths=depths,\n                query_points=query_points_3d,\n                intrs=intrs,\n                extrs=extrs,\n                iters=self.n_iters,\n                save_debug_logs=save_debug_logs,\n                debug_logs_path=debug_logs_path,\n                query_points_view=query_points_view,\n                **kwargs,\n            )\n            traj_e = results[\"traj_e\"][:, :, :num_points, :]\n            vis_e = results[\"vis_e\"][:, :, :num_points]\n\n            if save_debug_logs:\n                visualizer = MultiViewVisualizer(\n                    save_dir=debug_logs_path,\n                    pad_value=16,\n                    fps=12,\n                    show_first_frame=0,\n                    tracks_leave_trace=0,\n                )\n                num_support_grid_points = grid_points.shape[1] if self.grid_size > 0 else 0\n                view_pts_all_timesteps = num_support_grid_points // num_views\n                view_pts = view_pts_all_timesteps // self.n_grids_per_view if self.grid_size > 0 else 0\n                for filename, pred_trajectories, pred_visibilities, qps in [\n                    (\"predictor__viz_A_pred\", traj_e, vis_e, query_points_3d[:, :num_points, :]),\n                    (\"predictor__viz_B_pred_w_support_grid\", results[\"traj_e\"], results[\"vis_e\"], query_points_3d),\n                    (\"predictor__viz_C_support_grid_only\", results[\"traj_e\"][:, :, num_points:, :],\n                     results[\"vis_e\"][:, :, num_points:], query_points_3d[:, num_points:, :]),\n                    *[(\n                            f\"predictor__viz_D_support_grid_only__t-0_view-{view_idx}\",\n                            results[\"traj_e\"][:, :,\n                            num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1), :],\n                            results[\"vis_e\"][:, :,\n                            num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1)],\n                            query_points_3d[:, num_points + view_pts * view_idx:num_points + view_pts * (view_idx + 1),\n                            :],\n                    ) for view_idx in range(num_views)],\n                ]:\n                    filename = time_now() + \"__\" + filename\n                    # Project the predictions to pixel space for visualization\n                    pred_trajectories_pixel_xy_camera_z_per_view = torch.stack([\n                        torch.cat(world_space_to_pixel_xy_and_camera_z(\n                            world_xyz=pred_trajectories[0],\n                            intrs=intrs[0, view_idx],\n                            extrs=extrs[0, view_idx],\n                        ), dim=-1)\n                        for view_idx in range(num_views)\n                    ], dim=0)[None]\n                    pred_viz, _ = visualizer.visualize(\n                        video=rgbs,\n                        video_depth=depths,\n                        tracks=pred_trajectories_pixel_xy_camera_z_per_view,\n                        visibility=pred_visibilities > 0.5,\n                        query_frame=qps[..., 0].long().clone(),\n                        filename=filename,\n                        writer=None,\n                        step=0,\n                        save_video=True,\n                    )\n\n        return {\n            \"traj_e\": traj_e,\n            \"vis_e\": vis_e > self.visibility_threshold,\n            \"vis_e_as_prob\": vis_e,\n        }\n\n\ndef get_uniformly_sampled_pts(\n        size: int,\n        num_frames: int,\n        extent: Tuple[float, ...],\n        device: Optional[torch.device] = torch.device(\"cpu\"),\n):\n    time_points = torch.randint(low=0, high=num_frames, size=(size, 1), device=device)\n    space_points = torch.rand(size, 2, device=device) * torch.tensor(\n        [extent[1], extent[0]], device=device\n    )\n    points = torch.cat((time_points, space_points), dim=1)\n    return points[None]\n\n\ndef get_superpoint_sampled_pts(\n        video,\n        size: int,\n        num_frames: int,\n        extent: Tuple[float, ...],\n        device: Optional[torch.device] = torch.device(\"cpu\"),\n):\n    extractor = SuperPoint(max_num_keypoints=48).eval().cuda()\n    points = list()\n    for _ in range(8):\n        frame_num = random.randint(0, int(num_frames * 0.25))\n        key_points = extractor.extract(\n            video[0, frame_num, :, :, :] / 255.0, resize=None\n        )[\"keypoints\"]\n        frame_tensor = torch.full((1, key_points.shape[1], 1), frame_num).cuda()\n        points.append(torch.cat([frame_tensor.cuda(), key_points], dim=2))\n    return torch.cat(points, dim=1)[:, :size, :]\n\n\ndef get_sift_sampled_pts(\n        video,\n        size: int,\n        num_frames: int,\n        extent: Tuple[float, ...],\n        device: Optional[torch.device] = torch.device(\"cpu\"),\n        num_sampled_frames: int = 8,\n        sampling_length_percent: float = 0.25,\n):\n    import cv2\n    # assert size == 384, \"hardcoded for experiment\"\n    sift = cv2.SIFT_create(nfeatures=size // num_sampled_frames)\n    points = list()\n    for _ in range(num_sampled_frames):\n        frame_num = random.randint(0, int(num_frames * sampling_length_percent))\n        key_points, _ = sift.detectAndCompute(\n            video[0, frame_num, :, :, :]\n            .cpu()\n            .permute(1, 2, 0)\n            .numpy()\n            .astype(np.uint8),\n            None,\n        )\n        for kp in key_points:\n            points.append([frame_num, int(kp.pt[0]), int(kp.pt[1])])\n    return torch.tensor(points[:size], device=device)[None]\n"
  },
  {
    "path": "mvtracker/utils/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "mvtracker/utils/basic.py",
    "content": "import os\nfrom datetime import datetime\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nEPS = 1e-6\n\n\ndef sub2ind(height, width, y, x):\n    return y * width + x\n\n\ndef ind2sub(height, width, ind):\n    y = ind // width\n    x = ind % width\n    return y, x\n\n\ndef get_lr_str(lr):\n    lrn = \"%.1e\" % lr  # e.g., 5.0e-04\n    lrn = lrn[0] + lrn[3:5] + lrn[-1]  # e.g., 5e-4\n    return lrn\n\n\ndef strnum(x):\n    s = '%g' % x\n    if '.' in s:\n        if x < 1.0:\n            s = s[s.index('.'):]\n        s = s[:min(len(s), 4)]\n    return s\n\n\ndef assert_same_shape(t1, t2):\n    for (x, y) in zip(list(t1.shape), list(t2.shape)):\n        assert (x == y)\n\n\ndef print_stats(name, tensor):\n    shape = tensor.shape\n    tensor = tensor.detach().cpu().numpy()\n    print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (\n        name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)\n\n\ndef print_stats_py(name, tensor):\n    shape = tensor.shape\n    print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (\n        name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)\n\n\ndef print_(name, tensor):\n    tensor = tensor.detach().cpu().numpy()\n    print(name, tensor, tensor.shape)\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef normalize_single(d):\n    # d is a whatever shape torch tensor\n    dmin = torch.min(d)\n    dmax = torch.max(d)\n    d = (d - dmin) / (EPS + (dmax - dmin))\n    return d\n\n\ndef normalize(d):\n    # d is B x whatever. normalize within each element of the batch\n    out = torch.zeros(d.size())\n    if d.is_cuda:\n        out = out.cuda()\n    B = list(d.size())[0]\n    for b in list(range(B)):\n        out[b] = normalize_single(d[b])\n    return out\n\n\ndef hard_argmax2d(tensor):\n    B, C, Y, X = list(tensor.shape)\n    assert (C == 1)\n\n    # flatten the Tensor along the height and width axes\n    flat_tensor = tensor.reshape(B, -1)\n    # argmax of the flat tensor\n    argmax = torch.argmax(flat_tensor, dim=1)\n\n    # convert the indices into 2d coordinates\n    argmax_y = torch.floor(argmax / X)  # row\n    argmax_x = argmax % X  # col\n\n    argmax_y = argmax_y.reshape(B)\n    argmax_x = argmax_x.reshape(B)\n    return argmax_y, argmax_x\n\n\ndef argmax2d(heat, hard=True):\n    B, C, Y, X = list(heat.shape)\n    assert (C == 1)\n\n    if hard:\n        # hard argmax\n        loc_y, loc_x = hard_argmax2d(heat)\n        loc_y = loc_y.float()\n        loc_x = loc_x.float()\n    else:\n        heat = heat.reshape(B, Y * X)\n        prob = torch.nn.functional.softmax(heat, dim=1)\n\n        grid_y, grid_x = meshgrid2d(B, Y, X)\n\n        grid_y = grid_y.reshape(B, -1)\n        grid_x = grid_x.reshape(B, -1)\n\n        loc_y = torch.sum(grid_y * prob, dim=1)\n        loc_x = torch.sum(grid_x * prob, dim=1)\n        # these are B\n\n    return loc_y, loc_x\n\n\ndef reduce_masked_mean(x, mask, dim=None, keepdim=False):\n    # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting\n    # returns shape-1\n    # axis can be a list of axes\n    for (a, b) in zip(x.size(), mask.size()):\n        # if not b==1: \n        assert (a == b)  # some shape mismatch!\n    # assert(x.size() == mask.size())\n    prod = x * mask\n    if dim is None:\n        numer = torch.sum(prod)\n        denom = EPS + torch.sum(mask)\n    else:\n        numer = torch.sum(prod, dim=dim, keepdim=keepdim)\n        denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)\n\n    mean = numer / denom\n    return mean\n\n\ndef reduce_masked_median(x, mask, keep_batch=False):\n    # x and mask are the same shape\n    assert (x.size() == mask.size())\n    device = x.device\n\n    B = list(x.shape)[0]\n    x = x.detach().cpu().numpy()\n    mask = mask.detach().cpu().numpy()\n\n    if keep_batch:\n        x = np.reshape(x, [B, -1])\n        mask = np.reshape(mask, [B, -1])\n        meds = np.zeros([B], np.float32)\n        for b in list(range(B)):\n            xb = x[b]\n            mb = mask[b]\n            if np.sum(mb) > 0:\n                xb = xb[mb > 0]\n                meds[b] = np.median(xb)\n            else:\n                meds[b] = np.nan\n        meds = torch.from_numpy(meds).to(device)\n        return meds.float()\n    else:\n        x = np.reshape(x, [-1])\n        mask = np.reshape(mask, [-1])\n        if np.sum(mask) > 0:\n            x = x[mask > 0]\n            med = np.median(x)\n        else:\n            med = np.nan\n        med = np.array([med], np.float32)\n        med = torch.from_numpy(med).to(device)\n        return med.float()\n\n\ndef pack_seqdim(tensor, B):\n    shapelist = list(tensor.shape)\n    B_, S = shapelist[:2]\n    assert (B == B_)\n    otherdims = shapelist[2:]\n    tensor = torch.reshape(tensor, [B * S] + otherdims)\n    return tensor\n\n\ndef unpack_seqdim(tensor, B):\n    shapelist = list(tensor.shape)\n    BS = shapelist[0]\n    assert (BS % B == 0)\n    otherdims = shapelist[1:]\n    S = int(BS / B)\n    tensor = torch.reshape(tensor, [B, S] + otherdims)\n    return tensor\n\n\ndef meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):\n    # returns a meshgrid sized B x Y x X\n\n    grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))\n    grid_y = torch.reshape(grid_y, [1, Y, 1])\n    grid_y = grid_y.repeat(B, 1, X)\n\n    grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))\n    grid_x = torch.reshape(grid_x, [1, 1, X])\n    grid_x = grid_x.repeat(B, Y, 1)\n\n    if norm:\n        grid_y, grid_x = normalize_grid2d(\n            grid_y, grid_x, Y, X)\n\n    if stack:\n        # note we stack in xy order\n        # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)\n        if on_chans:\n            grid = torch.stack([grid_x, grid_y], dim=1)\n        else:\n            grid = torch.stack([grid_x, grid_y], dim=-1)\n        return grid\n    else:\n        return grid_y, grid_x\n\n\ndef meshgrid3d(B, Z, Y, X, stack=False, norm=False, device='cuda'):\n    # returns a meshgrid sized B x Z x Y x X\n\n    grid_z = torch.linspace(0.0, Z - 1, Z, device=device)\n    grid_z = torch.reshape(grid_z, [1, Z, 1, 1])\n    grid_z = grid_z.repeat(B, 1, Y, X)\n\n    grid_y = torch.linspace(0.0, Y - 1, Y, device=device)\n    grid_y = torch.reshape(grid_y, [1, 1, Y, 1])\n    grid_y = grid_y.repeat(B, Z, 1, X)\n\n    grid_x = torch.linspace(0.0, X - 1, X, device=device)\n    grid_x = torch.reshape(grid_x, [1, 1, 1, X])\n    grid_x = grid_x.repeat(B, Z, Y, 1)\n\n    # if cuda:\n    #     grid_z = grid_z.cuda()\n    #     grid_y = grid_y.cuda()\n    #     grid_x = grid_x.cuda()\n\n    if norm:\n        grid_z, grid_y, grid_x = normalize_grid3d(\n            grid_z, grid_y, grid_x, Z, Y, X)\n\n    if stack:\n        # note we stack in xyz order\n        # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)\n        grid = torch.stack([grid_x, grid_y, grid_z], dim=-1)\n        return grid\n    else:\n        return grid_z, grid_y, grid_x\n\n\ndef normalize_grid2d(grid_y, grid_x, Y, X, clamp_extreme=True):\n    # make things in [-1,1]\n    grid_y = 2.0 * (grid_y / float(Y - 1)) - 1.0\n    grid_x = 2.0 * (grid_x / float(X - 1)) - 1.0\n\n    if clamp_extreme:\n        grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)\n        grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)\n\n    return grid_y, grid_x\n\n\ndef normalize_grid3d(grid_z, grid_y, grid_x, Z, Y, X, clamp_extreme=True):\n    # make things in [-1,1]\n    grid_z = 2.0 * (grid_z / float(Z - 1)) - 1.0\n    grid_y = 2.0 * (grid_y / float(Y - 1)) - 1.0\n    grid_x = 2.0 * (grid_x / float(X - 1)) - 1.0\n\n    if clamp_extreme:\n        grid_z = torch.clamp(grid_z, min=-2.0, max=2.0)\n        grid_y = torch.clamp(grid_y, min=-2.0, max=2.0)\n        grid_x = torch.clamp(grid_x, min=-2.0, max=2.0)\n\n    return grid_z, grid_y, grid_x\n\n\ndef gridcloud2d(B, Y, X, norm=False, device='cuda'):\n    # we want to sample for each location in the grid\n    grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)\n    x = torch.reshape(grid_x, [B, -1])\n    y = torch.reshape(grid_y, [B, -1])\n    # these are B x N\n    xy = torch.stack([x, y], dim=2)\n    # this is B x N x 2\n    return xy\n\n\ndef gridcloud3d(B, Z, Y, X, norm=False, device='cuda'):\n    # we want to sample for each location in the grid\n    grid_z, grid_y, grid_x = meshgrid3d(B, Z, Y, X, norm=norm, device=device)\n    x = torch.reshape(grid_x, [B, -1])\n    y = torch.reshape(grid_y, [B, -1])\n    z = torch.reshape(grid_z, [B, -1])\n    # these are B x N\n    xyz = torch.stack([x, y, z], dim=2)\n    # this is B x N x 3\n    return xyz\n\n\nimport re\n\n\ndef readPFM(file):\n    file = open(file, 'rb')\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header == b'PF':\n        color = True\n    elif header == b'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(rb'^(\\d+)\\s(\\d+)\\s$', file.readline())\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data\n\n\ndef normalize_boxlist2d(boxlist2d, H, W):\n    boxlist2d = boxlist2d.clone()\n    ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)\n    ymin = ymin / float(H)\n    ymax = ymax / float(H)\n    xmin = xmin / float(W)\n    xmax = xmax / float(W)\n    boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)\n    return boxlist2d\n\n\ndef unnormalize_boxlist2d(boxlist2d, H, W):\n    boxlist2d = boxlist2d.clone()\n    ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)\n    ymin = ymin * float(H)\n    ymax = ymax * float(H)\n    xmin = xmin * float(W)\n    xmax = xmax * float(W)\n    boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)\n    return boxlist2d\n\n\ndef unnormalize_box2d(box2d, H, W):\n    return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)\n\n\ndef normalize_box2d(box2d, H, W):\n    return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)\n\n\ndef get_gaussian_kernel_2d(channels, kernel_size=3, sigma=2.0, mid_one=False):\n    C = channels\n    xy_grid = gridcloud2d(C, kernel_size, kernel_size)  # C x N x 2\n\n    mean = (kernel_size - 1) / 2.0\n    variance = sigma ** 2.0\n\n    gaussian_kernel = (1.0 / (2.0 * np.pi * variance) ** 1.5) * torch.exp(\n        -torch.sum((xy_grid - mean) ** 2.0, dim=-1) / (2.0 * variance))  # C X N\n    gaussian_kernel = gaussian_kernel.view(C, 1, kernel_size, kernel_size)  # C x 1 x 3 x 3\n    kernel_sum = torch.sum(gaussian_kernel, dim=(2, 3), keepdim=True)\n\n    gaussian_kernel = gaussian_kernel / kernel_sum  # normalize\n\n    if mid_one:\n        # normalize so that the middle element is 1\n        maxval = gaussian_kernel[:, :, (kernel_size // 2), (kernel_size // 2)].reshape(C, 1, 1, 1)\n        gaussian_kernel = gaussian_kernel / maxval\n\n    return gaussian_kernel\n\n\ndef gaussian_blur_2d(input, kernel_size=3, sigma=2.0, reflect_pad=False, mid_one=False):\n    B, C, Z, X = input.shape\n    kernel = get_gaussian_kernel_2d(C, kernel_size, sigma, mid_one=mid_one)\n    if reflect_pad:\n        pad = (kernel_size - 1) // 2\n        out = F.pad(input, (pad, pad, pad, pad), mode='reflect')\n        out = F.conv2d(out, kernel, padding=0, groups=C)\n    else:\n        out = F.conv2d(input, kernel, padding=(kernel_size - 1) // 2, groups=C)\n    return out\n\n\ndef gradient2d(x, absolute=False, square=False, return_sum=False):\n    # x should be B x C x H x W\n    dh = x[:, :, 1:, :] - x[:, :, :-1, :]\n    dw = x[:, :, :, 1:] - x[:, :, :, :-1]\n\n    zeros = torch.zeros_like(x)\n    zero_h = zeros[:, :, 0:1, :]\n    zero_w = zeros[:, :, :, 0:1]\n    dh = torch.cat([dh, zero_h], axis=2)\n    dw = torch.cat([dw, zero_w], axis=3)\n    if absolute:\n        dh = torch.abs(dh)\n        dw = torch.abs(dw)\n    if square:\n        dh = dh ** 2\n        dw = dw ** 2\n    if return_sum:\n        return dh + dw\n    else:\n        return dh, dw\n\n\ndef to_homogeneous(x):\n    return torch.cat([x, x.new_ones(x[..., :1].shape)], -1)\n\n\ndef from_homogeneous(x, assert_homogeneous_part_is_equal_to_1=False, eps=0.1):\n    if assert_homogeneous_part_is_equal_to_1:\n        assert torch.allclose(x[..., -1], x.new_ones(x[..., -1].shape), atol=eps)\n    return x[..., :-1] / x[..., -1:]\n\n\ndef time_now():\n    return datetime.now().strftime(\"%Y%m%d_%H%M%S_%f\")\n"
  },
  {
    "path": "mvtracker/utils/eval_utils.py",
    "content": "import os\n\nimport matplotlib\nimport numpy as np\nimport rerun as rr\nimport json\nfrom tqdm import tqdm\nfrom scipy.stats import multivariate_normal\n\n\n\ndef medianTrajError(output, target):\n\n    diff = np.linalg.norm(target - output, axis = 1)\n    orderedDiff = np.sort(diff)\n\n    return orderedDiff[len(orderedDiff)//2]\n\n\ndef averageTrajError(output, target):\n\n    diff = np.linalg.norm(target - output, axis = 1)\n\n    return np.mean(diff, axis = 0)\n\n\ndef pointTrack(queryPoint, anchorPos, anchorRot):\n    R = qToRot(anchorRot[0])\n    \n    t0 = R.T@(queryPoint - anchorPos[0])\n    track = []\n    for idx in tqdm(range(len(anchorPos)), 'Track', position = 1, leave = False):\n        track.append(anchorPos[idx] + qToRot(anchorRot[idx])@t0)\n    \n    return np.array(track)\n\ndef qToRot(q):\n    norm = np.linalg.norm(q)\n    r = q[0]/norm\n    x = q[1]/norm\n    y = q[2]/norm\n    z = q[3]/norm\n\n\n    R = np.array(\n        [[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - r * z), 2.0 * (x * z + r * y)],\n        [2.0 * (x * y + r * z), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - r * x)],\n        [2.0 * (x * z - r * y), 2.0 * (y * z + r * x), 1.0 - 2.0 * (x * x + y * y)]]\n    )\n    return R\n\ndef get3DCov(scale, rotation, scale_mod = 1):\n    \n    S = np.zeros((3,3))\n    S[0][0] = scale_mod * scale[0]\n    S[1][1] = scale_mod * scale[1]\n    S[2][2] = scale_mod * scale[2]\n\n    R = qToRot(rotation)\n    M = S * R\n    \n    sigma = np.transpose(M) * M\n    \n    return sigma\n\ndef getAll3DCov(scales, rotations, scale_mod = 1):\n\n    cov3Ds = []\n    for idx in tqdm(range(len(scales)), 'Cov'):\n        cov3Ds.append(get3DCov(scales[idx], rotations[idx], scale_mod))\n\n    return np.array(cov3Ds)\n\ndef getContributions(mean3Ds, cov3Ds, query):\n\n    assert len(mean3Ds) == len(cov3Ds), f'{mean3Ds.shape} {cov3Ds.shape}'\n\n    PDFs = []\n\n    for idx in tqdm(range(len(mean3Ds)),'PDF', position = 1, leave = False):\n        try:\n            pdf = multivariate_normal.pdf(query, mean = mean3Ds[idx], cov = cov3Ds[idx])\n            PDFs.append(pdf)\n        except:\n            PDFs.append(-1)\n\n    return np.array(PDFs)\n    \n\n"
  },
  {
    "path": "mvtracker/utils/geom.py",
    "content": "import numpy as np\nimport torch\nimport torchvision.ops as ops\n\n\ndef matmul2(mat1, mat2):\n    return torch.matmul(mat1, mat2)\n\n\ndef matmul3(mat1, mat2, mat3):\n    return torch.matmul(mat1, torch.matmul(mat2, mat3))\n\n\ndef eye_3x3(B, device='cuda'):\n    rt = torch.eye(3, device=torch.device(device)).view(1, 3, 3).repeat([B, 1, 1])\n    return rt\n\n\ndef eye_4x4(B, device='cuda'):\n    rt = torch.eye(4, device=torch.device(device)).view(1, 4, 4).repeat([B, 1, 1])\n    return rt\n\n\ndef safe_inverse(a):  # parallel version\n    B, _, _ = list(a.shape)\n    inv = a.clone()\n    r_transpose = a[:, :3, :3].transpose(1, 2)  # inverse of rotation matrix\n\n    inv[:, :3, :3] = r_transpose\n    inv[:, :3, 3:4] = -torch.matmul(r_transpose, a[:, :3, 3:4])\n\n    return inv\n\n\ndef safe_inverse_single(a):\n    r, t = split_rt_single(a)\n    t = t.view(3, 1)\n    r_transpose = r.t()\n    inv = torch.cat([r_transpose, -torch.matmul(r_transpose, t)], 1)\n    bottom_row = a[3:4, :]  # this is [0, 0, 0, 1]\n    # bottom_row = torch.tensor([0.,0.,0.,1.]).view(1,4)\n    inv = torch.cat([inv, bottom_row], 0)\n    return inv\n\n\ndef split_intrinsics(K):\n    # K is B x 3 x 3 or B x 4 x 4\n    fx = K[:, 0, 0]\n    fy = K[:, 1, 1]\n    x0 = K[:, 0, 2]\n    y0 = K[:, 1, 2]\n    return fx, fy, x0, y0\n\n\ndef apply_pix_T_cam(pix_T_cam, xyz):\n    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)\n\n    # xyz is shaped B x H*W x 3\n    # returns xy, shaped B x H*W x 2\n\n    B, N, C = list(xyz.shape)\n    assert (C == 3)\n\n    x, y, z = torch.unbind(xyz, axis=-1)\n\n    fx = torch.reshape(fx, [B, 1])\n    fy = torch.reshape(fy, [B, 1])\n    x0 = torch.reshape(x0, [B, 1])\n    y0 = torch.reshape(y0, [B, 1])\n\n    EPS = 1e-4\n    z = torch.clamp(z, min=EPS)\n    x = (x * fx) / (z) + x0\n    y = (y * fy) / (z) + y0\n    xy = torch.stack([x, y], axis=-1)\n    return xy\n\n\ndef apply_pix_T_cam_py(pix_T_cam, xyz):\n    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)\n\n    # xyz is shaped B x H*W x 3\n    # returns xy, shaped B x H*W x 2\n\n    B, N, C = list(xyz.shape)\n    assert (C == 3)\n\n    x, y, z = xyz[:, :, 0], xyz[:, :, 1], xyz[:, :, 2]\n\n    fx = np.reshape(fx, [B, 1])\n    fy = np.reshape(fy, [B, 1])\n    x0 = np.reshape(x0, [B, 1])\n    y0 = np.reshape(y0, [B, 1])\n\n    EPS = 1e-4\n    z = np.clip(z, EPS, None)\n    x = (x * fx) / (z) + x0\n    y = (y * fy) / (z) + y0\n    xy = np.stack([x, y], axis=-1)\n    return xy\n\n\ndef get_camM_T_camXs(origin_T_camXs, ind=0):\n    B, S = list(origin_T_camXs.shape)[0:2]\n    camM_T_camXs = torch.zeros_like(origin_T_camXs)\n    for b in list(range(B)):\n        camM_T_origin = safe_inverse_single(origin_T_camXs[b, ind])\n        for s in list(range(S)):\n            camM_T_camXs[b, s] = torch.matmul(camM_T_origin, origin_T_camXs[b, s])\n    return camM_T_camXs\n\n\ndef apply_4x4(RT, xyz):\n    B, N, _ = list(xyz.shape)\n    ones = torch.ones_like(xyz[:, :, 0:1])\n    xyz1 = torch.cat([xyz, ones], 2)\n    xyz1_t = torch.transpose(xyz1, 1, 2)\n    # this is B x 4 x N\n    xyz2_t = torch.matmul(RT, xyz1_t)\n    xyz2 = torch.transpose(xyz2_t, 1, 2)\n    xyz2 = xyz2[:, :, :3]\n    return xyz2\n\n\ndef apply_4x4_py(RT, xyz):\n    # print('RT', RT.shape)\n    B, N, _ = list(xyz.shape)\n    ones = np.ones_like(xyz[:, :, 0:1])\n    xyz1 = np.concatenate([xyz, ones], 2)\n    # print('xyz1', xyz1.shape)\n    xyz1_t = xyz1.transpose(0, 2, 1)\n    # print('xyz1_t', xyz1_t.shape)\n    # this is B x 4 x N\n    xyz2_t = np.matmul(RT, xyz1_t)\n    # print('xyz2_t', xyz2_t.shape)\n    xyz2 = xyz2_t.transpose(0, 2, 1)\n    # print('xyz2', xyz2.shape)\n    xyz2 = xyz2[:, :, :3]\n    return xyz2\n\n\ndef apply_3x3(RT, xy):\n    B, N, _ = list(xy.shape)\n    ones = torch.ones_like(xy[:, :, 0:1])\n    xy1 = torch.cat([xy, ones], 2)\n    xy1_t = torch.transpose(xy1, 1, 2)\n    # this is B x 4 x N\n    xy2_t = torch.matmul(RT, xy1_t)\n    xy2 = torch.transpose(xy2_t, 1, 2)\n    xy2 = xy2[:, :, :2]\n    return xy2\n\n\ndef generate_polygon(ctr_x, ctr_y, avg_r, irregularity, spikiness, num_verts):\n    '''\n    Start with the center of the polygon at ctr_x, ctr_y, \n    Then creates the polygon by sampling points on a circle around the center.\n    Random noise is added by varying the angular spacing between sequential points,\n    and by varying the radial distance of each point from the centre.\n\n    Params:\n        ctr_x, ctr_y - coordinates of the \"centre\" of the polygon\n        avg_r - in px, the average radius of this polygon, this roughly controls how large the polygon is, really only useful for order of magnitude.\n        irregularity - [0,1] indicating how much variance there is in the angular spacing of vertices. [0,1] will map to [0, 2pi/numberOfVerts]\n        spikiness - [0,1] indicating how much variance there is in each vertex from the circle of radius avg_r. [0,1] will map to [0, avg_r]\npp        num_verts\n\n    Returns:\n        np.array [num_verts, 2] - CCW order.\n    '''\n    # spikiness\n    spikiness = np.clip(spikiness, 0, 1) * avg_r\n\n    # generate n angle steps\n    irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / num_verts\n    lower = (2 * np.pi / num_verts) - irregularity\n    upper = (2 * np.pi / num_verts) + irregularity\n\n    # angle steps\n    angle_steps = np.random.uniform(lower, upper, num_verts)\n    sc = (2 * np.pi) / angle_steps.sum()\n    angle_steps *= sc\n\n    # get all radii\n    angle = np.random.uniform(0, 2 * np.pi)\n    radii = np.clip(np.random.normal(avg_r, spikiness, num_verts), 0, 2 * avg_r)\n\n    # compute all points\n    points = []\n    for i in range(num_verts):\n        x = ctr_x + radii[i] * np.cos(angle)\n        y = ctr_y + radii[i] * np.sin(angle)\n        points.append([x, y])\n        angle += angle_steps[i]\n\n    return np.array(points).astype(int)\n\n\ndef get_random_affine_2d(B, rot_min=-5.0, rot_max=5.0, tx_min=-0.1, tx_max=0.1, ty_min=-0.1, ty_max=0.1, sx_min=-0.05,\n                         sx_max=0.05, sy_min=-0.05, sy_max=0.05, shx_min=-0.05, shx_max=0.05, shy_min=-0.05,\n                         shy_max=0.05):\n    '''\n    Params:\n        rot_min: rotation amount min\n        rot_max: rotation amount max\n\n        tx_min: translation x min\n        tx_max: translation x max\n\n        ty_min: translation y min\n        ty_max: translation y max\n\n        sx_min: scaling x min\n        sx_max: scaling x max\n\n        sy_min: scaling y min\n        sy_max: scaling y max\n\n        shx_min: shear x min\n        shx_max: shear x max\n\n        shy_min: shear y min\n        shy_max: shear y max\n\n    Returns:\n        transformation matrix: (B, 3, 3)\n    '''\n    # rotation\n    if rot_max - rot_min != 0:\n        rot_amount = np.random.uniform(low=rot_min, high=rot_max, size=B)\n        rot_amount = np.pi / 180.0 * rot_amount\n    else:\n        rot_amount = rot_min\n    rotation = np.zeros((B, 3, 3))  # B, 3, 3\n    rotation[:, 2, 2] = 1\n    rotation[:, 0, 0] = np.cos(rot_amount)\n    rotation[:, 0, 1] = -np.sin(rot_amount)\n    rotation[:, 1, 0] = np.sin(rot_amount)\n    rotation[:, 1, 1] = np.cos(rot_amount)\n\n    # translation\n    translation = np.zeros((B, 3, 3))  # B, 3, 3\n    translation[:, [0, 1, 2], [0, 1, 2]] = 1\n    if (tx_max - tx_min) > 0:\n        trans_x = np.random.uniform(low=tx_min, high=tx_max, size=B)\n        translation[:, 0, 2] = trans_x\n    # else:\n    #     translation[:, 0, 2] = tx_max\n    if ty_max - ty_min != 0:\n        trans_y = np.random.uniform(low=ty_min, high=ty_max, size=B)\n        translation[:, 1, 2] = trans_y\n    # else:\n    #     translation[:, 1, 2] = ty_max\n\n    # scaling\n    scaling = np.zeros((B, 3, 3))  # B, 3, 3\n    scaling[:, [0, 1, 2], [0, 1, 2]] = 1\n    if (sx_max - sx_min) > 0:\n        scale_x = 1 + np.random.uniform(low=sx_min, high=sx_max, size=B)\n        scaling[:, 0, 0] = scale_x\n    # else:\n    #     scaling[:, 0, 0] = sx_max\n    if (sy_max - sy_min) > 0:\n        scale_y = 1 + np.random.uniform(low=sy_min, high=sy_max, size=B)\n        scaling[:, 1, 1] = scale_y\n    # else:\n    #     scaling[:, 1, 1] = sy_max\n\n    # shear\n    shear = np.zeros((B, 3, 3))  # B, 3, 3\n    shear[:, [0, 1, 2], [0, 1, 2]] = 1\n    if (shx_max - shx_min) > 0:\n        shear_x = np.random.uniform(low=shx_min, high=shx_max, size=B)\n        shear[:, 0, 1] = shear_x\n    # else:\n    #     shear[:, 0, 1] = shx_max\n    if (shy_max - shy_min) > 0:\n        shear_y = np.random.uniform(low=shy_min, high=shy_max, size=B)\n        shear[:, 1, 0] = shear_y\n    # else:\n    #     shear[:, 1, 0] = shy_max\n\n    # compose all those\n    rt = np.einsum(\"ijk,ikl->ijl\", rotation, translation)\n    ss = np.einsum(\"ijk,ikl->ijl\", scaling, shear)\n    trans = np.einsum(\"ijk,ikl->ijl\", rt, ss)\n\n    return trans\n\n\ndef get_centroid_from_box2d(box2d):\n    ymin = box2d[:, 0]\n    xmin = box2d[:, 1]\n    ymax = box2d[:, 2]\n    xmax = box2d[:, 3]\n    x = (xmin + xmax) / 2.0\n    y = (ymin + ymax) / 2.0\n    return y, x\n\n\ndef normalize_boxlist2d(boxlist2d, H, W):\n    boxlist2d = boxlist2d.clone()\n    ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)\n    ymin = ymin / float(H)\n    ymax = ymax / float(H)\n    xmin = xmin / float(W)\n    xmax = xmax / float(W)\n    boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)\n    return boxlist2d\n\n\ndef unnormalize_boxlist2d(boxlist2d, H, W):\n    boxlist2d = boxlist2d.clone()\n    ymin, xmin, ymax, xmax = torch.unbind(boxlist2d, dim=2)\n    ymin = ymin * float(H)\n    ymax = ymax * float(H)\n    xmin = xmin * float(W)\n    xmax = xmax * float(W)\n    boxlist2d = torch.stack([ymin, xmin, ymax, xmax], dim=2)\n    return boxlist2d\n\n\ndef unnormalize_box2d(box2d, H, W):\n    return unnormalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)\n\n\ndef normalize_box2d(box2d, H, W):\n    return normalize_boxlist2d(box2d.unsqueeze(1), H, W).squeeze(1)\n\n\ndef get_size_from_box2d(box2d):\n    ymin = box2d[:, 0]\n    xmin = box2d[:, 1]\n    ymax = box2d[:, 2]\n    xmax = box2d[:, 3]\n    height = ymax - ymin\n    width = xmax - xmin\n    return height, width\n\n\ndef crop_and_resize(im, boxlist, PH, PW, boxlist_is_normalized=False):\n    B, C, H, W = im.shape\n    B2, N, D = boxlist.shape\n    assert (B == B2)\n    assert (D == 4)\n    # PH, PW is the size to resize to\n\n    # output is B,N,C,PH,PW\n\n    # pt wants xy xy, unnormalized\n    if boxlist_is_normalized:\n        boxlist_unnorm = unnormalize_boxlist2d(boxlist, H, W)\n    else:\n        boxlist_unnorm = boxlist\n\n    ymin, xmin, ymax, xmax = boxlist_unnorm.unbind(2)\n    # boxlist_pt = torch.stack([boxlist_unnorm[:,1], boxlist_unnorm[:,0], boxlist_unnorm[:,3], boxlist_unnorm[:,2]], dim=1)\n    boxlist_pt = torch.stack([xmin, ymin, xmax, ymax], dim=2)\n    # we want a B-len list of K x 4 arrays\n\n    # print('im', im.shape)\n    # print('boxlist', boxlist.shape)\n    # print('boxlist_pt', boxlist_pt.shape)\n\n    # boxlist_pt = list(boxlist_pt.unbind(0))\n\n    crops = []\n    for b in range(B):\n        crops_b = ops.roi_align(im[b:b + 1], [boxlist_pt[b]], output_size=(PH, PW))\n        crops.append(crops_b)\n    # # crops = im\n\n    # print('crops', crops.shape)\n    # crops = crops.reshape(B,N,C,PH,PW)\n\n    # crops = []\n    # for b in range(B):\n    #     crop_b = ops.roi_align(im[b:b+1], [boxlist_pt[b]], output_size=(PH, PW))\n    #     print('crop_b', crop_b.shape)\n    #     crops.append(crop_b)\n    crops = torch.stack(crops, dim=0)\n\n    # print('crops', crops.shape)\n    # boxlist_list = boxlist_pt.unbind(0)\n    # print('rgb_crop', rgb_crop.shape)\n\n    return crops\n\n\n# def get_boxlist_from_centroid_and_size(cy, cx, h, w, clip=True):\n#     # cy,cx are both B,N\n#     ymin = cy - h/2\n#     ymax = cy + h/2\n#     xmin = cx - w/2\n#     xmax = cx + w/2\n\n#     box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)\n#     if clip:\n#         box = torch.clamp(box, 0, 1)\n#     return box\n\n\ndef get_boxlist_from_centroid_and_size(cy, cx, h, w):  # , clip=False):\n    # cy,cx are the same shape\n    ymin = cy - h / 2\n    ymax = cy + h / 2\n    xmin = cx - w / 2\n    xmax = cx + w / 2\n\n    # if clip:\n    #     ymin = torch.clamp(ymin, 0, H-1)\n    #     ymax = torch.clamp(ymax, 0, H-1)\n    #     xmin = torch.clamp(xmin, 0, W-1)\n    #     xmax = torch.clamp(xmax, 0, W-1)\n\n    box = torch.stack([ymin, xmin, ymax, xmax], dim=-1)\n    return box\n\n\ndef get_box2d_from_mask(mask, normalize=False):\n    # mask is B, 1, H, W\n\n    B, C, H, W = mask.shape\n    assert (C == 1)\n    xy = utils.basic.gridcloud2d(B, H, W, norm=False, device=mask.device)  # B, H*W, 2\n\n    box = torch.zeros((B, 4), dtype=torch.float32, device=mask.device)\n    for b in range(B):\n        xy_b = xy[b]  # H*W, 2\n        mask_b = mask[b].reshape(H * W)\n        xy_ = xy_b[mask_b > 0]\n        x_ = xy_[:, 0]\n        y_ = xy_[:, 1]\n        ymin = torch.min(y_)\n        ymax = torch.max(y_)\n        xmin = torch.min(x_)\n        xmax = torch.max(x_)\n        box[b] = torch.stack([ymin, xmin, ymax, xmax], dim=0)\n    if normalize:\n        box = normalize_boxlist2d(box.unsqueeze(1), H, W).squeeze(1)\n    return box\n\n\ndef convert_box2d_to_intrinsics(box2d, pix_T_cam, H, W, use_image_aspect_ratio=True, mult_padding=1.0):\n    # box2d is B x 4, with ymin, xmin, ymax, xmax in normalized coords\n    # ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)\n    # H, W is the original size of the image\n    # mult_padding is relative to object size in pixels\n\n    # i assume we're rendering an image the same size as the original (H, W)\n\n    if not mult_padding == 1.0:\n        y, x = get_centroid_from_box2d(box2d)\n        h, w = get_size_from_box2d(box2d)\n        box2d = get_box2d_from_centroid_and_size(\n            y, x, h * mult_padding, w * mult_padding, clip=False)\n\n    if use_image_aspect_ratio:\n        h, w = get_size_from_box2d(box2d)\n        y, x = get_centroid_from_box2d(box2d)\n\n        # note h,w are relative right now\n        # we need to undo this, to see the real ratio\n\n        h = h * float(H)\n        w = w * float(W)\n        box_ratio = h / w\n        im_ratio = H / float(W)\n\n        # print('box_ratio:', box_ratio)\n        # print('im_ratio:', im_ratio)\n\n        if box_ratio >= im_ratio:\n            w = h / im_ratio\n            # print('setting w:', h/im_ratio)\n        else:\n            h = w * im_ratio\n            # print('setting h:', w*im_ratio)\n\n        box2d = get_box2d_from_centroid_and_size(\n            y, x, h / float(H), w / float(W), clip=False)\n\n    assert (h > 1e-4)\n    assert (w > 1e-4)\n\n    ymin, xmin, ymax, xmax = torch.unbind(box2d, dim=1)\n\n    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)\n\n    # the topleft of the new image will now have a different offset from the center of projection\n\n    new_x0 = x0 - xmin * W\n    new_y0 = y0 - ymin * H\n\n    pix_T_cam = pack_intrinsics(fx, fy, new_x0, new_y0)\n    # this alone will give me an image in original resolution,\n    # with its topleft at the box corner\n\n    box_h, box_w = get_size_from_box2d(box2d)\n    # these are normalized, and shaped B. (e.g., [0.4], [0.3])\n\n    # we are going to scale the image by the inverse of this,\n    # since we are zooming into this area\n\n    sy = 1. / box_h\n    sx = 1. / box_w\n\n    pix_T_cam = scale_intrinsics(pix_T_cam, sx, sy)\n    return pix_T_cam, box2d\n\n\ndef pixels2camera(x, y, z, fx, fy, x0, y0):\n    # x and y are locations in pixel coordinates, z is a depth in meters\n    # they can be images or pointclouds\n    # fx, fy, x0, y0 are camera intrinsics\n    # returns xyz, sized B x N x 3\n\n    B = x.shape[0]\n\n    fx = torch.reshape(fx, [B, 1])\n    fy = torch.reshape(fy, [B, 1])\n    x0 = torch.reshape(x0, [B, 1])\n    y0 = torch.reshape(y0, [B, 1])\n\n    x = torch.reshape(x, [B, -1])\n    y = torch.reshape(y, [B, -1])\n    z = torch.reshape(z, [B, -1])\n\n    # unproject\n    x = (z / fx) * (x - x0)\n    y = (z / fy) * (y - y0)\n\n    xyz = torch.stack([x, y, z], dim=2)\n    # B x N x 3\n    return xyz\n\n\ndef camera2pixels(xyz, pix_T_cam):\n    # xyz is shaped B x H*W x 3\n    # returns xy, shaped B x H*W x 2\n\n    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)\n    x, y, z = torch.unbind(xyz, dim=-1)\n    B = list(z.shape)[0]\n\n    fx = torch.reshape(fx, [B, 1])\n    fy = torch.reshape(fy, [B, 1])\n    x0 = torch.reshape(x0, [B, 1])\n    y0 = torch.reshape(y0, [B, 1])\n    x = torch.reshape(x, [B, -1])\n    y = torch.reshape(y, [B, -1])\n    z = torch.reshape(z, [B, -1])\n\n    EPS = 1e-4\n    z = torch.clamp(z, min=EPS)\n    x = (x * fx) / z + x0\n    y = (y * fy) / z + y0\n    xy = torch.stack([x, y], dim=-1)\n    return xy\n\n\ndef depth2pointcloud(z, pix_T_cam):\n    B, C, H, W = list(z.shape)\n    device = z.device\n    y, x = utils.basic.meshgrid2d(B, H, W, device=device)\n    z = torch.reshape(z, [B, H, W])\n    fx, fy, x0, y0 = split_intrinsics(pix_T_cam)\n    xyz = pixels2camera(x, y, z, fx, fy, x0, y0)\n    return xyz\n"
  },
  {
    "path": "mvtracker/utils/improc.py",
    "content": "import cv2\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision\nfrom matplotlib import cm\nfrom sklearn.decomposition import PCA\n\nEPS = 1e-6\n\nfrom skimage.color import (\n    hsv2rgb)\n\n\ndef _convert(input_, type_):\n    return {\n        'float': input_.float(),\n        'double': input_.double(),\n    }.get(type_, input_)\n\n\ndef _generic_transform_sk_3d(transform, in_type='', out_type=''):\n    def apply_transform_individual(input_):\n        device = input_.device\n        input_ = input_.cpu()\n        input_ = _convert(input_, in_type)\n\n        input_ = input_.permute(1, 2, 0).detach().numpy()\n        transformed = transform(input_)\n        output = torch.from_numpy(transformed).float().permute(2, 0, 1)\n        output = _convert(output, out_type)\n        return output.to(device)\n\n    def apply_transform(input_):\n        to_stack = []\n        for image in input_:\n            to_stack.append(apply_transform_individual(image))\n        return torch.stack(to_stack)\n\n    return apply_transform\n\n\nhsv_to_rgb = _generic_transform_sk_3d(hsv2rgb)\n\n\ndef preprocess_color_tf(x):\n    import tensorflow as tf\n    return tf.cast(x, tf.float32) * 1. / 255 - 0.5\n\n\ndef preprocess_color(x):\n    if isinstance(x, np.ndarray):\n        return x.astype(np.float32) * 1. / 255 - 0.5\n    else:\n        return x.float() * 1. / 255 - 0.5\n\n\ndef pca_embed(emb, keep, valid=None):\n    ## emb -- [S,H/2,W/2,C]\n    ## keep is the number of principal components to keep\n    ## Helper function for reduce_emb.\n    emb = emb + EPS\n    # emb is B x C x H x W\n    emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy()  # this is B x H x W x C\n\n    if valid:\n        valid = valid.cpu().detach().numpy().reshape((H * W))\n\n    emb_reduced = list()\n\n    B, H, W, C = np.shape(emb)\n    for img in emb:\n        if np.isnan(img).any():\n            emb_reduced.append(np.zeros([H, W, keep]))\n            continue\n\n        pixels_kd = np.reshape(img, (H * W, C))\n\n        if valid:\n            pixels_kd_pca = pixels_kd[valid]\n        else:\n            pixels_kd_pca = pixels_kd\n\n        P = PCA(keep)\n        P.fit(pixels_kd_pca)\n\n        if valid:\n            pixels3d = P.transform(pixels_kd) * valid\n        else:\n            pixels3d = P.transform(pixels_kd)\n\n        out_img = np.reshape(pixels3d, [H, W, keep]).astype(np.float32)\n        if np.isnan(out_img).any():\n            emb_reduced.append(np.zeros([H, W, keep]))\n            continue\n\n        emb_reduced.append(out_img)\n\n    emb_reduced = np.stack(emb_reduced, axis=0).astype(np.float32)\n\n    return torch.from_numpy(emb_reduced).permute(0, 3, 1, 2)\n\n\ndef pca_embed_together(emb, keep):\n    ## emb -- [S,H/2,W/2,C]\n    ## keep is the number of principal components to keep\n    ## Helper function for reduce_emb.\n    emb = emb + EPS\n    # emb is B x C x H x W\n    emb = emb.permute(0, 2, 3, 1).cpu().detach().numpy()  # this is B x H x W x C\n\n    B, H, W, C = np.shape(emb)\n    if np.isnan(emb).any():\n        return torch.zeros(B, keep, H, W)\n\n    pixelskd = np.reshape(emb, (B * H * W, C))\n    P = PCA(keep)\n    P.fit(pixelskd)\n    pixels3d = P.transform(pixelskd)\n    out_img = np.reshape(pixels3d, [B, H, W, keep]).astype(np.float32)\n\n    if np.isnan(out_img).any():\n        return torch.zeros(B, keep, H, W)\n\n    return torch.from_numpy(out_img).permute(0, 3, 1, 2)\n\n\ndef reduce_emb(emb, valid=None, inbound=None, together=False):\n    ## emb -- [S,C,H/2,W/2], inbound -- [S,1,H/2,W/2]\n    ## Reduce number of chans to 3 with PCA. For vis.\n    # S,H,W,C = emb.shape.as_list()\n    S, C, H, W = list(emb.size())\n    keep = 3\n\n    if together:\n        reduced_emb = pca_embed_together(emb, keep)\n    else:\n        reduced_emb = pca_embed(emb, keep, valid)  # not im\n\n    reduced_emb = utils.basic.normalize(reduced_emb) - 0.5\n    if inbound is not None:\n        emb_inbound = emb * inbound\n    else:\n        emb_inbound = None\n\n    return reduced_emb, emb_inbound\n\n\ndef get_feat_pca(feat, valid=None):\n    B, C, D, W = list(feat.size())\n    # feat is B x C x D x W. If 3D input, average it through Height dimension before passing into this function.\n\n    pca, _ = reduce_emb(feat, valid=valid, inbound=None, together=True)\n    # pca is B x 3 x W x D\n    return pca\n\n\ndef gif_and_tile(ims, just_gif=False):\n    S = len(ims)\n    # each im is B x H x W x C\n    # i want a gif in the left, and the tiled frames on the right\n    # for the gif tool, this means making a B x S x H x W tensor\n    # where the leftmost part is sequential and the rest is tiled\n    gif = torch.stack(ims, dim=1)\n    if just_gif:\n        return gif\n    til = torch.cat(ims, dim=2)\n    til = til.unsqueeze(dim=1).repeat(1, S, 1, 1, 1)\n    im = torch.cat([gif, til], dim=3)\n    return im\n\n\ndef back2color(i, blacken_zeros=False):\n    if blacken_zeros:\n        const = torch.tensor([-0.5])\n        i = torch.where(i == 0.0, const.cuda() if i.is_cuda else const, i)\n        return back2color(i)\n    else:\n        return ((i + 0.5) * 255).type(torch.ByteTensor)\n\n\ndef convert_occ_to_height(occ, reduce_axis=3):\n    B, C, D, H, W = list(occ.shape)\n    assert (C == 1)\n    # note that height increases DOWNWARD in the tensor\n    # (like pixel/camera coordinates)\n\n    G = list(occ.shape)[reduce_axis]\n    values = torch.linspace(float(G), 1.0, steps=G, dtype=torch.float32, device=occ.device)\n    if reduce_axis == 2:\n        # fro view\n        values = values.view(1, 1, G, 1, 1)\n    elif reduce_axis == 3:\n        # top view\n        values = values.view(1, 1, 1, G, 1)\n    elif reduce_axis == 4:\n        # lateral view\n        values = values.view(1, 1, 1, 1, G)\n    else:\n        assert (False)  # you have to reduce one of the spatial dims (2-4)\n    values = torch.max(occ * values, dim=reduce_axis)[0] / float(G)\n    # values = values.view([B, C, D, W])\n    return values\n\n\ndef xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=False):\n    # xy is B x N x 2, containing float x and y coordinates of N things\n    # grid_xs and grid_ys are B x N x Y x X\n\n    B, N, Y, X = list(grid_xs.shape)\n\n    mu_x = xy[:, :, 0].clone()\n    mu_y = xy[:, :, 1].clone()\n\n    x_valid = (mu_x > -0.5) & (mu_x < float(X + 0.5))\n    y_valid = (mu_y > -0.5) & (mu_y < float(Y + 0.5))\n    not_valid = ~(x_valid & y_valid)\n\n    mu_x[not_valid] = -10000\n    mu_y[not_valid] = -10000\n\n    mu_x = mu_x.reshape(B, N, 1, 1).repeat(1, 1, Y, X)\n    mu_y = mu_y.reshape(B, N, 1, 1).repeat(1, 1, Y, X)\n\n    sigma_sq = sigma * sigma\n    # sigma_sq = (sigma*sigma).reshape(B, N, 1, 1)\n    sq_diff_x = (grid_xs - mu_x) ** 2\n    sq_diff_y = (grid_ys - mu_y) ** 2\n\n    term1 = 1. / 2. * np.pi * sigma_sq\n    term2 = torch.exp(-(sq_diff_x + sq_diff_y) / (2. * sigma_sq))\n    gauss = term1 * term2\n\n    if norm:\n        # normalize so each gaussian peaks at 1\n        gauss_ = gauss.reshape(B * N, Y, X)\n        gauss_ = utils.basic.normalize(gauss_)\n        gauss = gauss_.reshape(B, N, Y, X)\n\n    return gauss\n\n\ndef xy2heatmaps(xy, Y, X, sigma=30.0, norm=True):\n    # xy is B x N x 2\n\n    B, N, D = list(xy.shape)\n    assert (D == 2)\n\n    device = xy.device\n\n    grid_y, grid_x = utils.basic.meshgrid2d(B, Y, X, device=device)\n    # grid_x and grid_y are B x Y x X\n    grid_xs = grid_x.unsqueeze(1).repeat(1, N, 1, 1)\n    grid_ys = grid_y.unsqueeze(1).repeat(1, N, 1, 1)\n    heat = xy2heatmap(xy, sigma, grid_xs, grid_ys, norm=norm)\n    return heat\n\n\ndef draw_circles_at_xy(xy, Y, X, sigma=12.5, round=False):\n    B, N, D = list(xy.shape)\n    assert (D == 2)\n    prior = xy2heatmaps(xy, Y, X, sigma=sigma)\n    # prior is B x N x Y x X\n    if round:\n        prior = (prior > 0.5).float()\n    return prior\n\n\ndef seq2color(im, norm=True, colormap='coolwarm'):\n    B, S, H, W = list(im.shape)\n    # S is sequential\n\n    # prep a mask of the valid pixels, so we can blacken the invalids later\n    mask = torch.max(im, dim=1, keepdim=True)[0]\n\n    # turn the S dim into an explicit sequence\n    coeffs = np.linspace(1.0, float(S), S).astype(np.float32) / float(S)\n\n    # # increase the spacing from the center\n    # coeffs[:int(S/2)] -= 2.0\n    # coeffs[int(S/2)+1:] += 2.0\n\n    coeffs = torch.from_numpy(coeffs).float().cuda()\n    coeffs = coeffs.reshape(1, S, 1, 1).repeat(B, 1, H, W)\n    # scale each channel by the right coeff\n    im = im * coeffs\n    # now im is in [1/S, 1], except for the invalid parts which are 0\n    # keep the highest valid coeff at each pixel\n    im = torch.max(im, dim=1, keepdim=True)[0]\n\n    out = []\n    for b in range(B):\n        im_ = im[b]\n        # move channels out to last dim_\n        im_ = im_.detach().cpu().numpy()\n        im_ = np.squeeze(im_)\n        # im_ is H x W\n        if colormap == 'coolwarm':\n            im_ = cm.coolwarm(im_)[:, :, :3]\n        elif colormap == 'PiYG':\n            im_ = cm.PiYG(im_)[:, :, :3]\n        elif colormap == 'winter':\n            im_ = cm.winter(im_)[:, :, :3]\n        elif colormap == 'spring':\n            im_ = cm.spring(im_)[:, :, :3]\n        elif colormap == 'onediff':\n            im_ = np.reshape(im_, (-1))\n            im0_ = cm.spring(im_)[:, :3]\n            im1_ = cm.winter(im_)[:, :3]\n            im1_[im_ == 1 / float(S)] = im0_[im_ == 1 / float(S)]\n            im_ = np.reshape(im1_, (H, W, 3))\n        else:\n            assert (False)  # invalid colormap\n        # move channels into dim 0\n        im_ = np.transpose(im_, [2, 0, 1])\n        im_ = torch.from_numpy(im_).float().cuda()\n        out.append(im_)\n    out = torch.stack(out, dim=0)\n\n    # blacken the invalid pixels, instead of using the 0-color\n    out = out * mask\n    # out = out*255.0\n\n    # put it in [-0.5, 0.5]\n    out = out - 0.5\n\n    return out\n\n\ndef colorize(d):\n    # this is actually just grayscale right now\n\n    if d.ndim == 2:\n        d = d.unsqueeze(dim=0)\n    else:\n        assert (d.ndim == 3)\n\n    # color_map = cm.get_cmap('plasma')\n    color_map = cm.get_cmap('inferno')\n    # S1, D = traj.shape\n\n    # print('d1', d.shape)\n    C, H, W = d.shape\n    assert (C == 1)\n    d = d.reshape(-1)\n    d = d.detach().cpu().numpy()\n    # print('d2', d.shape)\n    color = np.array(color_map(d)) * 255  # rgba\n    # print('color1', color.shape)\n    color = np.reshape(color[:, :3], [H * W, 3])\n    # print('color2', color.shape)\n    color = torch.from_numpy(color).permute(1, 0).reshape(3, H, W)\n    # # gather\n    # cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray')\n    # if cmap=='RdBu' or cmap=='RdYlGn':\n    #     colors = cm(np.arange(256))[:, :3]\n    #  else:\n    #      colors = cm.colors\n    #      colors = np.array(colors).astype(np.float32)\n    #      colors = np.reshape(colors, [-1, 3])\n    #      colors = tf.constant(colors, dtype=tf.float32)\n\n    #      value = tf.gather(colors, indices)\n    # colorize(value, normalize=True, vmin=None, vmax=None, cmap=None, vals=255)\n\n    # copy to the three chans\n    # d = d.repeat(3, 1, 1)\n    return color\n\n\ndef oned2inferno(d, norm=True, do_colorize=False):\n    # convert a 1chan input to a 3chan image output\n\n    # if it's just B x H x W, add a C dim\n    if d.ndim == 3:\n        d = d.unsqueeze(dim=1)\n    # d should be B x C x H x W, where C=1\n    B, C, H, W = list(d.shape)\n    assert (C == 1)\n\n    if norm:\n        d = utils.basic.normalize(d)\n\n    if do_colorize:\n        rgb = torch.zeros(B, 3, H, W)\n        for b in list(range(B)):\n            rgb[b] = colorize(d[b])\n    else:\n        rgb = d.repeat(1, 3, 1, 1) * 255.0\n    # rgb = (255.0*rgb).type(torch.ByteTensor)\n    rgb = rgb.type(torch.ByteTensor)\n\n    # rgb = tf.cast(255.0*rgb, tf.uint8)\n    # rgb = tf.reshape(rgb, [-1, hyp.H, hyp.W, 3])\n    # rgb = tf.expand_dims(rgb, axis=0)\n    return rgb\n\n\ndef oned2gray(d, norm=True):\n    # convert a 1chan input to a 3chan image output\n\n    # if it's just B x H x W, add a C dim\n    if d.ndim == 3:\n        d = d.unsqueeze(dim=1)\n    # d should be B x C x H x W, where C=1\n    B, C, H, W = list(d.shape)\n    assert (C == 1)\n\n    if norm:\n        d = utils.basic.normalize(d)\n\n    rgb = d.repeat(1, 3, 1, 1)\n    rgb = (255.0 * rgb).type(torch.ByteTensor)\n    return rgb\n\n\ndef draw_frame_id_on_vis(vis, frame_id, scale=0.5, left=5, top=20):\n    rgb = vis.detach().cpu().numpy()[0]\n    rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n    rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)\n    color = (255, 255, 255)\n    # print('putting frame id', frame_id)\n\n    frame_str = utils.basic.strnum(frame_id)\n\n    text_color_bg = (0, 0, 0)\n    font = cv2.FONT_HERSHEY_SIMPLEX\n    text_size, _ = cv2.getTextSize(frame_str, font, scale, 1)\n    text_w, text_h = text_size\n    cv2.rectangle(rgb, (left, top - text_h), (left + text_w, top + 1), text_color_bg, -1)\n\n    cv2.putText(\n        rgb,\n        frame_str,\n        (left, top),  # from left, from top\n        font,\n        scale,  # font scale (float)\n        color,\n        1)  # font thickness (int)\n    rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)\n    vis = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)\n    return vis\n\n\nCOLORMAP_FILE = \"./utils/bremm.png\"\n\n\nclass ColorMap2d:\n    def __init__(self, filename=None):\n        self._colormap_file = filename or COLORMAP_FILE\n        self._img = plt.imread(self._colormap_file)\n\n        self._height = self._img.shape[0]\n        self._width = self._img.shape[1]\n\n    def __call__(self, X):\n        assert len(X.shape) == 2\n        output = np.zeros((X.shape[0], 3))\n        for i in range(X.shape[0]):\n            x, y = X[i, :]\n            xp = int((self._width - 1) * x)\n            yp = int((self._height - 1) * y)\n            xp = np.clip(xp, 0, self._width - 1)\n            yp = np.clip(yp, 0, self._height - 1)\n            output[i, :] = self._img[yp, xp]\n        return output\n\n\ndef get_n_colors(N, sequential=False):\n    label_colors = []\n    for ii in range(N):\n        if sequential:\n            rgb = cm.winter(ii / (N - 1))\n            rgb = (np.array(rgb) * 255).astype(np.uint8)[:3]\n        else:\n            rgb = np.zeros(3)\n            while np.sum(rgb) < 128:  # ensure min brightness\n                rgb = np.random.randint(0, 256, 3)\n        label_colors.append(rgb)\n    return label_colors\n\n\nclass Summ_writer(object):\n    def __init__(self, writer, global_step, log_freq=10, fps=8, scalar_freq=100, just_gif=False):\n        self.writer = writer\n        self.global_step = global_step\n        self.log_freq = log_freq\n        self.fps = fps\n        self.just_gif = just_gif\n        self.maxwidth = 10000\n        self.save_this = (self.global_step % self.log_freq == 0)\n        self.scalar_freq = max(scalar_freq, 1)\n\n    def summ_gif(self, name, tensor, blacken_zeros=False):\n        # tensor should be in B x S x C x H x W\n\n        assert tensor.dtype in {torch.uint8, torch.float32}\n        shape = list(tensor.shape)\n\n        if tensor.dtype == torch.float32:\n            tensor = back2color(tensor, blacken_zeros=blacken_zeros)\n\n        video_to_write = tensor[0:1]\n\n        S = video_to_write.shape[1]\n        if S == 1:\n            # video_to_write is 1 x 1 x C x H x W\n            self.writer.add_image(name, video_to_write[0, 0], global_step=self.global_step)\n        else:\n            self.writer.add_video(name, video_to_write, fps=self.fps, global_step=self.global_step)\n\n        return video_to_write\n\n    def draw_boxlist2d_on_image(self, rgb, boxlist, scores=None, tids=None, linewidth=1):\n        B, C, H, W = list(rgb.shape)\n        assert (C == 3)\n        B2, N, D = list(boxlist.shape)\n        assert (B2 == B)\n        assert (D == 4)  # ymin, xmin, ymax, xmax\n\n        rgb = back2color(rgb)\n        if scores is None:\n            scores = torch.ones(B2, N).float()\n        if tids is None:\n            tids = torch.arange(N).reshape(1, N).repeat(B2, N).long()\n            # tids = torch.zeros(B2, N).long()\n        out = self.draw_boxlist2d_on_image_py(\n            rgb[0].cpu().detach().numpy(),\n            boxlist[0].cpu().detach().numpy(),\n            scores[0].cpu().detach().numpy(),\n            tids[0].cpu().detach().numpy(),\n            linewidth=linewidth)\n        out = torch.from_numpy(out).type(torch.ByteTensor).permute(2, 0, 1)\n        out = torch.unsqueeze(out, dim=0)\n        out = preprocess_color(out)\n        out = torch.reshape(out, [1, C, H, W])\n        return out\n\n    def draw_boxlist2d_on_image_py(self, rgb, boxlist, scores, tids, linewidth=1):\n        # all inputs are numpy tensors\n        # rgb is H x W x 3\n        # boxlist is N x 4\n        # scores is N\n        # tids is N\n\n        rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n        # rgb = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)\n\n        rgb = rgb.astype(np.uint8).copy()\n\n        H, W, C = rgb.shape\n        assert (C == 3)\n        N, D = boxlist.shape\n        assert (D == 4)\n\n        # color_map = cm.get_cmap('tab20')\n        # color_map = cm.get_cmap('set1')\n        color_map = cm.get_cmap('Accent')\n        color_map = color_map.colors\n        # print('color_map', color_map)\n\n        # draw\n        for ind, box in enumerate(boxlist):\n            # box is 4\n            if not np.isclose(scores[ind], 0.0):\n                # box = utils.geom.scale_box2d(box, H, W)\n                ymin, xmin, ymax, xmax = box\n\n                # ymin, ymax = ymin*H, ymax*H\n                # xmin, xmax = xmin*W, xmax*W\n\n                # print 'score = %.2f' % scores[ind]\n                # color_id = tids[ind] % 20\n                color_id = tids[ind]\n                color = color_map[color_id]\n                color = np.array(color) * 255.0\n                color = color.round()\n                # color = color.astype(np.uint8)\n                # color = color[::-1]\n                # print('color', color)\n\n                # print 'tid = %d; score = %.3f' % (tids[ind], scores[ind])\n\n                # if False:\n                if scores[ind] < 1.0:  # not gt\n                    cv2.putText(rgb,\n                                # '%d (%.2f)' % (tids[ind], scores[ind]), \n                                '%.2f' % (scores[ind]),\n                                (int(xmin), int(ymin)),\n                                cv2.FONT_HERSHEY_SIMPLEX,\n                                0.5,  # font size\n                                color),\n                    # 1) # font weight\n\n                xmin = np.clip(int(xmin), 0, W - 1)\n                xmax = np.clip(int(xmax), 0, W - 1)\n                ymin = np.clip(int(ymin), 0, H - 1)\n                ymax = np.clip(int(ymax), 0, H - 1)\n\n                cv2.line(rgb, (xmin, ymin), (xmin, ymax), color, linewidth, cv2.LINE_AA)\n                cv2.line(rgb, (xmin, ymin), (xmax, ymin), color, linewidth, cv2.LINE_AA)\n                cv2.line(rgb, (xmax, ymin), (xmax, ymax), color, linewidth, cv2.LINE_AA)\n                cv2.line(rgb, (xmax, ymax), (xmin, ymax), color, linewidth, cv2.LINE_AA)\n\n        # rgb = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_BGR2RGB)\n        return rgb\n\n    def summ_boxlist2d(self, name, rgb, boxlist, scores=None, tids=None, frame_id=None, only_return=False, linewidth=2):\n        B, C, H, W = list(rgb.shape)\n        boxlist_vis = self.draw_boxlist2d_on_image(rgb, boxlist, scores=scores, tids=tids, linewidth=linewidth)\n        return self.summ_rgb(name, boxlist_vis, frame_id=frame_id, only_return=only_return)\n\n    def summ_rgbs(self, name, ims, frame_ids=None, blacken_zeros=False, only_return=False):\n        if self.save_this:\n\n            ims = gif_and_tile(ims, just_gif=self.just_gif)\n            vis = ims\n\n            assert vis.dtype in {torch.uint8, torch.float32}\n\n            if vis.dtype == torch.float32:\n                vis = back2color(vis, blacken_zeros)\n\n            B, S, C, H, W = list(vis.shape)\n\n            if frame_ids is not None:\n                assert (len(frame_ids) == S)\n                for s in range(S):\n                    vis[:, s] = draw_frame_id_on_vis(vis[:, s], frame_ids[s])\n\n            if int(W) > self.maxwidth:\n                vis = vis[:, :, :, :self.maxwidth]\n\n            if only_return:\n                return vis\n            else:\n                return self.summ_gif(name, vis, blacken_zeros)\n\n    def summ_rgb(self, name, ims, blacken_zeros=False, frame_id=None, only_return=False, halfres=False):\n        if self.save_this:\n            assert ims.dtype in {torch.uint8, torch.float32}\n\n            if ims.dtype == torch.float32:\n                ims = back2color(ims, blacken_zeros)\n\n            # ims is B x C x H x W\n            vis = ims[0:1]  # just the first one\n            B, C, H, W = list(vis.shape)\n\n            if halfres:\n                vis = F.interpolate(vis, scale_factor=0.5)\n\n            if frame_id is not None:\n                vis = draw_frame_id_on_vis(vis, frame_id)\n\n            if int(W) > self.maxwidth:\n                vis = vis[:, :, :, :self.maxwidth]\n\n            if only_return:\n                return vis\n            else:\n                return self.summ_gif(name, vis.unsqueeze(1), blacken_zeros)\n\n    def flow2color(self, flow, clip=50.0):\n        \"\"\"\n        :param flow: Optical flow tensor.\n        :return: RGB image normalized between 0 and 1.\n        \"\"\"\n\n        # flow is B x C x H x W\n\n        B, C, H, W = list(flow.size())\n\n        flow = flow.clone().detach()\n\n        abs_image = torch.abs(flow)\n        flow_mean = abs_image.mean(dim=[1, 2, 3])\n        flow_std = abs_image.std(dim=[1, 2, 3])\n\n        if clip:\n            flow = torch.clamp(flow, -clip, clip) / clip\n        else:\n            # Apply some kind of normalization. Divide by the perceived maximum (mean + std*2)\n            flow_max = flow_mean + flow_std * 2 + 1e-10\n            for b in range(B):\n                flow[b] = flow[b].clamp(-flow_max[b].item(), flow_max[b].item()) / flow_max[b].clamp(min=1)\n\n        radius = torch.sqrt(torch.sum(flow ** 2, dim=1, keepdim=True))  # B x 1 x H x W\n        radius_clipped = torch.clamp(radius, 0.0, 1.0)\n\n        angle = torch.atan2(flow[:, 1:], flow[:, 0:1]) / np.pi  # B x 1 x H x W\n\n        hue = torch.clamp((angle + 1.0) / 2.0, 0.0, 1.0)\n        saturation = torch.ones_like(hue) * 0.75\n        value = radius_clipped\n        hsv = torch.cat([hue, saturation, value], dim=1)  # B x 3 x H x W\n\n        # flow = tf.image.hsv_to_rgb(hsv)\n        flow = hsv_to_rgb(hsv)\n        flow = (flow * 255.0).type(torch.ByteTensor)\n        return flow\n\n    def summ_flow(self, name, im, clip=0.0, only_return=False, frame_id=None):\n        # flow is B x C x D x W\n        if self.save_this:\n            return self.summ_rgb(name, self.flow2color(im, clip=clip), only_return=only_return, frame_id=frame_id)\n        else:\n            return None\n\n    def summ_oneds(self, name, ims, frame_ids=None, bev=False, fro=False, logvis=False, reduce_max=False, max_val=0.0,\n                   norm=True, only_return=False, do_colorize=False):\n        if self.save_this:\n            if bev:\n                B, C, H, _, W = list(ims[0].shape)\n                if reduce_max:\n                    ims = [torch.max(im, dim=3)[0] for im in ims]\n                else:\n                    ims = [torch.mean(im, dim=3) for im in ims]\n            elif fro:\n                B, C, _, H, W = list(ims[0].shape)\n                if reduce_max:\n                    ims = [torch.max(im, dim=2)[0] for im in ims]\n                else:\n                    ims = [torch.mean(im, dim=2) for im in ims]\n\n            if len(ims) != 1:  # sequence\n                im = gif_and_tile(ims, just_gif=self.just_gif)\n            else:\n                im = torch.stack(ims, dim=1)  # single frame\n\n            B, S, C, H, W = list(im.shape)\n\n            if logvis and max_val:\n                max_val = np.log(max_val)\n                im = torch.log(torch.clamp(im, 0) + 1.0)\n                im = torch.clamp(im, 0, max_val)\n                im = im / max_val\n                norm = False\n            elif max_val:\n                im = torch.clamp(im, 0, max_val)\n                im = im / max_val\n                norm = False\n\n            if norm:\n                # normalize before oned2inferno,\n                # so that the ranges are similar within B across S\n                im = utils.basic.normalize(im)\n\n            im = im.view(B * S, C, H, W)\n            vis = oned2inferno(im, norm=norm, do_colorize=do_colorize)\n            vis = vis.view(B, S, 3, H, W)\n\n            if frame_ids is not None:\n                assert (len(frame_ids) == S)\n                for s in range(S):\n                    vis[:, s] = draw_frame_id_on_vis(vis[:, s], frame_ids[s])\n\n            if W > self.maxwidth:\n                vis = vis[..., :self.maxwidth]\n\n            if only_return:\n                return vis\n            else:\n                self.summ_gif(name, vis)\n\n    def summ_oned(self, name, im, bev=False, fro=False, logvis=False, max_val=0, max_along_y=False, norm=True,\n                  frame_id=None, only_return=False):\n        if self.save_this:\n\n            if bev:\n                B, C, H, _, W = list(im.shape)\n                if max_along_y:\n                    im = torch.max(im, dim=3)[0]\n                else:\n                    im = torch.mean(im, dim=3)\n            elif fro:\n                B, C, _, H, W = list(im.shape)\n                if max_along_y:\n                    im = torch.max(im, dim=2)[0]\n                else:\n                    im = torch.mean(im, dim=2)\n            else:\n                B, C, H, W = list(im.shape)\n\n            im = im[0:1]  # just the first one\n            assert (C == 1)\n\n            if logvis and max_val:\n                max_val = np.log(max_val)\n                im = torch.log(im)\n                im = torch.clamp(im, 0, max_val)\n                im = im / max_val\n                norm = False\n            elif max_val:\n                im = torch.clamp(im, 0, max_val) / max_val\n                norm = False\n\n            vis = oned2inferno(im, norm=norm)\n            if W > self.maxwidth:\n                vis = vis[..., :self.maxwidth]\n            return self.summ_rgb(name, vis, blacken_zeros=False, frame_id=frame_id, only_return=only_return)\n\n    def summ_feats(self, name, feats, valids=None, pca=True, fro=False, only_return=False, frame_ids=None):\n        if self.save_this:\n            if valids is not None:\n                valids = torch.stack(valids, dim=1)\n\n            feats = torch.stack(feats, dim=1)\n            # feats leads with B x S x C\n\n            if feats.ndim == 6:\n\n                # feats is B x S x C x D x H x W\n                if fro:\n                    reduce_dim = 3\n                else:\n                    reduce_dim = 4\n\n                if valids is None:\n                    feats = torch.mean(feats, dim=reduce_dim)\n                else:\n                    valids = valids.repeat(1, 1, feats.size()[2], 1, 1, 1)\n                    feats = utils.basic.reduce_masked_mean(feats, valids, dim=reduce_dim)\n\n            B, S, C, D, W = list(feats.size())\n\n            if not pca:\n                # feats leads with B x S x C\n                feats = torch.mean(torch.abs(feats), dim=2, keepdims=True)\n                # feats leads with B x S x 1\n                feats = torch.unbind(feats, dim=1)\n                return self.summ_oneds(name=name, ims=feats, norm=True, only_return=only_return, frame_ids=frame_ids)\n\n            else:\n                __p = lambda x: utils.basic.pack_seqdim(x, B)\n                __u = lambda x: utils.basic.unpack_seqdim(x, B)\n\n                feats_ = __p(feats)\n\n                if valids is None:\n                    feats_pca_ = get_feat_pca(feats_)\n                else:\n                    valids_ = __p(valids)\n                    feats_pca_ = get_feat_pca(feats_, valids)\n\n                feats_pca = __u(feats_pca_)\n\n                return self.summ_rgbs(name=name, ims=torch.unbind(feats_pca, dim=1), only_return=only_return,\n                                      frame_ids=frame_ids)\n\n    def summ_feat(self, name, feat, valid=None, pca=True, only_return=False, bev=False, fro=False, frame_id=None):\n        if self.save_this:\n            if feat.ndim == 5:  # B x C x D x H x W\n\n                if bev:\n                    reduce_axis = 3\n                elif fro:\n                    reduce_axis = 2\n                else:\n                    # default to bev\n                    reduce_axis = 3\n\n                if valid is None:\n                    feat = torch.mean(feat, dim=reduce_axis)\n                else:\n                    valid = valid.repeat(1, feat.size()[1], 1, 1, 1)\n                    feat = utils.basic.reduce_masked_mean(feat, valid, dim=reduce_axis)\n\n            B, C, D, W = list(feat.shape)\n\n            if not pca:\n                feat = torch.mean(torch.abs(feat), dim=1, keepdims=True)\n                # feat is B x 1 x D x W\n                return self.summ_oned(name=name, im=feat, norm=True, only_return=only_return, frame_id=frame_id)\n            else:\n                feat_pca = get_feat_pca(feat, valid)\n                return self.summ_rgb(name, feat_pca, only_return=only_return, frame_id=frame_id)\n\n    def summ_scalar(self, name, value):\n        if (not (isinstance(value, int) or isinstance(value, float) or isinstance(value, np.float32))) and (\n                'torch' in value.type()):\n            value = value.detach().cpu().numpy()\n        if not np.isnan(value):\n            if (self.log_freq == 1):\n                self.writer.add_scalar(name, value, global_step=self.global_step)\n            elif self.save_this or np.mod(self.global_step, self.scalar_freq) == 0:\n                self.writer.add_scalar(name, value, global_step=self.global_step)\n\n    def summ_seg(self, name, seg, only_return=False, frame_id=None, colormap='tab20', label_colors=None):\n        if not self.save_this:\n            return\n\n        B, H, W = seg.shape\n\n        if label_colors is None:\n            custom_label_colors = False\n            # label_colors = get_n_colors(int(torch.max(seg).item()), sequential=True)\n            label_colors = cm.get_cmap(colormap).colors\n            label_colors = [[int(i * 255) for i in l] for l in label_colors]\n        else:\n            custom_label_colors = True\n        # label_colors = matplotlib.cm.get_cmap(colormap).colors\n        # label_colors = [[int(i*255) for i in l] for l in label_colors]\n        # print('label_colors', label_colors)\n\n        # label_colors = [\n        #     (0, 0, 0),         # None\n        #     (70, 70, 70),      # Buildings\n        #     (190, 153, 153),   # Fences\n        #     (72, 0, 90),       # Other\n        #     (220, 20, 60),     # Pedestrians\n        #     (153, 153, 153),   # Poles\n        #     (157, 234, 50),    # RoadLines\n        #     (128, 64, 128),    # Roads\n        #     (244, 35, 232),    # Sidewalks\n        #     (107, 142, 35),    # Vegetation\n        #     (0, 0, 255),      # Vehicles\n        #     (102, 102, 156),  # Walls\n        #     (220, 220, 0)     # TrafficSigns\n        # ]\n\n        r = torch.zeros_like(seg, dtype=torch.uint8)\n        g = torch.zeros_like(seg, dtype=torch.uint8)\n        b = torch.zeros_like(seg, dtype=torch.uint8)\n\n        for label in range(0, len(label_colors)):\n            if (not custom_label_colors):  # and (N > 20):\n                label_ = label % 20\n            else:\n                label_ = label\n\n            idx = (seg == label + 1)\n            r[idx] = label_colors[label_][0]\n            g[idx] = label_colors[label_][1]\n            b[idx] = label_colors[label_][2]\n\n        rgb = torch.stack([r, g, b], axis=1)\n        return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)\n\n    def summ_pts_on_rgb(self, name, trajs, rgb, valids=None, frame_id=None, only_return=False, show_dots=True,\n                        cmap='coolwarm', linewidth=1):\n        # trajs is B, S, N, 2\n        # rgbs is B, S, C, H, W\n        B, C, H, W = rgb.shape\n        B, S, N, D = trajs.shape\n\n        rgb = rgb[0]  # C, H, W\n        trajs = trajs[0]  # S, N, 2\n        if valids is None:\n            valids = torch.ones_like(trajs[:, :, 0])  # S, N\n        else:\n            valids = valids[0]\n        # print('trajs', trajs.shape)\n        # print('valids', valids.shape)\n\n        rgb = back2color(rgb).detach().cpu().numpy()\n        rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n\n        trajs = trajs.long().detach().cpu().numpy()  # S, N, 2\n        valids = valids.long().detach().cpu().numpy()  # S, N\n\n        rgb = rgb.astype(np.uint8).copy()\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i]  # S,2\n            valid = valids[:, i]  # S\n\n            color_map = cm.get_cmap(cmap)\n            color = np.array(color_map(i)[:3]) * 255  # rgb\n            for s in range(S):\n                if valid[s]:\n                    cv2.circle(rgb, (int(traj[s, 0]), int(traj[s, 1])), linewidth, color, -1)\n        rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)\n        rgb = preprocess_color(rgb)\n        return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)\n\n    def summ_pts_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=True,\n                         cmap='coolwarm', linewidth=1):\n        # trajs is B, S, N, 2\n        # rgbs is B, S, C, H, W\n        B, S, C, H, W = rgbs.shape\n        B, S2, N, D = trajs.shape\n        assert (S == S2)\n\n        rgbs = rgbs[0]  # S, C, H, W\n        trajs = trajs[0]  # S, N, 2\n        if valids is None:\n            valids = torch.ones_like(trajs[:, :, 0])  # S, N\n        else:\n            valids = valids[0]\n        # print('trajs', trajs.shape)\n        # print('valids', valids.shape)\n\n        rgbs_color = []\n        for rgb in rgbs:\n            rgb = back2color(rgb).detach().cpu().numpy()\n            rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n            rgbs_color.append(rgb)  # each element 3 x H x W\n\n        trajs = trajs.long().detach().cpu().numpy()  # S, N, 2\n        valids = valids.long().detach().cpu().numpy()  # S, N\n\n        rgbs_color = [rgb.astype(np.uint8).copy() for rgb in rgbs_color]\n\n        for i in range(N):\n            traj = trajs[:, i]  # S,2\n            valid = valids[:, i]  # S\n\n            color_map = cm.get_cmap(cmap)\n            color = np.array(color_map(0)[:3]) * 255  # rgb\n            for s in range(S):\n                if valid[s]:\n                    cv2.circle(rgbs_color[s], (traj[s, 0], traj[s, 1]), linewidth, color, -1)\n        rgbs = []\n        for rgb in rgbs_color:\n            rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)\n            rgbs.append(preprocess_color(rgb))\n\n        return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)\n\n    def summ_traj2ds_on_rgbs(self, name, trajs, rgbs, valids=None, frame_ids=None, only_return=False, show_dots=False,\n                             cmap='coolwarm', vals=None, linewidth=1):\n        # trajs is B, S, N, 2\n        # rgbs is B, S, C, H, W\n        B, S, C, H, W = rgbs.shape\n        B, S2, N, D = trajs.shape\n        assert (S == S2)\n\n        rgbs = rgbs[0]  # S, C, H, W\n        trajs = trajs[0]  # S, N, 2\n        if valids is None:\n            valids = torch.ones_like(trajs[:, :, 0])  # S, N\n        else:\n            valids = valids[0]\n\n        # print('trajs', trajs.shape)\n        # print('valids', valids.shape)\n\n        if vals is not None:\n            vals = vals[0]  # N\n            # print('vals', vals.shape)\n\n        rgbs_color = []\n        for rgb in rgbs:\n            rgb = back2color(rgb).detach().cpu().numpy()\n            rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n            rgbs_color.append(rgb)  # each element 3 x H x W\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i].long().detach().cpu().numpy()  # S, 2\n            valid = valids[:, i].long().detach().cpu().numpy()  # S\n\n            # print('traj', traj.shape)\n            # print('valid', valid.shape)\n\n            if vals is not None:\n                # val = vals[:,i].float().detach().cpu().numpy() # []\n                val = vals[i].float().detach().cpu().numpy()  # []\n                # print('val', val.shape)\n            else:\n                val = None\n\n            for t in range(S):\n                # if valid[t]:\n                # traj_seq = traj[max(t-16,0):t+1]\n                traj_seq = traj[max(t - 8, 0):t + 1]\n                val_seq = np.linspace(0, 1, len(traj_seq))\n                # if t<2:\n                #     val_seq = np.zeros_like(val_seq)\n                # print('val_seq', val_seq)\n                # val_seq = 1.0\n                # val_seq = np.arange(8)/8.0\n                # val_seq = val_seq[-len(traj_seq):]\n                # rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots, cmap=cmap_, val=val_seq, linewidth=linewidth)\n                rgbs_color[t] = self.draw_traj_on_image_py(rgbs_color[t], traj_seq, S=S, show_dots=show_dots,\n                                                           cmap=cmap_, val=val_seq, linewidth=linewidth)\n            # input()\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i]  # S,2\n            # vis = visibles[:,i] # S\n            vis = torch.ones_like(traj[:, 0])  # S\n            valid = valids[:, i]  # S\n            rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=0, show_dots=show_dots, cmap=cmap_,\n                                                     linewidth=linewidth)\n\n        rgbs = []\n        for rgb in rgbs_color:\n            rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)\n            rgbs.append(preprocess_color(rgb))\n\n        return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)\n\n    def summ_traj2ds_on_rgbs2(self, name, trajs, visibles, rgbs, valids=None, frame_ids=None, only_return=False,\n                              show_dots=True, cmap=None, linewidth=1):\n        # trajs is B, S, N, 2\n        # rgbs is B, S, C, H, W\n        B, S, C, H, W = rgbs.shape\n        B, S2, N, D = trajs.shape\n        assert (S == S2)\n\n        rgbs = rgbs[0]  # S, C, H, W\n        trajs = trajs[0]  # S, N, 2\n        visibles = visibles[0]  # S, N\n        if valids is None:\n            valids = torch.ones_like(trajs[:, :, 0])  # S, N\n        else:\n            valids = valids[0]\n        # print('trajs', trajs.shape)\n        # print('valids', valids.shape)\n\n        rgbs_color = []\n        for rgb in rgbs:\n            rgb = back2color(rgb).detach().cpu().numpy()\n            rgb = np.transpose(rgb, [1, 2, 0])  # put channels last\n            rgbs_color.append(rgb)  # each element 3 x H x W\n\n        trajs = trajs.long().detach().cpu().numpy()  # S, N, 2\n        visibles = visibles.float().detach().cpu().numpy()  # S, N\n        valids = valids.long().detach().cpu().numpy()  # S, N\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i]  # S,2\n            vis = visibles[:, i]  # S\n            valid = valids[:, i]  # S\n            rgbs_color = self.draw_traj_on_images_py(rgbs_color, traj, S=S, show_dots=show_dots, cmap=cmap_,\n                                                     linewidth=linewidth)\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i]  # S,2\n            vis = visibles[:, i]  # S\n            valid = valids[:, i]  # S\n            if valid[0]:\n                rgbs_color = self.draw_circ_on_images_py(rgbs_color, traj, vis, S=S, show_dots=show_dots, cmap=None,\n                                                         linewidth=linewidth)\n\n        rgbs = []\n        for rgb in rgbs_color:\n            rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)\n            rgbs.append(preprocess_color(rgb))\n\n        return self.summ_rgbs(name, rgbs, only_return=only_return, frame_ids=frame_ids)\n\n    def summ_traj2ds_on_rgb(self, name, trajs, rgb, valids=None, show_dots=False, show_lines=True, frame_id=None,\n                            only_return=False, cmap='coolwarm', linewidth=1):\n        # trajs is B, S, N, 2\n        # rgb is B, C, H, W\n        B, C, H, W = rgb.shape\n        B, S, N, D = trajs.shape\n\n        rgb = rgb[0]  # S, C, H, W\n        trajs = trajs[0]  # S, N, 2\n\n        if valids is None:\n            valids = torch.ones_like(trajs[:, :, 0])\n        else:\n            valids = valids[0]\n\n        rgb_color = back2color(rgb).detach().cpu().numpy()\n        rgb_color = np.transpose(rgb_color, [1, 2, 0])  # put channels last\n\n        # using maxdist will dampen the colors for short motions\n        norms = torch.sqrt(1e-4 + torch.sum((trajs[-1] - trajs[0]) ** 2, dim=1))  # N\n        maxdist = torch.quantile(norms, 0.95).detach().cpu().numpy()\n        maxdist = None\n        trajs = trajs.long().detach().cpu().numpy()  # S, N, 2\n        valids = valids.long().detach().cpu().numpy()  # S, N\n\n        for i in range(N):\n            if cmap == 'onediff' and i == 0:\n                cmap_ = 'spring'\n            elif cmap == 'onediff':\n                cmap_ = 'winter'\n            else:\n                cmap_ = cmap\n            traj = trajs[:, i]  # S, 2\n            valid = valids[:, i]  # S\n            if valid[0] == 1:\n                traj = traj[valid > 0]\n                rgb_color = self.draw_traj_on_image_py(\n                    rgb_color, traj, S=S, show_dots=show_dots, show_lines=show_lines, cmap=cmap_, maxdist=maxdist,\n                    linewidth=linewidth)\n\n        rgb_color = torch.from_numpy(rgb_color).permute(2, 0, 1).unsqueeze(0)\n        rgb = preprocess_color(rgb_color)\n        return self.summ_rgb(name, rgb, only_return=only_return, frame_id=frame_id)\n\n    def draw_traj_on_image_py(self, rgb, traj, S=50, linewidth=1, show_dots=False, show_lines=True, cmap='coolwarm',\n                              val=None, maxdist=None):\n        # all inputs are numpy tensors\n        # rgb is 3 x H x W\n        # traj is S x 2\n\n        H, W, C = rgb.shape\n        assert (C == 3)\n\n        rgb = rgb.astype(np.uint8).copy()\n\n        S1, D = traj.shape\n        assert (D == 2)\n\n        color_map = cm.get_cmap(cmap)\n        S1, D = traj.shape\n\n        for s in range(S1):\n            if val is not None:\n                # if len(val) == S1:\n                color = np.array(color_map(val[s])[:3]) * 255  # rgb\n                # else:\n                #     color = np.array(color_map(val)[:3]) * 255 # rgb\n            else:\n                if maxdist is not None:\n                    val = (np.sqrt(np.sum((traj[s] - traj[0]) ** 2)) / maxdist).clip(0, 1)\n                    color = np.array(color_map(val)[:3]) * 255  # rgb\n                else:\n                    color = np.array(color_map((s) / max(1, float(S - 2)))[:3]) * 255  # rgb\n\n            if show_lines and s < (S1 - 1):\n                cv2.line(rgb,\n                         (int(traj[s, 0]), int(traj[s, 1])),\n                         (int(traj[s + 1, 0]), int(traj[s + 1, 1])),\n                         color,\n                         linewidth,\n                         cv2.LINE_AA)\n            if show_dots:\n                cv2.circle(rgb, (int(traj[s, 0]), int(traj[s, 1])), linewidth, np.array(color_map(1)[:3]) * 255, -1)\n\n        # if maxdist is not None:\n        #     val = (np.sqrt(np.sum((traj[-1]-traj[0])**2))/maxdist).clip(0,1)\n        #     color = np.array(color_map(val)[:3]) * 255 # rgb\n        # else:\n        #     # draw the endpoint of traj, using the next color (which may be the last color)\n        #     color = np.array(color_map((S1-1)/max(1,float(S-2)))[:3]) * 255 # rgb\n\n        # # emphasize endpoint\n        # cv2.circle(rgb, (traj[-1,0], traj[-1,1]), linewidth*2, color, -1)\n\n        return rgb\n\n    def draw_traj_on_images_py(self, rgbs, traj, S=50, linewidth=1, show_dots=False, cmap='coolwarm', maxdist=None):\n        # all inputs are numpy tensors\n        # rgbs is a list of H,W,3\n        # traj is S,2\n        H, W, C = rgbs[0].shape\n        assert (C == 3)\n\n        rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]\n\n        S1, D = traj.shape\n        assert (D == 2)\n\n        x = int(np.clip(traj[0, 0], 0, W - 1))\n        y = int(np.clip(traj[0, 1], 0, H - 1))\n        color = rgbs[0][y, x]\n        color = (int(color[0]), int(color[1]), int(color[2]))\n        for s in range(S):\n            # bak_color = np.array(color_map(1.0)[:3]) * 255 # rgb\n            # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth*4, bak_color, -1)\n            cv2.polylines(rgbs[s],\n                          [traj[:s + 1]],\n                          False,\n                          color,\n                          linewidth,\n                          cv2.LINE_AA)\n        return rgbs\n\n    def draw_circs_on_image_py(self, rgb, xy, colors=None, linewidth=10, radius=3, show_dots=False, maxdist=None):\n        # all inputs are numpy tensors\n        # rgbs is a list of 3,H,W\n        # xy is N,2\n        H, W, C = rgb.shape\n        assert (C == 3)\n\n        rgb = rgb.astype(np.uint8).copy()\n\n        N, D = xy.shape\n        assert (D == 2)\n\n        xy = xy.astype(np.float32)\n        xy[:, 0] = np.clip(xy[:, 0], 0, W - 1)\n        xy[:, 1] = np.clip(xy[:, 1], 0, H - 1)\n        xy = xy.astype(np.int32)\n\n        if colors is None:\n            colors = get_n_colors(N)\n\n        for n in range(N):\n            color = colors[n]\n            # print('color', color)\n            # color = (color[0]*255).astype(np.uint8) \n            color = (int(color[0]), int(color[1]), int(color[2]))\n\n            # x = int(np.clip(xy[0,0], 0, W-1))\n            # y = int(np.clip(xy[0,1], 0, H-1))\n            # color_ = rgbs[0][y,x]\n            # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))\n            # color_ = (int(color_[0]),int(color_[1]),int(color_[2]))\n\n            cv2.circle(rgb, (xy[n, 0], xy[n, 1]), linewidth, color, 3)\n            # vis_color = int(np.squeeze(vis[s])*255)\n            # vis_color = (vis_color,vis_color,vis_color)\n            # cv2.circle(rgbs[s], (traj[s,0], traj[s,1]), linewidth+1, vis_color, -1)\n        return rgb\n\n    def draw_circ_on_images_py(self, rgbs, traj, vis, S=50, linewidth=1, show_dots=False, cmap=None, maxdist=None):\n        # all inputs are numpy tensors\n        # rgbs is a list of 3,H,W\n        # traj is S,2\n        H, W, C = rgbs[0].shape\n        assert (C == 3)\n\n        rgbs = [rgb.astype(np.uint8).copy() for rgb in rgbs]\n\n        S1, D = traj.shape\n        assert (D == 2)\n\n        if cmap is None:\n            bremm = ColorMap2d()\n            traj_ = traj[0:1].astype(np.float32)\n            traj_[:, 0] /= float(W)\n            traj_[:, 1] /= float(H)\n            color = bremm(traj_)\n            # print('color', color)\n            color = (color[0] * 255).astype(np.uint8)\n            # color = (int(color[0]),int(color[1]),int(color[2]))\n            color = (int(color[2]), int(color[1]), int(color[0]))\n\n        for s in range(S1):\n            if cmap is not None:\n                color_map = cm.get_cmap(cmap)\n                # color = np.array(color_map(s/(S-1))[:3]) * 255 # rgb\n                color = np.array(color_map((s + 1) / max(1, float(S - 1)))[:3]) * 255  # rgb\n                # color = color.astype(np.uint8)\n                # color = (color[0], color[1], color[2])\n                # print('color', color)\n            # import ipdb; ipdb.set_trace()\n\n            cv2.circle(rgbs[s], (int(traj[s, 0]), int(traj[s, 1])), linewidth + 1, color, -1)\n            # vis_color = int(np.squeeze(vis[s])*255)\n            # vis_color = (vis_color,vis_color,vis_color)\n            # cv2.circle(rgbs[s], (int(traj[s,0]), int(traj[s,1])), linewidth+1, vis_color, -1)\n\n        return rgbs\n\n    def summ_traj_as_crops(self, name, trajs_e, rgbs, frame_id=None, only_return=False, show_circ=False, trajs_g=None,\n                           is_g=False):\n        B, S, N, D = trajs_e.shape\n        assert (N == 1)\n        assert (D == 2)\n\n        rgbs_vis = []\n        n = 0\n        pad_amount = 100\n        trajs_e_py = trajs_e[0].detach().cpu().numpy()\n        # trajs_e_py = np.clip(trajs_e_py, min=pad_amount/2, max=pad_amoun\n        trajs_e_py = trajs_e_py + pad_amount\n\n        if trajs_g is not None:\n            trajs_g_py = trajs_g[0].detach().cpu().numpy()\n            trajs_g_py = trajs_g_py + pad_amount\n\n        for s in range(S):\n            rgb = rgbs[0, s].detach().cpu().numpy()\n            # print('orig rgb', rgb.shape)\n            rgb = np.transpose(rgb, (1, 2, 0))  # H, W, 3\n\n            rgb = np.pad(rgb, ((pad_amount, pad_amount), (pad_amount, pad_amount), (0, 0)))\n            # print('pad rgb', rgb.shape)\n            H, W, C = rgb.shape\n\n            if trajs_g is not None:\n                xy_g = trajs_g_py[s, n]\n                xy_g[0] = np.clip(xy_g[0], pad_amount, W - pad_amount)\n                xy_g[1] = np.clip(xy_g[1], pad_amount, H - pad_amount)\n                rgb = self.draw_circs_on_image_py(rgb, xy_g.reshape(1, 2), colors=[(0, 255, 0)], linewidth=2, radius=3)\n\n            xy_e = trajs_e_py[s, n]\n            xy_e[0] = np.clip(xy_e[0], pad_amount, W - pad_amount)\n            xy_e[1] = np.clip(xy_e[1], pad_amount, H - pad_amount)\n\n            if show_circ:\n                if is_g:\n                    rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1, 2), colors=[(0, 255, 0)], linewidth=2,\n                                                      radius=3)\n                else:\n                    rgb = self.draw_circs_on_image_py(rgb, xy_e.reshape(1, 2), colors=[(255, 0, 255)], linewidth=2,\n                                                      radius=3)\n\n            xmin = int(xy_e[0]) - pad_amount // 2\n            xmax = int(xy_e[0]) + pad_amount // 2\n            ymin = int(xy_e[1]) - pad_amount // 2\n            ymax = int(xy_e[1]) + pad_amount // 2\n\n            rgb_ = rgb[ymin:ymax, xmin:xmax]\n\n            H_, W_ = rgb_.shape[:2]\n            # if np.any(rgb_.shape==0):\n            #     input()\n            if H_ == 0 or W_ == 0:\n                import ipdb;\n                ipdb.set_trace()\n\n            rgb_ = rgb_.transpose(2, 0, 1)\n            rgb_ = torch.from_numpy(rgb_)\n\n            rgbs_vis.append(rgb_)\n\n        # nrow = int(np.sqrt(S)*(16.0/9)/2.0)\n        nrow = int(np.sqrt(S) * 1.5)\n        grid_img = torchvision.utils.make_grid(torch.stack(rgbs_vis, dim=0), nrow=nrow).unsqueeze(0)\n        # print('grid_img', grid_img.shape)\n        return self.summ_rgb(name, grid_img.byte(), frame_id=frame_id, only_return=only_return)\n\n    def summ_occ(self, name, occ, reduce_axes=[3], bev=False, fro=False, pro=False, frame_id=None, only_return=False):\n        if self.save_this:\n            B, C, D, H, W = list(occ.shape)\n            if bev:\n                reduce_axes = [3]\n            elif fro:\n                reduce_axes = [2]\n            elif pro:\n                reduce_axes = [4]\n            for reduce_axis in reduce_axes:\n                height = convert_occ_to_height(occ, reduce_axis=reduce_axis)\n                if reduce_axis == reduce_axes[-1]:\n                    return self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False,\n                                          frame_id=frame_id, only_return=only_return)\n                else:\n                    self.summ_oned(name=('%s_ax%d' % (name, reduce_axis)), im=height, norm=False, frame_id=frame_id,\n                                   only_return=only_return)\n\n\ndef erode2d(im, times=1, device='cuda'):\n    weights2d = torch.ones(1, 1, 3, 3, device=device)\n    for time in range(times):\n        im = 1.0 - F.conv2d(1.0 - im, weights2d, padding=1).clamp(0, 1)\n    return im\n\n\ndef dilate2d(im, times=1, device='cuda', mode='square'):\n    weights2d = torch.ones(1, 1, 3, 3, device=device)\n    if mode == 'cross':\n        weights2d[:, :, 0, 0] = 0.0\n        weights2d[:, :, 0, 2] = 0.0\n        weights2d[:, :, 2, 0] = 0.0\n        weights2d[:, :, 2, 2] = 0.0\n    for time in range(times):\n        im = F.conv2d(im, weights2d, padding=1).clamp(0, 1)\n    return im\n"
  },
  {
    "path": "mvtracker/utils/misc.py",
    "content": "import numpy as np\nimport torch\nfrom prettytable import PrettyTable\n\n\ndef count_parameters(model):\n    table = PrettyTable([\"Modules\", \"Parameters\"])\n    total_params = 0\n    for name, parameter in model.named_parameters():\n        if not parameter.requires_grad:\n            continue\n        param = parameter.numel()\n        if param > 100000:\n            table.add_row([name, param])\n        total_params += param\n    print(table)\n    print('total params: %.2f M' % (total_params / 1000000.0))\n    return total_params\n\n\ndef posemb_sincos_2d_xy(xy, C, temperature=10000, dtype=torch.float32, cat_coords=False):\n    device = xy.device\n    dtype = xy.dtype\n    B, S, D = xy.shape\n    assert (D == 2)\n    x = xy[:, :, 0]\n    y = xy[:, :, 1]\n    assert (C % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'\n    omega = torch.arange(C // 4, device=device) / (C // 4 - 1)\n    omega = 1. / (temperature ** omega)\n\n    y = y.flatten()[:, None] * omega[None, :]\n    x = x.flatten()[:, None] * omega[None, :]\n    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)\n    pe = pe.reshape(B, S, C).type(dtype)\n    if cat_coords:\n        pe = torch.cat([pe, xy], dim=2)  # B,N,C+2\n    return pe\n\n\nclass SimplePool():\n    def __init__(self, pool_size, version='pt'):\n        self.pool_size = pool_size\n        self.version = version\n        self.items = []\n\n        if not (version == 'pt' or version == 'np'):\n            print('version = %s; please choose pt or np')\n            assert (False)  # please choose pt or np\n\n    def __len__(self):\n        return len(self.items)\n\n    def mean(self, min_size=1):\n        if min_size == 'half':\n            pool_size_thresh = self.pool_size / 2\n        else:\n            pool_size_thresh = min_size\n\n        if self.version == 'np':\n            if len(self.items) >= pool_size_thresh:\n                return np.sum(self.items) / float(len(self.items))\n            else:\n                return np.nan\n        if self.version == 'pt':\n            if len(self.items) >= pool_size_thresh:\n                return torch.sum(self.items) / float(len(self.items))\n            else:\n                return torch.from_numpy(np.nan)\n\n    def sample(self, with_replacement=True):\n        idx = np.random.randint(len(self.items))\n        if with_replacement:\n            return self.items[idx]\n        else:\n            return self.items.pop(idx)\n\n    def fetch(self, num=None):\n        if self.version == 'pt':\n            item_array = torch.stack(self.items)\n        elif self.version == 'np':\n            item_array = np.stack(self.items)\n        if num is not None:\n            # there better be some items\n            assert (len(self.items) >= num)\n\n            # if there are not that many elements just return however many there are\n            if len(self.items) < num:\n                return item_array\n            else:\n                idxs = np.random.randint(len(self.items), size=num)\n                return item_array[idxs]\n        else:\n            return item_array\n\n    def is_full(self):\n        full = len(self.items) == self.pool_size\n        return full\n\n    def empty(self):\n        self.items = []\n\n    def update(self, items):\n        for item in items:\n            if len(self.items) < self.pool_size:\n                # the pool is not full, so let's add this in\n                self.items.append(item)\n            else:\n                # the pool is full\n                # pop from the front\n                self.items.pop(0)\n                # add to the back\n                self.items.append(item)\n        return self.items\n\n\ndef farthest_point_sample(xyz, npoint, include_ends=False, deterministic=False):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [B, N, C], where C is probably 3\n        npoint: number of samples\n    Return:\n        inds: sampled pointcloud index, [B, npoint]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    xyz = xyz.float()\n    inds = torch.zeros(B, npoint, dtype=torch.long).to(device)\n    distance = torch.ones(B, N).to(device) * 1e10\n    if deterministic:\n        farthest = torch.randint(0, 1, (B,), dtype=torch.long).to(device)\n    else:\n        farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)\n    batch_indices = torch.arange(B, dtype=torch.long).to(device)\n    for i in range(npoint):\n        if include_ends:\n            if i == 0:\n                farthest = 0\n            elif i == 1:\n                farthest = N - 1\n        inds[:, i] = farthest\n        centroid = xyz[batch_indices, farthest, :].view(B, 1, C)\n        dist = torch.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = torch.max(distance, -1)[1]\n\n        if npoint > N:\n            # if we need more samples, make them random\n            distance += torch.randn_like(distance)\n    return inds\n\n\ndef farthest_point_sample_py(xyz, npoint):\n    N, C = xyz.shape\n    inds = np.zeros(npoint, dtype=np.int32)\n    distance = np.ones(N) * 1e10\n    farthest = np.random.randint(0, N, dtype=np.int32)\n    for i in range(npoint):\n        inds[i] = farthest\n        centroid = xyz[farthest, :].reshape(1, C)\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n        if npoint > N:\n            # if we need more samples, make them random\n            distance += np.random.randn(*distance.shape)\n    return inds\n"
  },
  {
    "path": "mvtracker/utils/visualizer_mp4.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport logging\nimport os\nimport threading\nfrom typing import Tuple\n\nimport cv2\nimport flow_vis\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision.transforms as transforms\nfrom matplotlib import cm\nfrom moviepy.editor import ImageSequenceClip\n\nfrom mvtracker.models.core.model_utils import world_space_to_pixel_xy_and_camera_z\n\n\ndef read_video_from_path(path):\n    cap = cv2.VideoCapture(path)\n    if not cap.isOpened():\n        raise ValueError(f\"Unable to open video file: {path}\")\n\n    frames = []\n    while cap.isOpened():\n        ret, frame = cap.read()\n        if ret == True:\n            frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))\n        else:\n            break\n    cap.release()\n\n    return np.stack(frames)\n\n\nclass Visualizer:\n    def __init__(\n            self,\n            save_dir: str = \"./results\",\n            grayscale: bool = False,\n            pad_value: int = 0,\n            fps: int = 10,\n            mode: str = \"rainbow\",  # 'cool', 'optical_flow'\n            linewidth: int = 2,\n            show_first_frame: int = 10,\n            tracks_leave_trace: int = 0,  # -1 for infinite\n            tracks_use_alpha: bool = False,\n            print_debug_info: bool = False,\n    ):\n        self.mode = mode\n        self.save_dir = save_dir\n        if mode == \"rainbow\":\n            self.color_map = cm.get_cmap(\"gist_rainbow\")\n        elif mode == \"cool\":\n            self.color_map = cm.get_cmap(mode)\n        self.show_first_frame = show_first_frame\n        self.grayscale = grayscale\n        self.tracks_leave_trace = tracks_leave_trace\n        self.tracks_use_alpha = tracks_use_alpha\n        self.print_debug_info = print_debug_info\n        self.pad_value = pad_value\n        self.linewidth = linewidth\n        self.fps = fps\n\n    def visualize(\n            self,\n            video: torch.Tensor,  # (B,T,C,H,W)\n            tracks: torch.Tensor,  # (B,T,N,2)\n            visibility: torch.Tensor = None,  # (B, T, N) bool\n            gt_tracks: torch.Tensor = None,  # (B,T,N,2)\n            segm_mask: torch.Tensor = None,  # (B,1,H,W)\n            filename: str = \"video\",\n            writer=None,  # tensorboard Summary Writer, used for visualization during training\n            step: int = 0,\n            query_frame: torch.Tensor = None,  # (B,N)\n            save_video: bool = True,\n            compensate_for_camera_motion: bool = False,\n            rigid_part=None,\n            video_depth=None,  # (B,T,C,H,W)\n            vector_colors=None,\n    ):\n        batch_size, num_frames, _, height, width = video.shape\n        num_points = tracks.shape[-2]\n        num_dims = tracks.shape[-1]\n\n        assert video.shape == (batch_size, num_frames, 3, height, width)\n        assert tracks.shape == (batch_size, num_frames, num_points, num_dims)\n        if visibility is not None:\n            assert visibility.shape == (batch_size, num_frames, num_points)\n        if gt_tracks is not None:\n            assert gt_tracks.shape == (batch_size, num_frames, num_points, num_dims)\n        if query_frame is not None:\n            assert query_frame.shape == (batch_size, num_points)\n\n        if compensate_for_camera_motion:\n            assert segm_mask is not None\n\n        if segm_mask is not None:\n            assert (query_frame == 0).all().item()\n            coords = tracks[0, 0].round().long()\n            segm_mask = segm_mask[0, 0][coords[:, 1], coords[:, 0]].long()\n\n        video = F.pad(\n            video,\n            (self.pad_value, self.pad_value, self.pad_value, self.pad_value),\n            \"constant\",\n            255,\n        )\n        if video_depth is not None:\n            video_depth = video_depth.squeeze(2)\n            video_depth = video_depth.cpu().numpy()\n            highest_depth_value = max(video_depth.max(), 100)\n            video_depth = plt.cm.Spectral(video_depth / highest_depth_value) * 255\n            video_depth = video_depth[..., :3]\n            video_depth = video_depth.astype(np.uint8)\n            video_depth = torch.from_numpy(video_depth)\n            video_depth = video_depth.permute(0, 1, 4, 2, 3)\n            video_depth = F.pad(\n                video_depth,\n                (self.pad_value, self.pad_value, self.pad_value, self.pad_value),\n                \"constant\",\n                255,\n            )\n\n        tracks = tracks + self.pad_value\n\n        if self.grayscale:\n            transform = transforms.Grayscale()\n            video = transform(video)\n            video = video.repeat(1, 1, 3, 1, 1)\n\n        res_video, vector_colors = self.draw_tracks_on_video(\n            video=video,\n            tracks=tracks[..., :2],\n            visibility=visibility,\n            segm_mask=segm_mask,\n            gt_tracks=gt_tracks,\n            query_frame=query_frame,\n            compensate_for_camera_motion=compensate_for_camera_motion,\n            rigid_part=rigid_part,\n            vector_colors=vector_colors,\n        )\n        if video_depth is not None:\n            res_video_depth, _ = self.draw_tracks_on_video(\n                video=video_depth,\n                tracks=tracks[..., :2],\n                visibility=visibility,\n                segm_mask=segm_mask,\n                gt_tracks=gt_tracks,\n                query_frame=query_frame,\n                compensate_for_camera_motion=compensate_for_camera_motion,\n                vector_colors=vector_colors,\n            )\n            res_video = torch.cat([res_video, res_video_depth], dim=4)  # B, T, 3, H, [W]\n\n        if save_video:\n            # self.save_video(res_video, filename=filename, writer=writer, step=step)\n            thread = threading.Thread(\n                target=Visualizer.save_video,\n                args=(res_video, self.save_dir, filename, writer, self.fps, step)\n            )\n            thread.start()\n        return res_video, vector_colors\n\n    @staticmethod\n    def save_video(video, save_dir, filename, writer=None, fps=12, step=0):\n        if writer is not None:\n            writer.add_video(f\"{filename}\", video.to(torch.uint8), global_step=step, fps=fps)\n            writer.flush()\n            logging.info(f\"Video {filename} saved to tensorboard\")\n\n        if save_dir is not None:\n            os.makedirs(save_dir, exist_ok=True)\n            wide_list = list(video.unbind(1))\n            wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]\n            clip = ImageSequenceClip(wide_list, fps=fps)\n\n            # Write the video file\n            save_path = os.path.join(save_dir, f\"{filename}_step_{step}.mp4\")\n            clip.write_videofile(save_path, codec=\"libx264\", fps=fps, logger=None)\n\n            logging.info(f\"Video saved to {save_path}\")\n\n    def draw_tracks_on_video(\n            self,\n            video: torch.Tensor,\n            tracks: torch.Tensor,\n            visibility: torch.Tensor = None,\n            segm_mask: torch.Tensor = None,\n            gt_tracks=None,\n            query_frame: torch.Tensor = None,\n            compensate_for_camera_motion=False,\n            vector_colors=None,\n            rigid_part=None,\n    ):\n        B, T, C, H, W = video.shape\n        _, _, N, D = tracks.shape\n\n        assert D == 2\n        assert C == 3\n        video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy()  # S, H, W, C\n        tracks = tracks[0].long().detach().cpu().numpy()  # S, N, 2\n        if query_frame is not None:\n            query_frame = query_frame[0].long().detach().cpu().numpy()  # N\n        if gt_tracks is not None:\n            gt_tracks = gt_tracks[0].detach().cpu().numpy()\n\n        res_video = []\n\n        # process input video\n        for rgb in video:\n            res_video.append(rgb.copy())\n\n        if vector_colors is None:\n            vector_colors = np.zeros((T, N, 3))\n            if self.mode == \"optical_flow\":\n                vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame, torch.arange(N)][None])\n            elif segm_mask is None:\n                if self.mode == \"rainbow\":\n                    # y_min, y_max = (\n                    #     tracks[query_frame, :, 1].min(),\n                    #     tracks[query_frame, :, 1].max(),\n                    # )\n                    y_min, y_max = 0, H\n                    norm = plt.Normalize(y_min, y_max)\n                    for n in range(N):\n                        color = self.color_map(norm(tracks[query_frame[n], n, 1]))\n                        color = np.array(color[:3])[None] * 255\n                        vector_colors[:, n] = np.repeat(color, T, axis=0)\n                else:\n                    # color changes with time\n                    for t in range(T):\n                        color = np.array(self.color_map(t / T)[:3])[None] * 255\n                        vector_colors[t] = np.repeat(color, N, axis=0)\n            else:\n                if self.mode == \"rainbow\":\n                    vector_colors[:, segm_mask <= 0, :] = 255\n\n                    # y_min, y_max = (\n                    #     tracks[0, segm_mask > 0, 1].min(),\n                    #     tracks[0, segm_mask > 0, 1].max(),\n                    # )\n                    y_min, y_max = 0, H\n                    norm = plt.Normalize(y_min, y_max)\n                    for n in range(N):\n                        if segm_mask[n] > 0:\n                            color = self.color_map(norm(tracks[0, n, 1]))\n                            color = np.array(color[:3])[None] * 255\n                            vector_colors[:, n] = np.repeat(color, T, axis=0)\n\n                else:\n                    # color changes with segm class\n                    segm_mask = segm_mask.cpu()\n                    color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)\n                    color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0\n                    color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0\n                    vector_colors = np.repeat(color[None], T, axis=0)\n\n        #  draw tracks\n        if self.tracks_leave_trace != 0:\n            for t in range(1, T):\n                first_ind = (\n                    max(0, t - self.tracks_leave_trace)\n                    if self.tracks_leave_trace >= 0\n                    else 0\n                )\n                curr_tracks = tracks[first_ind: t + 1]\n                curr_colors = vector_colors[first_ind: t + 1]\n                if compensate_for_camera_motion:\n                    diff = (\n                                   tracks[first_ind: t + 1, segm_mask <= 0]\n                                   - tracks[t: t + 1, segm_mask <= 0]\n                           ).mean(1)[:, None]\n\n                    curr_tracks = curr_tracks - diff\n                    curr_tracks = curr_tracks[:, segm_mask > 0]\n                    curr_colors = curr_colors[:, segm_mask > 0]\n\n                res_video[t] = self._draw_pred_tracks(\n                    res_video[t],\n                    curr_tracks,\n                    curr_colors,\n                    query_frame - first_ind,\n                    use_alpha=self.tracks_use_alpha,\n                )\n                if gt_tracks is not None:\n                    res_video[t] = self._draw_gt_tracks(\n                        res_video[t], gt_tracks[first_ind: t + 1]\n                    )\n\n        # Add frame number\n        if self.print_debug_info:\n            for t in range(T):\n                min_x = tracks[t].min(0)[0]\n                min_y = tracks[t].min(0)[1]\n                min_xy = f\"{min_x:6.1f}, {min_y:6.1f}\"\n\n                median_x = np.median(tracks[t], axis=0)[0]\n                median_y = np.median(tracks[t], axis=0)[1]\n                median_xy = f\"{median_x:6.1f}, {median_y:6.1f}\"\n\n                max_x = tracks[t].max(0)[0]\n                max_y = tracks[t].max(0)[1]\n                max_xy = f\"{max_x:6.1f}, {max_y:6.1f}\"\n\n                text = (\n                    f\"Frame {t}\"\n                    f\"\\nH,W={H},{W}\"\n                    f\"\\nT,N={T},{N}\"\n                    f\"\\nmin_xy    = {min_xy} \"\n                    f\"\\nmedian_xy = {median_xy} \"\n                    f\"\\nmax_xy    = {max_xy} \"\n                )\n                res_video[t] = put_debug_text_onto_image(res_video[t], text)\n\n        if rigid_part is not None:\n            cls_label = torch.unique(rigid_part)\n            cls_num = len(torch.unique(rigid_part))\n            # visualize the clustering results \n            cmap = plt.get_cmap('jet')  # get the color mapping\n            colors = cmap(np.linspace(0, 1, cls_num))\n            colors = (colors[:, :3] * 255)\n            color_map = {label.item(): color for label, color in zip(cls_label, colors)}\n\n        #  draw points\n        for t in range(T):\n            for i in range(N):\n                if query_frame is not None and query_frame[i] > t:\n                    continue\n\n                coord = (tracks[t, i, 0], tracks[t, i, 1])\n                visibile = True\n                if visibility is not None:\n                    visibile = visibility[0, t, i]\n\n                # Check for NaN or Inf in coordinates\n                if np.isnan(coord).any() or np.isinf(coord).any():\n                    logging.info(f\"Warning: Skipping track {i} at t={t} due to NaN or Inf coord={coord}.\")\n                    continue  # Skip plotting this point\n\n                if coord[0] != 0 and coord[1] != 0:\n                    if not compensate_for_camera_motion or (\n                            compensate_for_camera_motion and segm_mask[i] > 0\n                    ):\n                        if rigid_part is not None:\n                            color = color_map[rigid_part.squeeze()[i].item()]\n                            cv2.circle(\n                                res_video[t],\n                                coord,\n                                int(self.linewidth * 2),\n                                color.tolist(),\n                                thickness=-1 if visibile else 2 - 1,\n                            )\n                        else:\n                            cv2.circle(\n                                res_video[t],\n                                coord,\n                                int(self.linewidth * 2),\n                                vector_colors[t, i].tolist(),\n                                thickness=-1 if visibile else 2 - 1,\n                            )\n\n        #  construct the final rgb sequence\n        if self.show_first_frame > 0:\n            res_video = [res_video[0]] * self.show_first_frame + res_video[1:]\n        return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte(), vector_colors\n\n    def _draw_pred_tracks(\n            self,\n            rgb: np.ndarray,  # H x W x 3\n            tracks: np.ndarray,  # shape: [T, N, 2]\n            vector_colors: np.ndarray,  # shape: [T, N, 3]\n            query_frame: np.ndarray,  # shape: [N], each entry = birth frame for track i\n            use_alpha: bool = False,\n    ) -> np.ndarray:\n        \"\"\"\n        Draws trajectory lines from frame s to s+1, but only if s >= query_frame[i].\n        That is, no lines are drawn before the track 'appears' at query_frame[i].\n        \"\"\"\n        T, N, _ = tracks.shape\n\n        for s in range(T - 1):\n            # We'll blend older lines more lightly (alpha) if desired:\n            original_rgb = rgb.copy()\n            if use_alpha:\n                alpha = (s / T) ** 2  # or pick some function of s, T\n            else:\n                alpha = 1\n\n            for i in range(N):\n                # If the query/birth frame for track i is after s, skip drawing\n                if query_frame is not None and s < query_frame[i]:\n                    continue\n\n                pt_s = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))\n                pt_sp1 = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))\n\n                # Skip if the points are 0 or invalid\n                if pt_s[0] == 0 and pt_s[1] == 0:\n                    continue\n                if pt_sp1[0] == 0 and pt_sp1[1] == 0:\n                    continue\n\n                color = vector_colors[s, i].tolist()\n                cv2.line(rgb, pt_s, pt_sp1, color, self.linewidth, cv2.LINE_AA)\n\n            # Optionally alpha-blend older lines if you want them to fade out:\n            rgb = cv2.addWeighted(rgb, alpha, original_rgb, 1 - alpha, 0)\n\n        return rgb\n\n    def _draw_gt_tracks(\n            self,\n            rgb: np.ndarray,  # H x W x 3,\n            gt_tracks: np.ndarray,  # T x 2\n    ):\n        T, N, _ = gt_tracks.shape\n        color = np.array((211.0, 0.0, 0.0))\n\n        for t in range(T):\n            for i in range(N):\n                gt_tracks = gt_tracks[t][i]\n                #  draw a red cross\n                if gt_tracks[0] > 0 and gt_tracks[1] > 0:\n                    length = self.linewidth * 3\n                    coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)\n                    coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)\n                    cv2.line(\n                        rgb,\n                        coord_y,\n                        coord_x,\n                        color,\n                        self.linewidth,\n                        cv2.LINE_AA,\n                    )\n                    coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)\n                    coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)\n                    cv2.line(\n                        rgb,\n                        coord_y,\n                        coord_x,\n                        color,\n                        self.linewidth,\n                        cv2.LINE_AA,\n                    )\n        return rgb\n\n\ndef put_debug_text_onto_image(img: np.ndarray, text: str, font_scale: float = 0.5, left: int = 5, top: int = 20,\n                              font_thickness: int = 1, text_color_bg: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:\n    \"\"\"\n    Overlay debug text on the provided image.\n\n    Parameters\n    ----------\n    img : np.ndarray\n        A 3D numpy array representing the input image. The image is expected to have three color channels.\n    text : str\n        The debug text to overlay on the image. The text can include newline characters ('\\n') to create multi-line text.\n    font_scale : float, default 0.5\n        The scale factor that is multiplied by the font-specific base size.\n    left : int, default 5\n        The left-most coordinate where the text is to be put.\n    top : int, default 20\n        The top-most coordinate where the text is to be put.\n    font_thickness : int, default 1\n        Thickness of the lines used to draw the text.\n    text_color_bg : Tuple[int, int, int], default (0, 0, 0)\n        The color of the text background in BGR format.\n\n    Returns\n    -------\n    img : np.ndarray\n        A 3D numpy array representing the image with the debug text overlaid.\n    \"\"\"\n    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n    font_color = (255, 255, 255)\n\n    # Write each line of text in a new row\n    (_, label_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)\n    if text_color_bg is not None:\n        for i, line in enumerate(text.split('\\n')):\n            (line_width, _), _ = cv2.getTextSize(line, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)\n            top_i = top + i * label_height\n            cv2.rectangle(img, (left, top_i - label_height), (left + line_width, top_i), text_color_bg, -1)\n    for i, line in enumerate(text.split('\\n')):\n        top_i = top + i * label_height\n        cv2.putText(img, line, (left, top_i), cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_color, font_thickness)\n\n    img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)\n    return img\n\n\nclass MultiViewVisualizer(Visualizer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    def visualize(\n            self,\n            video: torch.Tensor,  # (B,V,T,C,H,W)\n            tracks: torch.Tensor,  # (B,V,T,N,2)\n            visibility: torch.Tensor = None,  # (B,V,T,N) bool\n            gt_tracks: torch.Tensor = None,  # (B,V,T,N,2)\n            segm_mask: torch.Tensor = None,  # (B,V,1,H,W)\n            filename: str = \"video\",\n            writer=None,  # tensorboard Summary Writer, used for visualization during training\n            step: int = 0,\n            query_frame: torch.Tensor = None,  # (B,N)\n            save_video: bool = True,\n            compensate_for_camera_motion: bool = False,\n            rigid_part=None,\n            video_depth=None,  # (B,V,T,C,H,W)\n            vector_colors=None,\n    ):\n        # Replace NaN and Inf values with 0\n        tracks = tracks.detach().clone().clip(-1e4, 1e4)\n        tracks[torch.isnan(tracks)] = 0\n        gt_tracks = gt_tracks.detach().clone().clip(-1e4, 1e4) if gt_tracks is not None else None\n\n        batch_size, num_views, num_frames, _, height, width = video.shape\n        num_points = tracks.shape[-2]\n        num_dims = tracks.shape[-1]\n\n        # Repeat visibility for each view if only global visibility is provided\n        if visibility is not None and visibility.dim() == 3:\n            visibility = visibility[:, None, :, :].repeat(1, num_views, 1, 1)\n\n        # Assert shapes of per-view data\n        assert video.shape == (batch_size, num_views, num_frames, 3, height, width)\n        assert tracks.shape == (batch_size, num_views, num_frames, num_points, num_dims)\n        assert num_dims in [2, 3]\n        if gt_tracks is not None:\n            assert gt_tracks.shape == (batch_size, num_views, num_frames, num_points, num_dims)\n        if visibility is not None:\n            assert visibility.shape == (batch_size, num_views, num_frames, num_points)\n        if segm_mask is not None:\n            assert segm_mask.shape == (batch_size, num_views, 1, height, width)\n        if video_depth is not None:\n            assert video_depth.shape == (batch_size, num_views, num_frames, 1, height, width)\n\n        res_video_list = []\n        for view_idx in range(num_views):\n            res_video, vector_colors = super(MultiViewVisualizer, self).visualize(\n                # Extract view-specific data\n                video=video[:, view_idx],\n                tracks=tracks[:, view_idx],\n                visibility=visibility[:, view_idx],\n                gt_tracks=gt_tracks[:, view_idx] if gt_tracks is not None else None,\n                segm_mask=segm_mask[:, view_idx] if segm_mask is not None else None,\n                video_depth=video_depth[:, view_idx] if video_depth is not None else None,\n\n                # Pass-through arguments\n                step=step,\n                query_frame=query_frame,\n                compensate_for_camera_motion=compensate_for_camera_motion,\n                rigid_part=rigid_part,\n                vector_colors=vector_colors,\n\n                # Disable saving video for individual views as we will save the merged videos\n                filename=None,\n                writer=None,\n                save_video=False\n            )\n            res_video_list.append(res_video)\n        res_video = torch.cat(res_video_list, dim=3)\n        if save_video:\n            # Visualizer.save_video(res_video, self.save_dir, filename, writer, self.fps, step)\n            thread = threading.Thread(\n                target=Visualizer.save_video,\n                args=(res_video, self.save_dir, filename, writer, self.fps, step)\n            )\n            thread.start()\n        return res_video, vector_colors\n\n\ndef log_mp4_track_viz(\n        log_dir,\n        dataset_name,\n        datapoint_idx,\n        rgbs,\n        intrs,\n        extrs,\n        gt_trajectories,\n        gt_visibilities,\n        pred_trajectories,\n        pred_visibilities,\n        query_points_3d,\n        step=0,\n        prefix=\"comparison__\",\n        max_tracks_to_visualize=36,\n        max_individual_tracks_to_visualize=6,\n):\n    batch_size, num_frames, num_points, _ = gt_trajectories.shape\n    num_views = rgbs.shape[1]\n\n    intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n    extrs_square = torch.eye(4).to(extrs.device)[None].repeat(batch_size, num_views, num_frames, 1, 1)\n    extrs_square[:, :, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n    assert intrs_inv.shape == (batch_size, num_views, num_frames, 3, 3)\n    assert extrs_inv.shape == (batch_size, num_views, num_frames, 4, 4)\n\n    gt_pix_xy_cam_z = torch.stack([\n        torch.cat(world_space_to_pixel_xy_and_camera_z(\n            world_xyz=gt_trajectories[0],\n            intrs=intrs[0, view_idx],\n            extrs=extrs[0, view_idx],\n        ), dim=-1)\n        for view_idx in range(num_views)\n    ], dim=0)[None]\n\n    pred_pix_xy_cam_z = torch.stack([\n        torch.cat(world_space_to_pixel_xy_and_camera_z(\n            world_xyz=pred_trajectories[0],\n            intrs=intrs[0, view_idx],\n            extrs=extrs[0, view_idx],\n        ), dim=-1)\n        for view_idx in range(num_views)\n    ], dim=0)[None]\n\n    visualizer = MultiViewVisualizer(\n        save_dir=log_dir,\n        pad_value=0,\n        fps=30 if \"panoptic\" in dataset_name else 12,\n        show_first_frame=0,\n        tracks_leave_trace=-1,\n    )\n    seq_name = f\"seq-{datapoint_idx}\"\n\n    # Plot all tracks at the same time\n    gt_viz, vector_colors = visualizer.visualize(\n        video=rgbs.cpu(),\n        video_depth=None,\n        tracks=gt_pix_xy_cam_z[:, :, :, :max_tracks_to_visualize].cpu(),\n        visibility=gt_visibilities.clone()[:, :, :max_tracks_to_visualize].cpu(),\n        query_frame=query_points_3d[..., 0].long().clone()[:, :max_tracks_to_visualize].cpu(),\n        filename=f\"eval_{dataset_name}_gt_traj_{seq_name}_any_visib\",\n        save_video=False,\n    )\n    pred_viz, _ = visualizer.visualize(\n        video=rgbs.cpu(),\n        video_depth=None,\n        tracks=pred_pix_xy_cam_z[:, :, :, :max_tracks_to_visualize].cpu(),\n        visibility=pred_visibilities[:, :, :max_tracks_to_visualize].cpu(),\n        query_frame=query_points_3d[..., 0].long().clone()[:, :max_tracks_to_visualize].cpu(),\n        filename=f\"eval_{dataset_name}_pred_traj_{seq_name}\",\n        save_video=False,\n        vector_colors=vector_colors,\n    )\n    viz = torch.cat([gt_viz, pred_viz], dim=-1)\n    thread = threading.Thread(\n        target=Visualizer.save_video,\n        args=(viz, visualizer.save_dir, f\"{prefix}{seq_name}\", None, visualizer.fps, step)\n    )\n    thread.start()\n    thread.join()\n\n    # Plot individual tracks\n    for track_idx in range(min(num_points, max_individual_tracks_to_visualize)):\n        seq_name_i = f\"seq-{datapoint_idx}-point-{track_idx:02d}\"\n        gt_viz, vector_colors_i = visualizer.visualize(\n            video=rgbs.cpu(),\n            video_depth=None,\n            tracks=gt_pix_xy_cam_z[:, :, :, track_idx:track_idx + 1].cpu(),\n            visibility=gt_visibilities.clone()[:, :, track_idx:track_idx + 1].cpu(),\n            query_frame=query_points_3d[..., 0].long().clone()[:, track_idx:track_idx + 1].cpu(),\n            filename=f\"eval_{dataset_name}_gt_traj_{seq_name_i}_any_visib\",\n            step=step,\n            save_video=False,\n        )\n        pred_viz, _ = visualizer.visualize(\n            video=rgbs.cpu(),\n            video_depth=None,\n            tracks=pred_pix_xy_cam_z[:, :, :, track_idx:track_idx + 1].cpu(),\n            visibility=pred_visibilities[:, :, track_idx:track_idx + 1].cpu(),\n            query_frame=query_points_3d[..., 0].long().clone()[:, track_idx:track_idx + 1].cpu(),\n            filename=f\"eval_{dataset_name}_pred_traj_{seq_name_i}\",\n            save_video=False,\n            vector_colors=vector_colors_i,\n        )\n        viz = torch.cat([gt_viz, pred_viz], dim=-1)\n        thread = threading.Thread(\n            target=Visualizer.save_video,\n            args=(viz, visualizer.save_dir, f\"{prefix}{seq_name_i}\", None, visualizer.fps, step)\n        )\n        thread.start()\n        thread.join()\n"
  },
  {
    "path": "mvtracker/utils/visualizer_rerun.py",
    "content": "from typing import Union, Optional, List, Dict, Any\n\nimport matplotlib\nimport numpy as np\nimport pandas as pd\nimport rerun as rr\nimport seaborn as sns\nimport torch\nfrom matplotlib import pyplot as plt, colors as mcolors, cm as cm\nfrom sklearn.decomposition import PCA\n\n\ndef setup_libs(latex=False):\n    pd.set_option('display.max_rows', 500)\n    pd.set_option('display.max_columns', 500)\n    pd.set_option('display.width', 1000)\n\n    sns.set_theme(style=\"white\", rc={\"axes.facecolor\": (0, 0, 0, 0)})\n    sns.set_style(\"ticks\")\n    sns.set_palette(\"flare\")\n\n    if latex:\n        plt.rc('font', **{'family': 'serif', 'serif': ['Computer Modern Roman']})\n        plt.rc('text', usetex=True)\n    plt.rcParams.update({\n        'figure.titlesize': '28',\n        'axes.titlesize': '22',\n        'axes.titlepad': '10',\n        'legend.title_fontsize': '16',\n        'legend.fontsize': '14',\n        'axes.labelsize': '18',\n        'xtick.labelsize': '16',\n        'ytick.labelsize': '16',\n        'figure.dpi': 200,\n    })\n\n\ndef log_pointclouds_to_rerun(\n        dataset_name: str,\n        datapoint_idx: Union[int, str],\n        rgbs: torch.Tensor,\n        depths: torch.Tensor,\n        intrs: torch.Tensor,\n        extrs: torch.Tensor,\n        depths_conf: Optional[torch.Tensor] = None,\n        conf_thrs: Optional[List[float]] = None,\n        log_only_confident_pc: bool = False,\n        radii: float = -2.45,\n        fps: float = 30.0,\n        bbox_crop: Optional[torch.Tensor] = None,  # e.g., np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]])\n        sphere_radius_crop: Optional[float] = None,  # e.g., 6.0\n        sphere_center_crop: Optional[np.ndarray] = np.array([0, 0, 0]),\n        log_rgb_image: bool = False,\n        log_depthmap_as_image_v1: bool = False,\n        log_depthmap_as_image_v2: bool = False,\n        log_camera_frustrum: bool = True,\n        log_rgb_pointcloud: bool = True,\n        timesteps_to_log: Optional[List[int]] = None,\n):\n    # Set the up-axis for the world\n    # Log coordinate axes for reference\n    rr.set_time_seconds(\"frame\", 0)\n    B, V, T, _, H, W = rgbs.shape\n    assert rgbs.shape == (B, V, T, 3, H, W)\n    assert depths.shape == (B, V, T, 1, H, W)\n    assert depths_conf is None or depths_conf.shape == (B, V, T, 1, H, W)\n    assert intrs.shape == (B, V, T, 3, 3)\n    assert extrs.shape == (B, V, T, 3, 4)\n    assert B == 1\n    # Compute inverse intrinsics and extrinsics\n    intrs_inv = torch.inverse(intrs.float()).type(intrs.dtype)\n    extrs_square = torch.eye(4).to(extrs.device)[None].repeat(B, V, T, 1, 1)\n    extrs_square[:, :, :, :3, :] = extrs\n    extrs_inv = torch.inverse(extrs_square.float()).type(extrs.dtype)\n    assert intrs_inv.shape == (B, V, T, 3, 3)\n    assert extrs_inv.shape == (B, V, T, 4, 4)\n    for v in range(V):  # Iterate over views\n        for t in range(T):  # Iterate over frames\n\n            if timesteps_to_log is not None and t not in timesteps_to_log:\n                continue\n\n            rr.set_time_seconds(\"frame\", t / fps)\n\n            # Log RGB image\n            rgb_image = rgbs[0, v, t].permute(1, 2, 0).cpu().numpy()\n            if log_rgb_image:\n                rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/rgb\", rr.Image(rgb_image))\n\n            # Log Depth map\n            depth_map = depths[0, v, t, 0].cpu().numpy()\n            if log_depthmap_as_image_v1:\n                rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/depth\",\n                       rr.DepthImage(depth_map, point_fill_ratio=0.2))\n\n            # Log Depth map as RGB\n            d_min, d_max = depth_map.min(), depth_map.max()\n            norm = mcolors.Normalize(vmin=d_min, vmax=d_max)\n            turbo_cmap = cm.get_cmap(\"turbo\")  # \"viridis\", \"plasma\", etc.\n            depth_color_rgba = turbo_cmap(norm(depth_map))\n            depth_color_rgb = (depth_color_rgba[..., :3] * 255).astype(np.uint8)\n            if log_depthmap_as_image_v2:\n                rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}/deptha-as-rgb\",\n                       rr.Image(depth_color_rgb))\n\n            # Log Camera\n            K = intrs[0, v, t].cpu().numpy()\n            world_T_cam = np.eye(4)\n            world_T_cam[:3, :3] = extrs_inv[0, v, t, :3, :3].cpu().numpy()\n            world_T_cam[:3, 3] = extrs_inv[0, v, t, :3, 3].cpu().numpy()\n            if log_camera_frustrum:\n                rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}\",\n                       rr.Pinhole(image_from_camera=K, width=W, height=H))\n                rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/image/view-{v}\",\n                       rr.Transform3D(translation=world_T_cam[:3, 3], mat3x3=world_T_cam[:3, :3]))\n\n            # Generate and log point cloud colored by RGB values\n            # Compute 3D points from depth map\n            y, x = np.indices((H, W))\n            homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n            depth_values = depth_map.ravel()\n            cam_coords = (intrs_inv[0, v, t].cpu().numpy() @ homo_pixel_coords) * depth_values\n            cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n            world_coords = (world_T_cam @ cam_coords)[:3].T\n            rgb_colors = rgb_image.reshape(-1, 3).astype(np.uint8)\n\n            # Log point clouds\n            if log_rgb_pointcloud:\n                # Filter out points with zero depth\n                valid_mask = depth_values > 0\n\n                # Filter out points outside this bbox\n                # bbox_crop = np.array([[-4, 4], [-3, 3.7], [1.2, 5.2]])\n                if bbox_crop is not None:\n                    bbox_mask = (\n                            (world_coords[..., 0] > bbox_crop[0, 0])\n                            & (world_coords[..., 0] < bbox_crop[0, 1])\n                            & (world_coords[..., 1] > bbox_crop[1, 0])\n                            & (world_coords[..., 1] < bbox_crop[1, 1])\n                            & (world_coords[..., 2] > bbox_crop[2, 0])\n                            & (world_coords[..., 2] < bbox_crop[2, 1])\n                    )\n                    valid_mask = valid_mask & bbox_mask\n\n                # Lightweight Kubric and DexYCB\n                if sphere_radius_crop is not None:\n                    assert sphere_center_crop is not None\n                    sphere_mask = ((world_coords - sphere_center_crop) ** 2).sum(-1) < sphere_radius_crop ** 2\n                    valid_mask = valid_mask & sphere_mask\n\n                # Filter out points with confidence below threshold\n                pc_name__mask__tuples = []\n                if not (log_only_confident_pc and depths_conf is not None):\n                    pc_name__mask__tuples += [(\"point_cloud\", valid_mask)]\n                if depths_conf is not None:\n                    confs = depths_conf[0, v, t, 0].cpu().numpy()\n                    assert conf_thrs is not None\n                    for thr in conf_thrs:\n                        name = f\"point_cloud__conf-{thr}\"\n                        mask = valid_mask & (confs.ravel() > thr)\n                        if (valid_mask == mask).all():\n                            continue\n                        pc_name__mask__tuples += [(name, mask)]\n                for pc_name, mask in pc_name__mask__tuples:\n                    rr.log(f\"sequence-{datapoint_idx}/{dataset_name}/{pc_name}/view-{v}\",\n                           rr.Points3D(world_coords[mask], colors=rgb_colors[mask], radii=radii))\n\n\ndef _log_tracks_to_rerun(\n        tracks: np.ndarray,\n        visibles: np.ndarray,\n        query_timestep: np.ndarray,\n        colors: np.ndarray,\n        track_names=None,\n        fps=30.0,\n\n        entity_format_str=\"{}\",\n\n        log_points=True,\n        points_radii=-3.6,\n\n        log_line_strips=True,\n        max_strip_length_past=10,\n        max_strip_length_future=0,\n        strips_radii=-1.8,\n\n        log_error_lines=False,\n        error_lines_radii=0.0042,\n        error_lines_color=[1., 0., 0.],\n        gt_for_error_lines=None,\n) -> None:\n    \"\"\"\n    Log tracks to Rerun.\n\n    Parameters:\n        tracks: Shape (T, N, 3), the 3D trajectories of points.\n        visibles: Shape (T, N), boolean visibility mask for each point at each timestep.\n        query_timestep: Shape (T, N), the frame index after which the tracks start.\n        colors: Shape (N, 4), RGBA colors for each point.\n    \"\"\"\n    T, N, _ = tracks.shape\n    assert tracks.shape == (T, N, 3)\n    assert visibles.shape == (T, N)\n    assert query_timestep.shape == (N,)\n    assert query_timestep.min() >= 0\n    assert query_timestep.max() < T\n    assert colors.shape == (N, 4)\n\n    for n in range(N):\n        track_name = track_names[n] if track_names is not None else f\"track-{n}\"\n        rr.log(entity_format_str.format(track_name), rr.Clear(recursive=True))\n        for t in range(query_timestep[n], T):\n            # if t not in [0] + [T * (x + 1) // 3 - 1 for x in range(3)]:\n            # if t not in [T - 1]:\n            #     continue\n            rr.set_time_seconds(\"frame\", t / fps)\n\n            # Log the point (special handling for invisible points)\n            if log_points:\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/point\"),\n                    rr.Points3D(\n                        positions=[tracks[t, n]],\n                        colors=[colors[n, :3]] if visibles[t, n] else [colors[n, :3] * 0.7],\n                        radii=points_radii,\n                    ),\n                )\n\n            # Log line segments for visible tracks\n            if log_line_strips and t > query_timestep[n]:\n                strip_t_start = max(t - max_strip_length_past, query_timestep[n].item())\n                strip_t_end = min(t + max_strip_length_future, T - 1)\n\n                strips = np.stack([\n                    tracks[strip_t_start:strip_t_end, n],\n                    tracks[strip_t_start + 1:strip_t_end + 1, n],\n                ], axis=-2)\n                strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n]\n                strips_colors = np.where(\n                    strips_visibility[:, None],\n                    colors[None, n, :3],\n                    colors[None, n, :3] * 0.7,\n                )\n\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/line\"),\n                    rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii),\n                )\n\n            if log_error_lines:\n                assert gt_for_error_lines is not None\n                strips = np.stack([\n                    tracks[t, n],\n                    gt_for_error_lines[t, n],\n                ], axis=-2)\n                rr.log(\n                    entity_format_str.format(f\"{track_name}/error\"),\n                    rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii),\n                )\n\n\ndef _log_tracks_to_rerun_lightweight(\n        tracks: np.ndarray,\n        visibles: np.ndarray,\n        query_timestep: np.ndarray,\n        colors: np.ndarray,\n        track_names=None,\n        fps=30.0,\n\n        entity_format_str=\"{}\",\n\n        log_points=True,\n        points_radii=0.01,\n\n        log_line_strips=True,\n        max_strip_length_past=24,\n        max_strip_length_future=0,\n        strips_radii=0.0042,\n\n        log_error_lines=False,\n        error_lines_radii=0.0010,\n        error_lines_color=[1., 0., 0.],\n        gt_for_error_lines=None,\n) -> None:\n    \"\"\"\n    Log tracks to Rerun.\n\n    Parameters:\n        tracks: Shape (T, N, 3), the 3D trajectories of points.\n        visibles: Shape (T, N), boolean visibility mask for each point at each timestep.\n        query_timestep: Shape (T, N), the frame index after which the tracks start.\n        colors: Shape (N, 4), RGBA colors for each point.\n    \"\"\"\n    T, N, _ = tracks.shape\n    assert tracks.shape == (T, N, 3)\n    assert visibles.shape == (T, N)\n    assert query_timestep.shape == (N,)\n    assert query_timestep.min() >= 0\n    assert query_timestep.max() < T\n    assert colors.shape == (N, 4)\n\n    for t in range(T):\n        rr.set_time_seconds(\"frame\", t / fps)\n        points_list, points_colors = [], []\n        strips_list, strips_colors_list = [], []\n        errors_list = []\n        for n in range(N):\n            if t > query_timestep[n]:\n                strip_t_start = max(t - max_strip_length_past, query_timestep[n].item())\n                strip_t_end = min(t + max_strip_length_future, T - 1)\n\n                strips = np.stack([\n                    tracks[strip_t_start:strip_t_end, n],\n                    tracks[strip_t_start + 1:strip_t_end + 1, n],\n                ], axis=-2)\n                strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n]\n                strips_colors = np.where(\n                    strips_visibility[:, None],\n                    colors[None, n, :3],\n                    colors[None, n, :3] * 0.7,\n                )\n                if log_line_strips:\n                    strips_list.append(strips)\n                    strips_colors_list.append(strips_colors)\n\n                for t_ in range(strip_t_start, strip_t_end + 1):\n                    if log_points:\n                        points_list += [tracks[t_, n]]\n                        points_colors += [colors[n, :3]] if visibles[t_, n] else [colors[n, :3] * 0.7]\n\n                    if log_error_lines:\n                        assert gt_for_error_lines is not None\n                        error_lines = np.stack([\n                            tracks[t_, n],\n                            gt_for_error_lines[t_, n],\n                        ], axis=-2)\n                        errors_list.append(error_lines)\n\n        if log_points and len(points_list) > 0:\n            rr.log(\n                entity_format_str.format(f\"points\"),\n                rr.Points3D(\n                    positions=points_list,\n                    colors=points_colors,\n                    radii=points_radii,\n                ),\n            )\n        if log_line_strips and len(strips_list) > 0:\n            rr.log(\n                entity_format_str.format(f\"trajectories\"),\n                rr.LineStrips3D(\n                    strips=np.concatenate(strips_list, axis=0),\n                    colors=np.concatenate(strips_colors_list, axis=0),\n                    radii=strips_radii,\n                ),\n            )\n        if log_error_lines and len(errors_list) > 0:\n            rr.log(\n                entity_format_str.format(f\"errors\"),\n                rr.LineStrips3D(\n                    strips=np.stack(errors_list),\n                    colors=error_lines_color,\n                    radii=error_lines_radii,\n                ),\n            )\n\n\ndef log_tracks_to_rerun(\n        dataset_name: str,\n        datapoint_idx: Union[int, str],\n        predictor_name: str,\n        gt_trajectories_3d_worldspace: Optional[torch.Tensor],\n        gt_visibilities_any_view: Optional[torch.Tensor],\n        query_points_3d: torch.Tensor,\n        pred_trajectories: torch.Tensor,\n        pred_visibilities: torch.Tensor,\n        per_track_results: Optional[Dict[str, Any]] = None,\n        radii_scale: float = 1.0,\n        fps: float = 30.0,\n        sphere_radius_crop: Optional[float] = None,  # e.g., 6.0\n        sphere_center_crop: Optional[np.ndarray] = np.array([0, 0, 0]),\n        log_per_interval_results: bool = False,\n        max_tracks_to_log: Optional[int] = None,\n        track_batch_size: int = 100,\n        method_id: Optional[int] = None,\n        color_per_method_id: Optional[Dict[int, tuple]] = None,  # { 0: (46, 204, 113), ... }\n        memory_lightweight_logging: bool = True,\n):\n    # Prepare track data\n    gt_tracks = gt_trajectories_3d_worldspace[0].cpu().numpy() if gt_trajectories_3d_worldspace is not None else None\n    gt_vis = gt_visibilities_any_view[0].cpu().numpy() if gt_visibilities_any_view is not None else None\n    pred_tracks = pred_trajectories[0].cpu().numpy()\n    pred_vis = pred_visibilities[0].cpu().numpy()\n    query_timestep = query_points_3d[0, :, 0].cpu().numpy().astype(int)\n    T, N, _ = pred_tracks.shape\n    assert gt_tracks is None or gt_tracks.shape == (T, N, 3)\n    assert gt_vis is None or gt_vis.shape == (T, N)\n    assert pred_tracks.shape == (T, N, 3)\n    assert pred_vis.shape == (T, N)\n    assert query_timestep.shape == (N,)\n\n    if sphere_radius_crop is not None:\n        pred_tracks = pred_tracks.copy()\n        assert sphere_center_crop is not None\n        dist = np.linalg.norm(pred_tracks - sphere_center_crop, axis=-1, keepdims=True)\n        mask = dist > sphere_radius_crop\n        pred_tracks[mask[..., 0]] = (\n                sphere_center_crop + sphere_radius_crop *\n                (pred_tracks[mask[..., 0]] - sphere_center_crop) /\n                dist[mask][..., None]\n        )\n        if gt_tracks is not None:\n            gt_tracks = gt_tracks.copy()\n            assert sphere_center_crop is not None\n            dist = np.linalg.norm(gt_tracks - sphere_center_crop, axis=-1, keepdims=True)\n            mask = dist > sphere_radius_crop\n            gt_tracks[mask[..., 0]] = (\n                    sphere_center_crop + sphere_radius_crop *\n                    (gt_tracks[mask[..., 0]] - sphere_center_crop) /\n                    dist[mask][..., None]\n            )\n\n    # Last timestamp determines track color (unless method_id is specified)\n    final_xyz = gt_tracks[-1] if gt_tracks is not None else pred_tracks[-1]  # (N, 3)\n    pca = PCA(n_components=1).fit_transform(final_xyz)  # Apply PCA to spread values across 1D axis\n    pca_normalized = (pca - pca.min()) / (pca.max() - pca.min() + 1e-8)  # Normalize to [0, 1]\n    cmap = matplotlib.colormaps[\"gist_rainbow\"]\n    colors = cmap(pca_normalized[:, 0])  # Map to colormap\n    assert colors.shape == (N, 4)\n\n    # If method_id is specified, use fixed colors\n    # Fixed color mapping per method\n    if color_per_method_id is None:\n        color_per_method_id = {\n            0: (46, 204, 113),\n            1: (52, 152, 219),\n            2: (241, 196, 15),\n            3: (155, 89, 182),\n            4: (230, 126, 34),\n            5: (26, 188, 156),\n        }\n    if method_id is not None:\n        assert method_id in color_per_method_id\n        base_rgb = np.array(color_per_method_id[method_id]) / 255.0\n        colors = np.tile(np.append(base_rgb, 1.0), (N, 1))\n\n    assert colors.shape == (N, 4)\n\n    # Log the tracks\n    common_kwargs = {\n        \"points_radii\": -3.6 * radii_scale,\n        \"strips_radii\": -1.8 * radii_scale,\n        \"error_lines_radii\": 0.0042 * radii_scale,\n        \"fps\": fps,\n    }\n    if max_tracks_to_log:\n        N = min(N, max_tracks_to_log)\n    for tracks_batch_start in range(0, N, track_batch_size):\n        tracks_batch_end = min(tracks_batch_start + track_batch_size, N)\n        entity_format_strs = []\n        entity_format_strs += [\n            f\"sequence-{datapoint_idx}/tracks/{{track_name}}/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{{{}}}}\"\n        ]\n        if not memory_lightweight_logging:\n            entity_format_strs += [\n                f\"sequence-{datapoint_idx}/tracks/all/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{{{}}}}/{{track_name}}\"\n            ]\n        for entity_format_str in entity_format_strs:\n            log_tracks_fn = _log_tracks_to_rerun if not memory_lightweight_logging else _log_tracks_to_rerun_lightweight\n            # Log the GT tracks\n            if gt_tracks is not None and (method_id is None or method_id == 0):\n                log_tracks_fn(\n                    tracks=gt_tracks[:, tracks_batch_start:tracks_batch_end],\n                    visibles=gt_vis[:, tracks_batch_start:tracks_batch_end],\n                    query_timestep=query_timestep[tracks_batch_start:tracks_batch_end],\n                    colors=colors[tracks_batch_start:tracks_batch_end] * 0 + np.array([1, 1, 1, 1]),\n                    track_names=[f\"track-{i:02d}\" for i in range(tracks_batch_start, tracks_batch_end)],\n\n                    entity_format_str=entity_format_str.format(track_name=f\"gt\"),\n\n                    **common_kwargs,\n                )\n            # Log the predicted tracks\n            log_tracks_fn(\n                tracks=pred_tracks[:, tracks_batch_start:tracks_batch_end],\n                visibles=pred_vis[:, tracks_batch_start:tracks_batch_end],\n                query_timestep=query_timestep[tracks_batch_start:tracks_batch_end],\n                colors=colors[tracks_batch_start:tracks_batch_end],\n                track_names=[f\"track-{i:02d}\" for i in range(tracks_batch_start, tracks_batch_end)],\n\n                entity_format_str=entity_format_str.format(track_name=f\"pred--{predictor_name}\"),\n\n                log_error_lines=gt_tracks is not None,\n                gt_for_error_lines=gt_tracks[:, tracks_batch_start:tracks_batch_end] if gt_tracks is not None else None,\n\n                **common_kwargs,\n            )\n\n    if log_per_interval_results and per_track_results is not None:\n        intervals = [(i / 10 * 100, (i + 1) / 10 * 100) for i in range(10)]  # Intervals for 0-10%, ..., 90-100%\n        intervals += [(0, 33), (33, 66), (66, 100)]  # Intervals for lower, middle, upper third\n    else:\n        intervals = []\n    for lower, upper in intervals:\n        for point_type in [\"dynamic\", \"very_dynamic\", \"static\", \"any\"]:\n            if f\"all_{point_type}\" not in per_track_results:\n                continue\n            if lower == 0:  # Special case to include 0\n                track_indices = per_track_results[f\"all_{point_type}\"].indices[\n                    (per_track_results[f\"all_{point_type}\"].average_pts_within_thresh_per_track >= lower) &\n                    (per_track_results[f\"all_{point_type}\"].average_pts_within_thresh_per_track <= upper)\n                    ]\n            else:\n                track_indices = per_track_results[f\"all_{point_type}\"].indices[\n                    (per_track_results[f\"all_{point_type}\"].average_pts_within_thresh_per_track > lower) &\n                    (per_track_results[f\"all_{point_type}\"].average_pts_within_thresh_per_track <= upper)\n                    ]\n            if len(track_indices) == 0:\n                continue\n            entity_format_str = f\"sequence-{datapoint_idx}/tracks/location-accuracy-for-{point_type}/{int(lower)}-{int(upper)}-percent-{{track_name}}/{{{{}}}}\"\n            # Log the GT tracks\n            _log_tracks_to_rerun(\n                tracks=gt_tracks[:, track_indices],\n                visibles=gt_vis[:, track_indices],\n                query_timestep=query_timestep[track_indices],\n                colors=colors[track_indices] * 0 + np.array([1, 1, 1, 1]),\n                track_names=[f\"track-{i:02d}\" for i in track_indices],\n                entity_format_str=entity_format_str.format(track_name=f\"gt\"),\n                **common_kwargs,\n            )\n            # Log the predicted tracks\n            _log_tracks_to_rerun(\n                tracks=pred_tracks[:, track_indices],\n                visibles=pred_vis[:, track_indices],\n                query_timestep=query_timestep[track_indices],\n                colors=colors[track_indices],\n                track_names=[f\"track-{i:02d}\" for i in track_indices],\n                entity_format_str=entity_format_str.format(track_name=f\"pred-{dataset_name}\"),\n                log_error_lines=True,\n                gt_for_error_lines=gt_tracks[:, track_indices],\n                **common_kwargs,\n            )\n"
  },
  {
    "path": "requirements.full.txt",
    "content": "# Minimal runtime\nnumpy==1.24.3\nhuggingface-hub==0.30.2\neasydict==1.13\npandas==2.2.2\neinops==0.7.0\nopencv-python==4.11.0.86\nmatplotlib==3.8.3\nseaborn==0.13.2\nscikit-image==0.22.0\nscikit-learn==1.4.1.post1\npypng==0.20220715.0\nkornia==0.7.3\nflow-vis==0.1\nmoviepy==1.0.3\nmediapy==1.2.0\nrerun-sdk==0.21.0\n\n# Training / baselines\ntorchdata==0.11.0\nlightning==2.4.0\ntimm==0.6.7\nprettytable==3.10.0\n# tensorflow==2.12.1\n# tensorflow-datasets==4.9.8\n# tensorflow-graphics==2021.12.3\ntensorboard==2.12.3\ntqdm==4.67.1\ngpustat==1.1.1\nhydra-core==1.3.2\nwandb==0.19.9\nrich==14.0.0\n"
  },
  {
    "path": "requirements.txt",
    "content": "# Minimal dependencies\nnumpy==1.24.3\nhuggingface-hub==0.30.2\neasydict==1.13\npandas==2.2.2\neinops==0.7.0\nopencv-python==4.11.0.86\nmatplotlib==3.8.3\nseaborn==0.13.2\nscikit-image==0.22.0\nscikit-learn==1.4.1.post1\npypng==0.20220715.0\nkornia==0.7.3\nflow-vis==0.1\nmoviepy==1.0.3\nmediapy==1.2.0\nrerun-sdk==0.21.0\n"
  },
  {
    "path": "scripts/4ddress_preprocessing.py",
    "content": "\"\"\"\nFirst download the dataset. You'll have to fill in an online ETH form\nand then wait for a few days to get a temporary access code over email.\nI used the following sequence of commands to download and unpack the data\ninto the expected structure. You can probably replace the `dt=...` with\nyour access token that you can probably find in the access URL (or otherwise\nin the page source of the download page that will be linked). Note that\nyou don't need to download all the data if you don't need it, e.g., maybe\nyou just want to download a small sample. Note also that in the commands below,\nI didn't delete the `*.tar.gz` and `*.zip` files, but you can do so if you'd like.\nNote also that the extraction of 00135 had some unexpected structure in that some\ntakes were in the root of 00135 instead of subfolders, but I ignored that.\n```bash\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00122_Inner.tar.gz' -O 00122_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00122_Outer.tar.gz' -O 00122_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00123_Inner.tar.gz' -O 00123_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00123_Outer.tar.gz' -O 00123_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00127_Inner.tar.gz' -O 00127_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00127_Outer.tar.gz' -O 00127_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00129_Inner.tar.gz' -O 00129_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00129_Outer.tar.gz' -O 00129_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00134_Inner.tar.gz' -O 00134_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00134_Outer.tar.gz' -O 00134_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Inner.tar.gz' -O 00135_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Outer_1.tar.gz' -O 00135_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00135_Outer_2.tar.gz' -O 00135_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Inner.tar.gz' -O 00136_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Outer_1.tar.gz' -O 00136_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00136_Outer_2.tar.gz' -O 00136_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Inner_1.tar.gz' -O 00137_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Inner_2.tar.gz' -O 00137_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Outer_1.tar.gz' -O 00137_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00137_Outer_2.tar.gz' -O 00137_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Inner_1.tar.gz' -O 00140_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Inner_2.tar.gz' -O 00140_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Outer_1.tar.gz' -O 00140_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00140_Outer_2.tar.gz' -O 00140_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00147_Inner.tar.gz' -O 00147_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00147_Outer.tar.gz' -O 00147_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00148_Inner.tar.gz' -O 00148_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00148_Outer.tar.gz' -O 00148_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Inner_1.tar.gz' -O 00149_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Inner_2.tar.gz' -O 00149_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Outer_1.tar.gz' -O 00149_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00149_Outer_2.tar.gz' -O 00149_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00151_Inner.tar.gz' -O 00151_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00151_Outer.tar.gz' -O 00151_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Inner.tar.gz' -O 00152_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Outer_1.tar.gz' -O 00152_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00152_Outer_2.tar.gz' -O 00152_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Inner.tar.gz' -O 00154_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Outer_1.tar.gz' -O 00154_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00154_Outer_2.tar.gz' -O 00154_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00156_Inner.tar.gz' -O 00156_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00156_Outer.tar.gz' -O 00156_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00160_Inner.tar.gz' -O 00160_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00160_Outer.tar.gz' -O 00160_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Inner_1.tar.gz' -O 00163_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Inner_2.tar.gz' -O 00163_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00163_Outer.tar.gz' -O 00163_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00167_Inner.tar.gz' -O 00167_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00167_Outer.tar.gz' -O 00167_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Inner.tar.gz' -O 00168_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Outer_1.tar.gz' -O 00168_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00168_Outer_2.tar.gz' -O 00168_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00169_Inner.tar.gz' -O 00169_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00169_Outer.tar.gz' -O 00169_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Inner_1.tar.gz' -O 00170_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Inner_2.tar.gz' -O 00170_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00170_Outer.tar.gz' -O 00170_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00174_Inner.tar.gz' -O 00174_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00174_Outer.tar.gz' -O 00174_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Inner_1.tar.gz' -O 00175_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Inner_2.tar.gz' -O 00175_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Outer_1.tar.gz' -O 00175_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00175_Outer_2.tar.gz' -O 00175_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00176_Inner.tar.gz' -O 00176_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00176_Outer.tar.gz' -O 00176_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00179_Inner.tar.gz' -O 00179_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00179_Outer.tar.gz' -O 00179_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00180_Inner.tar.gz' -O 00180_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00180_Outer.tar.gz' -O 00180_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Inner_1.tar.gz' -O 00185_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Inner_2.tar.gz' -O 00185_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Outer_1.tar.gz' -O 00185_Outer_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00185_Outer_2.tar.gz' -O 00185_Outer_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Inner_1.tar.gz' -O 00187_Inner_1.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Inner_2.tar.gz' -O 00187_Inner_2.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00187_Outer.tar.gz' -O 00187_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00188_Inner.tar.gz' -O 00188_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00188_Outer.tar.gz' -O 00188_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00190_Inner.tar.gz' -O 00190_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00190_Outer.tar.gz' -O 00190_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00191_Inner.tar.gz' -O 00191_Inner.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/00191_Outer.tar.gz' -O 00191_Outer.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/Overview.tar.gz' -O Overview.tar.gz\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/README.md' -O README.md\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/4D-DRESS/Template.tar.gz' -O Template.tar.gz\n\nmkdir benchmark\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Clothing_Recon_inner.zip' -O benchmark/Clothing_Recon_inner.zip\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Clothing_Recon_outer.zip' -O benchmark/Clothing_Recon_outer.zip\nwget 'https://4d-dress.ait.ethz.ch/download.php?dt=def5020078d99c392bec963997126c8af8d41234f84ad3799702aafec5ee264c38b6516a5527a0612a28b607f86221d617d47f2c289c0da697797c694428ca6673011edebc672fe8c769de020df868b99d42d30216ce52086a348d5fc201ec1a421f0bdbaba362d0a19ee346736c6711b492&file=/Benchmark/Human_Recon.zip' -O benchmark/Human_Recon.zip\n\nmkdir -p 00122 00123 00127 00129 00134 00135 00136 00137 00140 00147 00148 00149 00151 00152 00154 00156 00160 00163 00167 00168 00169 00170 00174 00175 00176 00179 00180 00185 00187 00188 00190 00191\ntar -xvzf 00122_Inner.tar.gz -C 00122\ntar -xvzf 00122_Outer.tar.gz -C 00122\n\ntar -xvzf 00123_Inner.tar.gz -C 00123\ntar -xvzf 00123_Outer.tar.gz -C 00123\ntar -xvzf 00127_Inner.tar.gz -C 00127\ntar -xvzf 00127_Outer.tar.gz -C 00127\ntar -xvzf 00129_Inner.tar.gz -C 00129\ntar -xvzf 00129_Outer.tar.gz -C 00129\ntar -xvzf 00134_Inner.tar.gz -C 00134\ntar -xvzf 00134_Outer.tar.gz -C 00134\ntar -xvzf 00135_Inner.tar.gz -C 00135\ntar -xvzf 00135_Outer_1.tar.gz -C 00135\ntar -xvzf 00135_Outer_2.tar.gz -C 00135\ntar -xvzf 00136_Inner.tar.gz -C 00136\ntar -xvzf 00136_Outer_1.tar.gz -C 00136\ntar -xvzf 00136_Outer_2.tar.gz -C 00136\ntar -xvzf 00137_Inner_1.tar.gz -C 00137\ntar -xvzf 00137_Inner_2.tar.gz -C 00137\ntar -xvzf 00137_Outer_1.tar.gz -C 00137\ntar -xvzf 00137_Outer_2.tar.gz -C 00137\ntar -xvzf 00140_Inner_1.tar.gz -C 00140\ntar -xvzf 00140_Inner_2.tar.gz -C 00140\ntar -xvzf 00140_Outer_1.tar.gz -C 00140\ntar -xvzf 00140_Outer_2.tar.gz -C 00140\ntar -xvzf 00147_Inner.tar.gz -C 00147\ntar -xvzf 00147_Outer.tar.gz -C 00147\ntar -xvzf 00148_Inner.tar.gz -C 00148\ntar -xvzf 00148_Outer.tar.gz -C 00148\ntar -xvzf 00149_Inner_1.tar.gz -C 00149\ntar -xvzf 00149_Inner_2.tar.gz -C 00149\ntar -xvzf 00149_Outer_1.tar.gz -C 00149\ntar -xvzf 00149_Outer_2.tar.gz -C 00149\ntar -xvzf 00151_Inner.tar.gz -C 00151\ntar -xvzf 00151_Outer.tar.gz -C 00151\ntar -xvzf 00152_Inner.tar.gz -C 00152\ntar -xvzf 00152_Outer_1.tar.gz -C 00152\ntar -xvzf 00152_Outer_2.tar.gz -C 00152\ntar -xvzf 00154_Inner.tar.gz -C 00154\ntar -xvzf 00154_Outer_1.tar.gz -C 00154\ntar -xvzf 00154_Outer_2.tar.gz -C 00154\ntar -xvzf 00156_Inner.tar.gz -C 00156\ntar -xvzf 00156_Outer.tar.gz -C 00156\ntar -xvzf 00160_Inner.tar.gz -C 00160\ntar -xvzf 00160_Outer.tar.gz -C 00160\ntar -xvzf 00163_Inner_1.tar.gz -C 00163\ntar -xvzf 00163_Inner_2.tar.gz -C 00163\ntar -xvzf 00163_Outer.tar.gz -C 00163\ntar -xvzf 00167_Inner.tar.gz -C 00167\ntar -xvzf 00167_Outer.tar.gz -C 00167\ntar -xvzf 00168_Inner.tar.gz -C 00168\ntar -xvzf 00168_Outer_1.tar.gz -C 00168\ntar -xvzf 00168_Outer_2.tar.gz -C 00168\ntar -xvzf 00169_Inner.tar.gz -C 00169\ntar -xvzf 00169_Outer.tar.gz -C 00169\ntar -xvzf 00170_Inner_1.tar.gz -C 00170\ntar -xvzf 00170_Inner_2.tar.gz -C 00170\ntar -xvzf 00170_Outer.tar.gz -C 00170\ntar -xvzf 00174_Inner.tar.gz -C 00174\ntar -xvzf 00174_Outer.tar.gz -C 00174\ntar -xvzf 00175_Inner_1.tar.gz -C 00175\ntar -xvzf 00175_Inner_2.tar.gz -C 00175\ntar -xvzf 00175_Outer_1.tar.gz -C 00175\ntar -xvzf 00175_Outer_2.tar.gz -C 00175\ntar -xvzf 00176_Inner.tar.gz -C 00176\ntar -xvzf 00176_Outer.tar.gz -C 00176\ntar -xvzf 00179_Inner.tar.gz -C 00179\ntar -xvzf 00179_Outer.tar.gz -C 00179\ntar -xvzf 00180_Inner.tar.gz -C 00180\ntar -xvzf 00180_Outer.tar.gz -C 00180\ntar -xvzf 00185_Inner_1.tar.gz -C 00185\ntar -xvzf 00185_Inner_2.tar.gz -C 00185\ntar -xvzf 00185_Outer_1.tar.gz -C 00185\ntar -xvzf 00185_Outer_2.tar.gz -C 00185\ntar -xvzf 00187_Inner_1.tar.gz -C 00187\ntar -xvzf 00187_Inner_2.tar.gz -C 00187\ntar -xvzf 00187_Outer.tar.gz -C 00187\ntar -xvzf 00188_Inner.tar.gz -C 00188\ntar -xvzf 00188_Outer.tar.gz -C 00188\ntar -xvzf 00190_Inner.tar.gz -C 00190\ntar -xvzf 00190_Outer.tar.gz -C 00190\ntar -xvzf 00191_Inner.tar.gz -C 00191\ntar -xvzf 00191_Outer.tar.gz -C 00191\ntar -xvzf Overview.tar.gz\ntar -xvzf Template.tar.gz\n\ncd benchmark\nunzip Clothing_Recon_inner.zip\nunzip Clothing_Recon_outer.zip\nunzip Human_Recon.zip\n```\n\nWith the data downloaded, you can run the script: `python -m scripts.4ddress_preprocessing`.\n\nI create a subselection of the sequences as:\n```bash\nSRC=datasets/4d-dress-processed-resized-512\nDST=datasets/4d-dress-processed-resized-512-selection\nmkdir ${DST}\n\ncp ${SRC}/00129_Inner_Take3.pkl ${DST}/00129_Inner_Take3_happy.pkl\ncp ${SRC}/00129_Inner_Take4.pkl ${DST}/00129_Inner_Take4_stretch.pkl\ncp ${SRC}/00129_Inner_Take5.pkl ${DST}/00129_Inner_Take5_balerina.pkl\ncp ${SRC}/00129_Outer_Take13.pkl ${DST}/00129_Outer_Take13_kolo.pkl\n\ncp ${SRC}/00140_Inner_Take8.pkl ${DST}/00140_Inner_Take8_football.pkl\ncp ${SRC}/00140_Outer_Take13.pkl ${DST}/00140_Outer_Take13_stretch.pkl\ncp ${SRC}/00140_Outer_Take15.pkl ${DST}/00140_Outer_Take15_kicks.pkl\n\ncp ${SRC}/00147_Inner_Take10.pkl ${DST}/00147_Inner_Take10_basketball.pkl\ncp ${SRC}/00147_Inner_Take11.pkl ${DST}/00147_Inner_Take11_football.pkl\ncp ${SRC}/00147_Outer_Take16.pkl ${DST}/00147_Outer_Take16_dance.pkl\ncp ${SRC}/00147_Outer_Take17.pkl ${DST}/00147_Outer_Take17_avatar.pkl\n\ncp ${SRC}/00174_Inner_Take9.pkl ${DST}/00174_Inner_Take9_stretching.pkl\n\ncp ${SRC}/00175_Inner_Take6.pkl ${DST}/00175_Inner_Take6_basketball.pkl\n```\n\"\"\"\n\nimport os\nimport pickle\nfrom typing import Optional\n\nimport cv2\nimport numpy as np\nimport rerun as rr\nimport torch\nimport tqdm\nfrom PIL import Image\nfrom pytorch3d.renderer import (\n    PerspectiveCameras,\n    MeshRasterizer,\n    RasterizationSettings,\n)\nfrom pytorch3d.structures import Meshes\nfrom scipy.spatial.transform import Rotation\n\nfrom mvtracker.datasets.utils import transform_scene\n\n\ndef load_pickle(p):\n    with open(p, \"rb\") as f:\n        return pickle.load(f)\n\n\ndef save_pickle(p, data):\n    with open(p, \"wb\") as f:\n        pickle.dump(data, f)\n\n\ndef load_image(path):\n    return np.array(Image.open(path))\n\n\ndef extract_4d_dress_data(\n        dataset_root: str,\n        subject_name: str,\n        outfit_name: str,\n        take_name,\n        save_pkl_path,\n        downscaled_longerside: Optional[int] = None,\n        save_rerun_viz: bool = True,\n        stream_rerun_viz: bool = False,\n        skip_if_output_exists: bool = False,\n):\n    # Skip if output exists\n    if skip_if_output_exists and os.path.exists(save_pkl_path):\n        print(f\"Skipping {save_pkl_path} since it already exists\")\n        print()\n        return save_pkl_path\n    else:\n        print(f\"Processing {save_pkl_path}...\")\n\n    base_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name)\n    capture_dir = os.path.join(base_dir, \"Capture\")\n    mesh_dir = os.path.join(base_dir, \"Meshes_pkl\")\n\n    basic_info = load_pickle(os.path.join(base_dir, \"basic_info.pkl\"))\n    scan_frames = basic_info['scan_frames']\n\n    cameras = load_pickle(os.path.join(capture_dir, \"cameras.pkl\"))\n    cam_names = sorted(list(cameras.keys()))\n\n    # Prepare final structure\n    rgbs, intrs, extrs, depths = {}, {}, {}, {}\n    for cam_name in cam_names:\n        rgbs[cam_name] = []\n        depths[cam_name] = []\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    for frame in tqdm.tqdm(scan_frames, desc=\"Extracting frame data\"):\n        mesh_path = os.path.join(mesh_dir, f\"mesh-f{frame}.pkl\")\n        mesh_data = load_pickle(mesh_path)\n        vertices = mesh_data[\"vertices\"]\n        faces = mesh_data[\"faces\"]\n\n        verts = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)\n        faces = torch.tensor(faces, dtype=torch.int64, device=device).unsqueeze(0)\n        mesh = Meshes(verts=verts, faces=faces)\n\n        for cam_name in cam_names:\n            cam_path = os.path.join(capture_dir, cam_name)\n            img_path = os.path.join(cam_path, \"images\", f\"capture-f{frame}.png\")\n            if not os.path.exists(img_path):\n                continue\n\n            image = load_image(img_path)\n            h, w = image.shape[:2]\n            intr = cameras[cam_name]['intrinsics'].copy()\n            extr = cameras[cam_name]['extrinsics'].copy()\n\n            if downscaled_longerside is not None:\n                scale = downscaled_longerside / max(h, w)\n                h, w = int(h * scale), int(w * scale)\n                image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)\n                intr[:2] *= scale\n\n            if cam_name not in intrs:\n                intrs[cam_name] = intr\n                extrs[cam_name] = extr\n\n            rgbs[cam_name].append(image)\n\n            # Convert intrinsics to normalized device coords\n            fx, fy = intr[0, 0], intr[1, 1]\n            cx, cy = intr[0, 2], intr[1, 2]\n\n            R = extr[:3, :3]\n            T = extr[:3, 3]\n\n            R = R.T\n            R = R @ np.diag(np.array([-1, -1, 1.]))  # Flip the x and y axes (or multiply f by -1)\n            T = T @ np.diag(np.array([-1, -1, 1.]))  # Flip the x and y axes (or multiply f by -1)\n\n            cameras_p3d = PerspectiveCameras(\n                focal_length=torch.tensor([[fx, fy]], dtype=torch.float32, device=device),\n                principal_point=torch.tensor([[cx, cy]], dtype=torch.float32, device=device),\n                R=torch.tensor(R, dtype=torch.float32, device=device).unsqueeze(0),\n                T=torch.tensor(T, dtype=torch.float32, device=device).unsqueeze(0),\n                image_size=torch.tensor([[h, w]], dtype=torch.float32, device=device),\n                in_ndc=False,\n                device=device,\n            )\n            raster_settings = RasterizationSettings(\n                image_size=(h, w),\n                blur_radius=0.0,\n                faces_per_pixel=1,\n                bin_size=0\n            )\n\n            rasterizer = MeshRasterizer(cameras=cameras_p3d, raster_settings=raster_settings)\n            fragments = rasterizer(mesh)\n            zbuf = fragments.zbuf.squeeze().cpu().numpy()\n            zbuf[np.isnan(zbuf)] = 0.0\n\n            depths[cam_name].append(zbuf)\n\n    for cam_name in cam_names:\n        if rgbs[cam_name]:\n            rgbs[cam_name] = np.stack(rgbs[cam_name]).transpose(0, 3, 1, 2)  # T, C, H, W\n            depths[cam_name] = np.stack(depths[cam_name])  # T, H, W\n\n    # Rotate the scene to have the ground at z=0\n    rot_x = Rotation.from_euler('x', 90, degrees=True).as_matrix()\n    rot_y = Rotation.from_euler('y', 0, degrees=True).as_matrix()\n    rot_z = Rotation.from_euler('z', 0, degrees=True).as_matrix()\n    rot = torch.from_numpy(rot_z @ rot_y @ rot_x)\n    translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)\n    for cam_name in cam_names:\n        extrs[cam_name] = transform_scene(\n            1, rot, translation, None, torch.from_numpy(extrs[cam_name][None, None]),\n        )[1][0, 0].numpy()\n\n    # Check shapes\n    n_frames, _, h, w = rgbs[cam_names[0]].shape\n    for cam_name in cam_names:\n        assert rgbs[cam_name].shape == (n_frames, 3, h, w)\n        assert intrs[cam_name].shape == (3, 3)\n        assert extrs[cam_name].shape == (3, 4)\n\n    # Save processed output to a pickle file\n    save_pickle(save_pkl_path, dict(\n        rgbs=rgbs,\n        intrs=intrs,\n        extrs=extrs,\n        depths=depths,\n        ego_cam_name=None,\n    ))\n\n    # Visualize the data sample using rerun\n    rerun_modes = []\n    if stream_rerun_viz:\n        rerun_modes += [\"stream\"]\n    if save_rerun_viz:\n        rerun_modes += [\"save\"]\n    for rerun_mode in rerun_modes:\n        rr.init(f\"3dpt\", recording_id=\"v0.16\")\n        if rerun_mode == \"stream\":\n            rr.connect_tcp()\n\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n\n        rr.log(\n            \"mesh\",\n            rr.Mesh3D(\n                vertex_positions=vertices.astype(np.float32),  # (N, 3)\n                triangle_indices=faces.cpu().numpy().reshape(-1, 3).astype(np.int32),  # (M, 3)\n                albedo_factor=[200, 200, 255],  # Optional color\n            ),\n        )\n\n        fps = 30\n        for frame_idx in range(n_frames):\n            rr.set_time_seconds(\"frame\", frame_idx / fps)\n            for cam_name in cam_names:\n                extr = extrs[cam_name]\n                intr = intrs[cam_name]\n                img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8)\n                depth = depths[cam_name][frame_idx]\n\n                h, w = img.shape[:2]\n                fx, fy = intr[0, 0], intr[1, 1]\n                cx, cy = intr[0, 2], intr[1, 2]\n\n                # Camera pose\n                T = np.eye(4)\n                T[:3, :] = extr\n                world_T_cam = np.linalg.inv(T)\n                rr.log(f\"{cam_name}/image\", rr.Transform3D(\n                    translation=world_T_cam[:3, 3],\n                    mat3x3=world_T_cam[:3, :3],\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=w,\n                    height=h\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Image(img))\n\n                rr.log(f\"{cam_name}/depth\", rr.Transform3D(\n                    translation=world_T_cam[:3, 3],\n                    mat3x3=world_T_cam[:3, :3],\n                ))\n                rr.log(f\"{cam_name}/depth\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=w,\n                    height=h\n                ))\n                rr.log(f\"{cam_name}/depth\", rr.DepthImage(depth, meter=1.0, colormap=\"viridis\"))\n\n                # Unproject depth to point cloud\n                y, x = np.meshgrid(np.arange(h), np.arange(w), indexing=\"ij\")\n                z = depth\n                valid = z > 0\n                x = x[valid]\n                y = y[valid]\n                z = z[valid]\n\n                X = (x - cx) * z / fx\n                Y = (y - cy) * z / fy\n                pts_cam = np.stack([X, Y, z], axis=-1)\n\n                # Transform to world\n                R = world_T_cam[:3, :3]\n                t = world_T_cam[:3, 3]\n                pts_world = pts_cam @ R.T + t\n\n                # Color\n                colors = img[y, x]\n\n                rr.log(f\"point_cloud/{cam_name}\", rr.Points3D(positions=pts_world, colors=colors))\n\n        if rerun_mode == \"save\":\n            base, name = os.path.split(save_pkl_path)\n            name_no_ext = os.path.splitext(name)[0]\n            save_rrd_path = os.path.join(base, f\"rerun__{name_no_ext}.rrd\")\n            rr.save(save_rrd_path)\n            print(f\"Saved rerun viz to {os.path.abspath(save_rrd_path)}\")\n\n    print(f\"Done with {save_pkl_path}.\")\n    print()\n\n\ndef crete_overview_pngs(dataset_root, subject_names, overview_dir):\n    os.makedirs(overview_dir, exist_ok=True)\n\n    for subject_name in tqdm.tqdm(subject_names):\n        if \".\" in subject_name:\n            continue\n\n        for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)):\n            if outfit_name not in [\"Inner\", \"Outer\"]:\n                continue\n\n            for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)):\n                if \".\" in take_name:\n                    continue\n\n                cam_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name, \"Capture\")\n                cam_names = sorted([name for name in os.listdir(cam_dir) if \".\" not in name])\n\n                first_cam = cam_names[0]\n                img_folder = os.path.join(cam_dir, first_cam, \"images\")\n                images = sorted(os.listdir(img_folder))\n\n                last_img = images[-1]\n                img_path = os.path.join(img_folder, last_img)\n\n                # Load image and overlay info\n                from PIL import Image, ImageDraw, ImageFont\n                img = Image.open(img_path).convert(\"RGB\")\n                draw = ImageDraw.Draw(img)\n                text = (\n                    f\"{subject_name} / {outfit_name} / {take_name}\\n\"\n                    f\"Frame: {last_img.split('-')[-1].split('.')[0]}\\n\"\n                    f\"Cams: {cam_names}\"\n                )\n\n                try:\n                    font = ImageFont.truetype(\"DejaVuSans-Bold.ttf\", 16)\n                except:\n                    font = ImageFont.load_default()\n\n                draw.text((10, 10), text, fill=\"white\", font=font)\n\n                # Save image\n                overview_path = os.path.join(overview_dir, f\"{subject_name}__{outfit_name}__{take_name}.png\")\n                img.save(overview_path)\n                print(f\"Saved overview to {overview_path}\")\n\n\ndef crete_overview_mp4s(dataset_root, subject_names, overview_dir, fps=30):\n    os.makedirs(overview_dir, exist_ok=True)\n\n    for subject_name in tqdm.tqdm(subject_names):\n        if \".\" in subject_name:\n            continue\n\n        for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)):\n            if outfit_name not in [\"Inner\", \"Outer\"]:\n                continue\n\n            for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)):\n                if \".\" in take_name:\n                    continue\n\n                cam_dir = os.path.join(dataset_root, subject_name, outfit_name, take_name, \"Capture\")\n                cam_names = sorted([name for name in os.listdir(cam_dir) if \".\" not in name])\n\n                first_cam = cam_names[0]\n                img_folder = os.path.join(cam_dir, first_cam, \"images\")\n                images = sorted(os.listdir(img_folder))\n\n                # Load first frame to get size\n                first_img = cv2.imread(os.path.join(img_folder, images[0]))\n                height, width = first_img.shape[:2]\n\n                video_path = os.path.join(\n                    overview_dir,\n                    f\"{subject_name}__{outfit_name}__{take_name}.mp4\"\n                )\n                writer = cv2.VideoWriter(\n                    video_path,\n                    cv2.VideoWriter_fourcc(*\"mp4v\"),\n                    fps,\n                    (width, height)\n                )\n\n                for img_name in images:\n                    img_path = os.path.join(img_folder, img_name)\n                    img = cv2.imread(img_path)\n\n                    overlay_text = (\n                        f\"{subject_name} / {outfit_name} / {take_name} | \"\n                        f\"Frame: {img_name.split('-')[-1].split('.')[0]} | \"\n                        f\"Cams: {', '.join(cam_names)}\"\n                    )\n                    cv2.putText(\n                        img,\n                        overlay_text,\n                        (10, 25),\n                        cv2.FONT_HERSHEY_SIMPLEX,\n                        0.6,\n                        (255, 255, 255),\n                        2,\n                        lineType=cv2.LINE_AA\n                    )\n\n                    writer.write(img)\n\n                writer.release()\n                print(f\"Saved video to {video_path}\")\n\n\nif __name__ == \"__main__\":\n    dataset_root = \"datasets/4d-dress\"\n    output_root = \"datasets/4d-dress-processed\"\n    create_overviews = True  # Creates an overview folder with a png/mp4 summary of each subject-outfit-take\n\n    longside_resolution: Optional[int] = 512\n    if longside_resolution is not None:\n        output_root += f\"-resized-{longside_resolution}\"\n    os.makedirs(output_root, exist_ok=True)\n\n    subject_names = [\n        \"00122\", \"00123\", \"00127\", \"00129\", \"00134\",\n        \"00135\", \"00136\", \"00137\", \"00140\", \"00147\",\n        \"00148\", \"00149\", \"00151\", \"00152\", \"00154\",\n        \"00156\", \"00160\", \"00163\", \"00167\", \"00168\",\n        \"00169\", \"00170\", \"00174\", \"00175\", \"00176\",\n        \"00179\", \"00180\", \"00185\", \"00187\", \"00188\",\n        \"00190\", \"00191\",\n    ]\n    if create_overviews:\n        crete_overview_pngs(dataset_root, subject_names, os.path.join(dataset_root, \"overview-pngs\"))\n        crete_overview_mp4s(dataset_root, subject_names, os.path.join(dataset_root, \"overview-mp4s\"))\n\n    for subject_name in tqdm.tqdm(subject_names):\n        if \".\" in subject_name:\n            continue\n\n        for outfit_name in os.listdir(os.path.join(dataset_root, subject_name)):\n            if outfit_name not in [\"Inner\", \"Outer\"]:\n                continue\n\n            for take_name in os.listdir(os.path.join(dataset_root, subject_name, outfit_name)):\n                if \".\" in take_name:\n                    continue\n\n                pkl_path = os.path.join(output_root, f\"{subject_name}_{outfit_name}_{take_name}.pkl\")\n                extract_4d_dress_data(\n                    dataset_root=dataset_root,\n                    subject_name=subject_name,\n                    outfit_name=outfit_name,\n                    take_name=take_name,\n                    downscaled_longerside=longside_resolution,\n                    save_pkl_path=pkl_path,\n                    save_rerun_viz=True,\n                    stream_rerun_viz=False,\n                    skip_if_output_exists=True,\n                )\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/compare_cdist-topk_against_pointops-knn.py",
    "content": "import time\n\nimport torch\nfrom pointops import knn_query\n\nB, N, M, D, K = 12, 49152, 928, 3, 16\n\n\ndef knn_torch(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor):\n    dists = torch.cdist(xyz_query, xyz_ref, p=2)  # shape: (B, M, N)\n    sorted_dists, indices = torch.topk(dists, k, dim=-1, largest=False, sorted=True)\n    return sorted_dists, indices\n\n\ndef knn_pointops(k: int, xyz_ref: torch.Tensor, xyz_query: torch.Tensor):\n    B, N, _ = xyz_ref.shape\n    _, M, _ = xyz_query.shape\n    orig_dtype = xyz_ref.dtype\n\n    xyz_ref_flat = xyz_ref.contiguous().view(B * N, 3).to(torch.float32)\n    xyz_query_flat = xyz_query.contiguous().view(B * M, 3).to(torch.float32)\n\n    offset = torch.arange(1, B + 1, device=xyz_ref.device) * N\n    new_offset = torch.arange(1, B + 1, device=xyz_query.device) * M\n\n    idx, dists = knn_query(k, xyz_ref_flat, offset, xyz_query_flat, new_offset)\n\n    # Remap global indices to local per-batch\n    idx = idx.view(B, M, k)\n    idx = idx - (torch.arange(B, device=idx.device).view(B, 1, 1) * N)\n    dists = dists.view(B, M, k).to(orig_dtype)\n\n    return dists, idx\n\n\ndef benchmark(fn, name, HALF_PRECISION=False, iters=100):\n    total_time = 0.0\n    peak_memories = []\n    for _ in range(iters):\n        xyz_ref = torch.randn(B, N, D, device=\"cuda\")\n        xyz_query = torch.randn(B, M, D, device=\"cuda\")\n        if HALF_PRECISION:\n            xyz_ref = xyz_ref.half()\n            xyz_query = xyz_query.half()\n        fn(K, xyz_ref, xyz_query)  # warm up\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        torch.cuda.synchronize()\n        start = time.time()\n        fn(K, xyz_ref, xyz_query)\n        torch.cuda.synchronize()\n        total_time += time.time() - start\n        peak_memories.append(torch.cuda.max_memory_allocated() / 1e6)  # MB\n\n    avg_time = total_time / iters\n    peak_memory_min = min(peak_memories)\n    peak_memory_avg = sum(peak_memories) / len(peak_memories)\n    peak_memory_max = max(peak_memories)\n    print(f\"{name:<24} | \"\n          f\"Avg Time: {avg_time:.6f} s | \"\n          f\"Peak Memory: {peak_memory_avg:>6.2f} MB (min: {peak_memory_min:>6.2f}, max: {peak_memory_max:>6.2f})\")\n\n\nprint(\"Benchmarking KNN with different methods (HALF_PRECISION=True):\")\nbenchmark(knn_torch, \"torch.cdist+torch.topk\", True)\nbenchmark(knn_pointops, \"pointops.knn_query\", True)\n\nprint(\"\\nBenchmarking KNN with different methods (HALF_PRECISION=False):\")\nbenchmark(knn_torch, \"torch.cdist+torch.topk\", False)\nbenchmark(knn_pointops, \"pointops.knn_query\", False)\n"
  },
  {
    "path": "scripts/dex_ycb_to_neus_format.py",
    "content": "\"\"\"\nBefore running the script, you need to install the toolkit and other\ndependencies, as well as download the data and necessary MANO checkpoints/models.\n\nInstall the toolkit and dependencies:\n```sh\n# Create a new conda environment\nconda create -n dexycb python=3.9\nconda activate dexycb\nconda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia\nconda install -c iopath iopath\npip install --upgrade setuptools wheel\npip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu121_pyt241/download.html\nconda install ninja scipy matplotlib -c conda-forge\npip install numpy==1.21.6 matplotlib==3.6 pandas==2.0 scikit-image scipy==1.11 rerun-sdk pyembree rtree --no-deps\n\n# Install dex-ycb-toolkit\ncd /home/frrajic/xode/03-macos/\ngit clone --recursive git@github.com:NVlabs/dex-ycb-toolkit.git\ncd dex-ycb-toolkit\npip install -e .\n\n# Install bop_toolkit dependencies\ncd bop_toolkit\npip install -r requirements.txt\ncd ..\n\n# Install manopth\ncd manopth\npip install -e .\ncd ..\n\n# Make sure numpy version is not too high (so that np.bool is not deprecated)\npip install numpy==1.21.6 matplotlib==3.6 pandas==2.0 scikit-image scipy==1.11 rerun-sdk pyembree rtree --no-deps\n```\n\nDownload the DexYCB dataset from the [project site](https://dex-ycb.github.io):\n```sh\nexport DEX_YCB_DIR=/home/frrajic/xode/00-data/dex-january-2025\ncd $DEX_YCB_DIR\n\n#  20200709-subject-01.tar.gz (12G)\n#  20200813-subject-02.tar.gz (12G)\n#  20200820-subject-03.tar.gz (12G)\n#  20200903-subject-04.tar.gz (12G)\n#  20200908-subject-05.tar.gz (12G)\n#  20200918-subject-06.tar.gz (12G)\n#  20200928-subject-07.tar.gz (12G)\n#  20201002-subject-08.tar.gz (12G)\n#  20201015-subject-09.tar.gz (12G)\n#  20201022-subject-10.tar.gz (12G)\ngdown --fuzzy https://drive.google.com/file/d/1Ehh92wDE3CWAiKG7E9E73HjN2Xk2XfEk/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1Uo7MLqTbXEa-8s7YQZ3duugJ1nXFEo62/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1FkUxas8sv8UcVGgAzmSZlJw1eI5W5CXq/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/14up6qsTpvgEyqOQ5hir-QbjMB_dHfdpA/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1NBA_FPyGWOQF5-X9ueAat5g8lDMz-EmS/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1UWIN2-wOBZX2T0dkAi4ctAAW8KffkXMQ/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1oWEYD_o3PVh39pLzMlJcArkDtMj4nzI0/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1GTNZwhWbs7Mfez0krTgXwLPndvrw1Ztv/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1j0BLkaCjIuwjakmywKdOO9vynHTWR0UH/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1FvFlRfX-p5a5sAWoKEGc17zKJWwKaSB-/view?usp=sharing &\n\n#  bop.tar.gz (1.2G)\n#  calibration.tar.gz (16K)\n#  models.tar.gz (1.4G)\ngdown --fuzzy https://drive.google.com/file/d/1CPqLjsaYNjE3xSJbuWmqaMsGvyGIxiKL/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1UAwVKT4Rgb1fLcFoa1o71_-0NtSvvLAQ/view?usp=sharing &\ngdown --fuzzy https://drive.google.com/file/d/1cAzlQBpcTatI5ykYQ8ziQiHLUG_a_UpM/view?usp=sharing &\n\ntar xvf 20200709-subject-01.tar.gz &\ntar xvf 20200813-subject-02.tar.gz &\ntar xvf 20200820-subject-03.tar.gz &\ntar xvf 20200903-subject-04.tar.gz &\ntar xvf 20200908-subject-05.tar.gz &\ntar xvf 20200918-subject-06.tar.gz &\ntar xvf 20200928-subject-07.tar.gz &\ntar xvf 20201002-subject-08.tar.gz &\ntar xvf 20201015-subject-09.tar.gz &\ntar xvf 20201022-subject-10.tar.gz &\ntar xvf bop.tar.gz &\ntar xvf calibration.tar.gz &\ntar xvf models.tar.gz &\n\nrm 20200709-subject-01.tar.gz\nrm 20200813-subject-02.tar.gz\nrm 20200820-subject-03.tar.gz\nrm 20200903-subject-04.tar.gz\nrm 20200908-subject-05.tar.gz\nrm 20200918-subject-06.tar.gz\nrm 20200928-subject-07.tar.gz\nrm 20201002-subject-08.tar.gz\nrm 20201015-subject-09.tar.gz\nrm 20201022-subject-10.tar.gz\nrm bop.tar.gz\nrm calibration.tar.gz\nrm models.tar.gz\n```\n\nThe structure of the dataset should look like this:\n```sh\ntree -L 1 $DEX_YCB_DIR\n# /home/frrajic/xode/00-data/dex-january-2025\n# ├── 20200709-subject-01\n# ├── 20200813-subject-02\n# ├── 20200820-subject-03\n# ├── 20200903-subject-04\n# ├── 20200908-subject-05\n# ├── 20200918-subject-06\n# ├── 20200928-subject-07\n# ├── 20201002-subject-08\n# ├── 20201015-subject-09\n# ├── 20201022-subject-10\n# ├── bop\n# ├── calibration\n# └── models\n\ndu -sch $DEX_YCB_DIR/*\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200709-subject-01\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200813-subject-02\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200820-subject-03\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200903-subject-04\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200908-subject-05\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200918-subject-06\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20200928-subject-07\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20201002-subject-08\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20201015-subject-09\n# 13G     /home/frrajic/xode/00-data/dex-january-2025/20201022-subject-10\n# 24G     /home/frrajic/xode/00-data/dex-january-2025/bop\n# 200K    /home/frrajic/xode/00-data/dex-january-2025/calibration\n# 3.5G    /home/frrajic/xode/00-data/dex-january-2025/models\n# 154G    total\n```\n\nDownload MANO models and code (`mano_v1_2.zip`) from the [MANO website](https://mano.is.tue.mpg.de)\nand place the file under `manopath`. Unzip the file and create symlink:\n```sh\ncd /home/frrajic/xode/03-macos/dex-ycb-toolkit\n\ncd manopth\nunzip mano_v1_2.zip\ncd mano\nln -s ../mano_v1_2/models models\ncd ../..\n```\n\nFinally, run the script:\n```sh\nconda activate dexycb\nexport DEX_YCB_DIR=/home/frrajic/xode/00-data/dex-january-2025\ncd /home/frrajic/xode/03-macos/dex-ycb-toolkit\npython /home/frrajic/xode/03-macos/spatialtracker/scripts/dex_ycb_to_neus_format.py\n```\n\"\"\"\n\nimport os\n\nimport cv2\nimport imageio\nimport math\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport open3d as o3d\nimport open3d.visualization as vis\nimport rerun as rr\nimport torch\nimport trimesh\nimport yaml\nfrom dex_ycb_toolkit.layers.mano_group_layer import MANOGroupLayer\nfrom dex_ycb_toolkit.layers.ycb_group_layer import YCBGroupLayer\nfrom dex_ycb_toolkit.layers.ycb_layer import dcm2rv, rv2dcm\nfrom matplotlib import cm\nfrom matplotlib.cm import get_cmap\nfrom pytorch3d.renderer import (\n    MeshRasterizer,\n    MeshRendererWithFragments,\n    RasterizationSettings,\n    SoftPhongShader,\n    PointLights,\n)\nfrom pytorch3d.renderer import TexturesVertex\nfrom pytorch3d.structures import Meshes\nfrom pytorch3d.utils.camera_conversions import cameras_from_opencv_projection\nfrom scipy.spatial.transform import Rotation as Rot\nfrom tqdm import tqdm\n\n\ndef sample_surface(mesh: trimesh.Trimesh, count, face_weight=None, seed=None):\n    \"\"\"\n    Sample the surface of a mesh, returning the specified\n    number of points\n\n    For individual triangle sampling uses this method:\n    http://mathworld.wolfram.com/TrianglePointPicking.html\n\n    Adapted from:\n    https://github.com/mikedh/trimesh/blob/a47b66d2d18404bc044aa9fcb983a80b1287919a/trimesh/sample.py#L23\n\n    Parameters\n    -----------\n    mesh : trimesh.Trimesh\n      Geometry to sample the surface of\n    count : int\n      Number of points to return\n    face_weight : None or len(mesh.faces) float\n      Weight faces by a factor other than face area.\n      If None will be the same as face_weight=mesh.area\n    seed : None or int\n      If passed as an integer will provide deterministic results\n      otherwise pulls the seed from operating system entropy.\n\n    Returns\n    ---------\n    samples : (count, 3) float\n      Points in space on the surface of mesh\n    face_index : (count,) int\n      Indices of faces for each sampled point\n    colors : (count, 4) float\n      Colors of each sampled point\n      Returns only when the sample_color is True\n    \"\"\"\n\n    if face_weight is None:\n        # len(mesh.faces) float, array of the areas\n        # of each face of the mesh\n        face_weight = mesh.area_faces\n\n    # cumulative sum of weights (len(mesh.faces))\n    # cumulative sum of weights (len(mesh.faces))\n    weight_cum = np.cumsum(face_weight)\n\n    # seed the random number generator as requested\n    default_rng = np.random.default_rng\n    random = default_rng(seed).random\n\n    # last value of cumulative sum is total summed weight/area\n    face_pick = random(count) * weight_cum[-1]\n    # get the index of the selected faces\n    picked_faces = np.searchsorted(weight_cum, face_pick)\n\n    # pull triangles into the form of an origin + 2 vectors\n    tri_origins = mesh.vertices[mesh.faces[:, 0]]\n    tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy()\n    tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))\n\n    # pull the vectors for the faces we are going to sample from\n    tri_origins = tri_origins[picked_faces]\n    tri_vectors = tri_vectors[picked_faces]\n\n    # randomly generate two 0-1 scalar components to multiply edge vectors b\n    picked_weights = random((len(tri_vectors), 2, 1))\n\n    # points will be distributed on a quadrilateral if we use 2 0-1 samples\n    # if the two scalar components sum less than 1.0 the point will be\n    # inside the triangle, so we find vectors longer than 1.0 and\n    # transform them to be inside the triangle\n    outside_triangle = picked_weights.sum(axis=1).reshape(-1) > 1.0\n    picked_weights[outside_triangle] -= 1.0\n    picked_weights = np.abs(picked_weights)\n\n    # multiply triangle edge vectors by the random lengths and sum\n    sample_vector = (tri_vectors * picked_weights).sum(axis=1)\n\n    # finally, offset by the origin to generate\n    # (n,3) points in space on the triangle\n    picked_points = sample_vector + tri_origins\n\n    return picked_faces, picked_weights, picked_points\n\n\ndef pick_points_from_mesh(mesh, picked_faces, picked_weights, reference_mesh):\n    if reference_mesh is not None:\n        # Number of vertices must match, but the 3D location of vertices can change\n        assert reference_mesh.vertices.shape == mesh.vertices.shape, \"Number of vertices must match\"\n\n        # The faces must be the same\n        assert np.allclose(reference_mesh.faces, mesh.faces), \"Faces must be the same\"\n\n    # pull triangles into the form of an origin + 2 vectors\n    tri_origins = mesh.vertices[mesh.faces[:, 0]]\n    tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy()\n    tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3))\n\n    # pull the vectors for the faces we are going to sample from\n    tri_origins = tri_origins[picked_faces]\n    tri_vectors = tri_vectors[picked_faces]\n\n    # multiply triangle edge vectors by the random lengths and sum\n    sample_vector = (tri_vectors * picked_weights).sum(axis=1)\n\n    picked_points = sample_vector + tri_origins\n\n    return picked_points\n\n\nclass SequenceLoader():\n    \"\"\"DexYCB sequence loader.\"\"\"\n\n    def __init__(\n            self,\n            name,\n            device='cuda:0',\n            preload=True,\n            app='viewer',\n            **kwargs,\n    ):\n        \"\"\"Constructor.\n\n        Args:\n          name: Sequence name.\n          device: A torch.device string argument. The specified device is used only\n            for certain data loading computations, but not storing the loaded data.\n            Currently the loaded data is always stored as numpy arrays on CPU.\n          preload: Whether to preload the point cloud or load it online.\n          app: 'viewer' or 'renderer'.\n        \"\"\"\n        assert device in ('cuda', 'cpu') or device.split(':')[0] == 'cuda'\n        assert app in ('viewer', 'renderer', 'convert_to_neus')\n        self._name = name\n        self._device = torch.device(device)\n        self._preload = preload\n        self._app = app\n\n        assert 'DEX_YCB_DIR' in os.environ, \"environment variable 'DEX_YCB_DIR' is not set\"\n        self._dex_ycb_dir = os.environ['DEX_YCB_DIR']\n\n        # Load meta.\n        meta_file = self._dex_ycb_dir + '/' + self._name + \"/meta.yml\"\n        with open(meta_file, 'r') as f:\n            meta = yaml.load(f, Loader=yaml.FullLoader)\n\n        self._serials = meta['serials']\n        self._h = 480\n        self._w = 640\n        self._num_cameras = len(self._serials)\n        self._data_dir = [\n            self._dex_ycb_dir + '/' + self._name + '/' + s for s in self._serials\n        ]\n        self._color_prefix = \"color_\"\n        self._depth_prefix = \"aligned_depth_to_color_\"\n        self._label_prefix = \"labels_\"\n        self._num_frames = meta['num_frames']\n        self._ycb_ids = meta['ycb_ids']\n        self._mano_sides = meta['mano_sides']\n\n        # Load intrinsics.\n        def intr_to_K(x):\n            return torch.tensor(\n                [[x['fx'], 0.0, x['ppx']], [0.0, x['fy'], x['ppy']], [0.0, 0.0, 1.0]],\n                dtype=torch.float32,\n                device=self._device)\n\n        self._K = []\n        for s in self._serials:\n            intr_file = self._dex_ycb_dir + \"/calibration/intrinsics/\" + s + '_' + str(\n                self._w) + 'x' + str(self._h) + \".yml\"\n            with open(intr_file, 'r') as f:\n                intr = yaml.load(f, Loader=yaml.FullLoader)\n            K = intr_to_K(intr['color'])\n            self._K.append(K)\n        self._K_inv = [torch.inverse(k) for k in self._K]\n\n        # Load extrinsics.\n        extr_file = self._dex_ycb_dir + \"/calibration/extrinsics_\" + meta[\n            'extrinsics'] + \"/extrinsics.yml\"\n        with open(extr_file, 'r') as f:\n            extr = yaml.load(f, Loader=yaml.FullLoader)\n        T = extr['extrinsics']\n        T = {\n            s: torch.tensor(T[s], dtype=torch.float32,\n                            device=self._device).view(3, 4) for s in T\n        }\n        self._R = [T[s][:, :3] for s in self._serials]\n        self._t = [T[s][:, 3] for s in self._serials]\n        self._R_inv = [torch.inverse(r) for r in self._R]\n        self._t_inv = [torch.mv(r, -t) for r, t in zip(self._R_inv, self._t)]\n        self._master_intrinsics = self._K[[\n            i for i, s in enumerate(self._serials) if s == extr['master']\n        ][0]].cpu().numpy()\n        self._tag_R = T['apriltag'][:, :3]\n        self._tag_t = T['apriltag'][:, 3]\n        self._tag_R_inv = torch.inverse(self._tag_R)\n        self._tag_t_inv = torch.mv(self._tag_R_inv, -self._tag_t)\n        self._tag_lim = [-0.00, +1.20, -0.10, +0.70, -0.10, +0.70]\n\n        # Compute texture coordinates.\n        y, x = torch.meshgrid(torch.arange(self._h), torch.arange(self._w), indexing=\"ij\")\n        x = x.float()\n        y = y.float()\n        s = torch.stack((x / (self._w - 1), y / (self._h - 1)), dim=2)\n        self._pcd_tex_coord = [s.numpy()] * self._num_cameras\n\n        # Compute rays.\n        self._p = []\n        ones = torch.ones((self._h, self._w), dtype=torch.float32)\n        xy1s = torch.stack((x, y, ones), dim=2).view(self._w * self._h, 3).t()\n        xy1s = xy1s.to(self._device)\n        for c in range(self._num_cameras):\n            p = torch.mm(self._K_inv[c], xy1s)\n            self._p.append(p)\n\n        # Load point cloud.\n        if self._preload:\n            print('Preloading point cloud')\n            self._color = []\n            self._depth = []\n            for c in range(self._num_cameras):\n                color = []\n                depth = []\n                for i in range(self._num_frames):\n                    rgb, d = self._load_frame_rgbd(c, i)\n                    color.append(rgb)\n                    depth.append(d)\n                self._color.append(color)\n                self._depth.append(depth)\n            self._color = np.array(self._color, dtype=np.uint8)\n            self._depth = np.array(self._depth, dtype=np.uint16)\n            self._pcd_rgb = [x for x in self._color]\n            self._pcd_vert = []\n            self._pcd_mask = []\n            for c in range(self._num_cameras):\n                p, m = self._deproject_depth_and_filter_points(self._depth[c], c)\n                self._pcd_vert.append(p)\n                self._pcd_mask.append(m)\n        else:\n            print('Loading point cloud online')\n            self._pcd_rgb = [\n                np.zeros((self._h, self._w, 3), dtype=np.uint8)\n                for _ in range(self._num_cameras)\n            ]\n            self._pcd_vert = [\n                np.zeros((self._h, self._w, 3), dtype=np.float32)\n                for _ in range(self._num_cameras)\n            ]\n            self._pcd_mask = [\n                np.zeros((self._h, self._w), dtype=np.bool)\n                for _ in range(self._num_cameras)\n            ]\n\n        # Create YCB group layer.\n        self._ycb_group_layer = YCBGroupLayer(self._ycb_ids).to(self._device)\n\n        self._ycb_model_dir = self._dex_ycb_dir + \"/models\"\n        self._ycb_count = self._ycb_group_layer.count\n        self._ycb_material = self._ycb_group_layer.material\n        self._ycb_tex_coords = self._ycb_group_layer.tex_coords\n\n        # Create MANO group layer.\n        mano_betas = []\n        for m in meta['mano_calib']:\n            mano_calib_file = self._dex_ycb_dir + \"/calibration/mano_\" + m + \"/mano.yml\"\n            with open(mano_calib_file, 'r') as f:\n                mano_calib = yaml.load(f, Loader=yaml.FullLoader)\n            betas = np.array(mano_calib['betas'], dtype=np.float32)\n            mano_betas.append(betas)\n\n        self._mano_group_layer = MANOGroupLayer(self._mano_sides,\n                                                mano_betas).to(self._device)\n\n        # Prepare data for viewer.\n        if app == 'viewer':\n            s = np.cumsum([0] + self._ycb_group_layer.count[:-1])\n            e = np.cumsum(self._ycb_group_layer.count)\n            self._ycb_seg = list(zip(s, e))\n\n            ycb_file = self._dex_ycb_dir + '/' + self._name + \"/pose.npz\"\n            data = np.load(ycb_file)\n            ycb_pose = data['pose_y']\n            i = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2)\n            pose = ycb_pose.reshape(-1, 7)\n            v, n = self.transform_ycb(pose)\n            self._ycb_vert = [\n                np.zeros((self._num_frames, n, 3), dtype=np.float32)\n                for n in self._ycb_count\n            ]\n            self._ycb_norm = [\n                np.zeros((self._num_frames, n, 3), dtype=np.float32)\n                for n in self._ycb_count\n            ]\n            for o in range(self._ycb_group_layer.num_obj):\n                io = i[:, o]\n                self._ycb_vert[o][io] = v[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]]\n                self._ycb_norm[o][io] = n[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]]\n\n            mano_file = self._dex_ycb_dir + '/' + self._name + \"/pose.npz\"\n            data = np.load(mano_file)\n            mano_pose = data['pose_m']\n            i = np.any(mano_pose != 0.0, axis=2)\n            pose = torch.from_numpy(mano_pose).to(self._device)\n            pose = pose.view(-1, self._mano_group_layer.num_obj * 51)\n            verts, _ = self._mano_group_layer(pose)\n            # Numpy array is faster than PyTorch Tensor here.\n            verts = verts.cpu().numpy()\n            f = self._mano_group_layer.f.cpu().numpy()\n            v = verts[:, f.ravel()]\n            n = np.cross(v[:, 1::3, :] - v[:, 0::3, :], v[:, 2::3, :] - v[:, 1::3, :])\n            n = np.repeat(n, 3, axis=1)\n            l = verts[:, f[:, [0, 1, 1, 2, 2, 0]].ravel(), :]\n            self._mano_vert = [\n                np.zeros((self._num_frames, 4614, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            self._mano_norm = [\n                np.zeros((self._num_frames, 4614, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            self._mano_line = [\n                np.zeros((self._num_frames, 9228, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            for o in range(self._mano_group_layer.num_obj):\n                io = i[:, o]\n                self._mano_vert[o][io] = v[io, 4614 * o:4614 * (o + 1), :]\n                self._mano_norm[o][io] = n[io, 4614 * o:4614 * (o + 1), :]\n                self._mano_line[o][io] = l[io, 9228 * o:9228 * (o + 1), :]\n\n        # Prepare data for renderer.\n        if app == 'renderer':\n            self._ycb_pose = []\n            self._mano_vert = []\n            self._mano_joint_3d = []\n\n            for c in range(self._num_cameras):\n                ycb_pose = []\n                mano_pose = []\n                mano_joint_3d = []\n                for i in range(self._num_frames):\n                    label_file = self._data_dir[\n                                     c] + '/' + self._label_prefix + \"{:06d}.npz\".format(i)\n                    label = np.load(label_file)\n                    pose_y = np.hstack((label['pose_y'],\n                                        np.array([[[0, 0, 0, 1]]] * len(label['pose_y']),\n                                                 dtype=np.float32)))\n                    pose_m = label['pose_m']\n                    joint_3d = label['joint_3d']\n                    ycb_pose.append(pose_y)\n                    mano_pose.append(pose_m)\n                    mano_joint_3d.append(joint_3d)\n                ycb_pose = np.array(ycb_pose, dtype=np.float32)\n                mano_pose = np.array(mano_pose, dtype=np.float32)\n                mano_joint_3d = np.array(mano_joint_3d, dtype=np.float32)\n                self._ycb_pose.append(ycb_pose)\n                self._mano_joint_3d.append(mano_joint_3d)\n\n                i = np.any(mano_pose != 0.0, axis=2)\n                pose = torch.from_numpy(mano_pose).to(self._device)\n                pose = pose.view(-1, self._mano_group_layer.num_obj * 51)\n                verts, _ = self._mano_group_layer(pose)\n                verts = verts.cpu().numpy()\n                mano_vert = [\n                    np.zeros((self._num_frames, 778, 3), dtype=np.float32)\n                    for _ in range(self._mano_group_layer.num_obj)\n                ]\n                for o in range(self._mano_group_layer.num_obj):\n                    io = i[:, o]\n                    mano_vert[o][io] = verts[io, 778 * o:778 * (o + 1), :]\n                self._mano_vert.append(mano_vert)\n\n        # Convert to Neus format.\n        if app == \"convert_to_neus\":\n            output_dataset_path = kwargs.get(\"output_dataset_path\", \"output_dataset\")\n            downscaling_factor = kwargs.get(\"downscaling_factor\", 1)\n            n_points = kwargs.get(\"n_points\", 3_600)\n            n_subsample = kwargs.get(\"n_subsample\", 1)\n            seed = kwargs.get(\"seed\", 72)\n            stream_rerun_viz = kwargs.get(\"stream_rerun_viz\", False)\n            save_rerun_viz = kwargs.get(\"save_rerun_viz\", False)\n\n            np.random.seed(seed)\n            torch.manual_seed(seed)\n\n            # Save camera centers as a .ply pointcloud, for debugging purposes.\n            t_centered = torch.stack(self._t) - torch.tensor([0., 0., 1.])  # Move along z axis by -1\n            colors = cm.get_cmap('tab10')(np.linspace(0, 1, self._num_cameras))[:, :3]\n            pcd = o3d.geometry.PointCloud()\n            pcd.points = o3d.utility.Vector3dVector(t_centered.cpu().numpy())\n            pcd.colors = o3d.utility.Vector3dVector(colors)\n            pcd_file = os.path.join(output_dataset_path, f\"camera_center__{c:02d}_cameras.ply\")\n            o3d.io.write_point_cloud(pcd_file, pcd)\n\n            # Create the view folders.\n            for c in range(self._num_cameras):\n                view_folder = os.path.join(output_dataset_path, f\"view_{c:02d}\")\n                os.makedirs(view_folder, exist_ok=True)\n\n                # Save the intrinsics.txt file.\n                intrinsics_file = os.path.join(view_folder, \"intrinsics.txt\")\n                intrinsics = np.zeros((4, 4), dtype=np.float32)\n                intrinsics[:3, :3] = self._K[c].cpu().numpy()\n                intrinsics[3, 3] = 1\n                intrinsics_str = '\\n'.join([' '.join([str(x) for x in row]) for row in intrinsics])\n                with open(intrinsics_file, \"w\") as f:\n                    f.write(intrinsics_str)\n\n                # Save the cameras_sphere.npz file.\n                R = self._R\n                t = self._t\n                t_centered = [t_ - torch.tensor([0., 0., 1.]) for t_ in t]  # Move along z axis by -1\n                R_inv = [torch.inverse(r) for r in R]\n                t_centered_inv = [torch.mv(r, -t) for r, t in zip(R_inv, t_centered)]\n                extrinsics = np.zeros((4, 4), dtype=np.float32)\n                extrinsics[:3, :3] = R_inv[c].cpu().numpy()\n                extrinsics[:3, 3] = t_centered_inv[c].cpu().numpy()\n                extrinsics[3, 3] = 1\n                cameras_sphere_file = os.path.join(view_folder, \"cameras_sphere.npz\")\n                cameras_sphere = {\n                    **{f'world_mat_{output_frame_id}': intrinsics @ extrinsics\n                       for output_frame_id in range(math.ceil(self._num_frames / n_subsample))},\n                    **{f'scale_mat_{output_frame_id}': np.diag(\n                        [downscaling_factor, downscaling_factor, downscaling_factor, 1.0])\n                        for output_frame_id in range(math.ceil(self._num_frames / n_subsample))}\n                }\n                np.savez_compressed(cameras_sphere_file, **cameras_sphere)\n\n                # Also, save the intrinsics and extrinsics directly into a .npz file.\n                camera_params_path = os.path.join(view_folder, \"intrinsics_extrinsics.npz\")\n                np.savez_compressed(camera_params_path, intrinsics=intrinsics, extrinsics=extrinsics)\n\n                # Save the rgb and depth images. And dummy masks.\n                rgb_folder = os.path.join(view_folder, \"rgb\")\n                depth_folder = os.path.join(view_folder, \"depth\")\n                mask_folder = os.path.join(view_folder, \"mask\")\n                rgb_with_valid_depth_folder = os.path.join(view_folder, \"rgb_with_valid_depth\")\n                os.makedirs(rgb_folder, exist_ok=True)\n                os.makedirs(depth_folder, exist_ok=True)\n                os.makedirs(mask_folder, exist_ok=True)\n                os.makedirs(rgb_with_valid_depth_folder, exist_ok=True)\n                for output_frame_id in range(math.ceil(self._num_frames / n_subsample)):\n                    input_frame_id = output_frame_id * n_subsample\n                    rgb = self._color[c][input_frame_id][:, :, ::-1]\n                    rgb_file = os.path.join(rgb_folder, f\"{output_frame_id:05d}.png\")\n                    cv2.imwrite(rgb_file, rgb)\n\n                    depth = self._depth[c][input_frame_id]\n                    depth_file = os.path.join(depth_folder, f\"{output_frame_id:05d}.png\")\n                    cv2.imwrite(depth_file, depth)\n\n                    rgb_plot = rgb.copy()\n                    rgb_plot[depth == 0] = 255\n                    cv2.imwrite(os.path.join(rgb_with_valid_depth_folder, f\"{output_frame_id:05d}.png\"), rgb_plot)\n\n                    label_file = self._data_dir[c] + '/' + self._label_prefix + \"{:06d}.npz\".format(input_frame_id)\n                    label = np.load(label_file)\n                    seg_mask = label[\"seg\"]\n                    mask = seg_mask != 0  # Everything that is not background\n                    mask = mask[:, :, None].astype(np.uint8).repeat(3, 2) * 255\n                    # dummy_mask = np.ones((self._h, self._w, 3)).astype(np.uint8) * 255\n                    # mask = dummy_mask\n                    mask_file = os.path.join(mask_folder, f\"{output_frame_id:05d}.png\")\n                    imageio.imwrite(mask_file, mask)\n\n                    # Backproject the depth image to 3D points for visualization purposes.\n                    if output_frame_id in [0, math.ceil(self._num_frames / n_subsample) - 1] and c in range(\n                            self._num_cameras):\n                        d = self._depth[c][input_frame_id]\n                        d = d.astype(np.float32) / 1000\n                        d = torch.from_numpy(d).to(self._device)\n\n                        p = torch.mul(\n                            d.view(1, -1, self._w * self._h).expand(3, -1, -1),\n                            self._p[c].unsqueeze(1))\n                        p = torch.addmm(self._t[c].unsqueeze(1), self._R[c], p.view(3, -1))\n                        p = p.t().view(self._h, self._w, 3)\n                        p = p.cpu().numpy()\n\n                        m = d > 0\n                        p = p[m]\n                        colors = self._color[c][input_frame_id][m] / 255\n\n                        pcd = o3d.geometry.PointCloud()\n                        pcd.points = o3d.utility.Vector3dVector(p)\n                        pcd.colors = o3d.utility.Vector3dVector(colors)\n                        pcd_file = os.path.join(view_folder, f\"pcd_for_t{output_frame_id:03d}.ply\")\n                        o3d.io.write_point_cloud(pcd_file, pcd)\n\n            # Compute meshes for each frame.\n            s = np.cumsum([0] + self._ycb_group_layer.count[:-1])\n            e = np.cumsum(self._ycb_group_layer.count)\n            self._ycb_seg = list(zip(s, e))\n\n            ycb_file = self._dex_ycb_dir + '/' + self._name + \"/pose.npz\"\n            data = np.load(ycb_file)\n            ycb_pose = data['pose_y'][::n_subsample]\n            i = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2)\n            pose = ycb_pose.reshape(-1, 7)\n            v, n = self.transform_ycb(pose)\n            self._ycb_vert = [\n                np.zeros((math.ceil(self._num_frames / n_subsample), n, 3), dtype=np.float32)\n                for n in self._ycb_count\n            ]\n            self._ycb_norm = [\n                np.zeros((math.ceil(self._num_frames / n_subsample), n, 3), dtype=np.float32)\n                for n in self._ycb_count\n            ]\n            for o in range(self._ycb_group_layer.num_obj):\n                io = i[:, o]\n                self._ycb_vert[o][io] = v[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]]\n                self._ycb_norm[o][io] = n[io, self._ycb_seg[o][0]:self._ycb_seg[o][1]]\n            self._ycb_faces = [\n                np.arange(n).reshape(-1, 3)\n                for n in self._ycb_count\n            ]\n\n            mano_file = self._dex_ycb_dir + '/' + self._name + \"/pose.npz\"\n            data = np.load(mano_file)\n            mano_pose = data['pose_m'][::n_subsample]\n            i = np.any(mano_pose != 0.0, axis=2)\n            pose = torch.from_numpy(mano_pose).to(self._device)\n            pose = pose.view(-1, self._mano_group_layer.num_obj * 51)\n            verts, _ = self._mano_group_layer(pose)\n            # Numpy array is faster than PyTorch Tensor here.\n            verts = verts.cpu().numpy()\n            f = self._mano_group_layer.f.cpu().numpy()\n            v = verts[:, f.ravel()]\n            n = np.cross(v[:, 1::3, :] - v[:, 0::3, :], v[:, 2::3, :] - v[:, 1::3, :])\n            n = np.repeat(n, 3, axis=1)\n            l = verts[:, f[:, [0, 1, 1, 2, 2, 0]].ravel(), :]\n            self._mano_vert = [\n                np.zeros((math.ceil(self._num_frames / n_subsample), 4614, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            self._mano_norm = [\n                np.zeros((math.ceil(self._num_frames / n_subsample), 4614, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            self._mano_line = [\n                np.zeros((math.ceil(self._num_frames / n_subsample), 9228, 3), dtype=np.float32)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            self._mano_faces = [\n                np.arange(4614).reshape(-1, 3)\n                for _ in range(self._mano_group_layer.num_obj)\n            ]\n            for o in range(self._mano_group_layer.num_obj):\n                io = i[:, o]\n                self._mano_vert[o][io] = v[io, 4614 * o:4614 * (o + 1), :]\n                self._mano_norm[o][io] = n[io, 4614 * o:4614 * (o + 1), :]\n                self._mano_line[o][io] = l[io, 9228 * o:9228 * (o + 1), :]\n\n            vert = []\n            vert += self._ycb_vert\n            vert += self._mano_vert\n\n            norm = []\n            norm += self._ycb_norm\n            norm += self._mano_norm\n\n            ids = []\n            ids += self._ycb_group_layer._ids\n            ids += [255 for _ in self._mano_group_layer._sides]\n\n            names = []\n            names += [\"ycb-\" + layer._class_name for layer in self._ycb_group_layer._layers]\n            names += [f\"mano-{side}-hand\" for side in self._mano_group_layer._sides]\n\n            faces = []\n            faces += self._ycb_faces\n            faces += self._mano_faces\n\n            print(f\"Number of meshes: {len(vert)}\")\n            assert len(vert) == len(norm) == len(faces) == len(ids) == len(names)\n\n            print(f\"Mesh names: {names}\")\n            print(f\"Mesh IDS: {ids}\")\n\n            all_vertices = np.concatenate(vert, axis=1)\n            all_normals = np.concatenate(norm, axis=1)\n            all_faces = np.concatenate([\n                f + np.sum([v.shape[1] for v in vert[:i]]).astype(np.uint32)\n                for i, f in enumerate(faces)\n            ])\n            all_ids = np.concatenate([np.full(v.shape[1], i) for i, v in enumerate(vert)])\n            assert all_vertices.shape[0] == all_normals.shape[0]\n            assert all_vertices.shape[1] == all_normals.shape[1] == all_faces.shape[0] * 3 == all_ids.shape[0]\n            assert all_faces.max() + 1 == all_vertices.shape[1]\n            print(f\"all_vertices.shape: {all_vertices.shape}\")\n            print(f\"all_normals.shape: {all_normals.shape}\")\n            print(f\"all_faces.shape: {all_faces.shape}\")\n            print(f\"all_ids.shape: {all_ids.shape}\")\n\n            n_frames = all_vertices.shape[0]\n            meshes = [\n                trimesh.Trimesh(\n                    vertices=all_vertices[frame_idx],\n                    faces=all_faces,\n                    vertex_normals=all_normals[frame_idx],\n                    process=False,\n                )\n                for frame_idx in range(n_frames)\n            ]\n\n            # Put the query points onto the frame where the hand is first visible\n            hands_visible = np.any(mano_pose != 0.0, axis=2).all(axis=1)\n            assert np.any(hands_visible), \"Hands must be visible in at least one frame\"\n            t0 = np.argmax(hands_visible, axis=0)\n\n            objects_visible = np.any(ycb_pose != [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], axis=2).all(axis=1)\n            assert objects_visible[t0], \"Objects must be visible in the first frame where the hands are visible\"\n\n            picked_faces, picked_weights, picked_points = sample_surface(meshes[t0], n_points, seed=seed)\n            assert np.allclose(picked_points,\n                               pick_points_from_mesh(meshes[t0], picked_faces, picked_weights, meshes[t0]))\n            picked_vertices = meshes[t0].faces[:, 0][picked_faces]\n            picked_ids = all_ids[picked_vertices]\n\n            # Track the points\n            tracks_3d = []\n            for frame_idx in range(n_frames):\n                points = pick_points_from_mesh(meshes[frame_idx], picked_faces, picked_weights, meshes[t0])\n                tracks_3d.append(points)\n                if frame_idx == t0:\n                    assert np.allclose(points, picked_points)\n            tracks_3d = np.stack(tracks_3d)  # (n_frames, n_points, 3)\n\n            # Project the points to the camera\n            tracks_2d = []\n            tracks_2d_z = []\n            for c in range(self._num_cameras):\n                p = torch.from_numpy(tracks_3d).to(self._device).T.reshape(3, -1)\n                p = self._R_inv[c].double() @ p + self._t_inv[c][:, None]\n                p = self._K[c].double() @ p\n                z = p[2]\n                p = p[:2] / z\n\n                p = p.cpu().numpy().reshape(2, n_points, math.ceil(self._num_frames / n_subsample)).T\n                z = z.cpu().numpy().reshape(n_points, math.ceil(self._num_frames / n_subsample)).T\n\n                tracks_2d.append(p)\n                tracks_2d_z.append(z)\n\n            tracks_2d = np.stack(tracks_2d)\n            tracks_2d_z = np.stack(tracks_2d_z)\n\n            # --- Estimate occlusion\n            rendered_depth = []\n            for c in range(self._num_cameras):\n                rendered_depth_camera = []\n                for frame_idx in range(n_frames):\n                    rgb = self._color[c][0]\n                    depth = (self._depth[c][0] / 1000).clip(0, 2)\n                    h, w = self._h, self._w\n                    K = self._K[c].cpu().numpy()\n                    w2c = np.eye(4, dtype=float)\n                    w2c[:3, :3] = self._R_inv[c].cpu().numpy()\n                    w2c[:3, 3] = self._t_inv[c].cpu().numpy()\n                    c2w = np.linalg.inv(w2c)\n\n                    # Render depth\n                    device = \"cuda\"\n                    vertices = torch.tensor(all_vertices[frame_idx], dtype=torch.float32).to(device)\n                    faces = torch.tensor(all_faces, dtype=torch.int64).to(device)\n                    vertex_colors = torch.ones_like(vertices).unsqueeze(0).to(device)\n                    textures = TexturesVertex(verts_features=vertex_colors)\n                    mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)\n                    intrinsics = torch.eye(4, dtype=torch.float32).to(device)\n                    intrinsics[:3, :3] = torch.from_numpy(K)\n                    cameras = cameras_from_opencv_projection(\n                        R=torch.from_numpy(w2c[:3, :3]).to(device)[None].float(),\n                        tvec=torch.from_numpy(w2c[:3, 3]).to(device)[None].float(),\n                        camera_matrix=self._K[c].to(device)[None].float(),\n                        image_size=torch.tensor([self._h, self._w], dtype=torch.int32).to(device)[None].float(),\n                    )\n                    raster_settings = RasterizationSettings(\n                        image_size=(self._h, self._w),\n                        blur_radius=0.0,\n                        faces_per_pixel=1,\n                        bin_size=0,\n                    )\n                    renderer = MeshRendererWithFragments(\n                        rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),\n                        shader=SoftPhongShader(device=device, cameras=cameras, lights=PointLights(device=device)),\n                    )\n                    images, fragments = renderer(mesh)\n                    depth_map = fragments.zbuf\n                    rendered_depth_camera.append(depth_map.cpu().numpy()[0, :, :, 0])\n                rendered_depth.append(rendered_depth_camera)\n            rendered_depth = np.stack(rendered_depth)\n            assert rendered_depth.shape == (self._num_cameras, n_frames, self._h, self._w)\n\n            seg_masks = []\n            for c in range(self._num_cameras):\n                seg_masks_camera = []\n                for frame_idx in range(n_frames):\n                    input_frame_id = frame_idx * n_subsample\n                    label_file = self._data_dir[c] + '/' + self._label_prefix + \"{:06d}.npz\".format(input_frame_id)\n                    label = np.load(label_file)\n                    seg_masks_camera.append(label[\"seg\"])\n                seg_masks.append(seg_masks_camera)\n            seg_masks = np.stack(seg_masks)\n            assert seg_masks.shape == (self._num_cameras, n_frames, self._h, self._w)\n\n            seg_unique = np.unique(seg_masks)\n            cmap = get_cmap(\"tab10\")\n            seg_masks_rgb = np.zeros((*seg_masks.shape, 3), dtype=np.uint8)\n            for idx, val in enumerate(seg_unique):\n                seg_masks_rgb[seg_masks == val] = (np.array(cmap(idx / len(seg_unique))[:3]) * 255).astype(np.uint8)\n            assert seg_masks_rgb.shape == (self._num_cameras, n_frames, self._h, self._w, 3)\n\n            def estimate_occlusion_by_depth_and_segment(\n                    depth_map,\n                    x,\n                    y,\n                    num_frames,\n                    thresh,\n                    seg_id=None,\n                    segments=None,\n                    min_or_max_reduce=\"max\",\n                    convert_to_pixel_coords=True,\n                    occlude_if_depth_larger_than_xxx=None,\n            ):\n                # need to convert from raster to pixel coordinates\n                if convert_to_pixel_coords:\n                    x = x - 0.5\n                    y = y - 0.5\n\n                x0 = np.floor(x).astype(np.int32)\n                x1 = x0 + 1\n                y0 = np.floor(y).astype(np.int32)\n                y1 = y0 + 1\n\n                shp = depth_map.shape\n                assert len(depth_map.shape) == 3\n                x0 = np.clip(x0, 0, shp[2] - 1)\n                x1 = np.clip(x1, 0, shp[2] - 1)\n                y0 = np.clip(y0, 0, shp[1] - 1)\n                y1 = np.clip(y1, 0, shp[1] - 1)\n\n                depth_map = depth_map.reshape(-1)\n                rng = np.arange(num_frames)[:, np.newaxis]\n                assert x.shape[0] == y.shape[0] == num_frames\n                i1 = np.take(depth_map, rng * shp[1] * shp[2] + y0 * shp[2] + x0)\n                i2 = np.take(depth_map, rng * shp[1] * shp[2] + y1 * shp[2] + x0)\n                i3 = np.take(depth_map, rng * shp[1] * shp[2] + y0 * shp[2] + x1)\n                i4 = np.take(depth_map, rng * shp[1] * shp[2] + y1 * shp[2] + x1)\n\n                if min_or_max_reduce == \"max\":\n                    depth = np.maximum(np.maximum(np.maximum(i1, i2), i3), i4)\n                elif min_or_max_reduce == \"min\":\n                    depth = np.minimum(np.minimum(np.minimum(i1, i2), i3), i4)\n                else:\n                    raise ValueError(f\"Unknown min_or_max_reduce: {min_or_max_reduce}\")\n                if occlude_if_depth_larger_than_xxx is not None:\n                    depth[depth >= occlude_if_depth_larger_than_xxx] = 0\n                depth_occluded = depth < thresh\n                print(\"┌ Depth occlusion: \", depth_occluded.sum(), \"/\", depth_occluded.size)\n\n                occluded = depth_occluded\n                if segments is not None:\n                    segments = segments.reshape(-1)\n                    i1 = np.take(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x0)\n                    i2 = np.take(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x0)\n                    i3 = np.take(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x1)\n                    i4 = np.take(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x1)\n                    seg_occluded = np.ones_like(depth_occluded, dtype=bool)\n                    for i in [i1, i2, i3, i4]:\n                        i = i.astype(np.int32)\n                        seg_occluded = np.logical_and(seg_occluded, seg_id != i)\n                    print(\"| Segmentation occlusion: \", seg_occluded.sum(), \"/\", seg_occluded.size)\n                    occluded = np.logical_or(occluded, seg_occluded)\n\n                return occluded\n\n            tracks_2d_visibilities = []\n            for c in range(self._num_cameras):\n                occlusion = np.zeros((tracks_2d[c].shape[0], tracks_2d[c].shape[1]), dtype=bool)\n                print(f\"N occluded: {occlusion.sum()} / {occlusion.size}\")\n                occlusion = np.logical_or(occlusion, (tracks_2d_z[c] <= 0) | (tracks_2d_z[c] >= (65535 / 1000)))\n                print(f\"N occluded (after Z): {occlusion.sum()} / {occlusion.size}\")\n                occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 0] <= 0)\n                occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 1] <= 0)\n                occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 0] >= self._w - 1)\n                occlusion = np.logical_or(occlusion, tracks_2d[c][:, :, 1] >= self._h - 1)\n                print(f\"N occluded (& out-of-frame): {occlusion.sum()} / {occlusion.size}\")\n\n                # # V1: Use the depth map to estimate occlusion\n                # depth_map_for_occlusion = self._depth[c][::n_subsample].copy()\n                # depth_map_for_occlusion[depth_map_for_occlusion == 0] = 65535\n                # depth_map_for_occlusion = depth_map_for_occlusion / 1000.0\n\n                # # V2: Make the depth for occlussion be the depth from projected predicted points, taking the minimum z over all points at a pixel\n                # depth_map_for_occlusion = np.ones((tracks_2d_z.shape[1], self._h, self._w),\n                #                                   dtype=np.float32) * 65535 / 1000\n                # for frame_idx in range(math.ceil(self._num_frames / n_subsample)):\n                #     for point_idx in range(n_points):\n                #         if np.isnan(tracks_2d[c][frame_idx, point_idx]).any():\n                #             continue\n                #         x = int(tracks_2d[c][frame_idx, point_idx, 0])\n                #         y = int(tracks_2d[c][frame_idx, point_idx, 1])\n                #         z = tracks_2d_z[c][frame_idx, point_idx]\n                #         if 0 <= x < self._w and 0 <= y < self._h:\n                #             depth_map_for_occlusion[frame_idx, y - 3:y + 3, x - 3:x + 3] = np.minimum(\n                #                 depth_map_for_occlusion[frame_idx, y - 3:y + 3, x - 3:x + 3],\n                #                 z,\n                #             )\n                # # Visualize it side by side with GT depth\n                # if False:\n                #     for frame_idx in range(math.ceil(self._num_frames / n_subsample)):\n                #         if frame_idx not in [0, math.ceil(self._num_frames / n_subsample) - 1]:\n                #             continue\n                #         d1 = self._depth[c][frame_idx * n_subsample] / 1000\n                #         d2 = depth_map_for_occlusion[frame_idx]\n                #         d12 = np.concatenate([d1, d2], axis=1)\n                #         plt.figure(dpi=150, figsize=(d12.shape[1] / 100, d12.shape[0] / 100))\n                #         plt.title(f\"Depth GT (left) vs Depth used for occlusion (right), frame {frame_idx}\")\n                #         plt.imshow(d12.clip(0.5, 1))\n                #         plt.axis('off')\n                #         plt.tight_layout(pad=0)\n                #         plt.savefig(os.path.join(output_dataset_path,\n                #                                  f\"depth_used_for_occlussion_view_{c:02d}_frame_{frame_idx:05d}.png\"))\n                #         # plt.show()\n                #\n                # seg_mask = []\n                # for output_frame_id in range(math.ceil(self._num_frames / n_subsample)):\n                #     input_frame_id = output_frame_id * n_subsample\n                #     label_file = self._data_dir[c] + '/' + self._label_prefix + \"{:06d}.npz\".format(input_frame_id)\n                #     label = np.load(label_file)\n                #     seg_mask.append(label[\"seg\"])\n                # seg_mask = np.stack(seg_mask)\n                # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment(\n                #     depth_map=depth_map_for_occlusion,\n                #     segments=seg_mask,\n                #     x=tracks_2d[c][:, :, 0],\n                #     y=tracks_2d[c][:, :, 1],\n                #     num_frames=tracks_2d[c].shape[0],\n                #     thresh=tracks_2d_z[c] * 0.995,\n                #     seg_id=np.array(ids)[picked_ids],\n                # )\n                # occlusion = np.logical_or(occlusion, depth_or_segment_occluded)\n                # print(f\"N occluded (& obscured by other objects): {occlusion.sum()} / {occlusion.size}\")\n                # print()\n                # tracks_2d_visibilities.append(~occlusion)\n\n                # # V3.a: Neither the GT depth nor the segmentation mask are reliable for occlusion estimation.\n                # #       Instead, we will use the rendered depth map, with a little help from the GT depth.\n                # #       First, the rendered depth needs to match the point depth, if not, the point is occluded.\n                # #       This will work perfectly for all the objects that have a full 3D mesh over time. So all\n                # #       the objects on the table, plus the MANO hand (but not the arm). This is susceptible to\n                # #       errors in estimating the mesh location, but it should be less problematic than the other\n                # #       segmentation mask and GT depth in that it will be less noisy and more consistent over time.\n                # rendered_depth_for_occlusion = rendered_depth[c].copy()\n                # rendered_depth_for_occlusion[rendered_depth_for_occlusion <= 0] = 65535 / 1000\n                # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment(\n                #     depth_map=rendered_depth_for_occlusion,\n                #     x=tracks_2d[c, :, :, 0],\n                #     y=tracks_2d[c, :, :, 1],\n                #     num_frames=n_frames,\n                #     thresh=tracks_2d_z[c, :, :] - 0.01,\n                #     min_or_max_reduce=\"min\",\n                #     convert_to_pixel_coords=False,\n                #     occlude_if_depth_larger_than_xxx=65535 / 1000,\n                # )\n                # occlusion = np.logical_or(occlusion, depth_or_segment_occluded)\n                # print(f\"N occluded (& obscured in rendered depth): {occlusion.sum()} / {occlusion.size}\")\n                # print()\n                #\n                # # # V3.b: Second, to avoid occlusion by the arm, we will use the GT depth map but with a high threshold.\n                # # #       This will avoid the arm occluding the points as the arm is not in the rendered depth map.\n                # # depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment(\n                # #     depth_map=self._depth[c][::n_subsample] / 1000,\n                # #     x=tracks_2d[c, :, :, 0],\n                # #     y=tracks_2d[c, :, :, 1],\n                # #     num_frames=n_frames,\n                # #     thresh=tracks_2d_z[c, :, :] * 0.995,\n                # # )\n                # # occlusion = np.logical_or(occlusion, depth_or_segment_occluded)\n                # # print(f\"N occluded (& obscured in GT depth): {occlusion.sum()} / {occlusion.size}\")\n                # # print()\n\n                # V4.a: Forget the rendered depths, it's still difficult because the depth map is pixelized.\n                #       Instead, let's shoot rays from the camera onto the scene mesh and see where they intersect.\n                #       It is very very slow but most accurate.\n                camera_center = self._t[c].cpu().numpy()\n                for frame_idx in range(n_frames):\n                    for track_idx in tqdm(range(n_points), desc=f\"Ray casting for camera {c} frame {frame_idx}\"):\n                        if occlusion[frame_idx, track_idx]:\n                            continue\n                        ray_direction = tracks_3d[frame_idx, track_idx] - camera_center\n                        ray_direction /= np.linalg.norm(ray_direction)\n                        intersections = meshes[frame_idx].ray.intersects_location(camera_center[None],\n                                                                                  ray_direction[None])\n                        if len(intersections[0]) == 0:\n                            occlusion[frame_idx, track_idx] = True\n                            continue\n                        intersection_depth = np.inf\n                        for intersection in intersections[0]:\n                            intersection_depth = min(intersection_depth, np.linalg.norm(intersection - camera_center))\n                        track_depth = np.linalg.norm(tracks_3d[frame_idx, track_idx] - camera_center)\n                        occlusion[frame_idx, track_idx] = not np.isclose(intersection_depth, track_depth, atol=0.001)\n                print(f\"N occluded (& obscured in scene mesh): {occlusion.sum()} / {occlusion.size}\")\n                print()\n\n                # V4.b: The arm is not in the scene mesh and it is causing problems for 1/2 cameras. Let's use the\n                #       GT depths to figure out if the arm is occluding the points. Unfortunately, this will also not\n                #       work perfectly because the GT depths are missing around silhouette edges.\n                depth_map_for_occlusion = self._depth[c][::n_subsample].copy()\n                depth_map_for_occlusion[depth_map_for_occlusion <= 0] = 65535\n                depth_map_for_occlusion = depth_map_for_occlusion / 1000.0\n                depth_or_segment_occluded = estimate_occlusion_by_depth_and_segment(\n                    depth_map=depth_map_for_occlusion,\n                    x=tracks_2d[c, :, :, 0],\n                    y=tracks_2d[c, :, :, 1],\n                    num_frames=n_frames,\n                    thresh=tracks_2d_z[c, :, :] - 0.12,\n                    min_or_max_reduce=\"min\",\n                    convert_to_pixel_coords=False,\n                )\n                occlusion = np.logical_or(occlusion, depth_or_segment_occluded)\n                print(f\"N occluded (& obscured in GT depth): {occlusion.sum()} / {occlusion.size}\")\n                print()\n\n                # Idea for V5: Do V4 and additionally try looking at if the RGB changed a lot for the point.\n                #              If it did, then it is likely that the point is occluded by the arm/person.\n                #              However, this might suffer from the same problem as in V3: edges would be noisy\n                #              and might quickly jump from visible to occluded to visible again.\n                ...\n\n                tracks_2d_visibilities.append(~occlusion)\n\n            tracks_2d_visibilities = np.stack(tracks_2d_visibilities)\n            tracks_3d_visibilities = tracks_2d_visibilities.any(axis=0)\n\n            if stream_rerun_viz or save_rerun_viz:\n                assert not (stream_rerun_viz and save_rerun_viz), (\"Stream and save rerun at the same time not \"\n                                                                   \"supported. But you can save what was streamed \"\n                                                                   \"within the rerun viewer. Or run again. Or impl it.\")\n\n                rr.init(\"dexycb_preprocessing\", recording_id=\"v0.1\")\n                if stream_rerun_viz:\n                    rr.connect_tcp()\n                rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n                rr.set_time_seconds(\"frame\", 0)\n                rr.log(\n                    \"world/xyz\",\n                    rr.Arrows3D(\n                        vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                        colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n                    ),\n                )\n                entity_prefix = f\"{os.path.basename(output_dataset_path)}/\"\n                radii_scale = 0.1\n                for t in range(n_frames):\n                    t_input = t * n_subsample\n                    rr.set_time_seconds(\"frame\", t / 12)\n                    rr.log(f\"{entity_prefix}mesh\", rr.Mesh3D(\n                        vertex_positions=np.asarray(meshes[frame_idx].vertices),\n                        triangle_indices=np.asarray(meshes[frame_idx].faces),\n                    ))\n                    for c in range(self._num_cameras):\n                        rgb = self._color[c, t_input]\n                        depth = (self._depth[c, t_input] / 1000).clip(0, 2)\n                        rend_depth = rendered_depth[c, t].clip(0, 2)\n                        seg_mask = seg_masks_rgb[c, t]\n                        seg_rgb = seg_masks_rgb[c, t]\n                        h, w = self._h, self._w\n                        K = self._K[c].cpu().numpy()\n                        K_inv = np.linalg.inv(K)\n                        w2c = np.eye(4, dtype=float)\n                        w2c[:3, :3] = self._R_inv[c].cpu().numpy()\n                        w2c[:3, 3] = self._t_inv[c].cpu().numpy()\n                        c2w = np.linalg.inv(w2c)\n\n                        cam_pinhole = rr.Pinhole(image_from_camera=K, width=w, height=h)\n                        cam_transform = rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3])\n                        for name, archetype in [\n                            (\"rgb\", rr.Image(rgb)),\n                            (\"seg\", rr.Image(seg_rgb)),\n                            (\"depth-gt\", rr.DepthImage(depth, point_fill_ratio=0.2)),\n                            (\"depth-rendered\", rr.DepthImage(rend_depth, point_fill_ratio=0.2)),\n                        ]:\n                            rr.log(f\"{entity_prefix}/image/{name}/view-{c:02d}\", cam_pinhole)\n                            rr.log(f\"{entity_prefix}/image/{name}/view-{c:02d}\", cam_transform)\n                            rr.log(f\"{entity_prefix}/image/{name}/view-{c:02d}/{name}\", archetype)\n\n                        # Compute 3D points from GT depth map\n                        y, x = np.indices((self._h, self._w))\n                        homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                        cam_coords = (K_inv @ homo_pixel_coords) * depth.ravel()\n                        cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                        world_coords = (c2w @ cam_coords)[:3].T\n                        valid_mask = depth.ravel() > 0\n                        rr.log(f\"{entity_prefix}point_cloud/rgb-gt/view-{c}\", rr.Points3D(\n                            positions=world_coords[valid_mask],\n                            colors=rgb.reshape(-1, 3)[valid_mask].astype(np.uint8),\n                            radii=0.01 * radii_scale,\n                        ))\n                        rr.log(f\"{entity_prefix}point_cloud/seg-gt/view-{c}\", rr.Points3D(\n                            positions=world_coords[valid_mask],\n                            colors=seg_rgb.reshape(-1, 3)[valid_mask].astype(np.uint8),\n                            radii=0.01 * radii_scale,\n                        ))\n\n                        # Compute 3D points from GT depth map\n                        y, x = np.indices((self._h, self._w))\n                        homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                        cam_coords = (K_inv @ homo_pixel_coords) * rend_depth.ravel()\n                        cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                        world_coords = (c2w @ cam_coords)[:3].T\n                        valid_mask = rend_depth.ravel() > 0\n                        rr.log(f\"{entity_prefix}point_cloud/rgb-rend/view-{c}\", rr.Points3D(\n                            positions=world_coords[valid_mask],\n                            colors=rgb.reshape(-1, 3)[valid_mask].astype(np.uint8),\n                            radii=0.01 * radii_scale,\n                        ))\n                        rr.log(f\"{entity_prefix}point_cloud/seg-rend/view-{c}\", rr.Points3D(\n                            positions=world_coords[valid_mask],\n                            colors=seg_rgb.reshape(-1, 3)[valid_mask].astype(np.uint8),\n                            radii=0.01 * radii_scale,\n                        ))\n\n                def log_tracks(\n                        tracks: np.ndarray,\n                        visibles: np.ndarray,\n                        query_timestep: np.ndarray,\n                        colors: np.ndarray,\n\n                        entity_format_str=\"{}\",\n\n                        log_points=True,\n                        points_radii=0.03 * radii_scale,\n                        invisible_color=[0., 0., 0.],\n\n                        log_line_strips=True,\n                        max_strip_length_past=12,\n                        max_strip_length_future=1,\n                        strips_radii=0.0042 * radii_scale,\n\n                        log_error_lines=False,\n                        error_lines_radii=0.0072 * radii_scale,\n                        error_lines_color=[1., 0., 0.],\n                        gt_for_error_lines=None,\n                ) -> None:\n                    \"\"\"\n                    Log tracks to Rerun.\n\n                    Parameters:\n                        tracks: Shape (T, N, 3), the 3D trajectories of points.\n                        visibles: Shape (T, N), boolean visibility mask for each point at each timestep.\n                        query_timestep: Shape (T, N), the frame index after which the tracks start.\n                        colors: Shape (N, 4), RGBA colors for each point.\n                        entity_prefix: String prefix for entity hierarchy in Rerun.\n                        entity_suffix: String suffix for entity hierarchy in Rerun.\n                    \"\"\"\n\n                    T, N, _ = tracks.shape\n                    assert tracks.shape == (T, N, 3)\n                    assert visibles.shape == (T, N)\n                    assert query_timestep.shape == (N,)\n                    assert query_timestep.min() >= 0\n                    assert query_timestep.max() < T\n                    assert colors.shape == (N, 4)\n\n                    for n in range(N):\n                        rr.log(entity_format_str.format(f\"track-{n}\"), rr.Clear(recursive=True))\n                        for t in range(query_timestep[n], T):\n                            rr.set_time_seconds(\"frame\", t / 12)\n\n                            # Log the point (special handling for invisible points)\n                            if log_points:\n                                rr.log(\n                                    entity_format_str.format(f\"track-{n}/point\"),\n                                    rr.Points3D(\n                                        positions=[tracks[t, n]],\n                                        colors=[colors[n, :3]] if visibles[t, n] else [invisible_color],\n                                        radii=points_radii,\n                                    ),\n                                )\n\n                            # Log line segments for visible tracks\n                            if log_line_strips and t > query_timestep[n]:\n                                strip_t_start = max(t - max_strip_length_past, query_timestep[n].item())\n                                strip_t_end = min(t + max_strip_length_future, T - 1)\n\n                                strips = np.stack([\n                                    tracks[strip_t_start:strip_t_end, n],\n                                    tracks[strip_t_start + 1:strip_t_end + 1, n],\n                                ], axis=-2)\n                                strips_visibility = visibles[strip_t_start + 1:strip_t_end + 1, n]\n                                strips_colors = np.where(\n                                    strips_visibility[:, None],\n                                    colors[None, n, :3],\n                                    [invisible_color],\n                                )\n\n                                rr.log(\n                                    entity_format_str.format(f\"track-{n}/line\"),\n                                    rr.LineStrips3D(strips=strips, colors=strips_colors, radii=strips_radii),\n                                )\n\n                            if log_error_lines:\n                                assert gt_for_error_lines is not None\n                                strips = np.stack([\n                                    tracks[t, n],\n                                    gt_for_error_lines[t, n],\n                                ], axis=-2)\n                                rr.log(\n                                    entity_format_str.format(f\"track-{n}/error\"),\n                                    rr.LineStrips3D(strips=strips, colors=error_lines_color, radii=error_lines_radii),\n                                )\n\n                # Log the tracks\n                cmap = matplotlib.colormaps[\"gist_rainbow\"]\n                norm = matplotlib.colors.Normalize(vmin=tracks_3d[..., 0].min(), vmax=tracks_3d[..., 0].max())\n                track_color = cmap(norm(tracks_3d[-1, :, 0]))\n                track_color = track_color * 0 + 1  # Just make all tracks white\n\n                N = 800\n                B = 200\n                for tracks_batch_start in range(0, N, B):\n                    tracks_batch_end = min(tracks_batch_start + B, N)\n                    for name, visibles in [\n                        (\"tracks/c01234567-visibility\",\n                         tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c0123-visibility\",\n                         tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c2345-visibility\",\n                         tracks_2d_visibilities.any(0)[:, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c0-visibility\", tracks_2d_visibilities[0, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c1-visibility\", tracks_2d_visibilities[1, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c2-visibility\", tracks_2d_visibilities[2, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c3-visibility\", tracks_2d_visibilities[3, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c4-visibility\", tracks_2d_visibilities[4, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c5-visibility\", tracks_2d_visibilities[5, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c6-visibility\", tracks_2d_visibilities[6, :, tracks_batch_start:tracks_batch_end]),\n                        (\"tracks/c7-visibility\", tracks_2d_visibilities[7, :, tracks_batch_start:tracks_batch_end]),\n                    ]:\n                        log_tracks(\n                            tracks=tracks_3d[:, tracks_batch_start:tracks_batch_end],\n                            visibles=visibles,\n                            query_timestep=visibles.argmax(axis=0),\n                            colors=track_color[tracks_batch_start:tracks_batch_end],\n                            entity_format_str=f\"{entity_prefix}/{name}/{tracks_batch_start:02d}-{tracks_batch_end:02d}/{{}}\",\n                            max_strip_length_future=0,\n                        )\n\n                if save_rerun_viz:\n                    rr_rrd_path = os.path.join(output_dataset_path, f\"rerun_viz.rrd\")\n                    rr.save(rr_rrd_path)\n                    print(f\"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}\")\n\n            # import pydevd_pycharm\n            # pydevd_pycharm.settrace('localhost', port=51234, stdoutToServer=True, stderrToServer=True)\n\n            # Save the tracks\n            tracks_3d_file = os.path.join(output_dataset_path, \"tracks_3d.npz\")\n            np.savez(\n                tracks_3d_file,\n                tracks_3d=(tracks_3d - np.array([0., 0., 1.])) / DOWNSCALING_FACTOR,\n                tracks_3d_visibilities=tracks_3d_visibilities,\n                object_ids=np.array(ids)[picked_ids],\n                object_id_to_name={i: name for i, name in zip(ids, names)},\n                tracks_2d=tracks_2d,\n                tracks_2d_z=tracks_2d_z,\n                tracks_2d_visibilities=tracks_2d_visibilities,\n            )\n\n            # Save some .ply files of the trajectories for debugging\n            colors = plt.cm.viridis(tracks_3d[t0, :, 2] / tracks_3d[t0, :, 2].max())[:, :3]\n            for frame_idx in [0, t0, t0 + 1, n_frames // 3, (2 * n_frames) // 3, n_frames - 1]:\n                pcd = o3d.geometry.PointCloud()\n                pcd.points = o3d.utility.Vector3dVector(tracks_3d[frame_idx])\n                pcd.colors = o3d.utility.Vector3dVector(colors)\n                pcd_file = os.path.join(output_dataset_path, f\"tracks_3d_{frame_idx}.ply\")\n                o3d.io.write_point_cloud(pcd_file, pcd)\n\n                pcd = o3d.geometry.PointCloud()\n                pcd.points = o3d.utility.Vector3dVector(tracks_3d[frame_idx][tracks_3d_visibilities[frame_idx]])\n                pcd.colors = o3d.utility.Vector3dVector(colors[tracks_3d_visibilities[frame_idx]])\n                pcd_file = os.path.join(output_dataset_path, f\"tracks_3d_{frame_idx}_visible.ply\")\n                o3d.io.write_point_cloud(pcd_file, pcd)\n\n            # Also save the first frame trimesh as a mesh\n            meshes[0].export(os.path.join(output_dataset_path, \"first_frame_mesh.obj\"))\n\n        self._frame = -1\n\n    def _load_frame_rgbd(self, c, i):\n        \"\"\"Loads an RGB-D frame.\n\n        Args:\n          c: Camera index.\n          i: Frame index.\n\n        Returns:\n          color: A unit8 numpy array of shape [H, W, 3] containing the color image.\n          depth: A uint16 numpy array of shape [H, W] containing the depth image.\n        \"\"\"\n        color_file = self._data_dir[\n                         c] + '/' + self._color_prefix + \"{:06d}.jpg\".format(i)\n        color = cv2.imread(color_file)\n        color = color[:, :, ::-1]\n        depth_file = self._data_dir[\n                         c] + '/' + self._depth_prefix + \"{:06d}.png\".format(i)\n        depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH)\n        return color, depth\n\n    def _deproject_depth_and_filter_points(self, d, c):\n        \"\"\"Deprojects a depth image to point cloud and filters points.\n\n        Args:\n          d: A uint16 numpy array of shape [F, H, W] or [H, W] containing the depth\n            image in millimeters.\n          c: Camera index.\n\n        Returns:\n          p: A float32 numpy array of shape [F, H, W, 3] or [H, W, 3] containing the\n            point cloud.\n          m: A bool numpy array of shape [F, H, W] or [H, W] containing the mask for\n            points within the tag cooridnate limit.\n        \"\"\"\n        nd = d.ndim\n        d = d.astype(np.float32) / 1000\n        d = torch.from_numpy(d).to(self._device)\n        p = torch.mul(\n            d.view(1, -1, self._w * self._h).expand(3, -1, -1),\n            self._p[c].unsqueeze(1))\n        p = torch.addmm(self._t[c].unsqueeze(1), self._R[c], p.view(3, -1))\n        p_tag = torch.addmm(self._tag_t_inv.unsqueeze(1), self._tag_R_inv, p)\n        mx1 = p_tag[0, :] > self._tag_lim[0]\n        mx2 = p_tag[0, :] < self._tag_lim[1]\n        my1 = p_tag[1, :] > self._tag_lim[2]\n        my2 = p_tag[1, :] < self._tag_lim[3]\n        mz1 = p_tag[2, :] > self._tag_lim[4]\n        mz2 = p_tag[2, :] < self._tag_lim[5]\n        m = mx1 & mx2 & my1 & my2 & mz1 & mz2\n        p = p.t().view(-1, self._h, self._w, 3)\n        m = m.view(-1, self._h, self._w)\n        if nd == 2:\n            p = p.squeeze(0)\n            m = m.squeeze(0)\n        p = p.cpu().numpy()\n        m = m.cpu().numpy()\n        return p, m\n\n    def transform_ycb(self,\n                      pose,\n                      c=None,\n                      camera_to_world=True,\n                      run_ycb_group_layer=True,\n                      return_trans_mat=False):\n        \"\"\"Transforms poses in SE3 between world and camera frames.\n\n        Args:\n          pose: A float32 numpy array of shape [N, 7] or [N, 6] containing the\n            poses. Each row contains one pose represented by rotation in quaternion\n            (x, y, z, w) or rotation vector and translation.\n          c: Camera index.\n          camera_to_world: Whether from camera to world or from world to camera.\n          run_ycb_group_layer: Whether to return vertices and normals by running the\n            YCB group layer or to return poses.\n          return_trans_mat: Whether to return poses in transformation matrices.\n\n        Returns:\n          If run_ycb_group_layer is True:\n            v: A float32 numpy array of shape [F, V, 3] containing the vertices.\n            n: A float32 numpy array of shape [F, V, 3] containing the normals.\n          else:\n            A float32 numpy array of shape [N, 6] containing the transformed poses.\n        \"\"\"\n        if pose.shape[1] == 7:\n            q = pose[:, :4]\n            t = pose[:, 4:]\n            R = Rot.from_quat(q).as_matrix().astype(np.float32)\n            R = torch.from_numpy(R).to(self._device)\n            t = torch.from_numpy(t).to(self._device)\n        if pose.shape[1] == 6:\n            r = pose[:, :3]\n            t = pose[:, 3:]\n            r = torch.from_numpy(r).to(self._device)\n            t = torch.from_numpy(t).to(self._device)\n            R = rv2dcm(r)\n        if c is not None:\n            if camera_to_world:\n                R_c = self._R[c]\n                t_c = self._t[c]\n            else:\n                R_c = self._R_inv[c]\n                t_c = self._t_inv[c]\n            R = torch.bmm(R_c.expand(R.size(0), -1, -1), R)\n            t = torch.addmm(t_c, t, R_c.t())\n        if run_ycb_group_layer or not return_trans_mat:\n            r = dcm2rv(R)\n            p = torch.cat([r, t], dim=1)\n        else:\n            p = torch.cat([R, t.unsqueeze(2)], dim=2)\n            p = torch.cat([\n                p,\n                torch.tensor([[[0, 0, 0, 1]]] * R.size(0),\n                             dtype=torch.float32,\n                             device=self._device)\n            ],\n                dim=1)\n        if run_ycb_group_layer:\n            p = p.view(-1, self._ycb_group_layer.num_obj * 6)\n            v, n = self._ycb_group_layer(p)\n            v = v[:, self._ycb_group_layer.f.view(-1)]\n            n = n[:, self._ycb_group_layer.f.view(-1)]\n            v = v.cpu().numpy()\n            n = n.cpu().numpy()\n            return v, n\n        else:\n            p = p.cpu().numpy()\n            return p\n\n    @property\n    def serials(self):\n        return self._serials\n\n    @property\n    def num_cameras(self):\n        return self._num_cameras\n\n    @property\n    def num_frames(self):\n        return self._num_frames\n\n    @property\n    def dimensions(self):\n        return self._w, self._h\n\n    @property\n    def ycb_ids(self):\n        return self._ycb_ids\n\n    @property\n    def K(self):\n        return self._K\n\n    @property\n    def master_intrinsics(self):\n        return self._master_intrinsics\n\n    def step(self):\n        \"\"\"Steps the frame.\"\"\"\n        self._frame = (self._frame + 1) % self._num_frames\n        if not self._preload:\n            self._update_pcd()\n\n    def _update_pcd(self):\n        \"\"\"Updates the point cloud.\"\"\"\n        for c in range(self._num_cameras):\n            rgb, d = self._load_frame_rgbd(c, self._frame)\n            p, m = self._deproject_depth_and_filter_points(d, c)\n            self._pcd_rgb[c][:] = rgb\n            self._pcd_vert[c][:] = p\n            self._pcd_mask[c][:] = m\n\n    @property\n    def pcd_rgb(self):\n        if self._preload:\n            return [x[self._frame] for x in self._pcd_rgb]\n        else:\n            return self._pcd_rgb\n\n    @property\n    def pcd_vert(self):\n        if self._preload:\n            return [x[self._frame] for x in self._pcd_vert]\n        else:\n            return self._pcd_vert\n\n    @property\n    def pcd_tex_coord(self):\n        return self._pcd_tex_coord\n\n    @property\n    def pcd_mask(self):\n        if self._preload:\n            return [x[self._frame] for x in self._pcd_mask]\n        else:\n            return self._pcd_mask\n\n    @property\n    def ycb_group_layer(self):\n        return self._ycb_group_layer\n\n    @property\n    def num_ycb(self):\n        return self._ycb_group_layer.num_obj\n\n    @property\n    def ycb_model_dir(self):\n        return self._ycb_model_dir\n\n    @property\n    def ycb_count(self):\n        return self._ycb_count\n\n    @property\n    def ycb_material(self):\n        return self._ycb_material\n\n    @property\n    def ycb_pose(self):\n        if self._app == 'viewer':\n            return None\n        if self._app == 'renderer':\n            return [x[self._frame] for x in self._ycb_pose]\n\n    @property\n    def ycb_vert(self):\n        if self._app == 'viewer':\n            return [x[self._frame] for x in self._ycb_vert]\n        if self._app == 'renderer':\n            return None\n\n    @property\n    def ycb_norm(self):\n        if self._app == 'viewer':\n            return [x[self._frame] for x in self._ycb_norm]\n        if self._app == 'renderer':\n            return None\n\n    @property\n    def ycb_tex_coords(self):\n        return self._ycb_tex_coords\n\n    @property\n    def mano_group_layer(self):\n        return self._mano_group_layer\n\n    @property\n    def num_mano(self):\n        return self._mano_group_layer.num_obj\n\n    @property\n    def mano_vert(self):\n        if self._app == 'viewer':\n            return [x[self._frame] for x in self._mano_vert]\n        if self._app == 'renderer':\n            return [[y[self._frame] for y in x] for x in self._mano_vert]\n\n    @property\n    def mano_norm(self):\n        if self._app == 'viewer':\n            return [x[self._frame] for x in self._mano_norm]\n        if self._app == 'renderer':\n            return None\n\n    @property\n    def mano_line(self):\n        if self._app == 'viewer':\n            return [x[self._frame] for x in self._mano_line]\n        if self._app == 'renderer':\n            return None\n\n    @property\n    def mano_joint_3d(self):\n        if self._app == 'viewer':\n            return None\n        if self._app == 'renderer':\n            return [x[self._frame] for x in self._mano_joint_3d]\n\n\n# Some hacking with global variables to make the visualization work\nfirst_frame_seen = False\nready_to_close = False\n\n\ndef visualize_3dpt_tracks(tracks_path, output_video_path):\n    global first_frame_seen, ready_to_close\n    print(f\"Visualizing 3D point tracks from {tracks_path} to {output_video_path}...\")\n\n    tracks = np.load(tracks_path)[\"tracks_3d\"] + np.array([0, 0, 1])\n    n_frames, n_points, _ = tracks.shape\n\n    frames_path = f\"{output_video_path}__frames\"\n    os.makedirs(frames_path, exist_ok=True)\n    first_frame_seen = False\n    ready_to_close = False\n\n    # images = [imageio.imread(f\"{frames_path}/frame_{i:04d}.png\") for i in range(n_frames)]\n    # video_writer = imageio.get_writer(output_video_path, fps=10)\n    # for img in images:\n    #     video_writer.append_data(img)\n    # video_writer.close()\n    # ready_to_close = True\n    # return\n\n    z = tracks[2 * n_frames // 3, :, 2]\n\n    point_colors = np.zeros((n_points, 3))\n    point_colors[:, 0] = np.sin(z)\n    point_colors[:, 1] = np.sin(z + 2 * np.pi / 3)\n    point_colors[:, 2] = np.sin(z + 4 * np.pi / 3)\n\n    point_colors = cm.jet(z / np.percentile(z, 99.9))[:, :3]\n\n    print(\"Preparing clouds...\")\n    pointclouds = []\n    for frame_idx in tqdm(range(n_frames)):\n        pc = o3d.geometry.PointCloud()\n        pc.points = o3d.utility.Vector3dVector(tracks[frame_idx])\n        pc.colors = o3d.utility.Vector3dVector(point_colors)\n        pointclouds += [{\n            \"name\": f\"cloud t={frame_idx}\",\n            \"geometry\": pc,\n            \"time\": frame_idx / 4,\n        }]\n\n    def start_animation(w: o3d.cpu.pybind.visualization.O3DVisualizer) -> None:\n        w.is_animating = True\n\n    frames_path = f\"{output_video_path}__frames\"\n    os.makedirs(frames_path, exist_ok=True)\n    first_frame_seen = False\n    ready_to_close = False\n\n    def create_video(w: o3d.cpu.pybind.visualization.O3DVisualizer, t: float) -> None:\n        global first_frame_seen, ready_to_close\n        if ready_to_close:\n            print(\"Please close the window to finish the video export.\")\n            return\n        if t == 0 and not first_frame_seen:\n            first_frame_seen = True\n        elif t == 0 and first_frame_seen:\n            images = [imageio.imread(f\"{frames_path}/frame_{i:04d}.png\") for i in range(n_frames)]\n            video_writer = imageio.get_writer(output_video_path, fps=10)\n            for img in images:\n                video_writer.append_data(img)\n            video_writer.close()\n            ready_to_close = True\n            return\n        w.export_current_image(f\"{frames_path}/frame_{int(t * 4):04d}.png\")\n\n    vis.draw(\n        title=tracks_path,\n        width=1920,\n        height=1080,\n        point_size=4,\n        geometry=pointclouds,\n        animation_time_step=1 / 4,\n        # ibl=\"crossroads\",\n        eye=np.array([0, 0, 0]),\n        lookat=np.array([0, 0, 1]),\n        up=np.array([0, -1, 0]),\n        field_of_view=60.0,\n        on_init=start_animation,\n        on_animation_frame=create_video,\n        on_animation_tick=None,\n    )\n\n\nDOWNSCALING_FACTOR = 1.0\nSEQUENCES = [\n    # Each sequence has a different target object and a different (human) subject performing an action.\n    \"20200709-subject-01/20200709_141754\",\n    \"20200813-subject-02/20200813_145653\",\n    \"20200820-subject-03/20200820_135841\",\n    \"20200903-subject-04/20200903_104428\",\n    \"20200908-subject-05/20200908_144409\",\n    \"20200918-subject-06/20200918_114117\",\n    \"20200928-subject-07/20200928_144906\",\n    \"20201002-subject-08/20201002_110227\",\n    \"20201015-subject-09/20201015_144721\",\n    \"20201022-subject-10/20201022_112651\",\n]\n\n\ndef main():\n    assert os.environ['DEX_YCB_DIR']\n    for n_subsample in [3]:\n        for sequence in tqdm(SEQUENCES):\n            print(f\"Processing sequence: {sequence}\")\n            SequenceLoader(\n                sequence,\n                device=\"cpu\",\n                preload=True,\n                app=\"convert_to_neus\",\n                output_dataset_path=os.path.join(os.environ['DEX_YCB_DIR'],\n                                                 f\"neus_nsubsample-{n_subsample}/{sequence.replace('/', '__')}\"),\n                downscaling_factor=DOWNSCALING_FACTOR,\n                n_subsample=n_subsample,\n                seed=72,\n                stream_rerun_viz=False,\n                save_rerun_viz=True,\n            )\n    print(\"Done converting the dataset.\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "scripts/egoexo4d_preprocessing.py",
    "content": "\"\"\"\nEnvironment setup:\n```bash\ncd ..\n\n# Clone the projectaria_tools repository\ngit clone -b 1.5.0 https://github.com/facebookresearch/projectaria_tools\ncd projectaria_tools/\n\n# Install required libraries using Conda\nconda install -c conda-forge cmake fmt xxhash libjpeg-turbo gcc_linux-64 gxx_linux-64\nconda install -c conda-forge boost-cpp=1.82.0 boost=1.82.0\n\n# Set compiler environment variables\nexport BOOST_ROOT=$CONDA_PREFIX\nexport BOOST_INCLUDEDIR=$CONDA_PREFIX/include\nexport BOOST_LIBRARYDIR=$CONDA_PREFIX/lib\n\n# Clean previous builds and install projectaria_tools\nrm -rf build/ dist/ *.egg-info\ncmake -S . -B build \\\n  -DBOOST_ROOT=$BOOST_ROOT \\\n  -DBoost_NO_SYSTEM_PATHS=ON \\\n  -DBoost_INCLUDE_DIR=$BOOST_INCLUDEDIR \\\n  -DBoost_LIBRARY_DIR=$BOOST_LIBRARYDIR \\\n  -DBUILD_PYTHON_BINDINGS=ON\ncmake --build build -j\npip install .\n\n# Additional packages (if required)\npip install av\n\ncd ../mvtracker\n```\n\nDownload a subset of the data:\n```bash\n# Install CLI for downloading the data\npip install ego4d --upgrade\n\n# Get an access id and key after filling a form at https://ego4ddataset.com/egoexo-license/\n...\n\n# Install AWS CLI from https://aws.amazon.com/cli/ (assuming no sudo)\ncd ..\ncurl \"https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip\" -o \"awscliv2.zip\"\nunzip awscliv2.zip\n./aws/install -i ~/local/aws-cli -b ~/local/bin\n# Add to ~/.bashrc: export PATH=$HOME/.local/bin:$PATH\nsource ~/.bashrc\naws --version\n# aws-cli/2.27.49 Python/3.13.4 Linux/6.8.0-57-generic exe/x86_64.ubuntu.24\naws configure\n# Now you can enter the access id and key...\n\n# Download a small subset of the data (around 100 GB)\negoexo -o ./datasets/egoexo4d --parts metadata\negoexo -o ./datasets/egoexo4d --parts take_trajectory take_vrs_noimagestream captures annotations metadata takes downscaled_takes/448 --uids ed3ec638-8363-4e1d-9851-c7936cbfad8c 51fc36b3-e769-4617-b087-3826b280cad3 f179e1a2-3265-464a-a106-a08c30d0a2ae 43dca3b5-21d9-4ebf-856e-515a5c417699 c3915dd7-3ac0-40b7-a69b-73b7326bd15c e08856e4-a1c7-4e36-96b6-a233efb27bfd 425d8f94-ed65-49d5-86e7-174f555fda5d ed698f62-ccdb-4601-8a0a-ee89a0a7e1c0 4e5aa06a-7a60-4e23-9853-d55260a9e6e9 001ae9a5-9c8a-4710-9f7f-7dc67597a02f 2aaaca24-108a-437e-ab9a-bc3e8d65fcdf 503cc92d-7052-44ff-a21d-da6c4a5d6927 f32dc6d9-0eb8-4c85-8ab9-7d47b8c4c660 2423c2ff-c85d-4998-afab-29de8d26d263 0e5d13c6-87ba-4c9b-ab2f-1aaac4e0aacb 1a9a21ab-9023-402f-ac64-df08feaabb5b 811ad284-702f-4d38-af99-a2c4006fa298 a261cc1d-7a45-479f-81a9-7c73eb379e6c c2fb62e3-8894-4101-9923-5eedeb1b4282\negoexo -o ./datasets/egoexo4d --parts captures\negoexo -o ./datasets/egoexo4d --parts annotations --benchmarks egopose\n\n```\n\nRunning the script: `PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python -m scripts.egoexo4d_preprocessing`\nNote that you need to set up dust3r first, see docstring of `scripts/estimate_depth_with_duster.py`.\n\"\"\"\nimport json\nimport os\nimport pickle\nimport time\nfrom typing import Optional\n\nimport av\nimport cv2\nimport math\nimport numpy as np\nimport pandas as pd\nimport rerun as rr\nimport torch\nfrom projectaria_tools.core import calibration\nfrom projectaria_tools.core import data_provider\nfrom projectaria_tools.core import mps\nfrom projectaria_tools.core.calibration import CameraCalibration, KANNALA_BRANDT_K3\nfrom projectaria_tools.core.stream_id import StreamId\nfrom tqdm import tqdm\n\nfrom scripts.estimate_depth_with_duster import run_duster\n\n\ndef main_preprocess_egoexo4d(\n        release_dir: str,\n        take_name: str,\n        outputs_dir: str,\n        max_frames: Optional[int] = None,\n        frames_downsampling_factor: Optional[int] = None,\n        downscaled_longerside: Optional[int] = None,\n        save_rerun_viz: bool = True,\n        stream_rerun_viz: bool = False,\n        skip_if_output_exists: bool = True,\n):\n    # Skip if output exists\n    save_pkl_path = os.path.join(outputs_dir, f\"{take_name}.pkl\")\n    if skip_if_output_exists and os.path.exists(save_pkl_path):\n        print(f\"Skipping {save_pkl_path} since it already exists\")\n        print()\n        return\n    else:\n        print(f\"Processing {take_name}...\")\n\n    # Load necessary metadata files\n    egoexo = {\n        \"takes\": os.path.join(release_dir, \"takes.json\"),\n        \"captures\": os.path.join(release_dir, \"captures.json\")\n    }\n    for k, v in egoexo.items():\n        egoexo[k] = json.load(open(v))\n    takes = egoexo[\"takes\"]\n    captures = egoexo[\"captures\"]\n    takes_by_name = {x[\"take_name\"]: x for x in takes}\n\n    # Take the take\n    take = takes_by_name[take_name]\n\n    # Initialize exo cameras from calibration file\n    traj_dir = os.path.join(release_dir, take[\"root_dir\"], \"trajectory\")\n    exo_traj_path = os.path.join(traj_dir, \"gopro_calibs.csv\")\n\n    exo_traj_df = pd.read_csv(exo_traj_path)\n    exo_cam_names = list(exo_traj_df[\"cam_uid\"])\n    ego_cam_names = [x[\"cam_id\"] for x in take[\"capture\"][\"cameras\"] if x[\"is_ego\"] and x[\"cam_id\"].startswith(\"aria\")]\n    all_cams = ego_cam_names + exo_cam_names\n    ego_cam_name = ego_cam_names[0]\n    print(\"exo cameras: \", exo_cam_names)\n    print(\" ego camera: \", ego_cam_name)\n\n    go_pro_proxy = {}\n    static_calibrations = mps.read_static_camera_calibrations(exo_traj_path)\n    for static_calibration in static_calibrations:\n        # assert the GoPro was correctly localized\n        if static_calibration.quality != 1.0:\n            print(f\"Camera: {static_calibration.camera_uid} was not localized, ignoring this camera.\")\n            continue\n        proxy = {}\n        proxy[\"name\"] = static_calibration.camera_uid\n        proxy[\"pose\"] = static_calibration.transform_world_cam\n        proxy[\"camera\"] = CameraCalibration(\n            static_calibration.camera_uid,\n            KANNALA_BRANDT_K3,\n            static_calibration.intrinsics,\n            static_calibration.transform_world_cam,  # probably extrinsics\n            static_calibration.width,\n            static_calibration.height,\n            None,\n            math.pi,\n            \"\")\n\n        go_pro_proxy[static_calibration.camera_uid] = proxy\n\n    # Configure the VRSDataProvider (interface used to retrieve Trajectory data)\n    ego_exo_project_path = os.path.join(release_dir, 'takes', take['take_name'])\n\n    aria_dir = os.path.join(release_dir, take[\"root_dir\"])\n    aria_path = os.path.join(aria_dir, f\"{ego_cam_name}.vrs\")\n    vrs_data_provider = data_provider.create_vrs_data_provider(aria_path)\n    device_calibration = vrs_data_provider.get_device_calibration()\n\n    ego_stream_name = \"214-1\"\n    rgb_stream_id = StreamId(ego_stream_name)\n    rgb_stream_label = vrs_data_provider.get_label_from_stream_id(rgb_stream_id)\n    rgb_camera_calibration = device_calibration.get_camera_calib(rgb_stream_label)\n\n    mps_data_paths_provider = mps.MpsDataPathsProvider(ego_exo_project_path)\n    mps_data_paths = mps_data_paths_provider.get_data_paths()\n    mps_data_provider = mps.MpsDataProvider(mps_data_paths)\n\n    # Extract ego extrinsics\n    capture_name = take[\"capture\"][\"capture_name\"]\n    timesync = pd.read_csv(os.path.join(release_dir, f\"captures/{capture_name}/timesync.csv\"))\n\n    start_idx = take[\"timesync_start_idx\"] + 1\n    end_idx = take[\"timesync_end_idx\"]\n    take_timestamps = []\n    for idx in range(start_idx, end_idx):\n        ts = timesync.iloc[idx][f\"{ego_cam_name}_{ego_stream_name}_capture_timestamp_ns\"]\n        take_timestamps.append(ts)\n    if frames_downsampling_factor is not None:\n        take_timestamps = take_timestamps[::frames_downsampling_factor]\n    if max_frames is not None:\n        take_timestamps = take_timestamps[:max_frames]\n    valid_frames = np.array([not np.isnan(ts) for ts in take_timestamps])\n    if not valid_frames.all():\n        print(f\"Number of invalid frames (with nan ego timesync): {(~valid_frames).sum()}\")\n    take_timestamps = np.array(take_timestamps)[valid_frames].astype(int)\n    ego_closed_loop_poses = [mps_data_provider.get_closed_loop_pose(t) for t in take_timestamps]\n\n    ego_extrs = []\n    T_device_camera = rgb_camera_calibration.get_transform_device_camera()\n    for pose in ego_closed_loop_poses:\n        assert pose is not None\n        T_world_device = pose.transform_world_device\n        T_world_camera = T_world_device @ T_device_camera\n        extrinsic_matrix = T_world_camera.inverse().to_matrix()[:3, :]\n\n        # Rotate camera 90° clockwise around Z\n        R_z_90 = np.array([\n            [0, -1, 0],\n            [1, 0, 0],\n            [0, 0, 1]\n        ])\n        extrinsic_matrix[:3, :] = R_z_90 @ extrinsic_matrix[:3, :]\n\n        ego_extrs.append(extrinsic_matrix)\n\n    # Extract videos\n    base_directory = os.path.join(release_dir, take[\"root_dir\"])\n    videos = {}\n    for cam_name in all_cams:\n        if cam_name in exo_cam_names:\n            stream_name = '0'\n        else:\n            stream_name = 'rgb'\n\n        local_path = os.path.join(base_directory, take['frame_aligned_videos'][cam_name][stream_name]['relative_path'])\n        container = av.open(local_path)\n\n        frames = []\n        for frame_idx, frame in enumerate(tqdm(container.decode(video=0))):\n            if frame_idx % frames_downsampling_factor != 0:\n                continue\n            if max_frames is not None and len(frames) >= max_frames:\n                break\n            frames.append(np.array(frame.to_image()))\n        frames = np.stack(frames)[valid_frames]\n        videos[cam_name] = frames\n\n    # Undistorted videos\n    rgbs = {}\n    intrs = {}\n    extrs = {}\n    for cam_name in all_cams:\n        frames = videos[cam_name]\n        h, w = frames[0].shape[:2]\n\n        if cam_name in exo_cam_names:\n            calib = exo_traj_df[exo_traj_df.cam_uid == cam_name].iloc[0].to_dict()\n            D = np.array([calib[f\"intrinsics_{i}\"] for i in range(4, 8)])\n            K = np.array([\n                [calib[\"intrinsics_0\"], 0, calib[\"intrinsics_2\"]],\n                [0, calib[\"intrinsics_1\"], calib[\"intrinsics_3\"]],\n                [0, 0, 1]\n            ])\n            width, height = calib[\"image_width\"], calib[\"image_height\"]\n            scaled_K = K * w / width\n            scaled_K[2][2] = 1.0\n\n            new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(scaled_K, D, (w, h), np.eye(3), balance=0.0)\n            map1, map2 = cv2.fisheye.initUndistortRectifyMap(scaled_K, D, np.eye(3), new_K, (w, h), cv2.CV_16SC2)\n            undistorted = []\n            for img in tqdm(frames, desc=f\"Undistorting {cam_name}\"):\n                ud = cv2.remap(img, map1, map2, interpolation=cv2.INTER_LINEAR)\n                undistorted.append(ud)\n\n            intrs[cam_name] = new_K\n            extrs[cam_name] = go_pro_proxy[cam_name][\"pose\"].inverse().to_matrix()[:3, :]\n            rgbs[cam_name] = np.stack([f.transpose(2, 0, 1) for f in undistorted])\n\n        else:\n            src_calib = rgb_camera_calibration\n            dst_calib = calibration.get_linear_camera_calibration(w, h, 450)\n\n            fx, fy = dst_calib.get_focal_lengths()\n            cx, cy = dst_calib.get_principal_point()\n            K = np.array([[fx, 0, cx],\n                          [0, fy, cy],\n                          [0, 0, 1]])\n\n            undistorted = []\n            for img in tqdm(frames, desc=f\"Undistorting {cam_name}\"):\n                img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)\n                ud = calibration.distort_by_calibration(img, dst_calib, src_calib)\n                ud = cv2.rotate(ud, cv2.ROTATE_90_CLOCKWISE)\n                undistorted.append(ud)\n            undistorted = [ud.transpose(2, 0, 1) for ud in undistorted]\n\n            intrs[cam_name] = K\n            extrs[cam_name] = np.stack(ego_extrs)\n            rgbs[cam_name] = np.stack(undistorted)\n\n    # Check shapes\n    n_frames, _, h_exo, w_exo = rgbs[exo_cam_names[0]].shape\n    _, _, h_ego, w_ego = rgbs[ego_cam_name].shape\n    for cam_name in all_cams:\n        if cam_name in exo_cam_names:\n            assert rgbs[cam_name].shape == (n_frames, 3, h_exo, w_exo)\n            assert intrs[cam_name].shape == (3, 3)\n            assert extrs[cam_name].shape == (3, 4)\n        else:\n            assert rgbs[cam_name].shape == (n_frames, 3, h_ego, w_ego)\n            assert intrs[cam_name].shape == (3, 3)\n            assert extrs[cam_name].shape == (n_frames, 3, 4)\n\n    # Save downsized version\n    if downscaled_longerside is not None:\n        print(f\"Downscaling to longer side {downscaled_longerside}\")\n        for cam_name in rgbs:\n            _, _, h, w = rgbs[cam_name].shape\n            scale = downscaled_longerside / max(h, w)\n            new_h, new_w = int(h * scale), int(w * scale)\n\n            resized = []\n            for img in rgbs[cam_name]:\n                img = img.transpose(1, 2, 0)  # CHW -> HWC\n                img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)\n                resized.append(img.transpose(2, 0, 1))  # HWC -> CHW\n            rgbs[cam_name] = np.stack(resized)\n\n            # scale intrinsics\n            intrs[cam_name][:2] *= scale\n\n    # Save processed output to a pickle file\n    os.makedirs(outputs_dir, exist_ok=True)\n    with open(save_pkl_path, \"wb\") as f:\n        pickle.dump(\n            dict(\n                rgbs=rgbs,\n                intrs=intrs,\n                extrs=extrs,\n                ego_cam_name=ego_cam_name,\n            ),\n            f,\n            protocol=pickle.HIGHEST_PROTOCOL,\n        )\n    print(f\"Saved {save_pkl_path}\")\n\n    # Visualize the data sample using rerun\n    rerun_modes = []\n    if stream_rerun_viz:\n        rerun_modes += [\"stream\"]\n    if save_rerun_viz:\n        rerun_modes += [\"save\"]\n    for rerun_mode in rerun_modes:\n        rr.init(f\"3dpt\", recording_id=\"v0.16\")\n        if rerun_mode == \"stream\":\n            rr.connect_tcp()\n\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n\n        fps = 30\n        for frame_idx in range(n_frames):\n            rr.set_time_seconds(\"frame\", frame_idx / fps)\n\n            for cam_name in all_cams:\n                extr = extrs[cam_name] if cam_name in exo_cam_names else extrs[cam_name][frame_idx]\n                intr = intrs[cam_name]\n                img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8)\n\n                # Camera pose logging\n                E = extr if extr.shape == (3, 4) else extr[0]\n                T = np.eye(4)\n                T[:3, :] = E\n                T_world_cam = np.linalg.inv(T)\n                rr.log(f\"{cam_name}/image\", rr.Transform3D(\n                    translation=T_world_cam[:3, 3],\n                    mat3x3=T_world_cam[:3, :3],\n                ))\n\n                # Intrinsics and image\n                rr.log(f\"{cam_name}/image\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=img.shape[1],\n                    height=img.shape[0]\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Image(img))\n\n        if rerun_mode == \"save\":\n            save_rrd_path = os.path.join(outputs_dir, f\"rerun__{take_name}.rrd\")\n            rr.save(save_rrd_path)\n            print(f\"Saved rerun viz to {os.path.abspath(save_rrd_path)}\")\n\n\ndef main_estimate_duster_depth(\n        pkl_scene_file,\n        depths_output_dir,\n        save_rerun_viz=False,\n        skip_if_output_already_exists=True,\n):\n    duster_kwargs = {\n        \"model_name_or_path\": \"../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\",\n        \"silent\": True,\n        \"output_2d_matches\": False,\n        \"dump_exhaustive_data\": False,\n        \"save_ply\": False,\n        \"save_png_viz\": False,\n        \"show_debug_plots\": False,\n        \"save_rerun_viz\": save_rerun_viz,\n        \"skip_if_output_already_exists\": skip_if_output_already_exists,\n    }\n\n    print(f\"Generating DUSt3R depths to {os.path.abspath(depths_output_dir)}\")\n    assert os.path.exists(pkl_scene_file)\n    with open(pkl_scene_file, \"rb\") as f:\n        scene = pickle.load(f)\n\n    rgbs = scene[\"rgbs\"]\n    intrs = scene[\"intrs\"]\n    extrs = scene[\"extrs\"]\n    ego_cam_name = scene[\"ego_cam_name\"]\n    exo_cam_names = sorted([cam_name for cam_name in rgbs.keys() if cam_name != ego_cam_name])\n\n    n_frames, _, h, w = rgbs[exo_cam_names[0]].shape\n\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for cam_name in exo_cam_names:\n        intrinsics = intrs[cam_name]\n        extrinsics_view = np.eye(4)\n        extrinsics_view[:3, :4] = extrs[cam_name]\n\n        assert np.isclose(intrinsics[0, 1], 0)\n        assert np.isclose(intrinsics[1, 0], 0)\n        assert np.isclose(intrinsics[2, 0], 0)\n        assert np.isclose(intrinsics[2, 1], 0)\n        assert np.isclose(intrinsics[2, 2], 1)\n\n        fx.append(intrinsics[0, 0])\n        fy.append(intrinsics[1, 1])\n        cx.append(intrinsics[0, 2])\n        cy.append(intrinsics[1, 2])\n        extrinsics.append(extrinsics_view)\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n\n    start = time.time()\n    images_tensor = torch.from_numpy(np.stack([rgbs[cam_name] for cam_name in exo_cam_names]))\n    run_duster(images_tensor, depths_output_dir, fx, fy, cx, cy, extrinsics, **duster_kwargs)\n    time_elapsed = time.time() - start\n    print(f\"Time elapsed for DUST3R: {time_elapsed:.2f} seconds\")\n\n\nif __name__ == '__main__':\n    release_dir = \"datasets/egoexo4d/\"\n    outputs_dir = \"datasets/egoexo4d-processed/\"\n\n    num_devices = 1\n    device_id = int(os.environ.get(\"CUDA_VISIBLE_DEVICES\", \"0\"))\n    device_id = device_id % num_devices\n    print(f\"Device ID: {device_id} (out of {num_devices}). The devices split the work.\")\n\n    for i, take_name in enumerate([\n        \"fair_cooking_06_4\",  # take_uid = \"a261cc1d-7a45-479f-81a9-7c73eb379e6c\"\n        \"cmu_bike01_2\",  # take_uid = \"ed3ec638-8363-4e1d-9851-c7936cbfad8c\"\n        \"georgiatech_cooking_01_01_2\",  # take_uid = \"51fc36b3-e769-4617-b087-3826b280cad3\"\n        \"iiith_cooking_49_2\",  # take_uid = \"f179e1a2-3265-464a-a106-a08c30d0a2ae\"\n        \"indiana_bike_12_5\",  # take_uid = \"43dca3b5-21d9-4ebf-856e-515a5c417699\"\n        \"minnesota_rockclimbing_033_20\",  # take_uid = \"c3915dd7-3ac0-40b7-a69b-73b7326bd15c\"\n        \"sfu_basketball_09_21\",  # take_uid = \"425d8f94-ed65-49d5-86e7-174f555fda5d\"\n        \"unc_basketball_03-09-23_02_11\",  # take_uid = \"ed698f62-ccdb-4601-8a0a-ee89a0a7e1c0\"\n        \"unc_music_04-26-23_02_7\",  # take_uid = \"4e5aa06a-7a60-4e23-9853-d55260a9e6e9\"\n        \"uniandes_dance_017_57\",  # take_uid = \"0e5d13c6-87ba-4c9b-ab2f-1aaac4e0aacb\"\n        \"upenn_0331_Guitar_2_4\",  # take_uid = \"1a9a21ab-9023-402f-ac64-df08feaabb5b\"\n        \"unc_basketball_02-24-23_01_12\",  # take_uid = \"c2fb62e3-8894-4101-9923-5eedeb1b4282\"\n    ]):\n        if i % num_devices != device_id:\n            continue\n\n        for max_frames, frames_downsampling_factor, downscaled_longerside in [(300, 1, 512), (300, 1, 518)]:\n            # Extract rgbs, intrs, extrs from EgoExo4D dataset\n            outputs_subdir = os.path.join(outputs_dir, f\"maxframes-{max_frames}_\"\n                                                       f\"downsample-{frames_downsampling_factor}_\"\n                                                       f\"downscale-{downscaled_longerside}\")\n            main_preprocess_egoexo4d(release_dir, take_name, outputs_subdir,\n                                     max_frames, frames_downsampling_factor, downscaled_longerside)\n\n            # Run Dust3r to estimate depths from rgbs, fix the known intrs and extrs during multi-view stereo optim\n            take_pkl = os.path.join(outputs_subdir, f\"{take_name}.pkl\")\n            depth_subdir = os.path.join(outputs_subdir, f\"duster_depths__{take_name}\")\n            main_estimate_duster_depth(\n                pkl_scene_file=take_pkl,\n                depths_output_dir=depth_subdir,\n            )\n\n            # Run VGGT to estimate depths from rgbs, align with the known extrs afterward\n            ..."
  },
  {
    "path": "scripts/estimate_depth_with_duster.py",
    "content": "\"\"\"\nSet up the environment:\n```sh\ncd /local/home/frrajic/xode\n\ngit clone --recursive git@github.com:ethz-vlg/duster.git\ncd duster\n\n# Fix models path, since there are two in the project\nsed -i 's/from models/from croco.models/g' croco/*.py\nsed -i 's/from models/from croco.models/g' croco/*/*.py\nsed -i 's/from models/from croco.models/g' dust3r/*.py\nsed -i 's/from models/from croco.models/g' dust3r/*/*.py\n\n# Download the checkpoint\nwget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints\nmd5sum checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\n# c3fab9b455b03f23d20e6bf77f2607bb  checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\n\n# You should be able to use the same environment as for\n# the rest of the project, just install missing packages:\npip install roma==1.5.1\n```\nRunning the script:\n```sh\ncd /local/home/frrajic/xode/mvtracker\nexport PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH\n\npython scripts/estimate_depth_with_duster.py --dataset dexycb\npython scripts/estimate_depth_with_duster.py --dataset kubric-val\npython scripts/estimate_depth_with_duster.py --dataset kubric-train\n```\n\nRunning the script on Panoptic Sports from Dynamic 3DGS:\n```sh\n# Download the data\ncd datasets\nwget https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/data.zip\nunzip data.zip\nmv data panoptic_d3dgs\ncd -\n\n# Run the script\ncd /local/home/frrajic/xode/duster/mvtracker\nexport PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH\n\npython scripts/estimate_depth_with_duster.py --dataset panoptic_d3dgs\n```\n\"\"\"\nimport argparse\nimport json\nimport os\nimport random\nimport time\nimport warnings\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport cv2\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport rerun as rr\nimport torch\nimport torch.nn.functional as F\nimport trimesh\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\n\nfrom dust3r.cloud_opt import PointCloudOptimizer\nfrom dust3r.image_pairs import make_pairs\nfrom dust3r.inference import inference\nfrom dust3r.model import AsymmetricCroCo3DStereo\nfrom dust3r.utils.device import to_numpy\nfrom dust3r.utils.geometry import find_reciprocal_matches, xy_grid\nfrom dust3r.utils.image import load_images\nfrom dust3r.utils.image import rgb, heif_support_enabled, _resize_pil_image, ImgNorm\nfrom mvtracker.datasets import KubricMultiViewDataset\n\ntorch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12\n\n\ndef seed_all(seed):\n    \"\"\"\n    Seed all random number generators.\n\n    Parameters\n    ----------\n    seed : int\n        The seed to use.\n\n    Returns\n    -------\n    None\n    \"\"\"\n    random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n\n\ndef get_view_visibility(scene, pts):\n    vis = np.zeros((len(scene.imgs), len(pts)), dtype=bool)\n    poses = scene.get_im_poses().detach().cpu().numpy()\n    extrinsics = np.linalg.inv(poses)\n    focals = scene.get_focals().squeeze(-1).detach().cpu().numpy()\n    pps = scene.get_principal_points().detach().cpu().numpy()\n    depths = [d.detach().cpu().numpy() for d in scene.get_depthmaps(raw=False)]\n\n    # Apply masks to the depthmaps as to not consider points that have low confidence\n    per_view_masks = [m.detach().cpu().numpy() for m in scene.get_masks()]\n    for view_idx, mask in enumerate(per_view_masks):\n        depths[view_idx] = depths[view_idx] * mask\n\n    for view_idx in range(len(scene.imgs)):\n        p_world = pts\n        p_world = np.concatenate([p_world, np.ones((len(p_world), 1))], axis=1)\n        p_cam = extrinsics[view_idx] @ p_world.T\n        z = p_cam[2]\n        x = p_cam[0, :] / z[:] * focals[view_idx, 0] + pps[view_idx, 0]\n        y = p_cam[1, :] / z[:] * focals[view_idx, 1] + pps[view_idx, 1]\n        x_floor = np.floor(x).astype(int)\n        y_floor = np.floor(y).astype(int)\n        x_ceil = np.ceil(x).astype(int)\n        y_ceil = np.ceil(y).astype(int)\n        h, w = depths[view_idx].shape[:2]\n        out_of_view = (\n                (x_floor < 0)\n                | (x_ceil >= w)\n                | (y_floor < 0)\n                | (y_ceil >= h)\n                | (z < 0)\n        )\n        z_from_depthmap_1 = depths[view_idx][y_floor[~out_of_view], x_floor[~out_of_view]]\n        z_from_depthmap_2 = depths[view_idx][y_floor[~out_of_view], x_ceil[~out_of_view]]\n        z_from_depthmap_3 = depths[view_idx][y_ceil[~out_of_view], x_floor[~out_of_view]]\n        z_from_depthmap_4 = depths[view_idx][y_ceil[~out_of_view], x_ceil[~out_of_view]]\n        z_from_depthmap = np.stack([z_from_depthmap_1, z_from_depthmap_2, z_from_depthmap_3, z_from_depthmap_4], axis=0)\n        vis[view_idx] = ~out_of_view\n        vis[view_idx][~out_of_view] = np.isclose(z[~out_of_view], z_from_depthmap.min(axis=0), rtol=0.001, atol=0.1)\n\n        # import pandas as pd\n        # x = pd.Series(np.abs(z[~out_of_view] - z_from_depthmap.min(axis=0)))\n        # quantiles_to_print = [0.001, 0.01, 0.05, 0.1, 0.5, 0.9, 0.95, 0.99, 0.999]\n        # print(f\"Quantiles of the difference between the depthmap and the z coordinate of the point in the camera frame\")\n        # for q in quantiles_to_print:\n        #     print(f\"{q=}: {x.quantile(q)}\")\n\n    return vis\n\n\ndef get_3D_model_from_scene(\n        output_file_prefix,\n        silent,\n        scene,\n        min_conf_thr=3,\n        mask_sky=False,\n        clean_depth=False,\n        feats=None,\n\n        dump_exhaustive_data=False,\n        save_ply=False,\n        save_png_viz=False,\n        save_rerun_viz=False,\n        rerun_radii=0.01,\n        rerun_viz_timestamp=0,\n):\n    scene = deepcopy(scene)\n    if clean_depth:\n        scene = scene.clean_pointcloud()\n    if mask_sky:\n        scene = scene.mask_sky()\n\n    rgbimg = scene.imgs\n    pts3d = to_numpy(scene.get_pts3d())\n    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))\n    msk = to_numpy(scene.get_masks())\n\n    if not silent:\n        print(f'Exporting 3D scene to prefix={output_file_prefix}')\n\n    assert len(pts3d) == len(msk) <= len(rgbimg)\n    pts3d = to_numpy(pts3d)\n    pts3d_view_idx = [view_idx * np.ones_like(p[:, :, 0]) for view_idx, p in enumerate(pts3d)]\n    imgs = to_numpy(rgbimg)\n\n    pts_view_idx = np.concatenate([pvi[m] for pvi, m in zip(pts3d_view_idx, msk)])\n    pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)])\n    col = np.concatenate([p[m] for p, m in zip(imgs, msk)])\n    # get_view_visibility(scene, np.stack(pts3d).reshape(-1, 3)[:10], np.stack(pts3d_view_idx).reshape(-1)[:10])  # debug\n    vis = get_view_visibility(scene, pts)\n\n    msk = np.stack([m for m in msk])\n\n    depths = to_numpy(scene.get_depthmaps())\n    depths = np.stack([d for d in depths])\n\n    confs = to_numpy([c for c in scene.im_conf])\n    confs = np.stack([c for c in confs])\n\n    output_dict = {\n        \"depths\": depths,\n        \"confs\": confs,\n        \"cleaned_mask\": msk,\n        \"min_conf_thr\": min_conf_thr,\n        \"mask_sky\": mask_sky,\n        \"clean_depth\": clean_depth,\n    }\n    if dump_exhaustive_data:\n        output_dict.update({\n            \"pts\": pts,\n            \"pts_view\": pts_view_idx,\n            \"col\": col,\n            \"vis\": vis,\n            \"rgbs\": imgs,\n        })\n    if feats is not None:\n        output_dict[\"feats\"] = feats\n    np.savez(f\"{output_file_prefix}__scene.npz\", **output_dict)\n\n    if save_ply:\n        pcd = trimesh.PointCloud(vertices=pts, colors=col)\n        pcd.export(f\"{output_file_prefix}__pc.ply\")\n\n    if rerun_viz_timestamp == 0:\n        init_pt_cld = np.concatenate([pts, col, np.ones_like(pts[:, :1])], axis=1)\n        np.savez(f\"{output_file_prefix}__init_pt_cld.npz\", data=init_pt_cld)\n\n    if save_png_viz:\n        # Results visualization\n        rgbimg = scene.imgs\n        cmap = plt.get_cmap('jet')\n        depths_max = max([d.max() for d in depths])\n        depths_viz = [d / depths_max for d in depths]\n        confs_max = max([d.max() for d in confs])\n        confs_viz = [cmap(d / confs_max) for d in confs]\n        assert len(rgbimg) == len(depths_viz) == len(confs)\n        H, W = rgbimg[0].shape[:2]\n        N = len(rgbimg)\n        plt.figure(dpi=100, figsize=(4 * W / 100, N * H / 100))\n        for i in range(N):\n            a = rgbimg[i]\n            b = rgb(depths_viz[i])\n            c = rgb(confs_viz[i])\n            d = rgb(msk[i])\n            plt.subplot(N, 4, 1 + 4 * i)\n            plt.imshow(a)\n            plt.axis('off')\n            plt.subplot(N, 4, 2 + 4 * i)\n            plt.imshow(b)\n            plt.axis('off')\n            plt.subplot(N, 4, 3 + 4 * i)\n            plt.imshow(c)\n            plt.axis('off')\n            plt.subplot(N, 4, 4 + 4 * i)\n            plt.imshow(d)\n            plt.axis('off')\n        plt.tight_layout(pad=0)\n        plt.savefig(f\"{output_file_prefix}__viz.png\")\n        plt.close()\n\n    if save_rerun_viz:\n        rr.init(\"reconstruction\", recording_id=\"v0.1\")\n        # rr.connect_tcp()\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n        rr.set_time_seconds(\"frame\", rerun_viz_timestamp / 30)\n        for v in range(len(rgbimg)):\n            h, w = scene.imshape\n            fx, fy = scene.get_focals().cpu().numpy()[v]\n            cx, cy = scene.get_principal_points().cpu().numpy()[v]\n            K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])\n            c2w = scene.get_im_poses().cpu().numpy()[v]\n            rr.log(f\"image/view-{v}/rgb\", rr.Image(scene.imgs[v]))\n            rr.log(f\"image/view-{v}/depth\", rr.DepthImage(depths[v], point_fill_ratio=0.2))\n            rr.log(f\"image/view-{v}\", rr.Pinhole(image_from_camera=K, width=w, height=h))\n            rr.log(f\"image/view-{v}\", rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3]))\n            rr.log(f\"point_cloud/duster-cleaned/view-{v}\", rr.Points3D(pts, colors=col, radii=rerun_radii))\n            rr.log(f\"point_cloud/duster-raw/view-{v}\", rr.Points3D(positions=np.stack(pts3d).reshape(-1, 3),\n                                                                   colors=np.stack(imgs).reshape(-1, 3),\n                                                                   radii=rerun_radii))\n        rr_rrd_path = f\"{output_file_prefix}__rerun_viz.rrd\"\n        rr.save(rr_rrd_path)\n        print(f\"Saved Rerun recording to: {os.path.abspath(rr_rrd_path)}\")\n\n\ndef get_2D_matches(output_file_prefix, scene, input_views, min_conf_thr, clean_depth, viz_matches=False):\n    scene = deepcopy(scene)\n    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))\n    if clean_depth:\n        scene = scene.clean_pointcloud()\n\n    # retrieve useful values from scene:\n    imgs = scene.imgs\n    pts3d = scene.get_pts3d()\n    confidence_masks = scene.get_masks()\n\n    pts2d_list, pts3d_list = {}, {}\n    for view_i in range(len(input_views)):\n        conf_i = confidence_masks[view_i].cpu().numpy()\n        pts2d_list[view_i] = xy_grid(*imgs[view_i].shape[:2][::-1])[conf_i]  # imgs[i].shape[:2] = (H, W)\n        pts3d_list[view_i] = pts3d[view_i].detach().cpu().numpy()[conf_i]\n\n    matches = {}\n    for view_i in range(len(input_views) - 1):\n        for view_j in range(view_i + 1, len(input_views)):\n\n            # find 2D-2D matches between the two images\n            reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(pts3d_list[view_i], pts3d_list[view_j])\n            assert num_matches == reciprocal_in_P2.sum()\n            print(f'view_{view_i}-view_{view_j}: {num_matches} matches')\n            matches_i_xy = pts2d_list[view_i][nn2_in_P1][reciprocal_in_P2]\n            matches_j_xy = pts2d_list[view_j][reciprocal_in_P2]\n            matches_i_xyz = pts3d_list[view_i][nn2_in_P1][reciprocal_in_P2]\n            matches_j_xyz = pts3d_list[view_j][reciprocal_in_P2]\n            assert len(matches_i_xy) == len(matches_j_xy) == len(matches_i_xyz) == len(matches_j_xyz) == num_matches\n\n            # store the matches\n            matches[(view_i, view_j)] = {\n                'matches_i_xy': matches_i_xy,\n                'matches_j_xy': matches_j_xy,\n                'matches_i_xyz': matches_i_xyz,\n                'matches_j_xyz': matches_j_xyz,\n            }\n\n            # visualize a few matches\n            if viz_matches:\n                n_viz = 18\n                match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)\n                viz_matches_im0, viz_matches_im1 = matches_i_xy[match_idx_to_viz], matches_j_xy[match_idx_to_viz]\n                H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]\n                img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n                img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n                img = np.concatenate((img0, img1), axis=1)\n                plt.figure(dpi=200)\n                plt.imshow(img)\n                cmap = plt.get_cmap('jet')\n                for i in range(n_viz):\n                    (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T\n                    plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)\n                plt.savefig(f\"{output_file_prefix}__matches__v{view_i}-v{view_j}.png\")\n                plt.tight_layout(pad=0)\n                plt.close()\n\n    # save the matches\n    np.savez(f\"{output_file_prefix}__matches.npz\", matches=matches)\n\n\ndef load_images(folder_or_list, size, square_ok=False, verbose=True):\n    \"\"\" open and convert all images in a list or folder to proper input format for DUSt3R\n    \"\"\"\n    if isinstance(folder_or_list, str):\n        if verbose:\n            print(f'>> Loading images from {folder_or_list}')\n        root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))\n\n    elif isinstance(folder_or_list, list):\n        if verbose:\n            print(f'>> Loading a list of {len(folder_or_list)} images')\n        root, folder_content = '', folder_or_list\n\n    else:\n        raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')\n\n    supported_images_extensions = ['.jpg', '.jpeg', '.png']\n    if heif_support_enabled:\n        supported_images_extensions += ['.heic', '.heif']\n    supported_images_extensions = tuple(supported_images_extensions)\n\n    imgs = []\n    for path in folder_content:\n        if not path.lower().endswith(supported_images_extensions):\n            continue\n        img = exif_transpose(Image.open(os.path.join(root, path))).convert('RGB')\n        W1, H1 = img.size\n        if size == 224:\n            # resize short side to 224 (then crop)\n            img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))\n        else:\n            # resize long side to 512\n            img = _resize_pil_image(img, size)\n\n        # W, H = img.size\n        # cx, cy = W // 2, H // 2\n        # if size == 224:\n        #     half = min(cx, cy)\n        #     img = img.crop((cx - half, cy - half, cx + half, cy + half))\n        # else:\n        #     halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8\n        #     if not (square_ok) and W == H:\n        #         halfh = 3 * halfw / 4\n        #     img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))\n\n        W2, H2 = img.size\n        if verbose:\n            print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')\n        imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(\n            [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))\n\n    assert imgs, 'no images foud at ' + root\n    if verbose:\n        print(f' (Found {len(imgs)} images)')\n    return imgs, (W1, H1, W2, H2)\n\n\ndef tensor_to_pil(img_tensor):\n    \"\"\"Convert uint8 torch tensor [3, H, W] to PIL.Image\"\"\"\n    return Image.fromarray(img_tensor.permute(1, 2, 0).cpu().numpy())\n\n\ndef load_tensor_images(tensor_list, size, square_ok=False, verbose=True):\n    \"\"\"Convert torch.Tensor RGB uint8 images to DUSt3R-ready format\"\"\"\n    imgs = []\n    for i, tensor in enumerate(tensor_list):\n        if not (isinstance(tensor, torch.Tensor) and tensor.dtype == torch.uint8 and tensor.ndim == 3 and tensor.shape[\n            0] == 3):\n            raise ValueError(f\"Invalid tensor at index {i}\")\n\n        img = tensor_to_pil(tensor)\n        W1, H1 = img.size\n\n        if size == 224:\n            img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))\n        else:\n            img = _resize_pil_image(img, size)\n\n        W2, H2 = img.size\n        if verbose:\n            print(f' - tensor[{i}] resolution {W1}x{H1} --> {W2}x{H2}')\n        imgs.append(dict(\n            img=ImgNorm(img)[None],\n            true_shape=np.int32([img.size[::-1]]),\n            idx=i,\n            instance=str(i)\n        ))\n\n    if not imgs:\n        raise ValueError('No valid images in input list.')\n\n    return imgs, (W1, H1, W2, H2)\n\n\ndef global_aligner(dust3r_output, device, **optim_kw):\n    view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]\n    net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)\n    return net\n\n\ndef load_known_camera_parameters_from_neus_dataset(dataset_path, input_views):\n    fx = []\n    fy = []\n    cx = []\n    cy = []\n    extrinsics = []\n    for input_view in input_views:\n        cameras_sphere_path = os.path.join(dataset_path, input_view, \"cameras_sphere.npz\")\n        assert os.path.exists(cameras_sphere_path)\n\n        cameras_sphere = np.load(cameras_sphere_path)\n        world_mat_0 = cameras_sphere['world_mat_0']\n\n        out = cv2.decomposeProjectionMatrix(world_mat_0[:3, :])\n        K, R, t = out[:3]\n        K = K / K[2, 2]\n        t = t[:3].squeeze() / t[3]\n\n        fx.append(K[0, 0])\n        fy.append(K[1, 1])\n        cx.append(K[0, 2])\n        cy.append(K[1, 2])\n\n        pose = np.eye(4)\n        pose[:3, :3] = R.T\n        pose[:3, 3] = t\n        extrinsics_ = np.linalg.inv(pose)\n        extrinsics.append(extrinsics_)\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n    return fx, fy, cx, cy, extrinsics\n\n\ndef run_duster(\n        images_tensor_or_image_paths,\n        output_path,\n        fx,\n        fy,\n        cx,\n        cy,\n        extrinsics,\n\n        model_name_or_path=\"../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\",\n        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n        image_size=512,\n\n        skip_if_output_already_exists=True,\n        silent=False,\n        output_2d_matches=False,\n        dump_exhaustive_data=False,\n        save_ply=False,\n        save_png_viz=False,\n        show_debug_plots=False,\n        save_rerun_viz=False,\n        rerun_radii=0.01,\n        frame_selection=None,\n\n        ga_lr=0.01,\n        ga_schedule='linear',  # linear, cosine\n        scenegraph_type=\"complete\",  # complete, swin, oneref\n        use_known_poses_for_pairwise_pose_init=False,  # True, False\n        ga_niter=300,  # from 0 to 5000, default in demo was 300\n\n        min_conf_thr=20,  # from 1 to 20, step 0.1, defualt in demo was 3\n        mask_sky=False,  # True, False, default in demo was False\n        clean_depth=True,  # True, False, default in demo was True\n\n):\n    # Set the random seed\n    seed_all(72)\n    os.makedirs(output_path, exist_ok=True)\n    output_path = Path(output_path)\n\n    # Load the model\n    model = AsymmetricCroCo3DStereo.from_pretrained(model_name_or_path).to(device)\n\n    # Load images into a torch tensor\n    images_all = []\n    n_views, n_frames = None, None\n    original_w, original_h, target_w, target_h = None, None, None, None\n    if not isinstance(images_tensor_or_image_paths, torch.Tensor):\n        n_views = len(images_tensor_or_image_paths)\n        n_frames = len(images_tensor_or_image_paths[0])\n\n        for frame_idx in range(n_frames):\n            frame_img_paths = [str(images_tensor_or_image_paths[view_idx][frame_idx]) for view_idx in range(n_views)]\n            images, shapes = load_images(frame_img_paths, image_size, verbose=not silent)\n            if original_w is None:\n                original_w, original_h, target_w, target_h = shapes\n            images_all.append(images)\n    else:\n        n_views, n_frames, _, original_h, original_w = images_tensor_or_image_paths.shape\n        for frame_idx in range(n_frames):\n            frame_imgs = [images_tensor_or_image_paths[view_idx, frame_idx] for view_idx in range(n_views)]\n            images, shapes = load_tensor_images(frame_imgs, image_size, verbose=not silent)\n            if target_w is None:\n                assert (original_w, original_h) == shapes[:2]\n                _, _, target_w, target_h = shapes\n            images_all.append(images)\n\n    # Check the input data\n    assert len(fx) == len(fy) == len(cx) == len(cy) == len(extrinsics) == n_views\n    assert all(extrinsics[view_idx].shape == (4, 4) for view_idx in range(n_views))\n\n    # Assume known camera parameters\n    known_poses = extrinsics.inverse()\n    known_focals = torch.stack([fx, fy], dim=-1)\n    known_pp = torch.stack([cx, cy], dim=-1)\n\n    patch_h, patch_w = model.patch_embed.patch_size  # e.g., (16, 16)\n    pad_h = (patch_h - (target_h % patch_h)) % patch_h\n    pad_w = (patch_w - (target_w % patch_w)) % patch_w\n    assert pad_h % 2 == 0, f\"pad_h {pad_h} is not divisible by 2\"\n    assert pad_w % 2 == 0, f\"pad_w {pad_w} is not divisible by 2\"\n    pad_top = pad_h // 2\n    pad_bottom = pad_h - pad_top\n    pad_left = pad_w // 2\n    pad_right = pad_w - pad_left\n\n    if pad_h or pad_w:\n        for frame_images in images_all:  # images_all[frame_idx] == list of dicts per view\n            for im_dict in frame_images:\n                # shape: [1, 3, H, W]\n                assert im_dict[\"img\"].shape[-2:] == (target_h, target_w)\n                # F.pad takes (left, right, top, bottom)\n                im_dict[\"img\"] = F.pad(im_dict[\"img\"], (pad_left, pad_right, pad_top, pad_bottom), mode=\"replicate\")\n                im_dict[\"true_shape\"] = np.int32([[target_h + pad_h, target_w + pad_w]])\n\n        # shift principal point to the padded image coordinate system\n        # (we padded symmetrically, so add half the padding on each axis)\n        known_pp = known_pp.clone()\n        known_pp[..., 0] = known_pp[..., 0] + pad_left  # cx\n        known_pp[..., 1] = known_pp[..., 1] + pad_top  # cy\n\n    if frame_selection is None:\n        frame_selection = range(n_frames)\n    for frame_idx in frame_selection:\n        print(f\"Processing frame {frame_idx:05d}/{n_frames:05d}...\")\n        if skip_if_output_already_exists and os.path.exists(output_path / f\"3d_model__{frame_idx:05d}__scene.npz\"):\n            try:\n                np.load(output_path / f\"3d_model__{frame_idx:05d}__scene.npz\")\n                print(f\"Skipping frame because the output file already exists.\")\n                continue\n            except Exception as e:\n                print(f\"Output file already exists but is corrupted: {e}\")\n\n        # Load preprocessed input images\n        images = images_all[frame_idx]\n\n        assert (target_h + pad_h, target_w + pad_w) == images[0]['img'].shape[-2:]\n        assert len(images) == n_views\n        print(f\"Loaded {len(images)} images. \"\n              f\"Original resolution: {original_w}x{original_h}. \"\n              f\"Target resolution: {target_w}x{target_h}.\")\n\n        # Extract encoder features for each image\n        feats = []\n        for view_idx in range(n_views):\n            with torch.no_grad():\n                feat, pos_enc, _ = model._encode_image(images[view_idx][\"img\"].to(device),\n                                                       images[view_idx][\"true_shape\"])\n                feats.append(feat)\n        feats = torch.concat(feats).detach().cpu().numpy()\n\n        # Run DUSt3R on the pairs\n        pairs = make_pairs(images, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)\n        output = inference(pairs, model, device, batch_size=1, verbose=not silent)\n\n        # Unpad the output if padding was applied\n        if pad_h or pad_w:\n            H_pad = target_h + pad_h\n            W_pad = target_w + pad_w\n            t, l = pad_top, pad_left\n            b, r = t + target_h, l + target_w\n\n            assert output[\"view1\"][\"img\"].shape == (len(pairs), 3, H_pad, W_pad)\n            assert output[\"view2\"][\"img\"].shape == (len(pairs), 3, H_pad, W_pad)\n            assert output[\"pred1\"][\"conf\"].shape == (len(pairs), H_pad, W_pad)\n            assert output[\"pred2\"][\"conf\"].shape == (len(pairs), H_pad, W_pad)\n            assert output[\"pred1\"][\"pts3d\"].shape == (len(pairs), H_pad, W_pad, 3)\n            assert output[\"pred2\"][\"pts3d_in_other_view\"].shape == (len(pairs), H_pad, W_pad, 3)\n\n            output[\"view1\"][\"img\"] = output[\"view1\"][\"img\"][:, :, t:b, l:r].contiguous()\n            output[\"view2\"][\"img\"] = output[\"view2\"][\"img\"][:, :, t:b, l:r].contiguous()\n            output[\"pred1\"][\"conf\"] = output[\"pred1\"][\"conf\"][:, t:b, l:r].contiguous()\n            output[\"pred2\"][\"conf\"] = output[\"pred2\"][\"conf\"][:, t:b, l:r].contiguous()\n            output[\"pred1\"][\"pts3d\"] = output[\"pred1\"][\"pts3d\"][:, t:b, l:r, :].contiguous()\n            output[\"pred2\"][\"pts3d_in_other_view\"] = output[\"pred2\"][\"pts3d_in_other_view\"][:, t:b, l:r, :].contiguous()\n            output[\"view1\"][\"true_shape\"] = np.int32([[target_h, target_w]])\n            output[\"view2\"][\"true_shape\"] = np.int32([[target_h, target_w]])\n\n        # Set the known camera parameters\n        scene = global_aligner(output, device=device, verbose=not silent)\n        if not np.isclose(target_w / original_w, target_h / original_h):\n            warnings.warn(f\"The aspect ratio of the input images is different from the target aspect ratio:\\n\"\n                          f\" - rescaling factor x: {target_w}/{original_w} = {target_w / original_w}\\n\"\n                          f\" - rescaling factor y: {target_h}/{original_h} = {target_h / original_h}\")\n        if target_w == 512:\n            rescaling_factor = target_w / original_w\n        elif target_h == 512:\n            rescaling_factor = target_h / original_h\n        else:\n            raise ValueError(f\"Unexpected target resolution: {target_w}x{target_h}\")\n        print(f\"We will use the rescaling factor: {target_w}/{original_w} = {rescaling_factor}\")\n        scene.preset_focal(known_focals.clone() * rescaling_factor)\n        scene.im_pp.requires_grad_(True)\n        scene.preset_principal_point(known_pp.clone() * rescaling_factor)\n        scene.preset_pose(known_poses.clone())\n        # scene.im_pp.requires_grad_(True)\n\n        # Run global alignment to get the global pointcloud and estimated camera parameters\n        init = 'mst' if not use_known_poses_for_pairwise_pose_init else 'known_poses'\n        try:\n            loss = scene.compute_global_alignment(init=init, niter=ga_niter, schedule=ga_schedule, lr=ga_lr)\n        except Exception as e:\n            other_init = {\"mst\": \"known_poses\", \"known_poses\": \"mst\"}\n            print(f\"Error during global alignment: {e}\")\n            print(f\"Trying the other initialization method init={other_init[init]} instead of init={init}\")\n            loss = scene.compute_global_alignment(init=other_init[init], niter=ga_niter, schedule=ga_schedule, lr=ga_lr)\n        print(f\"Global alignment loss: {loss}\")\n        print(f\"Poses after global alignment:\")\n        print(f\"{scene.get_im_poses().cpu().tolist()},\")\n        print(f\"Intrinsic after global alignment:\")\n        print(f\"{scene.get_focals().cpu().tolist()}\")\n        print(f\"{scene.get_principal_points().cpu().tolist()}\")\n        print()\n\n        # Save the scene data, pointclouds, and camera parameters\n        if feats is not None and (pad_h or pad_w):\n            warnings.warn(f\"The saved 'feats' won't take into account the padding (pad_h={pad_h}, pad_w={pad_w}).\")\n        get_3D_model_from_scene(\n            output_file_prefix=output_path / f\"3d_model__{frame_idx:05d}\",\n            silent=silent,\n            scene=scene,\n            min_conf_thr=min_conf_thr,\n            mask_sky=mask_sky,\n            clean_depth=clean_depth,\n            feats=feats,\n            dump_exhaustive_data=dump_exhaustive_data,\n            save_ply=save_ply,\n            save_png_viz=save_png_viz,\n            save_rerun_viz=save_rerun_viz,\n            rerun_radii=rerun_radii,\n            rerun_viz_timestamp=frame_idx,\n        )\n        # get_3D_model_from_scene(output_path / f\"low_threshold_3d_model__{frame_idx:05d}\", silent, scene, 1, mask_sky, clean_depth)\n        # get_3D_model_from_scene(output_path / f\"non_clean_3d_model__{frame_idx:05d}\", silent, scene, 0, mask_sky, False)\n        if output_2d_matches:\n            output_file_prefix = os.path.join(output_path, f\"frame_{frame_idx}\")\n            get_2D_matches(output_file_prefix, scene, image_paths, min_conf_thr, clean_depth, viz_matches=True)\n\n        if show_debug_plots:\n            from sklearn.decomposition import PCA\n            reducer = PCA(n_components=3)\n            fvec_flat_all = feats.reshape(-1, 1024)\n            reducer.fit(fvec_flat_all)\n            fvec_reduced = reducer.transform(fvec_flat_all)\n            reducer_min = fvec_reduced.min(axis=0)\n            reducer_max = fvec_reduced.max(axis=0)\n\n            def fvec_to_rgb(fvec):\n                fvec_reduced = reducer.transform(fvec)\n                fvec_reduced_rescaled = (fvec_reduced - reducer_min) / (reducer_max - reducer_min)\n                fvec_reduced_rgb = (fvec_reduced_rescaled * 255).astype(int)\n                return fvec_reduced_rgb\n\n            rgb_with_feat_list = []\n            for view_idx in range(n_views):\n                fvec_flat = feats[view_idx, :, :].reshape(((target_h + pad_h) // 16) * ((target_w + 16) // 16), 1024)\n                fvec_reduced_rgb = fvec_to_rgb(fvec_flat).reshape((target_h + pad_h) // 16, (target_w + pad_w) // 16, 3)\n                rgb_img = ((images[view_idx][\"img\"][0].permute(1, 2, 0).numpy() / 2 + 0.5) * 255).astype(int)\n                fvec_img = np.kron(fvec_reduced_rgb, np.ones((16, 16, 1))).astype(int)\n                rgb_with_feat = np.concatenate([rgb_img, fvec_img], axis=1)\n                rgb_with_feat_list.append(rgb_with_feat)\n            rgb_with_feat = np.concatenate(rgb_with_feat_list, axis=0)\n\n            import matplotlib.pyplot as plt;\n            plt.figure(figsize=(rgb_with_feat.shape[1] / 100, rgb_with_feat.shape[0] / 100), dpi=100)\n            plt.imshow(rgb_with_feat)\n            plt.axis('off')\n            plt.tight_layout(pad=0)\n            plt.savefig(os.path.join(output_path, f\"debug__{frame_idx:05d}__rgb_with_encoder_features.png\"))\n            # plt.show()\n            plt.close()\n\n\ndef main_on_neus_scene(scene_root, views_selection, **duster_kwargs):\n    views_selection_str = ''.join(str(v) for v in views_selection)\n    output_path = scene_root / f'duster-views-{views_selection_str}'\n    view_paths = [scene_root / f\"view_{v:02d}\" for v in views_selection]\n\n    frame_paths = [sorted((view_path / \"rgb\").glob(\"*.png\")) for view_path in view_paths]\n    n_frames = len(frame_paths[0])\n    assert n_frames > 0\n    assert all(len(f) == n_frames for f in frame_paths)\n\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_path in view_paths:\n        camera_params_file = os.path.join(view_path, \"intrinsics_extrinsics.npz\")\n        params = np.load(camera_params_file)\n        intrinsics = params[\"intrinsics\"]\n        extrinsics_view = params[\"extrinsics\"]\n\n        assert intrinsics[0, 1] == 0\n        assert intrinsics[1, 0] == 0\n        assert intrinsics[2, 0] == 0\n        assert intrinsics[2, 1] == 0\n        assert intrinsics[2, 2] == 1\n\n        fx.append(intrinsics[0, 0])\n        fy.append(intrinsics[1, 1])\n        cx.append(intrinsics[0, 2])\n        cy.append(intrinsics[1, 2])\n        extrinsics.append(extrinsics_view)\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n\n    print(f\"Processing {output_path}\")\n    run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics, **duster_kwargs)\n\n\ndef main_on_kubric_scene(scene_root, views_selection, **duster_kwargs):\n    views_selection_str = ''.join(str(v) for v in views_selection)\n    output_path = scene_root / f'duster-views-{views_selection_str}'\n    view_paths = [scene_root / f\"view_{v:01d}\" for v in views_selection]\n\n    frame_paths = [sorted(view_path.glob(\"rgba_*.png\")) for view_path in view_paths]\n    n_frames = len(frame_paths[0])\n    assert n_frames > 0\n    assert all(len(f) == n_frames for f in frame_paths)\n\n    datapoint = KubricMultiViewDataset.getitem_raw_datapoint(scene_root)\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_idx in views_selection:\n        intrinsics = datapoint[\"views\"][view_idx][\"intrinsics\"]\n        extrinsics_view = np.eye(4)\n        extrinsics_view[:3, :4] = datapoint[\"views\"][view_idx][\"extrinsics\"][0]\n\n        assert intrinsics[0, 1] == 0\n        assert intrinsics[1, 0] == 0\n        assert intrinsics[2, 0] == 0\n        assert intrinsics[2, 1] == 0\n        assert intrinsics[2, 2] == 1\n\n        fx.append(intrinsics[0, 0])\n        fy.append(intrinsics[1, 1])\n        cx.append(intrinsics[0, 2])\n        cy.append(intrinsics[1, 2])\n        extrinsics.append(extrinsics_view)\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n\n    start = time.time()\n    print(f\"Processing {output_path}\")\n    run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics, **duster_kwargs)\n    time_elapsed = time.time() - start\n    print(f\"Time elapsed for DUST3R: {time_elapsed:.2f} seconds\")\n\n\ndef main_on_d3dgs_panoptic_scene(\n        scene_root,\n        views_selection,\n        save_rerun_viz=False,\n        rerun_radii=0.002,\n        **duster_kwargs,\n):\n    md = json.load(open(os.path.join(scene_root, \"train_meta.json\"), 'r'))\n    n_frames = len(md['fn'])\n\n    # Check that the selected views are in the training set\n    view_paths = []\n    for view_idx in views_selection:\n        view_path = scene_root / \"ims\" / f\"{view_idx}\"\n        assert view_idx in md[\"cam_id\"][0], f\"Camera {view_idx} is not in the training set\"\n        assert view_path.exists()\n        view_paths.append(view_path)\n    frame_paths = [sorted(view_path.glob(\"*.jpg\")) for view_path in view_paths]\n    assert all(len(frame_paths[v]) == n_frames for v in range(len(views_selection)))\n\n    # Create the output directory\n    views_selection_str = '-'.join(str(v) for v in views_selection)\n    output_path = scene_root / f'duster-views-{views_selection_str}'\n    os.makedirs(output_path, exist_ok=True)\n\n    # Load the camera parameters\n    fx, fy, cx, cy, extrinsics = [], [], [], [], []\n    for view_idx in views_selection:\n        fx_current, fy_current, cx_current, cy_current, extrinsics_current = [], [], [], [], []\n        for t in range(n_frames):\n            view_idx_in_array = md['cam_id'][t].index(view_idx)\n            k = md['k'][t][view_idx_in_array]\n            w2c = np.array(md['w2c'][t][view_idx_in_array])\n\n            fx_current.append(k[0][0])\n            fy_current.append(k[1][1])\n            cx_current.append(k[0][2])\n            cy_current.append(k[1][2])\n            extrinsics_current.append(w2c)\n\n        assert all(np.equal(fx_current[0], fx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(fy_current[0], fy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cx_current[0], cx_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(cy_current[0], cy_current[t]).all() for t in range(1, n_frames))\n        assert all(np.equal(extrinsics_current[0], extrinsics_current[t]).all() for t in range(1, n_frames))\n\n        fx.append(fx_current[0])\n        fy.append(fy_current[0])\n        cx.append(cx_current[0])\n        cy.append(cy_current[0])\n        extrinsics.append(extrinsics_current[0])\n\n    fx = torch.tensor(fx).float()\n    fy = torch.tensor(fy).float()\n    cx = torch.tensor(cx).float()\n    cy = torch.tensor(cy).float()\n    extrinsics = torch.from_numpy(np.stack(extrinsics)).float()\n\n    # Visualize the initialization point cloud used in D3DGS\n    if save_rerun_viz:\n        init_pt_cld = np.load(scene_root / \"init_pt_cld.npz\")[\"data\"]\n        xyz = init_pt_cld[:, :3]\n        col = init_pt_cld[:, 3:6]\n        seg = init_pt_cld[:, 6:7]\n        rr.init(\"reconstruction\", recording_id=\"v0.1\")\n        # rr.connect_tcp()\n        rr.set_time_seconds(\"frame\", 0 / 30)\n        rr.log(f\"point_cloud/sfm-full\", rr.Points3D(xyz, colors=col, radii=rerun_radii))\n        rr.log(f\"point_cloud/sfm-full-seg\", rr.Points3D(xyz, colors=col * seg, radii=rerun_radii))\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n        # moge_depths = []\n        # moge_masks = []\n        for selected_view_idx, view_idx in enumerate(views_selection):\n            rgbs = np.stack([np.array(Image.open(frame_paths[selected_view_idx][t])) for t in range(n_frames)])\n            rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()\n            H, W = rgbs.shape[-2], rgbs.shape[-1]\n            K = np.array([\n                [fx[selected_view_idx], 0, cx[selected_view_idx]],\n                [0, fy[selected_view_idx], cy[selected_view_idx]],\n                [0, 0, 1],\n            ])\n            K_inv = np.linalg.inv(K)\n            K_for_moge = np.array([\n                [fx[selected_view_idx] / W, 0, 0.5],\n                [0, fy[selected_view_idx] / H, 0.5],\n                [0, 0, 1],\n            ])\n            # depths, i, _, _, mask = moge(rgbs[::10], intrinsics=K_for_moge)\n            # moge_depths.append(depths)\n            # moge_masks.append(mask)\n            for t in range(0, n_frames, 10):\n                rr.set_time_seconds(\"frame\", t / 30)\n                c2w = torch.linalg.inv(extrinsics[selected_view_idx]).numpy()\n                rr.log(f\"image/view-{view_idx}/rgb\", rr.Image(rgbs[t].permute(1, 2, 0).numpy()))\n                # rr.log(f\"image/view-{view_idx}/depth\",\n                #        rr.DepthImage(moge_depths[selected_view_idx][t // 10], point_fill_ratio=0.2))\n                rr.log(f\"image/view-{view_idx}\", rr.Pinhole(image_from_camera=K, width=W, height=H))\n                rr.log(f\"image/view-{view_idx}\", rr.Transform3D(translation=c2w[:3, 3], mat3x3=c2w[:3, :3]))\n\n                # # Generate and log point cloud colored by RGB values\n                # y, x = np.indices((H, W))\n                # homo_pixel_coords = np.stack([x.ravel(), y.ravel(), np.ones_like(x).ravel()], axis=1).T\n                # depth_values = moge_depths[selected_view_idx][t // 10].ravel()\n                # cam_coords = (K_inv @ homo_pixel_coords) * depth_values\n                # cam_coords = np.vstack((cam_coords, np.ones((1, cam_coords.shape[1]))))\n                # world_coords = (c2w @ cam_coords)[:3].T\n                # valid_mask = (depth_values > 0) & moge_masks[selected_view_idx][t // 10].reshape(-1, )\n                # world_coords = world_coords[valid_mask]\n                # rgb_colors = rgbs[t].permute(1, 2, 0).reshape(-1, 3).numpy()[valid_mask].astype(np.uint8)\n                # rr.log(f\"point_cloud/view-{view_idx}\", rr.Points3D(world_coords, colors=rgb_colors, radii=rerun_radii))\n        rr.save(output_path / \"init_pt_cld.rrd\")\n\n    # Run DUSt3R\n    print(f\"Processing {output_path}\")\n    run_duster(frame_paths, output_path, fx, fy, cx, cy, extrinsics,\n               save_rerun_viz=save_rerun_viz, rerun_radii=rerun_radii, **duster_kwargs)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--dataset', type=str, required=True, help='The dataset to process')\n    args = parser.parse_args()\n\n    duster_kwargs = {\n        \"model_name_or_path\": \"../duster/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\",\n        \"silent\": False,\n        \"output_2d_matches\": False,\n        \"dump_exhaustive_data\": True,\n        \"save_ply\": True,\n        \"save_png_viz\": True,\n        \"show_debug_plots\": True,\n    }\n\n    if args.dataset == \"dexycb\":\n        data_root = Path('./datasets/dex-january-2025/neus_nsubsample-3/')\n        views_selections = [\n            [0, 1, 2, 3],\n            [2, 3, 4, 5],\n            [4, 5, 6, 7],\n            [0, 1, 2, 3, 4, 5, 6, 7],\n        ]\n        for scene_root in sorted(data_root.glob(\"*\")):\n            for views_selection in views_selections:\n                main_on_neus_scene(scene_root, views_selection, **duster_kwargs)\n\n    elif args.dataset == \"kubric-val\":\n        data_root = Path('./datasets/kubric_multiview_003/test/')\n        duster_kwargs[\"save_rerun_viz\"] = True\n        views_selections = [\n            # [0, 1],\n            [0, 1, 2, 3],\n            [0, 1, 2, 3, 4, 5, 6, 7],\n        ]\n        for scene_root in sorted(data_root.glob(\"[!.]*\")):\n            for views_selection in views_selections:\n                main_on_kubric_scene(scene_root, views_selection, **duster_kwargs)\n\n    elif args.dataset == \"kubric-train\":\n        # Save space by not saving all logs\n        duster_kwargs[\"dump_exhaustive_data\"] = False\n        duster_kwargs[\"save_ply\"] = False\n        duster_kwargs[\"save_png_viz\"] = False\n        duster_kwargs[\"show_debug_plots\"] = False\n\n        data_root = Path('./datasets/kubric_multiview_003/train/')\n        views_selections = [\n            [0, 1, 2, 3],\n            [0, 1, 2, 3, 4, 5, 6, 7],\n        ]\n\n        # # Parallelize across a machine with 4 GPUs\n        # total_gpus = 4\n        # gpu_id = int(os.environ.get(\"CUDA_VISIBLE_DEVICES\"))\n        # # Run, e.g., as:\n        # # --------------\n        # # CUDA_VISIBLE_DEVICES=0 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # # CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # # CUDA_VISIBLE_DEVICES=2 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # # CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n\n        # Parallelize across 128 machines with 4 GPUs each\n        total_gpus = 128 * 4\n        a = int(os.environ.get(\"CHUNK\"))\n        b = int(os.environ.get(\"CUDA_VISIBLE_DEVICES\"))\n        gpu_id = a * 4 + b\n        # Run, e.g., as:\n        # --------------\n        # CHUNK=0 CUDA_VISIBLE_DEVICES=0 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # CHUNK=0 CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # CHUNK=0 CUDA_VISIBLE_DEVICES=2 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # CHUNK=0 CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # CHUNK=1 CUDA_VISIBLE_DEVICES=1 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n        # ...\n        # CHUNK=15 CUDA_VISIBLE_DEVICES=3 python scripts/estimate_depth_with_duster.py --dataset kubric-train\n\n        print(f\"Running on GPU {gpu_id} (out of {total_gpus})\")\n        print(f'Total scenes to process: {len(sorted(data_root.glob(\"[!.]*\"))[gpu_id::total_gpus])}')\n        for scene_root in sorted(data_root.glob(\"[!.]*\"))[gpu_id::total_gpus]:\n            for views_selection in views_selections:\n                main_on_kubric_scene(scene_root, views_selection, **duster_kwargs)\n\n    elif args.dataset == \"panoptic_d3dgs\":\n        duster_kwargs[\"skip_if_output_already_exists\"] = True\n        duster_kwargs[\"save_rerun_viz\"] = False\n        duster_kwargs[\"frame_selection\"] = None  # [0]\n        data_root = Path('./datasets/panoptic_d3dgs/')\n        views_selections = [\n            # [27, 16, 14, 8, 11, 19, 11, 6, 23, 1],  # 10 views\n            [27, 16, 14, 8, 11, 19, 11, 6],  # 8 views\n            [27, 16, 14, 8],  # 4 views\n            [27, 16],  # 2 views\n\n            # [1, 4, 7, 11, 14, 17, 20, 23, 26, 29],  # 10 views\n            # # [5, 8, 11, 14, 17, 20, 23, 26, 29],  # 9 views\n            [1, 4, 7, 11, 14, 17, 20, 23],  # 8 views\n            #\n            [1, 4, 7, 11, ],  # 4 views - v1\n            [1, 7, 14, 20, ],  # 4 views - v2\n            #\n            # [1, 4],  # 2 views - v1\n            # [1, 14],  # 2 views - v2\n        ]\n        for scene_root in sorted(data_root.glob(\"[!.]*\")):\n            for views_selection in views_selections:\n                main_on_d3dgs_panoptic_scene(scene_root, views_selection, **duster_kwargs)\n\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n\n    print(f\"Done.\")\n"
  },
  {
    "path": "scripts/hi4d_preprocessing.py",
    "content": "\"\"\"\nFirst download the dataset. You'll have to fill in an online ETH form\nand then wait for a few days to get a temporary access code over email.\nI used the following sequence of commands to download and unpack the data\ninto the expected structure. You can probably replace the `dt=...` with\nyour access token that you can probably find in the access URL (or otherwise\nin the page source of the download page that will be linked). Note that\nyou don't need to download all the data if you don't need it, e.g., maybe\nyou just want to download a small sample. Note also that in the commands below,\nI didn't delete the `*.tar.gz` files, but you can do so if you'd like.\n```bash\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/LICENSE.txt' -O LICENSE.txt\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/README.md' -O README.md\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair00_1.tar.gz' -O pair00_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair00_2.tar.gz' -O pair00_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair01.tar.gz' -O pair01.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair02_1.tar.gz' -O pair02_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair02_2.tar.gz' -O pair02_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair09.tar.gz' -O pair09.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair10.tar.gz' -O pair10.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair12.tar.gz' -O pair12.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair13_1.tar.gz' -O pair13_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair13_2.tar.gz' -O pair13_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair14.tar.gz' -O pair14.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair15_1.tar.gz' -O pair15_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair15_2.tar.gz' -O pair15_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair16.tar.gz' -O pair16.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair17_1.tar.gz' -O pair17_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair17_2.tar.gz' -O pair17_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair18_1.tar.gz' -O pair18_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair18_2.tar.gz' -O pair18_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair19_1.tar.gz' -O pair19_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair19_2.tar.gz' -O pair19_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair21_1.tar.gz' -O pair21_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair21_2.tar.gz' -O pair21_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair22.tar.gz' -O pair22.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair23_1.tar.gz' -O pair23_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair23_2.tar.gz' -O pair23_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair27_1.tar.gz' -O pair27_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair27_2.tar.gz' -O pair27_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair28.tar.gz' -O pair28.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair32_1.tar.gz' -O pair32_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair32_2.tar.gz' -O pair32_2.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair37_1.tar.gz' -O pair37_1.tar.gz\nwget 'https://hi4d.ait.ethz.ch/download.php?dt=def502001190eca4e725f10acbfbd3520f0caca29004163d940aa67e31c024acac2f55ce060924e95b528e99e47e167d6d3e8dd34449e7c89fc60b1139e6ee28f45ed216e5f452230156127a2a1919ef0b796c8cc016630353296abd0c4294db83582d7a99a132d033e95928e4a1&file=/pair37_2.tar.gz' -O pair37_2.tar.gz\n\nmkdir -p pair00 pair01 pair02 pair09 pair10 pair12 pair13 pair14 pair15 pair16 pair17 pair18 pair19 pair21 pair22 pair23 pair27 pair28 pair32 pair37\n\ntar -xvzf pair00_1.tar.gz -C pair00\ntar -xvzf pair00_2.tar.gz -C pair00\ntar -xvzf pair01.tar.gz pair01\ntar -xvzf pair02_1.tar.gz -C pair02\ntar -xvzf pair02_2.tar.gz -C pair02\ntar -xvzf pair09.tar.gz -C pair09\ntar -xvzf pair10.tar.gz -C pair10\ntar -xvzf pair12.tar.gz -C pair12\ntar -xvzf pair13_1.tar.gz -C pair13\ntar -xvzf pair13_2.tar.gz -C pair13\ntar -xvzf pair14.tar.gz -C pair14\ntar -xvzf pair15_1.tar.gz -C pair15\ntar -xvzf pair15_2.tar.gz -C pair15\ntar -xvzf pair16.tar.gz -C pair16\ntar -xvzf pair17_1.tar.gz -C pair17\ntar -xvzf pair17_2.tar.gz -C pair17\ntar -xvzf pair18_1.tar.gz -C pair18\ntar -xvzf pair18_2.tar.gz -C pair18\ntar -xvzf pair19_1.tar.gz -C pair19\ntar -xvzf pair19_2.tar.gz -C pair19\ntar -xvzf pair21_1.tar.gz -C pair21\ntar -xvzf pair21_2.tar.gz -C pair21\ntar -xvzf pair22.tar.gz -C pair22\ntar -xvzf pair23_1.tar.gz -C pair23\ntar -xvzf pair23_2.tar.gz -C pair23\ntar -xvzf pair27_1.tar.gz -C pair27\ntar -xvzf pair27_2.tar.gz -C pair27\ntar -xvzf pair28.tar.gz -C pair28\ntar -xvzf pair32_1.tar.gz -C pair32\ntar -xvzf pair32_2.tar.gz -C pair32\ntar -xvzf pair37_1.tar.gz -C pair37\ntar -xvzf pair37_2.tar.gz -C pair37\n\n# Some cleanup because the tars were not consistently structured\nmv pair00/pair00/* pair00/\nmv pair01/pair01/* pair01/\nmv pair02/pair02/* pair02/\nmv pair09/pair09/* pair09/\nmv pair10/pair10/* pair10/\nmv pair12/pair12/* pair12/\nmv pair13/pair13/* pair13/\nmv pair14/pair14/* pair14/\nmv pair15/pair15/* pair15/\nmv pair16/pair16/* pair16/\nmv pair17/pair17/* pair17/\nmv pair18/pair18/* pair18/\nmv pair19/pair19/* pair19/\nmv pair21/pair21/* pair21/\nmv pair22/pair22/* pair22/\nmv pair23/pair23/* pair23/\nmv pair27/pair27/* pair27/\nmv pair28/pair28/* pair28/\nmv pair32/pair32/* pair32/\nmv pair37/pair37/* pair37/\nrm -rf pair*/pair*/\n```\n\nWith the data downloaded, you can run the script: `python -m scripts.hi4d_preprocessing`.\n\"\"\"\nfrom mvtracker.datasets.utils import transform_scene\n\n\ndef load_pickle(p):\n    with open(p, \"rb\") as f:\n        return pickle.load(f)\n\n\nimport glob\nimport os\nimport pickle\nfrom typing import Optional, Dict, List, Tuple\n\nimport cv2\nimport numpy as np\nimport rerun as rr\nimport torch\nimport tqdm\nfrom PIL import Image\nfrom pytorch3d.io import load_objs_as_meshes\nfrom pytorch3d.renderer import (\n    PerspectiveCameras,\n    MeshRasterizer,\n    RasterizationSettings,\n)\nfrom pytorch3d.structures import Meshes\nfrom scipy.spatial.transform import Rotation\n\n\ndef save_pickle(p, data):\n    os.makedirs(os.path.dirname(p), exist_ok=True)\n    with open(p, \"wb\") as f:\n        pickle.dump(data, f)\n\n\ndef load_image(path):\n    return np.array(Image.open(path))\n\n\ndef _safe_load_rgb_cameras(npz_path: str) -> Dict[str, np.ndarray]:\n    \"\"\"\n    Hi4D has a typo in docs ('intirnsics'). Support both.\n    Returns dict with keys: ids [N], intrinsics [N,3,3], extrinsics [N,3,4], dist_coeffs [N,5]\n    \"\"\"\n    data = dict(np.load(npz_path))\n    ids = data.get(\"ids\")\n    intr = data.get(\"intrinsics\", data.get(\"intirnsics\"))\n    extr = data.get(\"extrinsics\")\n    dist = data.get(\"dist_coeffs\")\n    assert ids is not None and intr is not None and extr is not None, \\\n        f\"Missing keys in {npz_path}. Found keys: {list(data.keys())}\"\n    return {\"ids\": ids, \"intrinsics\": intr, \"extrinsics\": extr, \"dist_coeffs\": dist}\n\n\ndef _find_all_frames_for_action(images_root: str, cam_ids: List[int]) -> List[int]:\n    \"\"\"\n    Robustly infer the list of frame indices by intersecting the available frames across cams.\n    Hi4D names images as 000XXX.jpg (zero-padded 6).\n    \"\"\"\n    per_cam_sets = []\n    for cid in cam_ids:\n        cam_dir = os.path.join(images_root, f\"{cid}\")\n        jpgs = sorted(glob.glob(os.path.join(cam_dir, \"*.jpg\")))\n        frames = set(int(os.path.splitext(os.path.basename(p))[0]) for p in jpgs)\n        per_cam_sets.append(frames)\n    if not per_cam_sets:\n        return []\n    common = set.intersection(*per_cam_sets) if len(per_cam_sets) > 1 else per_cam_sets[0]\n    return sorted(list(common))\n\n\ndef _mesh_path_for_frame(frames_dir: str, frame_idx: int) -> str:\n    \"\"\"\n    Hi4D meshes are 'mesh-f00XXX.obj' (5 digits). We'll format with 5 digits.\n    \"\"\"\n    return os.path.join(frames_dir, f\"mesh-f{frame_idx:05d}.obj\")\n\n\ndef extract_hi4d_action_to_pkl(\n        dataset_root: str,\n        pair: str,\n        action: str,\n        save_pkl_path: str,\n        downscaled_longerside: Optional[int] = None,\n        save_rerun_viz: bool = True,\n        stream_rerun_viz: bool = False,\n        skip_if_output_exists: bool = False,\n):\n    \"\"\"\n    Build a single .pkl for a (pair, action):\n      - rgbs:  dict[cam_id_str] -> [T,3,H,W] uint8\n      - intrs: dict[cam_id_str] -> [3,3] float32  (scaled if resized)\n      - extrs: dict[cam_id_str] -> [3,4] float32\n      - depths:dict[cam_id_str] -> [T,H,W] float32  (mesh-rendered)\n      - ego_cam_name: None\n    \"\"\"\n    if skip_if_output_exists and os.path.exists(save_pkl_path):\n        print(f\"Skipping {save_pkl_path} (exists).\")\n        return save_pkl_path\n    print(f\"Processing {pair}/{action} -> {save_pkl_path}\")\n\n    root = os.path.join(dataset_root, pair, action)\n    frames_dir = os.path.join(root, \"frames\")\n    images_dir = os.path.join(root, \"images\")\n    cameras_npz = os.path.join(root, \"cameras\", \"rgb_cameras.npz\")\n    meta_npz = os.path.join(root, \"meta.npz\")\n\n    cams = _safe_load_rgb_cameras(cameras_npz)\n    cam_ids: List[int] = list(map(int, cams[\"ids\"]))  # e.g., [4,16,28,40,52,64,76,88]\n    intr_all = cams[\"intrinsics\"].astype(np.float32)  # [N,3,3]\n    extr_all = cams[\"extrinsics\"].astype(np.float32)  # [N,3,4]\n\n    meta = dict(np.load(meta_npz))\n    frame_ids = _find_all_frames_for_action(images_dir, cam_ids)\n    assert len(frame_ids) > 0, f\"No common frames found across cameras at {images_dir}\"\n    assert frame_ids[0] == meta[\"start\"].item()\n    assert frame_ids[-1] == meta[\"end\"].item()\n    assert len(frame_ids) == (meta[\"end\"].item() - meta[\"start\"].item() + 1)\n\n    # Build containers\n    rgbs: Dict[str, List[np.ndarray]] = {str(cid): [] for cid in cam_ids}\n    depths: Dict[str, List[np.ndarray]] = {str(cid): [] for cid in cam_ids}\n    intrs: Dict[str, np.ndarray] = {}\n    extrs: Dict[str, np.ndarray] = {}\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    # Pre-load a single mesh per frame and rasterize to each camera\n    # (This is typically faster than reloading the mesh V times.)\n    raster_settings_cache: Dict[Tuple[int, int], RasterizationSettings] = {}\n\n    for frame in tqdm.tqdm(frame_ids, desc=f\"Frames {pair}/{action}\"):\n        mesh_path = _mesh_path_for_frame(frames_dir, frame)\n        if not os.path.isfile(mesh_path):\n            # Some sequences may use different padding; try 6 digits as fallback.\n            alt = os.path.join(frames_dir, f\"mesh-f{frame:06d}.obj\")\n            if os.path.isfile(alt):\n                mesh_path = alt\n            else:\n                # Skip missing mesh frame\n                continue\n\n        # Load mesh (geometry only is enough for depth)\n        meshes: Meshes = load_objs_as_meshes([mesh_path], device=device)\n\n        # For each camera, render depth & collect RGB\n        for i, cid in enumerate(cam_ids):\n            cam_name = str(cid)\n            img_path = os.path.join(images_dir, cam_name, f\"{frame:06d}.jpg\")\n            if not os.path.isfile(img_path):\n                # Skip if that particular view is missing the image for this frame\n                continue\n\n            image = load_image(img_path)\n            h0, w0 = image.shape[:2]\n\n            # Copy camera params\n            K = intr_all[i].copy()  # [3,3]\n            E = extr_all[i].copy()  # [3,4]  world->cam (Hi4D)\n\n            # Optional downscale (longer side) + scale intrinsics\n            if downscaled_longerside is not None:\n                scale = downscaled_longerside / float(max(h0, w0))\n                nh, nw = int(round(h0 * scale)), int(round(w0 * scale))\n                if (nh, nw) != (h0, w0):\n                    image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)\n                    K[:2] *= scale\n                h, w = nh, nw\n            else:\n                h, w = h0, w0\n\n            # Stash static intr/extr once (raw, no global transform)\n            if cam_name not in intrs:\n                intrs[cam_name] = K.astype(np.float32)\n                extrs[cam_name] = E.astype(np.float32)\n\n            rgbs[cam_name].append(image)\n\n            # Build PyTorch3D camera from raw E\n            fx, fy = K[0, 0], K[1, 1]\n            cx, cy = K[0, 2], K[1, 2]\n\n            R = E[:3, :3]\n            t = E[:3, 3]\n\n            # 4D-DRESS convention: transpose + flip X/Y\n            R = R.T\n            R = R @ np.diag(np.array([-1.0, -1.0, 1.0], dtype=np.float32))\n            t = t @ np.diag(np.array([-1.0, -1.0, 1.0], dtype=np.float32))\n\n            cameras_p3d = PerspectiveCameras(\n                focal_length=torch.tensor([[fx, fy]], dtype=torch.float32, device=device),\n                principal_point=torch.tensor([[cx, cy]], dtype=torch.float32, device=device),\n                R=torch.tensor(R, dtype=torch.float32, device=device).unsqueeze(0),\n                T=torch.tensor(t, dtype=torch.float32, device=device).unsqueeze(0),\n                image_size=torch.tensor([[h, w]], dtype=torch.float32, device=device),\n                in_ndc=False,\n                device=device,\n            )\n\n            # Rasterize (no global transform on mesh here)\n            rs_key = (h, w)\n            if rs_key not in raster_settings_cache:\n                raster_settings_cache[rs_key] = RasterizationSettings(\n                    image_size=(h, w),\n                    blur_radius=0.0,\n                    faces_per_pixel=1,\n                    bin_size=0,\n                )\n            rasterizer = MeshRasterizer(cameras=cameras_p3d, raster_settings=raster_settings_cache[rs_key])\n            fragments = rasterizer(meshes)\n\n            # faces_per_pixel=1 -> (1,H,W,1) -> (H,W)\n            zbuf = fragments.zbuf[0, ..., 0].detach().cpu().numpy()\n            zbuf = np.nan_to_num(zbuf, nan=0.0)\n\n            depths[cam_name].append(zbuf.astype(np.float32))\n\n    # Stack per-camera data\n    cam_names = sorted(rgbs.keys(), key=lambda s: int(s))\n    for cam_name in cam_names:\n        if len(rgbs[cam_name]) == 0:\n            # Camera had no valid frames (skip)\n            del intrs[cam_name], extrs[cam_name], rgbs[cam_name], depths[cam_name]\n            continue\n        rgbs[cam_name] = np.stack(rgbs[cam_name]).transpose(0, 3, 1, 2).astype(np.uint8)  # [T,3,H,W]\n        depths[cam_name] = np.stack(depths[cam_name]).astype(np.float32)  # [T,H,W]\n\n    # Basic shape checks (use first cam as reference)\n    kept_cams = sorted(rgbs.keys(), key=lambda s: int(s))\n    assert len(kept_cams) > 0, \"No cameras with data.\"\n    n_frames, _, h, w = rgbs[kept_cams[0]].shape\n    for cam_name in kept_cams:\n        assert rgbs[cam_name].shape == (n_frames, 3, h, w)\n        assert intrs[cam_name].shape == (3, 3)\n        assert extrs[cam_name].shape == (3, 4)\n        assert depths[cam_name].shape == (n_frames, h, w)\n\n    # Rotate the scene to have the ground at z=0\n    rot_x = Rotation.from_euler('x', 90, degrees=True).as_matrix()\n    rot_y = Rotation.from_euler('y', 0, degrees=True).as_matrix()\n    rot_z = Rotation.from_euler('z', 0, degrees=True).as_matrix()\n    rot = torch.from_numpy(rot_z @ rot_y @ rot_x)\n    translation = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)\n    for cam_name in kept_cams:\n        E = torch.from_numpy(extrs[cam_name][None, None])  # [1,1,3,4]\n        E_tx = transform_scene(1, rot, translation, None, E, None, None, None)[1]\n        extrs[cam_name] = E_tx[0, 0].numpy()\n\n    # Save\n    save_pickle(save_pkl_path, dict(\n        rgbs=rgbs,\n        intrs=intrs,\n        extrs=extrs,\n        depths=depths,\n        ego_cam_name=None,\n    ))\n\n    # Visualize the data sample using rerun\n    rerun_modes = []\n    if stream_rerun_viz:\n        rerun_modes += [\"stream\"]\n    if save_rerun_viz:\n        rerun_modes += [\"save\"]\n    for rerun_mode in rerun_modes:\n        rr.init(f\"3dpt\", recording_id=\"v0.16\")\n        if rerun_mode == \"stream\":\n            rr.connect_tcp()\n\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n\n        mesh_vertices = meshes._verts_list[0].cpu()\n        mesh_faces = meshes._faces_list[0].cpu()\n        mesh_vertices = transform_scene(1, rot, translation, None, None, None, mesh_vertices[None], None)[3][0]\n        rr.log(\n            \"mesh\",\n            rr.Mesh3D(\n                vertex_positions=mesh_vertices.numpy().astype(np.float32),  # (N, 3)\n                triangle_indices=mesh_faces.numpy().reshape(-1, 3).astype(np.int32),  # (M, 3)\n                albedo_factor=[200, 200, 255],  # Optional color\n            ),\n        )\n\n        fps = 30\n        for frame_idx in range(n_frames):\n            rr.set_time_seconds(\"frame\", frame_idx / fps)\n            for cam_name in cam_names:\n                extr = extrs[cam_name]\n                intr = intrs[cam_name]\n                img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8)\n                depth = depths[cam_name][frame_idx]\n\n                h, w = img.shape[:2]\n                fx, fy = intr[0, 0], intr[1, 1]\n                cx, cy = intr[0, 2], intr[1, 2]\n\n                # Camera pose\n                T = np.eye(4)\n                T[:3, :] = extr\n                world_T_cam = np.linalg.inv(T)\n                rr.log(f\"{cam_name}/image\", rr.Transform3D(\n                    translation=world_T_cam[:3, 3],\n                    mat3x3=world_T_cam[:3, :3],\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=w,\n                    height=h\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Image(img))\n\n                rr.log(f\"{cam_name}/depth\", rr.Transform3D(\n                    translation=world_T_cam[:3, 3],\n                    mat3x3=world_T_cam[:3, :3],\n                ))\n                rr.log(f\"{cam_name}/depth\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=w,\n                    height=h\n                ))\n                rr.log(f\"{cam_name}/depth\", rr.DepthImage(depth, meter=1.0, colormap=\"viridis\"))\n\n                # Unproject depth to point cloud\n                y, x = np.meshgrid(np.arange(h), np.arange(w), indexing=\"ij\")\n                z = depth\n                valid = z > 0\n                x = x[valid]\n                y = y[valid]\n                z = z[valid]\n\n                X = (x - cx) * z / fx\n                Y = (y - cy) * z / fy\n                pts_cam = np.stack([X, Y, z], axis=-1)\n\n                # Transform to world\n                R = world_T_cam[:3, :3]\n                t = world_T_cam[:3, 3]\n                pts_world = pts_cam @ R.T + t\n\n                # Color\n                colors = img[y, x]\n\n                rr.log(f\"point_cloud/{cam_name}\", rr.Points3D(positions=pts_world, colors=colors))\n\n        if rerun_mode == \"save\":\n            base, name = os.path.split(save_pkl_path)\n            name_no_ext = os.path.splitext(name)[0]\n            save_rrd_path = os.path.join(base, f\"rerun__{name_no_ext}.rrd\")\n            rr.save(save_rrd_path)\n            print(f\"Saved rerun viz to {os.path.abspath(save_rrd_path)}\")\n\n    print(f\"Done with {save_pkl_path}.\")\n    print()\n\n\nif __name__ == \"__main__\":\n    dataset_root = \"datasets/hi4d\"\n    output_root = \"datasets/hi4d-processed\"\n\n    longside_resolution: Optional[int] = 512\n    if longside_resolution is not None:\n        output_root += f\"-resized-{longside_resolution}\"\n    os.makedirs(output_root, exist_ok=True)\n\n    pairs = [\n        \"pair00\", \"pair01\", \"pair02\", \"pair09\", \"pair10\",\n        \"pair12\", \"pair13\", \"pair14\", \"pair15\", \"pair16\",\n        \"pair17\", \"pair18\", \"pair19\", \"pair21\", \"pair22\",\n        \"pair23\", \"pair27\", \"pair28\", \"pair32\", \"pair37\"\n    ]\n\n    # Enumerate actions per pair automatically\n    for pair in tqdm.tqdm(pairs, desc=\"Pairs\"):\n        pair_dir = os.path.join(dataset_root, pair)\n        assert os.path.isdir(pair_dir)\n        actions = sorted([\n            d for d in os.listdir(pair_dir)\n            if os.path.isdir(os.path.join(pair_dir, d)) and not d.startswith(\".\")\n        ])\n\n        for action in tqdm.tqdm(actions, desc=f\"{pair} actions\", leave=False):\n            out_pkl = os.path.join(output_root, f\"{pair}__{action}.pkl\")\n            extract_hi4d_action_to_pkl(\n                dataset_root=dataset_root,\n                pair=pair,\n                action=action,\n                save_pkl_path=out_pkl,\n                downscaled_longerside=longside_resolution,\n                save_rerun_viz=True,\n                stream_rerun_viz=False,\n                skip_if_output_exists=True,\n            )\n"
  },
  {
    "path": "scripts/merge_comparison_mp4s.py",
    "content": "\"\"\"\nMerge MP4 files of different methods into a single side-by-side comparison,\nadding a small text bar for each method using Pillow + ImageClip\ninstead of MoviePy's TextClip (which requires ImageMagick).\n\nUsage: python merge_comparison_mp4s.py\n\"\"\"\n\nimport os\n\nimport numpy as np\nfrom PIL import Image, ImageDraw, ImageFont\nfrom moviepy.editor import (\n    VideoFileClip,\n    ImageClip,\n    clips_array,\n    CompositeVideoClip\n)\n\n\ndef create_title_image(text, width, height=50, bg_color=(255, 255, 255)):\n    \"\"\"\n    Creates a PIL Image of size (width x height) with the given text, centered.\n    Returns a NumPy array (H x W x 3).\n    \"\"\"\n    # Create a blank RGB image\n    img = Image.new(\"RGB\", (width, height), color=bg_color)\n    draw = ImageDraw.Draw(img)\n\n    # Choose a default font. If you have a TTF file, specify it here:\n    font = ImageFont.truetype(\"times_new_roman.ttf\", size=36)\n    # font = ImageFont.truetype(\"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf\", size=24)\n    # If you don't have a TTF file handy, ImageFont.load_default() is the fallback:\n    # font = ImageFont.load_default()\n\n    text_w, text_h = draw.textsize(text, font=font)\n    x = (width - text_w) // 2\n    y = (height - text_h) // 2\n    draw.text((x, y), text, fill=(0, 0, 0), font=font)\n\n    return np.array(img)\n\n\ndef merge_mp4s(mp4s_title_to_path_dict, merged_mp4_output_path, num_columns):\n    \"\"\"\n    Merges each input MP4 (which presumably has a 'first column' or 'second column'\n    that you want to extract) into a side-by-side comparison video, arranged in\n    multiple rows if num_columns < number_of_videos, AND places each method's\n    title bar above its own clip.\n\n    :param mp4s_title_to_path_dict: dict of {title: path_to_video}\n    :param merged_mp4_output_path: output MP4 path\n    :param num_columns: number of clips to display per row\n    \"\"\"\n    titles = list(mp4s_title_to_path_dict.keys())\n    raw_clips = []\n\n    # 1) Load each video and crop the relevant half-column\n    for title in titles:\n        path = mp4s_title_to_path_dict[title]\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"Video file not found: {path}\")\n        clip = VideoFileClip(path)\n\n        w, h = clip.size  # (width, height)\n        if \"GT\" in title:\n            # Crop the first column\n            sub_clip = clip.crop(x1=0, x2=w // 2, y1=0, y2=h)\n        else:\n            # Crop the second column\n            sub_clip = clip.crop(x1=w // 2, x2=w, y1=0, y2=h)\n\n        raw_clips.append((title, sub_clip))\n\n    # 2) For each sub-clip, create a small \"title bar\" on top\n    #    so each method has its own label above its clip.\n    bar_height = 50\n    titled_clips = []\n    for (title, subclip) in raw_clips:\n        # Create a bar image for the subclip width\n        title_img_array = create_title_image(title, subclip.w, bar_height)\n        title_iclip = ImageClip(title_img_array, duration=subclip.duration)\n\n        # Shift subclip downward by bar_height\n        subclip_shifted = subclip.set_position((0, bar_height))\n\n        # Composite them vertically: [title bar on top, subclip below]\n        comp_h = bar_height + subclip.h\n        comp_w = subclip.w\n        composite = CompositeVideoClip(\n            [title_iclip, subclip_shifted],\n            size=(comp_w, comp_h)\n        )\n\n        titled_clips.append(composite)\n\n    # 3) Normalize all titled_clips to the same height if they differ.\n    import math\n    min_height = min(tc.h for tc in titled_clips)\n    normalized_clips = []\n    for tc in titled_clips:\n        if tc.h != min_height:\n            scale = min_height / tc.h\n            new_w = int(tc.w * scale)\n            resized = tc.resize((new_w, min_height))\n            normalized_clips.append(resized)\n        else:\n            normalized_clips.append(tc)\n\n    # 4) Arrange the normalized clips in rows of length `num_columns`.\n    n = len(normalized_clips)\n    n_rows = math.ceil(n / num_columns)\n    rows = []\n    idx = 0\n    for _ in range(n_rows):\n        row_clips = normalized_clips[idx: idx + num_columns]\n        rows.append(row_clips)\n        idx += num_columns\n\n    # 5) Stack them using clips_array\n    final_clip = clips_array(rows)\n\n    # 6) Write to output\n    final_clip.write_videofile(\n        merged_mp4_output_path,\n        fps=12,\n        codec=\"libx264\",\n        threads=4  # adjust as needed\n    )\n    print(f\"✅ Merged video saved successfully to {merged_mp4_output_path}\")\n\n\nif __name__ == '__main__':\n    for selection in [\"A\", \"B\", \"C\"]:\n        if selection == \"A\":\n            datasets_seq = [\n                *[(\"kubric-multiview-v3-views0123-novelviews4\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"panoptic-multiview-views1_7_14_20-novelviews24\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"panoptic-multiview-views1_7_14_20-novelviews27\", seq) for seq in [0, 3, 4, 5]],\n            ]\n            mp4s = {\n                \"GT\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"Dynamic 3DGS\": \"logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"Shape of Motion\": \"logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"LocoTrack\": \"logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"CoTracker3\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"DELTA\": \"logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"SpaTracker-1\": \"logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"SpaTracker\": \"logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4\",\n                # \"SpaTracker-3\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4\",\n                \"Triplane Baseline\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                # \"Triplane-2\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                \"Ours\": \"logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4\",\n            }\n        elif selection == \"B\":\n            datasets_seq = [\n                *[(\"dex-ycb-multiview-duster0123-novelviews4\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"dex-ycb-multiview-duster0123-novelviews5\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"dex-ycb-multiview-duster0123-novelviews6\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"dex-ycb-multiview-duster0123-novelviews7\", seq) for seq in [0, 3, 4, 5]],\n            ]\n            mp4s = {\n                \"GT\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"Dynamic 3DGS\": \"logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"Shape of Motion\": \"logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"LocoTrack\": \"logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"CoTracker3\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"DELTA\": \"logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"SpaTracker-1\": \"logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"SpaTracker-2\": \"logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4\",\n                \"SpaTracker\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4\",\n                # \"Triplane-1\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                \"Triplane Baseline\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                \"Ours\": \"logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4\",\n            }\n        elif selection == \"C\":\n            datasets_seq = [\n                *[(\"dex-ycb-multiview-duster2345-novelviews7\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"dex-ycb-multiview-duster4567-novelviews7\", seq) for seq in [0, 3, 4, 5]],\n                *[(\"dex-ycb-multiview-duster4567-novelviews0\", seq) for seq in [0, 3, 4, 5]],\n            ]\n            mp4s = {\n                \"GT\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"Dynamic 3DGS\": \"logs/dynamic_3dgs/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"Shape of Motion\": \"logs/shape_of_motion/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"LocoTrack\": \"logs/locotrack/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"CoTracker3\": \"logs/cotracker3-online/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                \"DELTA\": \"logs/delta/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"SpaTracker-1\": \"logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_-1.mp4\",\n                # \"SpaTracker-2\": \"logs/kubric_v3/multiview-adapter-002/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_69799.mp4\",\n                \"SpaTracker\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_90799.mp4\",\n                # \"Triplane-1\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                \"Triplane Baseline\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_99999.mp4\",\n                \"Ours\": \"logs/kubric_v3_augs/mvtracker/eval_{dataset}/comparison_v4b-novel__seq-{seq}_step_159999.mp4\",\n            }\n        else:\n            raise ValueError(f\"Invalid selection: {selection}\")\n\n        for dataset, seq in datasets_seq:\n            mp4s_title_to_path_dict = {\n                key: path.format(dataset=dataset, seq=seq)\n                for key, path in mp4s.items()\n            }\n            if not mp4s_title_to_path_dict:\n                print(f\"⚠️ Warning: No valid MP4 files found for dataset {dataset} seq {seq}. Skipping...\")\n                continue\n            merged_mp4 = f\"logs/comparison_v4__{dataset}__seq-{seq}.mp4\"\n            merge_mp4s(mp4s_title_to_path_dict, merged_mp4, num_columns=3)\n"
  },
  {
    "path": "scripts/panoptic_studio_preprocessing.py",
    "content": "\"\"\"\nThis script will convert the Panoptic Studio subset of TAPVid-3D to multi-view 3D point tracking dataset.\n\nFirst, follow the instructions at https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid3d\nto download the raw panoptic studio data, for example, as follows:\n```bash\n# Set up a temporary environment\nconda create -n panoptic-preprocessing python=3.10.12 -y\nconda activate panoptic-preprocessing\npip install \"git+https://github.com/google-deepmind/tapnet.git#egg=tapnet[tapvid3d_eval,tapvid3d_generation]\"\n\n# Download the raw data\npython -m tapnet.tapvid3d.annotation_generation.generate_pstudio --output_dir datasets/panoptic_studio_tapvid3d\nmkdir datasets/panoptic-multiview\nmv datasets/panoptic_studio_tapvid3d/tmp/data/* datasets/panoptic-multiview/\n\n# If you like, you can remove the temporary environment now\nconda deactivate\nconda env remove -n panoptic-preprocessing\n```\n\nFollowing https://github.com/JonathonLuiten/Dynamic3DGaussians#run-visualizer-on-pretrained-models,\ndownload and unzip the pretrained Dynamic3DGS checkpoints, e.g. as follows:\n```bash\nwget https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/output.zip -O checkpoints/output.zip\nunzip checkpoints/output.zip -d checkpoints/\nrm checkpoints/output.zip\nmv checkpoints/output/pretrained checkpoints/dynamic3dgs_pretrained\n```\n\nInstall the missing dependencies needed by Dynamic3DGS:\n```bash\nconda activate 3dpt\nconda install -c conda-forge gcc_linux-64=11.3.0 gxx_linux-64=11.3.0 gcc=11.3.0 gxx=11.3.0 -y\npip install git+https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth.git\npip install open3d==0.16.0\n```\n\nNow you can run this script to generate the Dynamic3DGS depths and merge the TAP-Vid3D annotations:\n```bash\npython -m scripts.panoptic_studio_preprocessing \\\n  --dataset_root ./datasets/panoptic-multiview \\\n  --checkpoint_root ./checkpoints/dynamic3dgs_pretrained \\\n  --tapvid3d_root ./datasets/panoptic_studio_tapvid3d\n```\n\nThe processed dataset is now stored in ./datasets/panoptic-multiview.\nIf you'd like, you can remove the raw tapvid3d data now to save space:\n```bash\nrm -rf ./datasets/panoptic_studio_tapvid3d\n```\n\"\"\"\n\nimport argparse\nfrom pathlib import Path\nfrom tqdm import tqdm\n\nfrom mvtracker.models.core.dynamic3dgs.export_depths_from_pretrained_checkpoint import export_depth\nfrom mvtracker.models.core.dynamic3dgs.merge_tapvid3d_per_camera_annotations import merge_annotations\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Preprocess Panoptic Studio TAPVid-3D subset.\")\n    parser.add_argument(\"--dataset_root\", type=Path, required=True,\n                        help=\"Root path to Panoptic Studio dataset (per-sequence folders).\")\n    parser.add_argument(\"--checkpoint_root\", type=Path, required=True,\n                        help=\"Root path to Dynamic3DGS pretrained checkpoints (per-sequence).\")\n    parser.add_argument(\"--tapvid3d_root\", type=Path, required=True,\n                        help=\"Root path to TAPVid-3D annotations for Panoptic Studio.\")\n    return parser.parse_args()\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    sequences = [\"basketball\", \"boxes\", \"football\", \"juggle\", \"softball\", \"tennis\"]\n\n    print(\"Exporting depths from pretrained checkpoints\")\n    for sequence_name in tqdm(sequences):\n        scene_root = args.dataset_root / sequence_name\n        output_path = scene_root / \"dynamic3dgs_depth\"\n        checkpoint_path = args.checkpoint_root / sequence_name\n        export_depth(scene_root, output_path, checkpoint_path)\n\n    print(\"Merging TAP-Vid3D per-camera annotations.\")\n    for sequence_name in tqdm(sequences):\n        scene_root = args.dataset_root / sequence_name\n        checkpoint_path = args.checkpoint_root / sequence_name\n        tapvid3d_annotation_paths = list(args.tapvid3d_root.glob(f\"{sequence_name}_*.npz\"))\n        merge_annotations(\n            scene_root,\n            checkpoint_path,\n            tapvid3d_annotation_paths,\n            skip_if_output_already_exists=True,\n            rerun_logging=True\n        )\n"
  },
  {
    "path": "scripts/plot_aj_for_varying_depth_noise_levels.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport seaborn as sns\n\n\n# set_size from https://jwalton.info/Embed-Publication-Matplotlib-Latex/\ndef set_size(width, fraction=1, golden_ratio=(5 ** .5 - 1) / 2):\n    \"\"\"Set figure dimensions to avoid scaling in LaTeX.\n\n    Parameters\n    ----------\n    width: float\n            Document textwidth or columnwidth in pts\n    fraction: float, optional\n            Fraction of the width which you wish the figure to occupy\n\n    Returns\n    -------\n    fig_dim: tuple\n            Dimensions of figure in inches\n    \"\"\"\n    # Width of figure (in pts)\n    fig_width_pt = width * fraction\n\n    # Convert from pt to inches\n    inches_per_pt = 1 / 72.27\n\n    # Golden ratio to set aesthetic figure height\n    # https://disq.us/p/2940ij3\n    # golden_ratio = (5 ** .5 - 1) / 2\n\n    # Figure width in inches\n    fig_width_in = fig_width_pt * inches_per_pt\n    # Figure height in inches\n    fig_height_in = fig_width_in * golden_ratio\n\n    fig_dim = (fig_width_in, fig_height_in)\n\n    return fig_dim\n\n\ndef setup_plot():\n    sns.set_theme(style=\"whitegrid\")\n    sns.set_palette(\"tab10\")\n    plt.rcParams[\"font.family\"] = \"Times New Roman\"\n    plt.rcParams['font.weight'] = 'normal'\n\n\ndef plot_aj(\n        save_name='plot_robustness_to_depth_noise.pdf',\n        width_in_paper_pts=237.13594,  # \\showthe\\linewidth --> > 237.13594pt.\n        linewidth=1.5,\n        marker_size=5,\n        label_font_size=9,\n        tick_font_size=9,\n        legend_font_size=7,\n        dpi=400,\n        results_dir=None,\n        save_svg=False,\n):\n    setup_plot()\n\n    fig, ax = plt.subplots(figsize=set_size(width_in_paper_pts, golden_ratio=0.3), dpi=dpi)\n\n    x_labels = ['0', '1', '2', '5', '10', '20', '50', '100', '200']\n    x = np.arange(len(x_labels))\n    # x = np.array([0, 1, 2, 5, 10, 20, 50, 100, 200])\n    # x_labels = ['0', '1', '2', '5', '10', '20', '50', '100', '200']\n\n    results = {\n        \"Ours\": [81.6, 80.7, 77.4, 69.8, 63.1, 59.3, 56.1, 54.3, 52.8],\n        \"Triplane\": [75.4, 75.0, 73.7, 69.2, 63.4, 57.4, 51.5, 49.1, 47.6],\n        \"SpaTracker\": [65.5, 63.8, 62.1, 58.7, 55.8, 52.6, 48.6, 45.4, 43.3],\n        \"DELTA\": [57.4, 51.8, 46.2, 34.3, 23.8, 13.2, 5.0, 2.3, 1.0],\n    }\n\n    for label, y in results.items():\n        sns.lineplot(x=x, y=y, ax=ax, linewidth=linewidth, marker='o', markersize=marker_size, label=label)\n    # ax.axhline(y=47.2, color=sns.color_palette(\"tab10\")[1], linestyle='--', linewidth=1.5, label='Blind Baseline')\n\n    ax.set_xticks(x)\n    ax.set_xticklabels(x_labels)\n    ax.set_yticks(np.arange(40, 90, 10))\n    ax.set_ylim([40, 83])\n    ax.tick_params(axis='both', which='major', labelsize=tick_font_size)\n\n    ax.set_xlabel('Depth Noise (σ, in cm)', fontsize=label_font_size, fontweight='normal', labelpad=0)\n    ax.set_ylabel('AJ', fontsize=label_font_size, fontweight='normal', labelpad=2)\n\n    for spine in ax.spines.values():\n        spine.set_color('black')\n\n    ax.grid(axis='y', color='lightgrey')\n\n    ax.tick_params(axis=\"y\", direction=\"in\")\n    ax.tick_params(axis=\"x\", direction=\"in\")\n\n    legend = plt.legend(\n        frameon=True,\n        fancybox=False,\n        loc=(0.675, 0.265),\n        # loc=\"upper right\",\n        prop={'size': legend_font_size},\n        handletextpad=0.2,\n        labelspacing=0.1,\n    )\n    # legend.get_frame().set_facecolor('white')\n    # legend.get_frame().set_edgecolor('black')\n\n    plt.tight_layout(pad=0)\n\n    if save_name:\n        if results_dir:\n            os.makedirs(results_dir, exist_ok=True)\n            save_name = os.path.join(results_dir, save_name)\n\n        plt.savefig(save_name, bbox_inches='tight', pad_inches=0)\n        if save_svg:\n            plt.savefig(save_name.replace('.pdf', '.svg'), bbox_inches='tight', pad_inches=0)\n\n    plt.show()\n\n\nif __name__ == '__main__':\n    plot_aj()\n"
  },
  {
    "path": "scripts/plot_aj_for_varying_n_of_views.py",
    "content": "import os\n\nimport matplotlib.pyplot as plt\nimport matplotlib.ticker as ticker\nimport numpy as np\nimport seaborn as sns\n\n\n# set_size from https://jwalton.info/Embed-Publication-Matplotlib-Latex/\ndef set_size(width, fraction=1):\n    \"\"\"Set figure dimensions to avoid scaling in LaTeX.\n\n    Parameters\n    ----------\n    width: float\n            Document textwidth or columnwidth in pts\n    fraction: float, optional\n            Fraction of the width which you wish the figure to occupy\n\n    Returns\n    -------\n    fig_dim: tuple\n            Dimensions of figure in inches\n    \"\"\"\n    # Width of figure (in pts)\n    fig_width_pt = width * fraction\n\n    # Convert from pt to inches\n    inches_per_pt = 1 / 72.27\n\n    # Golden ratio to set aesthetic figure height\n    # https://disq.us/p/2940ij3\n    golden_ratio = (5 ** .5 - 1) / 2\n\n    # Figure width in inches\n    fig_width_in = fig_width_pt * inches_per_pt\n    # Figure height in inches\n    fig_height_in = fig_width_in * golden_ratio\n\n    fig_dim = (fig_width_in, fig_height_in)\n\n    return fig_dim\n\n\ndef setup_plot():\n    sns.set_theme(style=\"whitegrid\")\n    sns.set_palette(\"tab10\")\n    plt.rcParams[\"font.family\"] = \"Times New Roman\"\n    plt.rcParams['font.weight'] = 'normal'\n\n\ndef plot_aj(\n        save_name='plot_number_of_views.pdf',\n        width_in_paper_pts=237.13594,  # \\showthe\\linewidth --> > 237.13594pt.\n        linewidth=1.5,\n        marker_size=5,\n        label_font_size=9,\n        tick_font_size=9,\n        legend_font_size=7,\n        dpi=400,\n        results_dir=None,\n        save_svg=False,\n):\n    setup_plot()\n\n    fig, ax = plt.subplots(figsize=set_size(width_in_paper_pts), dpi=dpi)\n\n    x = np.arange(1, 9)\n\n    y_data = {\n        \"MVTracker (ours)\": [64.0, 66.8, 73.2, 71.1, 77.4, 76.7, 77.3, 79.2],\n        \"Triplane\": [44.0, 48.0, 56.0, 57.6, 63.5, 64.5, 65.5, 66.8],\n        # \"TAPIP3D\": [36.6, 35.6, 40.5, 38.8, 57.7, 54.2, 55.2, 56.4],\n        # \"SpatialTrackerV2\": [39.8, 39.5, 36.5, 35.5, 41.1, 37.1, 37.0, 37.7],\n        \"SpatialTracker\": [60.6, 58.4, 61.8, 58.3, 63.2, 62.4, 62.9, 63.4],\n        \"CoTracker3\": [28.6, 27.0, 29.5, 29.4, 39.1, 37.5, 37.1, 37.3],\n        # \"CoTracker2\": [29.8, 26.4, 29.2, 28.8, 37.8, 36.2, 36.0, 36.0],\n        \"DELTA\": [33.0, 34.3, 38.0, 36.5, 37.2, 35.4, 34.9, 35.7],\n        \"LocoTrack\": [27.9, 26.0, 28.1, 27.8, 36.3, 34.8, 34.7, 34.9]\n    }\n\n    for label, y in y_data.items():\n        sns.lineplot(x=x, y=y, label=label, ax=ax, linewidth=linewidth, marker='o', markersize=marker_size)\n\n    ax.set_xticks(x)\n    ax.set_yticks(np.arange(30, 81, 10))\n    ax.set_ylim([25, 80])\n    ax.xaxis.set_major_formatter(ticker.ScalarFormatter())\n    ax.tick_params(axis='both', which='major', labelsize=tick_font_size)\n\n    ax.set_xlabel('Number of Views', fontsize=label_font_size, fontweight='normal', labelpad=0)\n    ax.set_ylabel('Average Jaccard (AJ)', fontsize=label_font_size, fontweight='normal', labelpad=2)\n\n    for spine in ax.spines.values():\n        spine.set_color('black')\n\n    ax.grid(axis='y', color='lightgrey')\n\n    ax.tick_params(axis=\"y\", direction=\"in\")\n    ax.tick_params(axis=\"x\", direction=\"in\")\n\n    legend = plt.legend(\n        frameon=True,\n        fancybox=False,\n        loc=(0.625, 0.265),\n        prop={'size': legend_font_size},\n        handletextpad=0.2,\n        labelspacing=0.1,\n\n    )\n    # legend.get_frame().set_facecolor('white')\n    # legend.get_frame().set_edgecolor('black')\n\n    plt.tight_layout(pad=0)\n\n    if save_name:\n        if results_dir:\n            os.makedirs(results_dir, exist_ok=True)\n            save_name = os.path.join(results_dir, save_name)\n\n        plt.savefig(save_name, bbox_inches='tight', pad_inches=0)\n        if save_svg:\n            plt.savefig(save_name.replace('.pdf', '.svg'), bbox_inches='tight', pad_inches=0)\n\n    plt.show()\n\n\nif __name__ == '__main__':\n    plot_aj()\n"
  },
  {
    "path": "scripts/profiling.md",
    "content": "# Profiling Notes\n\nThis document summarizes how to run performance profiling using PyTorch’s built-in tools, and how to interpret the results.\n\nTo profile one training iteration (forward + backward + optimizer step), the following snippet can be used:\n\n```python\nfrom torch.profiler import profile, ProfilerActivity\n\nwith profile(\n    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n    with_stack=True,\n    with_flops=True,\n    profile_memory=True,\n    record_shapes=True,\n) as prof:\n    # one iteration of fwd + bwd + optimize\n    pass\n\nprint(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=36))\nprint(prof.key_averages().table(sort_by=\"self_cuda_time_total\", row_limit=36))\nprint(prof.key_averages().table(sort_by=\"self_cpu_time_total\", row_limit=36))\nprint(prof.key_averages().table(sort_by=\"self_cuda_memory_usage\", row_limit=36))\n\nprof.export_chrome_trace(\"trace.json\")\nbreakpoint()\n```\n\nThe printed summary tables produced by `.key_averages()` are already extremely informative. Sorting by `cuda_time_total` highlights which operations dominate the runtime on GPU. The `self_cuda_memory_usage` sort can reveal the main contributors to memory consumption. Enabling `record_shapes=True` further helps diagnose operations that unexpectedly receive large tensors.\n\nFor more detailed inspection, the exported trace file `trace.json` can be opened in Chrome at `chrome://tracing`. This presents a flamegraph-style timeline of kernel launches, memory activity, and execution order, which can be especially helpful for understanding the global structure and scheduling behavior of the code. For a brief tutorial on how to navigate this view, see [here](https://www.youtube.com/watch?v=AhIOohJYSrw).\n\nNote that the trace file can become very large (e.g., over 1 GB) depending on how much code is profiled. This may slow down both trace generation and visualization. While `.key_averages()` provides a fast, summary-level view that is often sufficient for identifying key bottlenecks, the flamegraph timeline can be equally valuable for temporal insights."
  },
  {
    "path": "scripts/selfcap_preprocessing.py",
    "content": "\"\"\"\nSelfCap dataset (https://zju3dv.github.io/longvolcap/)\n\nDownload the dataset (but first fill in the form at https://forms.gle/MzJqZjBfyZ53fRMZ7):\n```bash\nmkdir -p datasets/selfcap\ncd datasets/selfcap\n\ngdown --fuzzy https://drive.google.com/file/d/1iTr6sTVQoCtTK4FbA3lRxMrh7sC0MhzP/view?usp=share_link  # LICENSE\ngdown --fuzzy https://drive.google.com/file/d/1cg54hE_IBsnVXuMCj44JCQEGnqU1Hr5b/view?usp=share_link  # yoga-calib.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1l84Pna4eO9m_bql2mR8nm6VnLO80e717/view?usp=share_link  # hair-calib.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1Desj7th500-vsyRYzRq8Xb6TtUgDPU4u/view?usp=share_link  # README.md\ngdown --fuzzy https://drive.google.com/file/d/1Ex3OtLmz6kBbgB84MImlDLJpVE6vI3ks/view?usp=share_link  # bike-release.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1muPLxdCm4il_X6TRVLaxx-6sYO6XYIwH/view?usp=share_link  # yoga-release.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/12mRUCpaTk1XearBq2hUIf5ZbHZw4AQAw/view?usp=share_link  # dance-release.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1AEiQBC9CIthR97qZeZzkH2nlXXpogfxH/view?usp=share_link  # hair-release.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1NFrHh-SxUER4jWBV0irnCcDhEmkg3WUg/view?usp=share_link  # corgi-release.tar.gz\ngdown --fuzzy https://drive.google.com/file/d/1b9Hf3YY_usPrtddgpMe569dSqh0bEGLo/view?usp=share_link  # bar-release.tar.gz\n\ntar xvf bar-release.tar.gz\ntar xvf bike-release.tar.gz\ntar xvf corgi-release.tar.gz\ntar xvf dance-release.tar.gz\ntar xvf hair-calib.tar.gz\ntar xvf hair-release.tar.gz\ntar xvf yoga-calib.tar.gz\ntar xvf yoga-release.tar.gz\n\nrm *.tar.gz\n\ncd -\n```\nRunning the script: `PYTHONPATH=/local/home/frrajic/xode/duster:$PYTHONPATH python -m scripts.selfcap_preprocessing`\nNote that you need to set up dust3r first, see docstring of `scripts/estimate_depth_with_duster.py`.\n\"\"\"\n\nimport concurrent.futures\nimport json\nimport os\nimport pickle\nfrom typing import Optional\n\nimport cv2\nimport numpy as np\nimport rerun as rr\nfrom scipy.spatial.transform import Rotation as R\nfrom tqdm import tqdm\n\nfrom scripts.egoexo4d_preprocessing import main_estimate_duster_depth\n\n\ndef main_preprocess_selfcap(\n        dataset_root: str,\n        scene_name: str,\n        outputs_dir: str,\n        num_cameras: Optional[int] = None,\n        sample_cameras_sequentially: Optional[bool] = False,\n        start_frame: Optional[int] = None,\n        max_frames: Optional[int] = None,\n        frames_downsampling_factor: Optional[int] = None,\n        downscaled_longerside: Optional[int] = None,\n        save_rerun_viz: bool = True,\n        stream_rerun_viz: bool = False,\n        skip_if_output_exists: bool = True,\n):\n    # Skip if output exists\n    save_pkl_path = os.path.join(outputs_dir, f\"{scene_name}.pkl\")\n    if skip_if_output_exists and os.path.exists(save_pkl_path):\n        print(f\"Skipping {save_pkl_path} since it already exists\")\n        print()\n        return save_pkl_path\n    else:\n        print(f\"Processing {scene_name}...\")\n\n    # --- Load calibration ---\n    calib_dir = os.path.join(dataset_root, f\"{scene_name}-calib\", \"optimized\")\n    intri_path = os.path.join(calib_dir, \"intri.yml\")\n    extri_path = os.path.join(calib_dir, \"extri.yml\")\n    sync_path = os.path.join(calib_dir, \"sync.json\")\n\n    assert all(os.path.exists(p) for p in [intri_path, extri_path, sync_path])\n\n    intri_fs = cv2.FileStorage(intri_path, cv2.FILE_STORAGE_READ)\n    extri_fs = cv2.FileStorage(extri_path, cv2.FILE_STORAGE_READ)\n    with open(sync_path) as f:\n        sync_data = json.load(f)\n\n    # --- Load videos ---\n    video_dir = os.path.join(dataset_root, f\"{scene_name}-release\", \"videos\")\n    cam_names = sorted([f.replace(\".mp4\", \"\") for f in os.listdir(video_dir) if f.endswith(\".mp4\")])\n\n    if num_cameras is not None and num_cameras < len(cam_names):\n        if sample_cameras_sequentially:\n            cam_names = cam_names[:num_cameras]\n        else:\n            step = len(cam_names) / num_cameras\n            cam_names = [cam_names[int(i * step)] for i in range(num_cameras)]\n\n    rgbs, intrs, extrs = {}, {}, {}\n\n    def load_cam_video(cam):\n        vid_path = os.path.join(video_dir, f\"{cam}.mp4\")\n        cap = cv2.VideoCapture(vid_path)\n        fps = cap.get(cv2.CAP_PROP_FPS)\n        offset = int(round(sync_data[cam] * fps))\n\n        frames = []\n        i = 0\n        while cap.isOpened():\n            ret, frame = cap.read()\n            if not ret:\n                break\n            idx = i - offset\n            i += 1\n            if idx < 0:\n                continue\n            if start_frame is not None and idx < start_frame:\n                continue\n            if frames_downsampling_factor and ((idx - start_frame) % frames_downsampling_factor != 0):\n                continue\n            if max_frames and len(frames) >= max_frames:\n                break\n            img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).transpose(2, 0, 1)\n            frames.append(img)\n        cap.release()\n\n        if not frames:\n            return None, None, None\n\n        rgb = np.stack(frames)\n        intr = intri_fs.getNode(f\"K_{cam}\").mat().astype(np.float32)\n        R = extri_fs.getNode(f\"Rot_{cam}\").mat().astype(np.float32)\n        T = extri_fs.getNode(f\"T_{cam}\").mat().astype(np.float32).reshape(3)\n        extr = np.concatenate([R, T[:, None]], axis=1)\n\n        return cam, rgb, intr, extr\n\n    # Run parallel loading\n    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:\n        futures = [executor.submit(load_cam_video, cam) for cam in cam_names]\n        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):\n            cam, rgb, intr, extr = future.result()\n            if cam is None:\n                print(\"Warning: camera skipped due to no usable frames.\")\n                continue\n            rgbs[cam] = rgb\n            intrs[cam] = intr\n            extrs[cam] = extr\n\n    intri_fs.release()\n    extri_fs.release()\n\n    # Apply a global -90° rotation around X axis to the scene\n    rot_x = R.from_euler('x', -90, degrees=True).as_matrix()\n    rot_y = R.from_euler('y', 0, degrees=True).as_matrix()\n    rot_z = R.from_euler('z', 0, degrees=True).as_matrix()\n    rot = rot_z @ rot_y @ rot_x\n    T_rot = np.eye(4)\n    T_rot[:3, :3] = rot\n    for cam in extrs:\n        extrs_square = np.eye(4, dtype=extrs[cam].dtype)\n        extrs_square[:3, :] = extrs[cam]\n        extrs_trans_square = np.einsum('ki,ij->kj', extrs_square, T_rot.T)\n        extrs_trans = extrs_trans_square[..., :3, :]\n        assert np.allclose(extrs_trans_square[..., 3, 3], np.ones_like(extrs_trans_square[..., 3, 3]))\n        extrs[cam] = extrs_trans\n\n    print(f\"Loaded SelfCap scene '{scene_name}' with {len(cam_names)} cams and {rgbs[cam_names[0]].shape[0]} frames.\")\n\n    # Check shapes\n    n_frames, _, h, w = rgbs[cam_names[0]].shape\n    for cam_name in cam_names:\n        assert rgbs[cam_name].shape == (n_frames, 3, h, w)\n        assert intrs[cam_name].shape == (3, 3)\n        assert extrs[cam_name].shape == (3, 4)\n\n    # Save downsized version\n    if downscaled_longerside is not None:\n        print(f\"Downscaling to longer side {downscaled_longerside}\")\n        for cam_name in tqdm(cam_names, desc=\"Downscaling\"):\n            _, _, h, w = rgbs[cam_name].shape\n            scale = downscaled_longerside / max(h, w)\n            new_h, new_w = int(h * scale), int(w * scale)\n\n            resized = []\n            for img in rgbs[cam_name]:\n                img = img.transpose(1, 2, 0)  # CHW -> HWC\n                img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)\n                resized.append(img.transpose(2, 0, 1))  # HWC -> CHW\n            rgbs[cam_name] = np.stack(resized)\n\n            # scale intrinsics\n            intrs[cam_name][:2] *= scale\n\n    # Save processed output to a pickle file\n    os.makedirs(outputs_dir, exist_ok=True)\n    with open(save_pkl_path, \"wb\") as f:\n        pickle.dump(\n            dict(\n                rgbs=rgbs,\n                intrs=intrs,\n                extrs=extrs,\n                ego_cam_name=None,\n            ),\n            f,\n            protocol=pickle.HIGHEST_PROTOCOL,\n        )\n    print(f\"Saved {save_pkl_path}\")\n\n    # Visualize the data sample using rerun\n    rerun_modes = []\n    if stream_rerun_viz:\n        rerun_modes += [\"stream\"]\n    if save_rerun_viz:\n        rerun_modes += [\"save\"]\n    for rerun_mode in rerun_modes:\n        rr.init(f\"3dpt\", recording_id=\"v0.16\")\n        if rerun_mode == \"stream\":\n            rr.connect_tcp()\n\n        rr.log(\"world\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n        rr.set_time_seconds(\"frame\", 0)\n        rr.log(\n            \"world/xyz\",\n            rr.Arrows3D(\n                vectors=[[1, 0, 0], [0, 2, 0], [0, 0, 3]],\n                colors=[[255, 0, 0], [0, 255, 0], [0, 0, 255]],\n            ),\n        )\n\n        fps = 30\n        for frame_idx in range(min(n_frames, 30)):\n            rr.set_time_seconds(\"frame\", frame_idx / fps)\n\n            for cam_name in cam_names:\n                extr = extrs[cam_name]\n                intr = intrs[cam_name]\n                img = rgbs[cam_name][frame_idx].transpose(1, 2, 0).astype(np.uint8)\n\n                # Camera pose logging\n                E = extr if extr.shape == (3, 4) else extr[0]\n                T = np.eye(4)\n                T[:3, :] = E\n                T_world_cam = np.linalg.inv(T)\n                rr.log(f\"{cam_name}/image\", rr.Transform3D(\n                    translation=T_world_cam[:3, 3],\n                    mat3x3=T_world_cam[:3, :3],\n                ))\n\n                # Intrinsics and image\n                rr.log(f\"{cam_name}/image\", rr.Pinhole(\n                    image_from_camera=intr,\n                    width=img.shape[1],\n                    height=img.shape[0]\n                ))\n                rr.log(f\"{cam_name}/image\", rr.Image(img))\n\n        if rerun_mode == \"save\":\n            save_rrd_path = os.path.join(outputs_dir, f\"rerun__{scene_name}.rrd\")\n            rr.save(save_rrd_path)\n            print(f\"Saved rerun viz to {os.path.abspath(save_rrd_path)}\")\n\n    return save_pkl_path\n\n\nif __name__ == '__main__':\n    dataset_root = \"datasets/selfcap/\"\n    outputs_dir = \"datasets/selfcap-processed/\"\n\n    for scene_name in [\"yoga\", \"hair\"]:\n        for num_cameras, sequential_cams, start_frame, max_frames, frames_downsampling_factor, downscaled_longerside in [\n            (8, False, 90, 256, 10, 512),\n            (8, True, 90, 256, 10, 512),\n            (8, False, 90, 2560, 10, 512),\n\n            (4, False, 90, 256, 10, 512),\n            (4, True, 90, 256, 10, 512),\n\n            (16, False, 90, 256, 10, 512),\n            (16, True, 90, 256, 10, 512),\n            (16, True, 90, 2560, 10, 512),\n\n            (8, False, 90, 256, 1, 512),\n            (8, False, 90, 2560, 1, 512),\n            (8, False, 90, 256, 5, 512),\n            (8, False, 90, 256, 20, 512),\n            (8, False, 90, 256, 30, 512),\n        ]:\n            # Extract rgbs, intrs, extrs from SelfCap\n            outputs_subdir = os.path.join(\n                outputs_dir, f\"numcams-{num_cameras}-seq-{sequential_cams}_\"\n                             f\"startframe-{start_frame}_\"\n                             f\"maxframes-{max_frames}_\"\n                             f\"downsample-{frames_downsampling_factor}_\"\n                             f\"downscale-{downscaled_longerside}\"\n            )\n            scene_pkl = main_preprocess_selfcap(\n                dataset_root=dataset_root,\n                scene_name=scene_name,\n                outputs_dir=outputs_subdir,\n                num_cameras=num_cameras,\n                sample_cameras_sequentially=sequential_cams,\n                start_frame=start_frame,\n                max_frames=max_frames,\n                frames_downsampling_factor=frames_downsampling_factor,\n                downscaled_longerside=downscaled_longerside,\n            )\n\n            # Run Dust3r to estimate depths from rgbs, fix the known intrs and extrs during multi-view stereo optim\n            depth_subdir = os.path.join(outputs_subdir, f\"duster_depths__{scene_name}\")\n            main_estimate_duster_depth(\n                pkl_scene_file=scene_pkl,\n                depths_output_dir=depth_subdir,\n            )\n\n            # Run VGGT to estimate depths from rgbs, align with the known extrs afterward\n            ...\n"
  },
  {
    "path": "scripts/slurm/eval.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=eval-058\n#SBATCH --nodes=1\n#SBATCH --ntasks-per-node=1\n#SBATCH --cpus-per-task=32\n#SBATCH --gres=gpu:1\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a-a03\n#SBATCH --time=00:10:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --array=0-85\n\nset -x\ncat $0\nDIR=$(realpath .)\nmkdir -p $DIR/runs\n\nCKPTS=(\n# \"experiment_path=logs/eval/copycat     model=copycat\"\n# \"experiment_path=logs/dynamic_3dgs     model=locotrack\"\n# \"experiment_path=logs/shape_of_motion  model=locotrack\"\n#\n# \"experiment_path=logs/eval/tapip3d            model=tapip3d\"\n# \"experiment_path=logs/eval/scenetracker       model=scenetracker\"\n# \"experiment_path=logs/eval/locotrack          model=locotrack\"\n# \"experiment_path=logs/eval/delta              model=delta\"\n# \"experiment_path=logs/eval/cotracker2_online  model=cotracker2_online\"\n# \"experiment_path=logs/eval/cotracker3_online  model=cotracker3_online\"\n#\n# \"experiment_path=logs/eval/spatracker_monocular_pretrained       model=spatracker_monocular_pretrained restore_ckpt_path=checkpoints/spatracker_monocular_original-authors-ckpt.pth\"\n# \"experiment_path=logs/eval/spatracker_monocular_kubric-training  model=spatracker_monocular            restore_ckpt_path=checkpoints/spatracker_monocular_trained-on-kubric-depth_069800.pth\"\n# \"experiment_path=logs/eval/spatracker_monocular_duster-training  model=spatracker_monocular            restore_ckpt_path=checkpoints/spatracker_monocular_trained-on-duster-depth_090800.pth\"\n# \"experiment_path=logs/eval/spatracker_multiview_kubric-training  model=spatracker_multiview            restore_ckpt_path=checkpoints/spatracker_multiview_trained-on-kubric-depth_100000.pth model.triplane_xres=128 model.triplane_yres=128 model.triplane_zres=128\"\n# \"experiment_path=logs/eval/spatracker_multiview_duster-training  model=spatracker_multiview            restore_ckpt_path=checkpoints/spatracker_multiview_trained-on-duster-depth_100000.pth model.triplane_xres=256 model.triplane_yres=256 model.triplane_zres=128\"\n#\n# \"experiment_path=logs/eval/mvtracker-v0_kubric-training  model=mvtracker  restore_ckpt_path=checkpoints/mvtracker_v0_trained-on-kubric-depth_091600.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed model.fmaps_dim=384 model.hidden_size=384 model.num_heads=8\"\n# \"experiment_path=logs/eval/mvtracker-v0_duster-training  model=mvtracker  restore_ckpt_path=checkpoints/mvtracker_v0_trained-on-duster-depth_100000.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed model.fmaps_dim=384\"\n# \"experiment_path=logs/eval/mvtracker-iccv-march2025      model=mvtracker  restore_ckpt_path=checkpoints/mvtracker_160000_march2025.pth model.updatetransformer_type=spatracker model.apply_sigmoid_to_vis=true trainer.precision=16-mixed \"\n \"experiment_path=logs/mvtracker                          model=mvtracker  restore_ckpt_path=checkpoints/mvtracker_200000_june2025.pth\"\n)\nDATASETS=(\n############################\n### ~~~ Main results ~~~ ###\n############################\n  dex-ycb-multiview\n  dex-ycb-multiview-duster0123\n  dex-ycb-multiview-duster0123cleaned\n  panoptic-multiview-views1_7_14_20\n  panoptic-multiview-views27_16_14_8\n  panoptic-multiview-views1_4_7_11\n  kubric-multiview-v3-views0123\n  kubric-multiview-v3-duster0123\n  kubric-multiview-v3-duster0123cleaned\n  tapvid2d-davis-mogewithextrinsics-256x256\n\n#############################\n### ~~~ 2DPT Ablation ~~~ ###\n#############################\n  dex-ycb-multiview-2dpt\n  dex-ycb-multiview-duster0123-2dpt\n  panoptic-multiview-views1_7_14_20-2dpt\n  panoptic-multiview-views27_16_14_8-2dpt\n  panoptic-multiview-views1_4_7_11-2dpt\n  kubric-multiview-v3-views0123-2dpt\n  kubric-multiview-v3-duster0123-2dpt\n\n####################################\n### ~~~ Single-point results ~~~ ###\n####################################\n#  dex-ycb-multiview-single\n#  dex-ycb-multiview-duster0123-single\n#  dex-ycb-multiview-duster0123cleaned-single\n#  panoptic-multiview-views1_7_14_20-single\n#  panoptic-multiview-views27_16_14_8-single\n#  panoptic-multiview-views1_4_7_11-single\n#  kubric-multiview-v3-views0123-single\n#  kubric-multiview-v3-duster0123-single\n#  kubric-multiview-v3-duster0123cleaned-single\n#  tapvid2d-davis-mogewithextrinsics-256x256-single\n\n#####################################\n### ~~~ Camera-setup Ablation ~~~ ###\n#####################################\n  panoptic-multiview-views1_7_14_20\n  panoptic-multiview-views27_16_14_8\n  panoptic-multiview-views1_4_7_11\n  dex-ycb-multiview-duster0123\n  dex-ycb-multiview-duster2345\n  dex-ycb-multiview-duster4567\n\n########################################\n### ~~~ Number-of-views Ablation ~~~ ###\n########################################\n  kubric-multiview-v3-views0\n  kubric-multiview-v3-views01\n  kubric-multiview-v3-views012\n  kubric-multiview-v3-views0123\n  kubric-multiview-v3-views01234\n  kubric-multiview-v3-views012345\n  kubric-multiview-v3-views0123456\n  kubric-multiview-v3-views01234567\n  kubric-multiview-v3-duster0123-views0\n  kubric-multiview-v3-duster0123-views01\n  kubric-multiview-v3-duster0123-views012\n  kubric-multiview-v3-duster0123-views0123\n  kubric-multiview-v3-duster01234567-views01234\n  kubric-multiview-v3-duster01234567-views012345\n  kubric-multiview-v3-duster01234567-views0123456\n  kubric-multiview-v3-duster01234567-views01234567\n  panoptic-multiview-views1\n  panoptic-multiview-views1_14\n  panoptic-multiview-views1_7_14\n  panoptic-multiview-views1_7_14_20\n  panoptic-multiview-views1_4_7_14_20\n  panoptic-multiview-views1_4_7_14_17_20\n  panoptic-multiview-views1_4_7_11_14_17_20\n  panoptic-multiview-views1_4_7_11_14_17_20_23\n  dex-ycb-multiview-duster0123-views0\n  dex-ycb-multiview-duster0123-views01\n  dex-ycb-multiview-duster0123-views012\n  dex-ycb-multiview-duster0123-views0123\n  dex-ycb-multiview-duster01234567-views01234\n  dex-ycb-multiview-duster01234567-views012345\n  dex-ycb-multiview-duster01234567-views0123456\n  dex-ycb-multiview-duster01234567-views01234567\n\n\n#####################################\n### ~~~ For video comparisons ~~~ ###\n#####################################\n  kubric-multiview-v3-views0123-novelviews4\n  panoptic-multiview-views1_7_14_20-novelviews24\n  panoptic-multiview-views1_7_14_20-novelviews27\n  dex-ycb-multiview-duster0123-novelviews4\n  dex-ycb-multiview-duster0123-novelviews5\n  dex-ycb-multiview-duster0123-novelviews6\n  dex-ycb-multiview-duster0123-novelviews7\n  dex-ycb-multiview-duster2345-novelviews7\n  dex-ycb-multiview-duster4567-novelviews7\n  dex-ycb-multiview-duster4567-novelviews0\n\n\n####################################\n### ~~~ For noise experiment ~~~ ###\n####################################\n  kubric-multiview-v3-noise0cm\n  kubric-multiview-v3-noise1cm\n  kubric-multiview-v3-noise2cm\n  kubric-multiview-v3-noise5cm\n  kubric-multiview-v3-noise10cm\n  kubric-multiview-v3-noise20cm\n  kubric-multiview-v3-noise50cm\n  kubric-multiview-v3-noise100cm\n  kubric-multiview-v3-noise200cm\n  kubric-multiview-v3-noise1000cm\n)\n\n# Compute number of jobs needed\nNUM_CKPTS=${#CKPTS[@]}\nNUM_DATASETS=${#DATASETS[@]}\nTOTAL_JOBS=$((NUM_CKPTS * NUM_DATASETS))\n\n# Check if SLURM_ARRAY_TASK_ID is valid\nif [ \"$SLURM_ARRAY_TASK_ID\" -ge \"$TOTAL_JOBS\" ]; then\n    echo \"Error: SLURM_ARRAY_TASK_ID=$SLURM_ARRAY_TASK_ID exceeds the max index $((TOTAL_JOBS-1))\"\n    exit 1\nfi\n\n# Map SLURM_ARRAY_TASK_ID to checkpoint and dataset\nCKPT_INDEX=$((SLURM_ARRAY_TASK_ID % NUM_CKPTS))\nDATASET_INDEX=$((SLURM_ARRAY_TASK_ID / NUM_CKPTS))\n\nSELECTED_CKPT=${CKPTS[$CKPT_INDEX]}\nSELECTED_DATASET=${DATASETS[$DATASET_INDEX]}\n\necho \"Selected Checkpoint: $SELECTED_CKPT\"\necho \"Selected Dataset: $SELECTED_DATASET\"\n\n# Run the job with the extracted checkpoint & dataset\nsrun -ul --container-writable --environment=my_pytorch_env numactl --membind=0-3 bash -c \"\n    source /users/fraji/venvs/spa10/bin/activate &&\n    CUDA_VISIBLE_DEVICES=0 TORCH_HOME=./checkpoints/.cache python eval.py $SELECTED_CKPT datasets.eval.names=[$SELECTED_DATASET]\n\""
  },
  {
    "path": "scripts/slurm/mvtracker-nodepthaugs.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=mvtracker_200000_june2025_cleandepths\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=72\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a136\n#SBATCH --time=12:00:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --error=./logs/slurm_logs/%x-%j.out\n#SBATCH --signal=USR1@60\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\npython train.py model=mvtracker \\\n  trainer.num_steps=200000 \\\n  trainer.eval_freq=10000 \\\n  trainer.viz_freq=10000 \\\n  trainer.save_ckpt_freq=500 \\\n  trainer.lr=0.0005 \\\n  datasets.train.traj_per_sample=2048 \\\n  model.updatetransformer_type=cotracker2 \\\n  reproducibility.seed=36 \\\n  trainer.precision=bf16-mixed \\\n  modes.do_initial_static_pretrain=false \\\n  trainer.augment_train_iters=false \\\n  model.apply_sigmoid_to_vis=false \\\n  augmentations.variable_depth_type=false \\\n  logging.log_wandb=true \\\n  experiment_path=logs/${SLURM_JOB_NAME}\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/slurm/mvtracker.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=mvtracker_200000_june2025\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=72\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a136\n#SBATCH --time=12:00:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --error=./logs/slurm_logs/%x-%j.out\n#SBATCH --signal=USR1@60\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\npython train.py model=mvtracker \\\n  trainer.num_steps=200000 \\\n  trainer.eval_freq=10000 \\\n  trainer.viz_freq=10000 \\\n  trainer.save_ckpt_freq=500 \\\n  trainer.lr=0.0005 \\\n  datasets.train.traj_per_sample=2048 \\\n  model.updatetransformer_type=cotracker2 \\\n  reproducibility.seed=36 \\\n  trainer.precision=bf16-mixed \\\n  modes.do_initial_static_pretrain=false \\\n  trainer.augment_train_iters=false \\\n  model.apply_sigmoid_to_vis=false \\\n  logging.log_wandb=true \\\n  experiment_path=logs/${SLURM_JOB_NAME}\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/slurm/spatracker.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spatracker_monocular\n#SBATCH --nodes=8\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=72\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a-a136-1\n#SBATCH --time=12:00:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --error=./logs/slurm_logs/%x-%j.out\n#SBATCH --signal=USR1@60\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\npython train.py model=spatracker_monocular \\\n  trainer.num_steps=200000 \\\n  trainer.eval_freq=10000 \\\n  trainer.viz_freq=10000 \\\n  trainer.save_ckpt_freq=500 \\\n  trainer.lr=0.001 \\\n  datasets.train.traj_per_sample=512 \\\n  reproducibility.seed=72 \\\n  trainer.precision=bf16-mixed \\\n  modes.do_initial_static_pretrain=true \\\n  trainer.augment_train_iters=true \\\n  experiment_path=logs/${SLURM_JOB_NAME}\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/slurm/test_reproducibility.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=repro-test-mvtracker\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=32\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a-a03\n#SBATCH --time=00:20:00\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\ncd \"$DIR\"\n\n# Use job ID for run directory\nRUN1=\"logs/debug/test_repro_${SLURM_JOB_ID}_run1\"\nRUN2=\"logs/debug/test_repro_${SLURM_JOB_ID}_run2\"\n[[ ! -e \"$RUN1\" ]] || { echo \"ERROR: $RUN1 already exists\"; exit 1; }\n[[ ! -e \"$RUN2\" ]] || { echo \"ERROR: $RUN2 already exists\"; exit 1; }\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\nexport CUBLAS_WORKSPACE_CONFIG=:4096:8\nexport PYTHONHASHSEED=0\n\n# === Run 1 ===\npython train.py +experiment=mvtracker_overfit \\\n  datasets.eval.names=[] \\\n  modes.tune_per_scene=false \\\n  trainer.num_steps=10 \\\n  reproducibility.deterministic=true \\\n  dataset.train.num_workers=4 \\\n  trainer.precision=32 \\\n  experiment_path=$RUN1\n\n# === Run 2 ===\npython train.py +experiment=mvtracker_overfit \\\n  datasets.eval.names=[] \\\n  modes.tune_per_scene=false \\\n  trainer.num_steps=10 \\\n  reproducibility.deterministic=true \\\n  dataset.train.num_workers=4 \\\n  trainer.precision=32 \\\n  experiment_path=$RUN2\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/slurm/triplane-128.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spatracker_multiview_128\n#SBATCH --nodes=8\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=72\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a-a136-1\n#SBATCH --time=12:00:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --error=./logs/slurm_logs/%x-%j.out\n#SBATCH --signal=USR1@60\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\npython train.py model=spatracker_multiview model.triplane_xres=128 model.triplane_yres=128 model.triplane_zres=128 \\\n  trainer.num_steps=200000 \\\n  trainer.eval_freq=10000 \\\n  trainer.viz_freq=10000 \\\n  trainer.save_ckpt_freq=500 \\\n  trainer.lr=0.001 \\\n  datasets.train.traj_per_sample=768 \\\n  reproducibility.seed=36 \\\n  trainer.precision=bf16-mixed \\\n  modes.do_initial_static_pretrain=true \\\n  trainer.augment_train_iters=true \\\n  experiment_path=logs/${SLURM_JOB_NAME}\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/slurm/triplane-256.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=spatracker_multiview_256\n#SBATCH --nodes=8\n#SBATCH --ntasks-per-node=4\n#SBATCH --cpus-per-task=72\n#SBATCH --gres=gpu:4\n#SBATCH --mem=460000\n#SBATCH --partition=normal\n#SBATCH --account=a-a136-1\n#SBATCH --time=12:00:00\n#SBATCH --dependency=singleton\n#SBATCH --mail-type=begin\n#SBATCH --mail-type=end\n#SBATCH --mail-user=frano.rajic@inf.ethz.ch\n#SBATCH --output=./logs/slurm_logs/%x-%j.out\n#SBATCH --error=./logs/slurm_logs/%x-%j.out\n#SBATCH --signal=USR1@60\n\nset -euo pipefail\nset -x\ncat $0\nDIR=$(realpath .)\n\n# Wrap the commands\nCMD=\"\nsource /users/fraji/venvs/spa10/bin/activate\ncd $DIR\n\npython train.py model=spatracker_multiview model.triplane_xres=256 model.triplane_yres=256 model.triplane_zres=128 \\\n  trainer.num_steps=200000 \\\n  trainer.eval_freq=10000 \\\n  trainer.viz_freq=10000 \\\n  trainer.save_ckpt_freq=500 \\\n  trainer.lr=0.001 \\\n  datasets.train.traj_per_sample=384 \\\n  reproducibility.seed=36 \\\n  trainer.precision=bf16-mixed \\\n  modes.do_initial_static_pretrain=true \\\n  trainer.augment_train_iters=true \\\n  experiment_path=logs/${SLURM_JOB_NAME}\n\"\n\n# Execute within the container\nsrun -ul --environment=my_pytorch_env bash -c \"$CMD\"\n"
  },
  {
    "path": "scripts/summarize_eval_results.py",
    "content": "import os\nimport re\nimport warnings\n\nimport pandas as pd\n\nREMAP_KUBRIC = {\n    \"Method\": (\"\", \"Method\"),\n    \"average_jaccard__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Jacc.\"),\n    \"jaccard_0.05__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.05\"),\n    \"jaccard_0.10__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.1\"),\n    \"jaccard_0.20__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.2\"),\n    \"jaccard_0.40__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.4\"),\n    \"jaccard_0.80__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.8\"),\n    \"average_pts_within_thresh__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Loc.\"),\n    \"pts_within_0.05__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.05\"),\n    \"pts_within_0.10__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.1\"),\n    \"pts_within_0.20__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.2\"),\n    \"pts_within_0.40__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.4\"),\n    \"pts_within_0.80__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.8\"),\n    \"survival__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Surv.\"),\n    \"occlusion_accuracy__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"OA\"),\n    \"mte_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"MTE\"),\n    \"ate_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"ATE\"),\n    \"fde_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"FDE\"),\n    \"n__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"n\"),\n    \"v__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"v\"),\n    \"average_jaccard__very_dynamic\": (\"Very Dynamic\", \"Jacc.\"),\n    \"average_pts_within_thresh__very_dynamic\": (\"Very Dynamic\", \"Loc.\"),\n    \"survival__very_dynamic\": (\"Very Dynamic\", \"Surv.\"),\n    \"occlusion_accuracy__very_dynamic\": (\"Very Dynamic\", \"OA\"),\n    \"mte_visible__very_dynamic\": (\"Very Dynamic\", \"MTE\"),\n    \"average_jaccard__static\": (\"Static Points (motion < 0.01)\", \"Jacc.\"),\n    \"average_pts_within_thresh__static\": (\"Static Points (motion < 0.01)\", \"Loc.\"),\n    \"survival__static\": (\"Static Points (motion < 0.01)\", \"Surv.\"),\n    \"occlusion_accuracy__static\": (\"Static Points (motion < 0.01)\", \"OA\"),\n    \"mte_visible__static\": (\"Static Points (motion < 0.01)\", \"MTE\"),\n    \"average_jaccard__any\": (\"Any Points\", \"Jacc.\"),\n    \"average_pts_within_thresh__any\": (\"Any Points\", \"Loc.\"),\n    \"survival__any\": (\"Any Points\", \"Surv.\"),\n    \"occlusion_accuracy__any\": (\"Any Points\", \"OA\"),\n    \"mte_visible__any\": (\"Any Points\", \"MTE\"),\n    \"n_iters\": (\"\", \"#iters\"),\n}\n\nREMAP_DEXYCB_V1 = {\n    \"Method\": (\"\", \"Method\"),\n    \"average_jaccard__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Jacc.\"),\n    \"jaccard_0.01__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.01\"),\n    \"jaccard_0.02__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.02\"),\n    \"jaccard_0.05__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.05\"),\n    \"jaccard_0.10__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.10\"),\n    \"jaccard_0.20__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.20\"),\n    \"average_pts_within_thresh__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Loc.\"),\n    \"pts_within_0.01__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.01\"),\n    \"pts_within_0.02__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.02\"),\n    \"pts_within_0.05__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.05\"),\n    \"pts_within_0.10__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.10\"),\n    \"pts_within_0.20__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"< 0.20\"),\n    \"survival__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"Surv.\"),\n    \"occlusion_accuracy__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"OA\"),\n    \"mte_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"MTE\"),\n    \"ate_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"ATE\"),\n    \"fde_visible__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"FDE\"),\n    \"n__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"n\"),\n    \"v__dynamic\": (\"Dynamic Points (motion > 0.1)\", \"v\"),\n    \"average_jaccard__very_dynamic\": (\"Very Dynamic\", \"Jacc.\"),\n    \"average_pts_within_thresh__very_dynamic\": (\"Very Dynamic\", \"Loc.\"),\n    \"survival__very_dynamic\": (\"Very Dynamic\", \"Surv.\"),\n    \"occlusion_accuracy__very_dynamic\": (\"Very Dynamic\", \"OA\"),\n    \"average_jaccard__static\": (\"Static Points (motion < 0.01)\", \"Jacc.\"),\n    \"average_pts_within_thresh__static\": (\"Static Points (motion < 0.01)\", \"Loc.\"),\n    \"survival__static\": (\"Static Points (motion < 0.01)\", \"Surv.\"),\n    \"occlusion_accuracy__static\": (\"Static Points (motion < 0.01)\", \"OA\"),\n    \"average_jaccard__any\": (\"Any Points\", \"Jacc.\"),\n    \"average_pts_within_thresh__any\": (\"Any Points\", \"Loc.\"),\n    \"survival__any\": (\"Any Points\", \"Surv.\"),\n    \"occlusion_accuracy__any\": (\"Any Points\", \"OA\"),\n    \"n_iters\": (\"\", \"#iters\"),\n}\n\n# Initialize remapping dictionary with the correct order\nREMAP_DEXYCB_V2 = {}\nREMAP_DEXYCB_V2[\"Method\"] = (\"\", \"Method\")\n\n# Define ordered point categories (dynamic first, then very dynamic, static, and any)\nPOINT_TYPES = {\n    \"dynamic\": \"Dynamic Points (motion > 0.1)\",\n    \"very_dynamic\": \"Very Dynamic\",\n    \"static\": \"Static Points (motion < 0.01)\",\n    \"any\": \"Any Points\",\n    \"dynamic-static-mean\": \"Dynamic+Static Points Mean\",\n}\nMETRICS = {\n    \"average_jaccard\": \"Jacc.\",\n    \"jaccard\": \"<{threshold}\",\n    \"average_pts_within_thresh\": \"Loc.\",\n    \"pts_within\": \"<{threshold}\",\n    \"survival\": \"Surv.\",\n    \"occlusion_accuracy\": \"OA\",\n    \"occlusion_accuracy_for_vis0\": \"OA(v=0)\",\n    \"occlusion_accuracy_for_vis1\": \"OA(v=1)\",\n    \"mte_visible\": \"MTE\",\n    \"ate_visible\": \"ATE\",\n    \"fde_visible\": \"FDE\",\n    \"n\": \"n\",\n    \"v\": \"v\"\n}\nTHRESHOLDS = [\"0.01\", \"0.02\", \"0.05\", \"0.10\", \"0.20\"]\nfor pt_key, pt_label in POINT_TYPES.items():\n    for metric, metric_label in METRICS.items():\n        if metric in [\"jaccard\", \"pts_within\"]:  # Threshold-based metrics\n            for thresh in THRESHOLDS:\n                REMAP_DEXYCB_V2[f\"{metric}_{thresh}__{pt_key}\"] = (pt_label, metric_label.format(threshold=thresh))\n        else:  # Regular metrics\n            REMAP_DEXYCB_V2[f\"{metric}__{pt_key}\"] = (pt_label, metric_label)\nREMAP_DEXYCB_V2[\"n_iters\"] = (\"\", \"#iters\")\n\nREMAP_TAPVID2D_INDEX_NAMES = [\"Metric Definition\", \"Metric\"]\nREMAP_TAPVID2D = {\n    \"Method\": (\"\", \"Method\",),\n    \"average_jaccard__any\": (\"Our Metrics\", \"Jacc.\",),\n    \"jaccard_1.00__any\": (\"Our Metrics\", \"<  1\",),\n    \"jaccard_2.00__any\": (\"Our Metrics\", \"<  2\",),\n    \"jaccard_4.00__any\": (\"Our Metrics\", \"<  4\",),\n    \"jaccard_8.00__any\": (\"Our Metrics\", \"<  8\",),\n    \"jaccard_16.00__any\": (\"Our Metrics\", \"< 16\",),\n    \"average_pts_within_thresh__any\": (\"Our Metrics\", \"Loc.\",),\n    \"pts_within_1.00__any\": (\"Our Metrics\", \"<  1\",),\n    \"pts_within_2.00__any\": (\"Our Metrics\", \"<  2\",),\n    \"pts_within_4.00__any\": (\"Our Metrics\", \"<  4\",),\n    \"pts_within_8.00__any\": (\"Our Metrics\", \"<  8\",),\n    \"pts_within_16.00__any\": (\"Our Metrics\", \"< 16\",),\n    \"survival__any\": (\"Our Metrics\", \"Surv.\",),\n    \"occlusion_accuracy__any\": (\"Our Metrics\", \"OA\",),\n    \"occlusion_accuracy_for_vis0__any\": (\"Our Metrics\", \"OA(v=0)\",),\n    \"occlusion_accuracy_for_vis1__any\": (\"Our Metrics\", \"OA(v=1)\",),\n    \"mte_visible__any\": (\"Our Metrics\", \"MTE\",),\n    \"ate_visible__any\": (\"Our Metrics\", \"ATE\",),\n    \"fde_visible__any\": (\"Our Metrics\", \"FDE\",),\n    \"n__any\": (\"Our Metrics\", \"n\",),\n    \"v__any\": (\"Our Metrics\", \"v\",),\n    \"tapvid2d_average_jaccard\": (\"TAPVid-2D Metrics\", \"Jacc.\",),\n    \"tapvid2d_jaccard_1\": (\"TAPVid-2D Metrics\", \"<  1\",),\n    \"tapvid2d_jaccard_2\": (\"TAPVid-2D Metrics\", \"<  2\",),\n    \"tapvid2d_jaccard_4\": (\"TAPVid-2D Metrics\", \"<  4\",),\n    \"tapvid2d_jaccard_8\": (\"TAPVid-2D Metrics\", \"<  8\",),\n    \"tapvid2d_jaccard_16\": (\"TAPVid-2D Metrics\", \"< 16\",),\n    \"tapvid2d_average_pts_within_thresh\": (\"TAPVid-2D Metrics\", \"Loc.\",),\n    \"tapvid2d_pts_within_1\": (\"TAPVid-2D Metrics\", \"<  1\",),\n    \"tapvid2d_pts_within_2\": (\"TAPVid-2D Metrics\", \"<  2\",),\n    \"tapvid2d_pts_within_4\": (\"TAPVid-2D Metrics\", \"<  4\",),\n    \"tapvid2d_pts_within_8\": (\"TAPVid-2D Metrics\", \"<  8\",),\n    \"tapvid2d_pts_within_16\": (\"TAPVid-2D Metrics\", \"< 16\",),\n    \"tapvid2d_occlusion_accuracy\": (\"TAPVid-2D Metrics\", \"OA\",),\n    \"n_iters\": (\"\", \"#iters\",),\n}\n\nREMAP_PANOPTIC = {}\nREMAP_PANOPTIC[\"Method\"] = (\"\", \"Method\")\nfor pt_key in [\"any\"]:\n    pt_label = POINT_TYPES[pt_key]\n    for metric, metric_label in METRICS.items():\n        if metric in [\"jaccard\", \"pts_within\"]:  # Threshold-based metrics\n            for thresh in [\"0.05\", \"0.10\", \"0.20\", \"0.40\"]:\n                REMAP_PANOPTIC[f\"{metric}_{thresh}__{pt_key}\"] = (pt_label, metric_label.format(threshold=thresh))\n        else:  # Regular metrics\n            REMAP_PANOPTIC[f\"{metric}__{pt_key}\"] = (pt_label, metric_label)\nREMAP_PANOPTIC[\"n_iters\"] = (\"\", \"#iters\")\n\nPARTIAL_REMAP_FOR_2DPT_ABLATION = {}\nfor pt_key, pt_label in POINT_TYPES.items():\n    for metric, metric_label in METRICS.items():\n        if \"jaccard\" in metric or \"occlusion\" in metric:\n            continue\n        if metric in [\"jaccard\", \"pts_within\"]:  # Threshold-based metrics\n            for thresh in [\"1.00\", \"2.00\", \"4.00\", \"8.00\", \"16.00\"]:\n                PARTIAL_REMAP_FOR_2DPT_ABLATION[f\"2dpt__{metric}_{thresh}__{pt_key}\"] = (\n                    \"(2DPT) \" + pt_label, metric_label.format(threshold=thresh)\n                )\n        else:  # Regular metrics\n            PARTIAL_REMAP_FOR_2DPT_ABLATION[f\"2dpt__{metric}__{pt_key}\"] = (\"(2DPT) \" + pt_label, metric_label)\nfor logged_key, (pt_label, metric_label) in REMAP_TAPVID2D.items():\n    if \"jaccard\" in logged_key or \"occlusion\" in logged_key:\n        continue\n    if \"tapvid2d\" not in logged_key:\n        continue\n    PARTIAL_REMAP_FOR_2DPT_ABLATION[f\"2dpt__{logged_key}\"] = (\"(2DPT) \" + pt_label, metric_label)\nREMAP_2DPT_ABLATION = REMAP_KUBRIC | PARTIAL_REMAP_FOR_2DPT_ABLATION\n\nONE_REMAP_TO_RULE_THEM_ALL = {}\nONE_REMAP_TO_RULE_THEM_ALL[\"Method\"] = (\"\", \"Method\")\nONE_REMAP_TO_RULE_THEM_ALL[\"Dataset\"] = (\"\", \"Dataset\")\nTHRESHOLDS = [\"0.01\", \"0.02\", \"0.05\", \"0.10\", \"0.20\", \"0.40\"]\nfor pt_key, pt_label in POINT_TYPES.items():\n    for metric, metric_label in METRICS.items():\n        if metric in [\"jaccard\", \"pts_within\"]:  # Threshold-based metrics\n            for thresh in THRESHOLDS:\n                ONE_REMAP_TO_RULE_THEM_ALL[f\"{metric}_{thresh}__{pt_key}\"] = (\n                    pt_label, metric_label.format(threshold=thresh))\n        else:  # Regular metrics\n            ONE_REMAP_TO_RULE_THEM_ALL[f\"{metric}__{pt_key}\"] = (pt_label, metric_label)\nONE_REMAP_TO_RULE_THEM_ALL[\"n_iters\"] = (\"\", \"#iters\")\n\n\ndef find_file_with_max_steps(folder):\n    if not os.path.isdir(folder):\n        return None, -1\n    pattern = re.compile(r\"step-(\\d+)_metrics_avg.csv\")\n    max_steps = -1\n    max_file = None\n    for filename in os.listdir(folder):\n        m = pattern.search(filename)\n        if m:\n            steps = int(m.group(1))\n            if steps > max_steps:\n                max_steps = steps\n                max_file = filename\n    return max_file, max_steps\n\n\ndef create_table(\n        method_name_to_csv_path,\n        remap=REMAP_KUBRIC,\n        remap_index_names=[\"Type\", \"Metric\"],\n        header=True,\n        skip_missing=False,\n):\n    assert len(method_name_to_csv_path) > 0, \"No CSV files provided\"\n    rows = []\n    order = []\n    for method_name, path in method_name_to_csv_path.items():\n        if \"step-?_\" in path:\n            filename, n_iters = find_file_with_max_steps(os.path.dirname(path))\n            if filename is None:\n                warnings.warn(f\"No CSV files found in {os.path.dirname(path)}\")\n                continue\n            path = os.path.join(os.path.dirname(path), filename)\n        if not os.path.exists(path):\n            if skip_missing:\n                warnings.warn(f\"Skipping missing file: {path}\")\n                continue\n            raise FileNotFoundError(f\"File not found: {path}\")\n        df = pd.read_csv(path, header=None, names=[\"Metric\", \"Value\"])\n        df = df.dropna(subset=[\"Metric\"]).reset_index(drop=True)\n        if type(method_name) == tuple:\n            method_name, dataset_name = method_name\n        else:\n            dataset_name = os.path.basename(os.path.dirname(path)).replace(\"eval_\", \"\")\n        df[\"Method\"] = method_name\n        match = re.search(r\"step-(\\d+)\", path)\n        n_iters = int(match.group(1)) if match else 0\n        df.loc[len(df)] = [\"n_iters\", n_iters, method_name]\n        df[\"Metric\"] = df[\"Metric\"].str.split(\"/\").str[-1].str.replace(\"model__\", \"\")\n        df[\"Dataset\"] = dataset_name\n        rows.append(df)\n        order.append((method_name, dataset_name))\n    combined_df = pd.concat(rows)\n    pivot_df = combined_df.pivot(index=[\"Method\", \"Dataset\"], columns=\"Metric\", values=\"Value\").reset_index()\n\n    pivot_df = pivot_df.set_index([\"Method\", \"Dataset\"]).reindex(order).reset_index()\n\n    # Define a mapping for the new names\n    for k in remap.keys():\n        if k not in pivot_df.columns:\n            pivot_df[k] = None\n            pivot_df = pivot_df.copy()  # To avoid \"DataFrame is highly fragmented\" warning\n    pivot_df = pivot_df[remap.keys()]\n    multi_index = pd.MultiIndex.from_tuples(\n        tuples=[remap[col] for col in pivot_df.columns],\n        names=remap_index_names,\n    )\n    pivot_df.columns = multi_index\n\n    return pivot_df, pivot_df.to_csv(index=False, header=header)\n\n\ndef kubric_single_point():\n    print(\"Kubric single-point evaluation results:\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        # ls logs/kubric_v3/*/eval_kubric-multiview-v3-single/step-*_kubric-multiview-v3-single_metrics_avg.csv | cat\n        \"SpaTracker (pretrained)\": \"logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-single/step--1_kubric-multiview-v3-single_metrics_avg.csv\",\n        \"SpaTracker (single-view baseline)\": \"logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-single/step-69799_kubric-multiview-v3-single_metrics_avg.csv\",\n        \"Multi-view-V1 (ours)\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-single/step-99999_kubric-multiview-v3-single_metrics_avg.csv\",\n        \"Multi-view-V2 (ours)\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-single/step-91599_kubric-multiview-v3-single_metrics_avg.csv\",\n    })\n    print(csv_str)\n\n\ndef kubric_before_gt0123():\n    print(\"Kubric multi-point evaluation results:\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        # ls logs/kubric_v3/*/eval_kubric-multiview-v3/step-*_kubric-multiview-v3_metrics_avg.csv | cat\n        \"CopyCat (No motion baseline)\": \"logs/copycat/eval_kubric-multiview-v3/step--1_kubric-multiview-v3_metrics_avg.csv\",\n        \"SpaTracker (pretrained)\": \"logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3/step--1_kubric-multiview-v3_metrics_avg.csv\",\n        \"SpaTracker (single-view baseline)\": \"logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3/step-69799_kubric-multiview-v3_metrics_avg.csv\",\n        \"Multi-view-V1 (ours)\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3/step-99999_kubric-multiview-v3_metrics_avg.csv\",\n        \"Multi-view-V2 (ours)\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-91599_kubric-multiview-v3_metrics_avg.csv\",\n\n        \"Multi-view-V1 (ours) (128; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3/step-100000_kubric-multiview-v3_metrics_avg.csv\",\n        \"Multi-view-V1 (ours) (256; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3/step-100000_kubric-multiview-v3_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (finetuned on D4c)\": \"logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-10000_kubric-multiview-v3_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3/step-99999_kubric-multiview-v3_metrics_avg.csv\",\n    })\n    print(csv_str)\n\n\ndef kubric():\n    print(\"Kubric multi-point evaluation results:\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        \"CopyCat (No motion baseline)\": \"logs/copycat/eval_kubric-multiview-v3-gt0123/step--1_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"SpaTracker (pretrained)\": \"logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-gt0123/step--1_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"SpaTracker (single-view baseline)\": \"logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-gt0123/step-69799_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V1 (ours)\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours)\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-91599_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n\n        \"Multi-view-V1 (ours) (128; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V1 (ours) (256; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (finetuned on D4c)\": \"logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-9999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-99999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4c)\": \"logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-gt0123/step-99999_metrics_avg.csv\",\n        \"Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)\": \"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-gt0123/step-9999_kubric-multiview-v3-gt0123_metrics_avg.csv\",\n    })\n    print(csv_str)\n\n\ndef kubric_duster():\n    print(\"Kubric multi-point evaluation results, Duster0123:\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        # ls logs/kubric_v3/*/eval_kubric-multiview-v3-duster0123/step-*_kubric-multiview-v3-duster0123_metrics_avg.csv | cat\n        \"CopyCat (No motion baseline)\": \"logs/copycat/eval_kubric-multiview-v3-duster0123/step--1_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"SpaTracker (pretrained)\": \"logs/kubric_v3/multiview-adapter-pretrained-004/eval_kubric-multiview-v3-duster0123/step--1_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"SpaTracker (single-view baseline)\": \"logs/kubric_v3/multiview-adapter-002/eval_kubric-multiview-v3-duster0123/step-69799_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V1 (ours)\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123/step-99999_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours)\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-91599_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n\n        \"SpaTracker (single-view baseline) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_kubric-multiview-v3-duster0123/step-90000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V1 (ours) (128; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V1 (ours) (256; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (finetuned on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-10000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-100000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours) (trained on D4c)\": \"logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123/step-70000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n        # \"Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)\": \"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123/step-10000_kubric-multiview-v3-duster0123_metrics_avg.csv\",\n    })\n    print(csv_str)\n    df, csv_str = create_table({\n        # \"SpaTracker (single-view baseline) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_kubric-multiview-v3-duster0123cleaned/step-90000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        # \"Multi-view-V1 (ours) (128; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        # \"Multi-view-V1 (ours) (256; trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours) (finetuned on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-10000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours) (trained on D4)\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-100000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4c)\": \"logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123cleaned/step-70000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        \"Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)\": \"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-10000_kubric-multiview-v3-duster0123cleaned_metrics_avg.csv\",\n        # \"Multi-view-V3 (ours) (trained on D4c;s=4)\": \"TBD\",\n        # \"Multi-view-V3 (ours) (trained on D4c;s=16)\": \"TBD\",\n    })\n\n\ndef mv3_kubric_duster_transformed():\n    print(\"Kubric transformed, Duster0123, Multi-view-V3 (ours) (finetuned^2 on D4c;s=4):\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        \"Kubric (no world space transformations)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.00_no_transform.csv\",\n        \"Kubric (translated by z-10)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.01_translate_z-10.csv\",\n        \"Kubric (translated by x+4, y+4)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.02a_translate_x+4_y+4.csv\",\n        \"Kubric (translated by x+10, y+10)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.02b_translate_x+10_y+10.csv\",\n        \"Kubric (rotated x+90)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.03_rotate_x+90.csv\",\n        \"Kubric (rotated y+90)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.04_rotate_y+90.csv\",\n        \"Kubric (rotated z+90)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.05_rotate_z+90.csv\",\n        \"Kubric (scaled down 2x)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.06_scale_down_2x.csv\",\n        \"Kubric (scaled down 8x)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.07_scale_down_8x.csv\",\n        \"Kubric (scaled up 2x)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.08_scale_up_2x.csv\",\n        \"Kubric (scaled up 8x)\": f\"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_kubric-multiview-v3-duster0123cleaned/step-9999_kubric-multiview-v3-duster0123cleaned_metrics_avg.09_scale_up_8x.csv\",\n    })\n    \"\"\n    print(csv_str)\n\n\ndef mv3_kubric_nviews():\n    print(\"eval_kubric-multiview-v3-views..., Multi-view-V2 (ours) (trained on D4):\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        \"1\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0/step-99999_metrics_avg.csv\",\n        \"2\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01/step-99999_metrics_avg.csv\",\n        \"3\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012/step-99999_metrics_avg.csv\",\n        \"4\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123/step-99999_metrics_avg.csv\",\n        \"5\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01234/step-99999_metrics_avg.csv\",\n        \"6\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012345/step-99999_metrics_avg.csv\",\n        \"7\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123456/step-99999_metrics_avg.csv\",\n        \"8\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views01234567/step-99999_metrics_avg.csv\",\n        \"9\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views012345678/step-99999_metrics_avg.csv\",\n        \"10\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-views0123456789/step-99999_metrics_avg.csv\",\n    })\n    print(csv_str)\n\n\ndef mv3_kubric_duster_nviews():\n    print(\"eval_kubric-multiview-v3-duster0123-views..., Multi-view-V2 (ours) (trained on D4):\")\n    print(\"================================\")\n    df, csv_str = create_table({\n        \"1\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views0/step-99999_metrics_avg.csv\",\n        \"2\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views01/step-99999_metrics_avg.csv\",\n        \"3\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views012/step-99999_metrics_avg.csv\",\n        \"4\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster0123-views0123/step-99999_metrics_avg.csv\",\n        \"5\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views01234/step-99999_metrics_avg.csv\",\n        \"6\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views012345/step-99999_metrics_avg.csv\",\n        \"7\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views0123456/step-99999_metrics_avg.csv\",\n        \"8\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_kubric-multiview-v3-duster01234567-views01234567/step-99999_metrics_avg.csv\",\n    })\n    print(csv_str)\n\n\ndef kubric_nviews():\n    print(\"=\" * 80)\n    print(\"=\" * 80)\n    print(\"=\" * 80)\n    method_name_to_csv_path_template = {\n        \"CopyCat (No motion baseline),{}\": \"logs/copycat/eval_{}/step--1_metrics_avg.csv\",\n        \"SpaTracker (pretrained),{}\": \"logs/kubric_v3/multiview-adapter-pretrained-004/eval_{}/step--1_metrics_avg.csv\",\n        \"SpaTracker (single-view baseline),{}\": \"logs/kubric_v3/multiview-adapter-002/eval_{}/step-69799_metrics_avg.csv\",\n        \"Multi-view-V1 (ours),{}\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{}/step-99999_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours),{}\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-91599_metrics_avg.csv\",\n\n        \"SpaTracker (single-view baseline) (trained on D4),{}\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_{}/step-logs/kubric_v3_duster0123/multiview-adapter-001_metrics_avg.csv\",\n        # \"Multi-view-V1 (ours) (128; trained on D4),{}\": \"logs/kubric_v3_duster0123/multiview-v1-with-128-triplane-001/eval_{}/step-99999_metrics_avg.csv\",\n        \"Multi-view-V1 (ours) (256; trained on D4),{}\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{}/step-99999_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours) (finetuned on D4c),{}\": \"logs/kubric_v3_duster0123/multiview-v2-pretrained-cleaned-003--lr-2.5e-4--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-9999_metrics_avg.csv\",\n        \"Multi-view-V2 (ours) (trained on D4),{}\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-99999_metrics_avg.csv\",\n        # \"Multi-view-V2 (ours) (trained on D4c),{}\": \"logs/kubric_v3_duster0123/multiview-v2-cleaned-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{}/step-99999_metrics_avg.csv\",\n        # \"Multi-view-V3 (ours) (finetuned^2 on D4c;s=4),{}\": \"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_{}/step-9999_metrics_avg.csv\",\n    }\n    method_name_to_csv_path_per_dataset = {}\n    for dataset_prefix in [\n        \"kubric-multiview-v3-views\",\n        \"kubric-multiview-v3-duster0123-views\",\n        \"kubric-multiview-v3-duster01234567-views\",\n        \"kubric-multiview-v3-duster0123cleaned-views\",\n        \"kubric-multiview-v3-duster01234567cleaned-views\",\n    ]:\n        method_name_to_csv_path_per_dataset[dataset_prefix] = {}\n        for method_name_template, csv_path_template in method_name_to_csv_path_template.items():\n            for n in range(8):\n                if (\"-duster0123-\" in dataset_prefix or \"-duster0123cleaned-\" in dataset_prefix) and n > 4:\n                    continue\n                if (\"-duster01234567-\" in dataset_prefix or \"-duster01234567cleaned-\" in dataset_prefix) and n < 5:\n                    continue\n                dataset = dataset_prefix + \"\".join(str(i) for i in range(n + 1))\n                method_name = method_name_template.format(n + 1)\n                csv_path = csv_path_template.format(dataset)\n                assert method_name not in method_name_to_csv_path_per_dataset[\n                    dataset_prefix], f\"Duplicate method name: {method_name}\"\n                method_name_to_csv_path_per_dataset[dataset_prefix][method_name] = csv_path\n    for dataset_prefix, method_name_to_csv_path in method_name_to_csv_path_per_dataset.items():\n        print(method_name_to_csv_path)\n        print(f\"Kubric multi-point evaluation results, {dataset_prefix}:\")\n        print(\"================================\")\n        df, csv_str = create_table(method_name_to_csv_path)\n        print(csv_str)\n\n\nMODELS = {\n    \"copycat\": {\n        \"name\": \"CopyCat (No motion baseline)\",\n        \"csv\": \"logs/copycat/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker3\": {\n        \"name\": \"CoTracker3 Offline (x)\",\n        \"csv\": \"logs/cotracker3/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker3offline\": {\n        \"name\": \"CoTracker3 Offline\",\n        \"csv\": \"logs/cotracker3-offline/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker3online\": {\n        \"name\": \"CoTracker3 Online\",\n        \"csv\": \"logs/cotracker3-online/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker2offline\": {\n        \"name\": \"CoTracker2 Offline\",\n        \"csv\": \"logs/cotracker2-offline/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker2online\": {\n        \"name\": \"CoTracker2 Online\",\n        \"csv\": \"logs/cotracker2-online/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker1offline\": {\n        \"name\": \"CoTracker1 Offline\",\n        \"csv\": \"logs/cotracker1-offline/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"cotracker1online\": {\n        \"name\": \"CoTracker1 Online\",\n        \"csv\": \"logs/cotracker1-online/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"delta\": {\n        \"name\": \"DELTA\",\n        \"csv\": \"logs/delta/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"locotrack\": {\n        \"name\": \"LocoTrack\",\n        \"csv\": \"logs/locotrack/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"scenetracker\": {\n        \"name\": \"SceneTracker\",\n        \"csv\": \"logs/scenetracker/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"spatracker-pretrained\": {\n        \"name\": \"SpaTracker (pretrained)\",\n        \"csv\": \"logs/kubric_v3_duster0123/multiview-adapter-pretrained-001/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"spatracker\": {\n        \"name\": \"SpaTracker (single-view baseline)\",\n        \"csv\": \"logs/kubric_v3/multiview-adapter-002/eval_{dataset}/step-69799_metrics_avg.csv\",\n    },\n    \"mv1\": {\n        \"name\": \"Multi-view-V1 (ours)\",\n        \"csv\": \"logs/kubric_v3/multiview-v1-with-128-triplane-001/eval_{dataset}/step-99999_metrics_avg.csv\",\n    },\n    \"mv2\": {\n        \"name\": \"Multi-view-V2 (ours)\",\n        \"csv\": \"logs/kubric_v3/multiview-v2-002--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{dataset}/step-91599_metrics_avg.csv\",\n    },\n    \"spatracker-d4\": {\n        \"name\": \"SpaTracker (single-view baseline) (trained on D4)\",\n        \"csv\": \"logs/kubric_v3_duster0123/multiview-adapter-001/eval_{dataset}/step-90799_metrics_avg.csv\",\n    },\n    \"mv1-d4\": {\n        \"name\": \"Multi-view-V1 (ours) (256; trained on D4)\",\n        \"csv\": \"logs/kubric_v3_duster0123/multiview-v1-with-256-triplane-001/eval_{dataset}/step-99999_metrics_avg.csv\",\n    },\n    \"mv2-d4\": {\n        \"name\": \"Multi-view-V2 (ours) (trained on D4)\",\n        \"csv\": \"logs/kubric_v3_duster0123/multiview-v2-001--k-16--fmaps-384--groups-1--levels-4--grad-clip-1--iters-4--window-12/eval_{dataset}/step-99999_metrics_avg.csv\",\n    },\n    \"mv3-d4c\": {\n        \"name\": \"Multi-view-V3 (ours) (finetuned^2 on D4c;s=4)\",\n        \"csv\": \"logs/kubric_v3_duster0123/multiview-v3-001--lr-2.5e-4--fmaps-384/eval_{dataset}/step-9999_metrics_avg.csv\",\n    },\n    \"mv4-a07\": {\n        \"name\": \"Multi-view-V4 (ours) (A07)\",\n        \"csv\": \"logs/kubric_v3_augs/multiview-v4-A07.augs_4.002/eval_{dataset}/step-25599_metrics_avg.csv\",\n    },\n    \"mv4-b01\": {\n        \"name\": \"Multi-view-V4 (ours) (B01)\",\n        \"csv\": \"logs/kubric_v3_augs/multiview-v4-B01.vary_n_views.004/eval_{dataset}/step-199999_metrics_avg.csv\",\n    },\n    \"mv4-b02\": {\n        \"name\": \"Multi-view-V4 (ours) (B02)\",\n        \"csv\": \"logs/kubric_v3_augs/multiview-v4-B02.vary_depth_type.002a/eval_{dataset}/step-199999_metrics_avg.csv\",\n    },\n    \"mv4-b03\": {\n        \"name\": \"Multi-view-V4 (ours) (B03)\",\n        \"csv\": \"logs/kubric_v3_augs/multiview-v4-B03.vary_both.004/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"mv4-b03-paper\": {\n        \"name\": \"Multi-view-V4 (ours) (B03 paper ckpt)\",\n        \"csv\": \"logs/kubric_v3_augs/multiview-v4-B03.vary_both.004/eval_{dataset}/step-153999_metrics_avg.csv\",\n    },\n\n    # # \"C01.001.0\" : {\n    # #     \"name\": \"Ablation (C01) – Offset 1 AddXYZ 0 K 16 P 4\",\n    # #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.0_K-16_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # # },\n    # \"C01.001.1\" : {\n    #     \"name\": \"Ablation (C01) – Offset 0 AddXYZ 0\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.1_K-16_FMAP-128_PYR-4_KNN-remove_offset/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # \"C01.001.2\" : {\n    #     \"name\": \"Ablation (C01) – Offset 1 AddXYZ 1\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.2_K-16_FMAP-128_PYR-4_KNN-add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # # \"C01.001.3\" : {\n    # #     \"name\": \"Ablation (C01) – Offset 0 AddXYZ 1\",\n    # #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.3_K-16_FMAP-128_PYR-4_KNN-remove_offset_and_add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv\",\n    # # },\n    # \"C01.001.4\" : {\n    #     \"name\": \"Ablation (C01) – K 1\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.4_K-1_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # \"C01.001.5\" : {\n    #     \"name\": \"Ablation (C01) – K 4\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.5_K-4_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # \"C01.001.6\" : {\n    #     \"name\": \"Ablation (C01) – K 8\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.6_K-8_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # # \"C01.001.7\" : {\n    # #     \"name\": \"Ablation (C01) – K 32\",\n    # #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.7_K-32_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # # },\n    # # \"C01.001.8\" : {\n    # #     \"name\": \"Ablation (C01) – K 64\",\n    # #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.8_K-64_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # # },\n    # \"C01.001.9\" : {\n    #     \"name\": \"Ablation (C01) – P 1\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.9_K-16_FMAP-128_PYR-1_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # \"C01.001.10\" : {\n    #     \"name\": \"Ablation (C01) – P 2\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.10_K-16_FMAP-128_PYR-2_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    # # \"C01.001.11\" : {\n    # #     \"name\": \"Ablation (C01) – P 6\",\n    # #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.001.11_K-16_FMAP-128_PYR-6_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # # },\n\n    \"C02.001.0\": {\n        \"name\": \"Ablation (C02) – Offset 1 AddXYZ 0 K 16 P 4\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.0_K-16_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.1\": {\n\n        \"name\": \"Ablation (C02) – Offset 0 AddXYZ 0\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.1_K-16_FMAP-128_PYR-4_KNN-remove_offset/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.2\": {\n        \"name\": \"Ablation (C02) – Offset 1 AddXYZ 1\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.2_K-16_FMAP-128_PYR-4_KNN-add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    # \"C02.001.3\" : {\n    #     \"name\": \"Ablation (C02) – Offset 0 AddXYZ 1\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.3_K-16_FMAP-128_PYR-4_KNN-remove_offset_and_add_neighbor_xyz/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    \"C02.001.4\": {\n        \"name\": \"Ablation (C02) – K 1\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.4_K-1_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.5\": {\n        \"name\": \"Ablation (C02) – K 4\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.5_K-4_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.6\": {\n        \"name\": \"Ablation (C02) – K 8\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.6_K-8_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    # \"C02.002.7\" : {\n    #     \"name\": \"Ablation (C02) – K 32\",\n    #     \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.002.7_K-32_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    # },\n    \"C02.001.8\": {\n        \"name\": \"Ablation (C02) – K 64\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.8_K-64_FMAP-128_PYR-4_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.9\": {\n        \"name\": \"Ablation (C02) – P 1\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.9_K-16_FMAP-128_PYR-1_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.10\": {\n        \"name\": \"Ablation (C02) – P 2\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.10_K-16_FMAP-128_PYR-2_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"C02.001.11\": {\n        \"name\": \"Ablation (C02) – P 6\",\n        \"csv\": \"logs/kubric_v3_augs/ablate-correlation.dusterdepths.C02.001.11_K-16_FMAP-128_PYR-6_KNN-default/eval_{dataset}/step-?_metrics_avg.csv\",\n    },\n    \"shape-of-motion\": {\n        \"name\": \"Shape of Motion (MV)\",\n        \"csv\": \"logs/shape_of_motion/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n\n    # June 2025\n    \"mvtracker-march\": {\n        \"name\": \"MV-Tracker (ours; March 2025)\",\n        \"csv\": \"logs/eval/mvtracker-iccv-march2025/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n    \"mvtracker-june\": {\n        \"name\": \"MV-Tracker (ours; June 2025)\",\n        \"csv\": \"logs/eval/mvtracker-june2025/eval_{dataset}/step--1_metrics_avg.csv\",\n    },\n}\n\n\ndef tavid2d_davis():\n    print(\"TAPVid-2D DAVIS:\")\n    print(\"================\")\n    models_to_report = [\n        \"copycat\",\n        \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n        \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n        \"spatracker-pretrained\",\n        \"spatracker\", \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n        \"mv4-b01\", \"mv4-b02\", \"mv4-b03\",\n    ]\n    assert all(m in MODELS for m in models_to_report)\n    for resolution in [\n        \"-256x256\",\n        # \"\",\n    ]:\n        for depth_estimator in [\n            # \"zoedepth\",\n            # \"moge\",\n            \"mogewithextrinsics\",\n        ]:\n            df, csv_str = create_table({\n                MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"tapvid2d-davis-{depth_estimator}{resolution}\")\n                for m in models_to_report\n            }, remap=REMAP_TAPVID2D, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES)\n            print(f\"Resolution: {resolution}, Depth estimator: {depth_estimator}\")\n            print(csv_str)\n            print()\n\n\ndef dexycb():\n    print(\"DexYCB evaluation results:\")\n    print(\"==========================\")\n    for models_to_report, depths in [\n        ([\"copycat\",\n          \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\", \"mv3-d4c\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\"], \"\"),\n        ([\"copycat\",\n          \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\", \"shape-of-motion\", \"mv4-b03-paper\"], \"-duster0123\"),\n        ([\"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          \"mv3-d4c\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\"], \"-duster0123cleaned\"),\n    ]:\n        assert all(m in MODELS for m in models_to_report)\n        # for remove_hand in [\"\", \"-removehand\"]:\n        for remove_hand in [\"\"]:\n            df, csv_str = create_table({\n                MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"dex-ycb-multiview{depths}{remove_hand}\")\n                for m in models_to_report\n            }, remap=REMAP_DEXYCB_V2)\n            print(f\"Depths: {depths} Remove hand: {remove_hand}\")\n            print(csv_str)\n            print()\n\n\ndef kubric_refactored():\n    print(\"Kubric evaluation results:\")\n    print(\"==========================\")\n    for models_to_report, depths in [\n        ([\"copycat\",\n          \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\", \"shape-of-motion\", \"mv4-b03-paper\"], \"-views0123\"),\n        ([\"copycat\",\n          \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          # \"spatracker-pretrained\",\n          # \"spatracker\",\n          \"mv1\", \"mv2\",\n          # \"spatracker-d4\",\n          \"mv1-d4\", \"mv2-d4\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\"], \"-duster0123\"),\n        ([\"spatracker\", \"spatracker-d4\", ], \"-duster0123-views0123\"),\n        ([\"copycat\",\n          \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n          \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n          # \"spatracker-pretrained\",\n          # \"spatracker\",\n          \"mv1\", \"mv2\",\n          # \"spatracker-d4\",\n          \"mv1-d4\", \"mv2-d4\",\n          \"mv4-b01\", \"mv4-b02\", \"mv4-b03\"], \"-duster0123cleaned\"),\n        ([\"spatracker-d4\"], \"-duster0123cleaned-views0123\"),\n    ]:\n        assert all(m in MODELS for m in models_to_report)\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"kubric-multiview-v3{depths}\")\n            for m in models_to_report\n        }, remap=REMAP_KUBRIC)\n        print(f\"Depths: {depths}\")\n        print(csv_str)\n        print()\n\n\ndef panoptic():\n    print(\"Panoptic Studio evaluation results:\")\n    print(\"===================================\")\n    models_to_report = [\n        \"copycat\",\n        \"locotrack\", \"scenetracker\", \"delta\", \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n        \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n        \"spatracker-pretrained\",\n        \"spatracker\", \"mv1\", \"mv2\",\n        \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n        \"mv4-b01\", \"mv4-b02\", \"mv4-b03\",\n        \"shape-of-motion\"\n    ]\n    assert all(m in MODELS for m in models_to_report)\n    for views in [\"-views1_7_14_20\", \"-views27_16_14_8\", \"-views1_4_7_11\"]:\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"panoptic-multiview{views}\")\n            for m in models_to_report\n        }, remap=REMAP_PANOPTIC)\n        print(f\"*** Views: {views} ***\")\n        print(csv_str)\n        print()\n\n\ndef kubric_single():\n    print(\"Kubric single-point evaluation results:\")\n    print(\"==========================\")\n    for models_to_report, depths in [\n        ([\"copycat\", \"cotracker3\", \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\"], \"-views0123\"),\n        ([\"copycat\", \"cotracker3\",\n          # \"spatracker-pretrained\",\n          # \"spatracker\",\n          \"mv1\", \"mv2\",\n          # \"spatracker-d4\",\n          \"mv1-d4\", \"mv2-d4\", ], \"-duster0123\"),\n        ([\"cotracker3\", \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\", ], \"-duster0123cleaned\"),\n    ]:\n        assert all(m in MODELS for m in models_to_report)\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"kubric-multiview-v3{depths}\")\n            for m in models_to_report\n        }, remap=REMAP_KUBRIC)\n        print(f\"Depths: {depths}\")\n        print(csv_str)\n        print()\n\n\ndef dexycb_single():\n    print(\"DexYCB single-point evaluation results:\")\n    print(\"==========================\")\n    for models_to_report, depths in [\n        ([\"copycat\", \"cotracker3\", \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\"], \"\"),\n        ([\"copycat\", \"cotracker3\", \"spatracker-pretrained\",\n          \"spatracker\", \"mv1\", \"mv2\",\n          \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\", ], \"-duster0123\"),\n        ([\"copycat\", \"cotracker3\",\n          # \"spatracker-pretrained\",\n          # \"spatracker\",\n          \"mv1\", \"mv2\",\n          # \"spatracker-d4\",\n          \"mv1-d4\", \"mv2-d4\", ], \"-duster0123cleaned\"),\n    ]:\n        assert all(m in MODELS for m in models_to_report)\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"dex-ycb-multiview{depths}-single\")\n            for m in models_to_report\n        }, remap=REMAP_DEXYCB_V2)\n        print(f\"Depths: {depths}\")\n        print(csv_str)\n        print()\n\n\ndef panoptic_single():\n    print(\"Panoptic Studio single-point evaluation results:\")\n    print(\"================================================\")\n    models_to_report = [\n        \"copycat\", \"cotracker3\", \"spatracker-pretrained\",\n        \"spatracker\", \"mv1\",\n        \"mv2\",\n        \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n    ]\n    assert all(m in MODELS for m in models_to_report)\n    for views in [\n        # \"-views27_16_14_8\",\n        # \"-views1_4_7_11\",\n        \"-views1_7_14_20\",\n    ]:\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=f\"panoptic-multiview{views}-single\")\n            for m in models_to_report\n        }, remap=REMAP_PANOPTIC)\n        print(f\"*** Views: {views} ***\")\n        print(csv_str)\n        print()\n\n\nMODEL_KEYS_ABLATION = [\n    \"copycat\",\n    \"locotrack\", \"scenetracker\", \"delta\",\n    \"cotracker1online\", \"cotracker2online\", \"cotracker3online\",\n    \"cotracker1offline\", \"cotracker2offline\", \"cotracker3offline\",\n    \"spatracker-pretrained\",\n    \"spatracker\", \"mv1\", \"mv2\",\n    \"spatracker-d4\", \"mv1-d4\", \"mv2-d4\",\n    \"mv4-b01\", \"mv4-b02\", \"mv4-b03\",\n]\n\n\ndef ablation_2dpt():\n    datasets = [\n        \"kubric-multiview-v3-views0123-2dpt\",\n        \"kubric-multiview-v3-duster0123-2dpt\",\n        \"dex-ycb-multiview-2dpt\",\n        \"dex-ycb-multiview-duster0123-2dpt\",\n        \"panoptic-multiview-views1_7_14_20-2dpt\",\n        \"panoptic-multiview-views27_16_14_8-2dpt\",\n        \"panoptic-multiview-views1_4_7_11-2dpt\",\n    ]\n    models_to_report = MODEL_KEYS_ABLATION\n    assert all(m in MODELS for m in models_to_report)\n    for dataset in datasets:\n        df, csv_str = create_table({\n            MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=dataset)\n            for m in models_to_report\n        }, remap=REMAP_KUBRIC | PARTIAL_REMAP_FOR_2DPT_ABLATION, header=dataset == datasets[0])\n        print(f\"DATASET: {dataset}\")\n        print(csv_str)\n        print()\n\n\ndef one_to_rule_them_all(models, datasets, separate_datasets=True, **create_table_kwargs):\n    assert all(m in MODELS for m in models)\n    if not separate_datasets:\n        df, csv_str = create_table({\n            (MODELS[m][\"name\"], dataset): MODELS[m][\"csv\"].format(dataset=dataset)\n            for m in models\n            for dataset in datasets\n        }, remap=ONE_REMAP_TO_RULE_THEM_ALL, header=True, **create_table_kwargs)\n        print(csv_str)\n        print()\n\n    else:\n        for dataset in datasets:\n            df, csv_str = create_table({\n                MODELS[m][\"name\"]: MODELS[m][\"csv\"].format(dataset=dataset)\n                for m in models\n            }, remap=ONE_REMAP_TO_RULE_THEM_ALL, header=dataset == datasets[0], **create_table_kwargs)\n            print(f\"DATASET: {dataset}\")\n            print(csv_str)\n            print()\n\n\ndef ablation_model_params():\n    datasets = [\n        \"kubric-multiview-v3-views0123\",\n        \"kubric-multiview-v3-duster0123\",\n        \"dex-ycb-multiview\",\n        \"dex-ycb-multiview-duster0123\",\n        \"panoptic-multiview-views1_7_14_20\",\n        \"panoptic-multiview-views27_16_14_8\",\n        \"panoptic-multiview-views1_4_7_11\",\n    ]\n    models = [m for m in MODELS if m.startswith(\"C01\") or m.startswith(\"C02\")]\n    one_to_rule_them_all(models, datasets)\n\n\ndef ablation_camera_setups():\n    datasets = [\n        \"panoptic-multiview-views1_7_14_20\",\n        \"panoptic-multiview-views27_16_14_8\",\n        \"panoptic-multiview-views1_4_7_11\",\n        \"dex-ycb-multiview-duster0123\",\n        \"dex-ycb-multiview-duster2345\",\n        \"dex-ycb-multiview-duster4567\",\n    ]\n    one_to_rule_them_all(MODEL_KEYS_ABLATION, datasets)\n\n\ndef ablation_num_views(separate_datasets):\n    datasets = [\n        \"kubric-multiview-v3-views0\",\n        \"kubric-multiview-v3-views01\",\n        \"kubric-multiview-v3-views012\",\n        \"kubric-multiview-v3-views0123\",\n        \"kubric-multiview-v3-views01234\",\n        \"kubric-multiview-v3-views012345\",\n        \"kubric-multiview-v3-views0123456\",\n        \"kubric-multiview-v3-views01234567\",\n        \"kubric-multiview-v3-duster0123-views0\",\n        \"kubric-multiview-v3-duster0123-views01\",\n        \"kubric-multiview-v3-duster0123-views012\",\n        \"kubric-multiview-v3-duster0123-views0123\",\n        \"kubric-multiview-v3-duster01234567-views01234\",\n        \"kubric-multiview-v3-duster01234567-views012345\",\n        \"kubric-multiview-v3-duster01234567-views0123456\",\n        \"kubric-multiview-v3-duster01234567-views01234567\",\n        \"panoptic-multiview-views1\",\n        \"panoptic-multiview-views1_14\",\n        \"panoptic-multiview-views1_7_14\",\n        \"panoptic-multiview-views1_7_14_20\",\n        \"panoptic-multiview-views1_4_7_14_20\",\n        \"panoptic-multiview-views1_4_7_14_17_20\",\n        \"panoptic-multiview-views1_4_7_11_14_17_20\",\n        \"panoptic-multiview-views1_4_7_11_14_17_20_23\",\n        \"dex-ycb-multiview-duster0123-views0\",\n        \"dex-ycb-multiview-duster0123-views01\",\n        \"dex-ycb-multiview-duster0123-views012\",\n        \"dex-ycb-multiview-duster0123-views0123\",\n        \"dex-ycb-multiview-duster01234567-views01234\",\n        \"dex-ycb-multiview-duster01234567-views012345\",\n        \"dex-ycb-multiview-duster01234567-views0123456\",\n        \"dex-ycb-multiview-duster01234567-views01234567\",\n    ]\n    one_to_rule_them_all(MODEL_KEYS_ABLATION, datasets, separate_datasets=separate_datasets, skip_missing=True)\n\n\nif __name__ == '__main__':\n    # kubric_single_point()\n    # kubric_before_gt0123()\n    # kubric()\n    # kubric_duster()\n\n    # mv3_kubric_duster_transformed()\n    # mv3_kubric_nviews()\n    # mv3_kubric_duster_nviews()\n\n    # kubric_nviews()\n\n    # tavid2d_davis()\n\n    # dexycb()\n    # kubric_refactored()\n    # panoptic()\n\n    # kubric_single()\n    # dexycb_single()\n    # panoptic_single()\n\n    # ablation_model_params()\n\n    # ablation_2dpt()\n    # ablation_camera_setups()\n    # ablation_num_views(separate_datasets=False)\n    # ablation_num_views(separate_datasets=True)\n\n    #########################################\n\n    # print(\"Dirty results:\")\n    # print(\"==========================\")\n    # df, csv_str = create_table({\n    #     \"CoTracker3 Online\": \"logs/eval/cotracker3_online/eval_tapvid2d-davis-megasam-256x256/step--1_metrics_avg.csv\",\n    #     \"MV-Tracker + MoGe\": \"logs/mvtracker-may/eval_tapvid2d-davis-moge-256x256/step--1_metrics_avg.csv\",\n    #     \"MV-Tracker + MoGe-with-extrinsics\": \"logs/mvtracker-may/eval_tapvid2d-davis-mogewithextrinsics-256x256/step--1_metrics_avg.csv\",\n    #     \"MV-Tracker + ZoeDepth\": \"logs/mvtracker-may/eval_tapvid2d-davis-zoedepth-256x256/step--1_metrics_avg.csv\",\n    #     \"MV-Tracker + MegaSAM\": \"logs/mvtracker-may/eval_tapvid2d-davis-megasam-256x256/step--1_metrics_avg.csv\",\n    # }, remap=REMAP_TAPVID2D, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES)\n    # print(csv_str)\n    #\n    # print(\"Depth + Gaussian noise\")\n    # print(\"==========================\")\n    # df, csv_str = create_table({\n    #     f\"{model};{noise}\": f\"{model}/eval_kubric-multiview-v3-noise{noise}/step--1_metrics_avg.csv\"\n    #     for model in [\n    #         \"logs/eval/delta\",\n    #         \"logs/eval/spatracker_monocular_pretrained\",\n    #         \"logs/eval/spatracker_monocular_kubric-training\",\n    #         \"logs/eval/spatracker_monocular_duster-training\",\n    #         \"logs/eval/spatracker_multiview_kubric-training\",\n    #         # \"logs/eval/spatracker_multiview_duster-training\",\n    #         # \"logs/mvtracker-noise2\",\n    #         \"logs/eval/spatracker_multiview_duster-training-noise3\",\n    #         \"logs/mvtracker-noise3\",\n    #     ]\n    #     for noise in [\"0cm\", \"1cm\", \"2cm\", \"5cm\", \"10cm\", \"20cm\", \"50cm\", \"100cm\", \"200cm\", \"1000cm\"]\n    # }, remap=ONE_REMAP_TO_RULE_THEM_ALL, remap_index_names=REMAP_TAPVID2D_INDEX_NAMES)\n    # print(csv_str)\n\n    #########################################\n\n    print(\"Final full-scale model re-training (June 2025)\")\n    print(\"==========================\")\n    datasets = [\n        \"kubric-multiview-v3-views0123\",\n        \"kubric-multiview-v3-duster0123\",\n        \"dex-ycb-multiview\",\n        \"dex-ycb-multiview-duster0123\",\n        \"panoptic-multiview-views1_7_14_20\",\n        \"panoptic-multiview-views27_16_14_8\",\n        \"panoptic-multiview-views1_4_7_11\",\n        \"tapvid2d-davis-mogewithextrinsics-256x256\",\n        \"tapvid2d-davis-megasam-256x256\",\n    ]\n    models = [\"mvtracker-march\", \"mvtracker-june\"]\n    one_to_rule_them_all(models, datasets)\n\n"
  }
]