Full Code of naver/dust3r for AI

main 4c24a6ebf048 cached
91 files
422.9 KB
114.1k tokens
512 symbols
1 requests
Download .txt
Showing preview only (448K chars total). Download the full file or copy to clipboard to get everything.
Repository: naver/dust3r
Branch: main
Commit: 4c24a6ebf048
Files: 91
Total size: 422.9 KB

Directory structure:
gitextract_0hi9z8ek/

├── .gitignore
├── .gitmodules
├── LICENSE
├── NOTICE
├── README.md
├── datasets_preprocess/
│   ├── habitat/
│   │   ├── README.md
│   │   ├── find_scenes.py
│   │   ├── habitat_renderer/
│   │   │   ├── __init__.py
│   │   │   ├── habitat_sim_envmaps_renderer.py
│   │   │   ├── multiview_crop_generator.py
│   │   │   ├── projections.py
│   │   │   └── projections_conversions.py
│   │   └── preprocess_habitat.py
│   ├── path_to_root.py
│   ├── preprocess_arkitscenes.py
│   ├── preprocess_blendedMVS.py
│   ├── preprocess_co3d.py
│   ├── preprocess_megadepth.py
│   ├── preprocess_scannetpp.py
│   ├── preprocess_staticthings3d.py
│   ├── preprocess_waymo.py
│   └── preprocess_wildrgbd.py
├── demo.py
├── docker/
│   ├── docker-compose-cpu.yml
│   ├── docker-compose-cuda.yml
│   ├── files/
│   │   ├── cpu.Dockerfile
│   │   ├── cuda.Dockerfile
│   │   └── entrypoint.sh
│   └── run.sh
├── dust3r/
│   ├── __init__.py
│   ├── cloud_opt/
│   │   ├── __init__.py
│   │   ├── base_opt.py
│   │   ├── commons.py
│   │   ├── init_im_poses.py
│   │   ├── modular_optimizer.py
│   │   ├── optimizer.py
│   │   └── pair_viewer.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── arkitscenes.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base_stereo_view_dataset.py
│   │   │   ├── batched_sampler.py
│   │   │   └── easy_dataset.py
│   │   ├── blendedmvs.py
│   │   ├── co3d.py
│   │   ├── habitat.py
│   │   ├── megadepth.py
│   │   ├── scannetpp.py
│   │   ├── staticthings3d.py
│   │   ├── utils/
│   │   │   ├── __init__.py
│   │   │   ├── cropping.py
│   │   │   └── transforms.py
│   │   ├── waymo.py
│   │   └── wildrgbd.py
│   ├── demo.py
│   ├── heads/
│   │   ├── __init__.py
│   │   ├── dpt_head.py
│   │   ├── linear_head.py
│   │   └── postprocess.py
│   ├── image_pairs.py
│   ├── inference.py
│   ├── losses.py
│   ├── model.py
│   ├── optim_factory.py
│   ├── patch_embed.py
│   ├── post_process.py
│   ├── training.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── device.py
│   │   ├── geometry.py
│   │   ├── image.py
│   │   ├── misc.py
│   │   ├── parallel.py
│   │   └── path_to_croco.py
│   └── viz.py
├── dust3r_visloc/
│   ├── README.md
│   ├── __init__.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── aachen_day_night.py
│   │   ├── base_colmap.py
│   │   ├── base_dataset.py
│   │   ├── cambridge_landmarks.py
│   │   ├── inloc.py
│   │   ├── sevenscenes.py
│   │   └── utils.py
│   ├── evaluation.py
│   └── localization.py
├── requirements.txt
├── requirements_optional.txt
├── train.py
└── visloc.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data/
checkpoints/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: .gitmodules
================================================
[submodule "croco"]
	path = croco
	url = https://github.com/naver/croco


================================================
FILE: LICENSE
================================================
DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.

A summary of the CC BY-NC-SA 4.0 license is located here:
	https://creativecommons.org/licenses/by-nc-sa/4.0/

The CC BY-NC-SA 4.0 license is located here:
	https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode


================================================
FILE: NOTICE
================================================
DUSt3R
Copyright 2024-present NAVER Corp.

This project contains subcomponents with separate copyright notices and license terms. 
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.

====

naver/croco
https://github.com/naver/croco/

Creative Commons Attribution-NonCommercial-ShareAlike 4.0


================================================
FILE: README.md
================================================
![demo](assets/dust3r.jpg)

Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`  
[[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]  

> Make sure to also check our other works:  
> [Grounding Image Matching in 3D with MASt3R](https://github.com/naver/mast3r): DUSt3R with a local feature head, metric pointmaps, and a more scalable global alignment!  
> [Pow3R: Empowering Unconstrained 3D Reconstruction with Camera and Scene Priors](https://github.com/naver/pow3r): DUSt3R with known depth / focal length / poses.  
> [MUSt3R: Multi-view Network for Stereo 3D Reconstruction](https://github.com/naver/must3r): Multi-view predictions (RGB SLAM/SfM) without any global alignment.    

![Example of reconstruction from two images](assets/pipeline1.jpg)

![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg)

```bibtex
@inproceedings{dust3r_cvpr24,
      title={DUSt3R: Geometric 3D Vision Made Easy}, 
      author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
      booktitle = {CVPR},
      year = {2024}
}

@misc{dust3r_arxiv23,
      title={DUSt3R: Geometric 3D Vision Made Easy}, 
      author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
      year={2023},
      eprint={2312.14132},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## Table of Contents

- [Table of Contents](#table-of-contents)
- [License](#license)
- [Get Started](#get-started)
  - [Installation](#installation)
  - [Checkpoints](#checkpoints)
  - [Interactive demo](#interactive-demo)
  - [Interactive demo with docker](#interactive-demo-with-docker)
- [Usage](#usage)
- [Training](#training)
  - [Datasets](#datasets)
  - [Demo](#demo)
  - [Our Hyperparameters](#our-hyperparameters)

## License

The code is distributed under the CC BY-NC-SA 4.0 License.
See [LICENSE](LICENSE) for more information.

```python
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
```

## Get Started

### Installation

1. Clone DUSt3R.
```bash
git clone --recursive https://github.com/naver/dust3r
cd dust3r
# if you have already cloned dust3r:
# git submodule update --init --recursive
```

2. Create the environment, here we show an example using conda.
```bash
conda create -n dust3r python=3.11 cmake=3.14.0
conda activate dust3r 
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia  # use the correct version of cuda for your system
pip install -r requirements.txt
# Optional: you can also install additional packages to:
# - add support for HEIC images
# - add pyrender, used to render depthmap in some datasets preprocessing
# - add required packages for visloc.py
pip install -r requirements_optional.txt
```

3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
```bash
# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
cd croco/models/curope/
python setup.py build_ext --inplace
cd ../../../
```

### Checkpoints

You can obtain the checkpoints by two ways:

1) You can use our huggingface_hub integration: the models will be downloaded automatically.

2) Otherwise, We provide several pre-trained models:

| Modelname   | Training resolutions | Head | Encoder | Decoder |
|-------------|----------------------|------|---------|---------|
| [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
| [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth)   | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
| [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |

You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)

To download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`:
```bash
mkdir -p checkpoints/
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
```

For the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details.

### Interactive demo

In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
First select images that depicts the same scene.

You can adjust the global alignment schedule and its number of iterations.

> [!NOTE]
> If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)

Hit "Run" and wait.
When the global alignment ends, the reconstruction appears.
Use the slider "min_conf_thr" to show or remove low confidence areas.

```bash
python3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt

# Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
# Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224
# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
# Use --server_port to change the port, by default it will search for an available port starting at 7860
# Use --device to use a different device, by default it's "cuda"
```

### Interactive demo with docker

To run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions:

1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).

2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).

3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands: 

```bash
cd docker
bash run.sh --with-cuda --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
```

Or if you want to run the demo without CUDA support, run the following command:

```bash 
cd docker
bash run.sh --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
```

By default, `demo.py` is lanched with the option `--local_network`.  
Visit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network).  

`run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh).


![demo](assets/demo.jpg)

## Usage

```python
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode

if __name__ == '__main__':
    device = 'cuda'
    batch_size = 1
    schedule = 'cosine'
    lr = 0.01
    niter = 300

    model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
    # you can put the path to a local checkpoint in model_name if needed
    model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)
    # load_images can take a list of images or a directory
    images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
    pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
    output = inference(pairs, model, device, batch_size=batch_size)

    # at this stage, you have the raw dust3r predictions
    view1, pred1 = output['view1'], output['pred1']
    view2, pred2 = output['view2'], output['pred2']
    # here, view1, pred1, view2, pred2 are dicts of lists of len(2)
    #  -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
    # in each view you have:
    # an integer image identifier: view1['idx'] and view2['idx']
    # the img: view1['img'] and view2['img']
    # the image shape: view1['true_shape'] and view2['true_shape']
    # an instance string output by the dataloader: view1['instance'] and view2['instance']
    # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
    # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
    # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']

    # next we'll use the global_aligner to align the predictions
    # depending on your task, you may be fine with the raw output and not need it
    # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
    # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
    scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
    loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

    # retrieve useful values from scene:
    imgs = scene.imgs
    focals = scene.get_focals()
    poses = scene.get_im_poses()
    pts3d = scene.get_pts3d()
    confidence_masks = scene.get_masks()

    # visualize reconstruction
    scene.show()

    # find 2D-2D matches between the two images
    from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
    pts2d_list, pts3d_list = [], []
    for i in range(2):
        conf_i = confidence_masks[i].cpu().numpy()
        pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)
        pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
    reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
    print(f'found {num_matches} matches')
    matches_im1 = pts2d_list[1][reciprocal_in_P2]
    matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]

    # visualize a few matches
    import numpy as np
    from matplotlib import pyplot as pl
    n_viz = 10
    match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
    viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]

    H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
    img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
    img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
    img = np.concatenate((img0, img1), axis=1)
    pl.figure()
    pl.imshow(img)
    cmap = pl.get_cmap('jet')
    for i in range(n_viz):
        (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
        pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
    pl.show(block=True)

```
![matching example on croco pair](assets/matching.jpg)

## Training

In this section, we present a short demonstration to get started with training DUSt3R.

### Datasets
At this moment, we have added the following training datasets:
  - [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE)
  - [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license)
  - [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf)
  - [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/)
  - [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/)
  - [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md)
  - [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/)
  - [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d)
  - [WildRGB-D](https://github.com/wildrgbd/wildrgbd/)

For each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed.
You have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script.

Links:  
  
[ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip)  
[ScanNet++ v1 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip)  
[ScanNet++ v2 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_v2_pairs.zip)  
[BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy)  
[WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz)  
[Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz)  
[MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz)  
[StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy)  

> [!NOTE]
> They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough.

### Demo
For this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
The demo model will be trained for a few epochs on a very small dataset.
It will not be very good.

```bash
# download and prepare the co3d subset
mkdir -p data/co3d_subset
cd data/co3d_subset
git clone https://github.com/facebookresearch/co3d
cd co3d
python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
rm ../*.zip
cd ../../..

python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed  --single_sequence_subset

# download the pretrained croco v2 checkpoint
mkdir -p checkpoints/
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/

# the training of dust3r is done in 3 steps.
# for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
# step 1 - train dust3r for 224 resolution
torchrun --nproc_per_node=4 train.py \
    --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
    --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
    --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --pretrained "checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
    --save_freq 1 --keep_freq 5 --eval_freq 1 \
    --output_dir "checkpoints/dust3r_demo_224"	  

# step 2 - train dust3r for 512 resolution
torchrun --nproc_per_node=4 train.py \
    --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
    --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
    --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --pretrained "checkpoints/dust3r_demo_224/checkpoint-best.pth" \
    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
    --save_freq 1 --keep_freq 5 --eval_freq 1 \
    --output_dir "checkpoints/dust3r_demo_512"

# step 3 - train dust3r for 512 resolution with dpt
torchrun --nproc_per_node=4 train.py \
    --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
    --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
    --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --pretrained "checkpoints/dust3r_demo_512/checkpoint-best.pth" \
    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
    --save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \
    --output_dir "checkpoints/dust3r_demo_512dpt"

```

### Our Hyperparameters

Here are the commands we used for training the models:

```bash
# NOTE: ROOT path omitted for datasets
# 224 linear
torchrun --nproc_per_node 8 train.py \
    --train_dataset=" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) " \
    --test_dataset=" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) " \
    --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
    --save_freq=5 --keep_freq=10 --eval_freq=1 \
    --output_dir="checkpoints/dust3r_224"

# 512 linear
torchrun --nproc_per_node 8 train.py \
    --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
    --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
    --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --pretrained="checkpoints/dust3r_224/checkpoint-best.pth" \
    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \
    --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
    --output_dir="checkpoints/dust3r_512"

# 512 dpt
torchrun --nproc_per_node 8 train.py \
    --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
    --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
    --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
    --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
    --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
    --pretrained="checkpoints/dust3r_512/checkpoint-best.pth" \
    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \
    --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \
    --output_dir="checkpoints/dust3r_512dpt"

```


================================================
FILE: datasets_preprocess/habitat/README.md
================================================
## Steps to reproduce synthetic training data using the Habitat-Sim simulator

### Create a conda environment
```bash
conda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge
conda active habitat
conda install pytorch -c pytorch
pip install opencv-python tqdm
```

or (if you get the error `For headless systems, compile with --headless for EGL support`)
```
git clone --branch stable https://github.com/facebookresearch/habitat-sim.git
cd habitat-sim

conda create -n habitat python=3.9 cmake=3.14.0
conda activate habitat
pip install . -v
conda install pytorch -c pytorch
pip install opencv-python tqdm
```

### Download Habitat-Sim scenes
Download Habitat-Sim scenes:
- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
- We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets.
- Please put the scenes in a directory `$SCENES_DIR` following the structure below:
(Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting.
The correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`).
```
$SCENES_DIR/
├──hm3d/
├──gibson/
├──habitat-test-scenes/
├──ReplicaCAD_baked_lighting/
└──scannet/
```

### Download renderings metadata 

Download metadata corresponding to each scene and extract them into a directory `$METADATA_DIR`
```bash
wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz
tar -xvzf habitat_5views_v1_512x512_metadata.tar.gz
```

### Render the scenes

Render the scenes in an output directory `$OUTPUT_DIR`
```bash
export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata"
export SCENES_DIR="/path/to/habitat/data/scene_datasets/"
export OUTPUT_DIR="data/habitat_processed"
cd datasets_preprocess/habitat/
export PYTHONPATH=$(pwd)
# Print commandlines to generate images corresponding to each scene
python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR
# Launch these commandlines in parallel e.g. using GNU-Parallel as follows:
python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16
```

### Make a list of scenes

```bash
python find_scenes.py --root $OUTPUT_DIR
```

================================================
FILE: datasets_preprocess/habitat/find_scenes.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Script to export the list of scenes for habitat (after having rendered them).
# Usage:
# python3 datasets_preprocess/preprocess_co3d.py --root data/habitat_processed
# --------------------------------------------------------
import numpy as np
import os
from collections import defaultdict
from tqdm import tqdm


def find_all_scenes(habitat_root, n_scenes=[100000]):
    np.random.seed(777)

    try:
        fpath = os.path.join(habitat_root, f'Habitat_all_scenes.txt')
        list_subscenes = open(fpath).read().splitlines()

    except IOError:
        if input('parsing sub-folders to find scenes? (y/n) ') != 'y':
            return
        list_subscenes = []
        for root, dirs, files in tqdm(os.walk(habitat_root)):
            for f in files:
                if not f.endswith('_1_depth.exr'):
                    continue
                scene = os.path.join(os.path.relpath(root, habitat_root), f.replace('_1_depth.exr', ''))
                if hash(scene) % 1000 == 0:
                    print('... adding', scene)
                list_subscenes.append(scene)

        with open(fpath, 'w') as f:
            f.write('\n'.join(list_subscenes))
        print(f'>> wrote {fpath}')

    print(f'Loaded {len(list_subscenes)} sub-scenes')

    # separate scenes
    list_scenes = defaultdict(list)
    for scene in list_subscenes:
        scene, id = os.path.split(scene)
        list_scenes[scene].append(id)

    list_scenes = list(list_scenes.items())
    print(f'from {len(list_scenes)} scenes in total')

    np.random.shuffle(list_scenes)
    train_scenes = list_scenes[len(list_scenes) // 10:]
    val_scenes = list_scenes[:len(list_scenes) // 10]

    def write_scene_list(scenes, n, fpath):
        sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids]
        np.random.shuffle(sub_scenes)

        if len(sub_scenes) < n:
            return

        with open(fpath, 'w') as f:
            f.write('\n'.join(sub_scenes[:n]))
        print(f'>> wrote {fpath}')

    for n in n_scenes:
        write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt'))
        write_scene_list(val_scenes, n // 10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt'))


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", required=True)
    parser.add_argument("--n_scenes", nargs='+', default=[1_000, 10_000, 100_000, 1_000_000], type=int)

    args = parser.parse_args()
    find_all_scenes(args.root, args.n_scenes)


================================================
FILE: datasets_preprocess/habitat/habitat_renderer/__init__.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).


================================================
FILE: datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Render environment maps from 3D meshes using the Habitat Sim simulator.
# --------------------------------------------------------
import numpy as np
import habitat_sim
import math
from habitat_renderer import projections

# OpenCV to habitat camera convention transformation
R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)

CUBEMAP_FACE_LABELS = ["left", "front", "right", "back", "up", "down"]
# Expressed while considering Habitat coordinates systems
CUBEMAP_FACE_ORIENTATIONS_ROTVEC = [
    [0, math.pi / 2, 0],  # Left
    [0, 0, 0],  # Front
                [0, - math.pi / 2, 0],  # Right
                [0, math.pi, 0],  # Back
                [math.pi / 2, 0, 0],  # Up
                [-math.pi / 2, 0, 0],]  # Down

class NoNaviguableSpaceError(RuntimeError):
    def __init__(self, *args):
        super().__init__(*args)

class HabitatEnvironmentMapRenderer:
    def __init__(self,
                 scene,
                 navmesh,
                 scene_dataset_config_file,
                 render_equirectangular=False,
                 equirectangular_resolution=(512, 1024),
                 render_cubemap=False,
                 cubemap_resolution=(512, 512),
                 render_depth=False,
                 gpu_id=0):
        self.scene = scene
        self.navmesh = navmesh
        self.scene_dataset_config_file = scene_dataset_config_file
        self.gpu_id = gpu_id

        self.render_equirectangular = render_equirectangular
        self.equirectangular_resolution = equirectangular_resolution
        self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution)
        # 3D unit ray associated to each pixel of the equirectangular map
        equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection)
        # Not needed, but just in case.
        equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True)
        # Depth map created by Habitat are produced by warping a cubemap,
        # so the values do not correspond to distance to the center and need some scaling.
        self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1)

        self.render_cubemap = render_cubemap
        self.cubemap_resolution = cubemap_resolution

        self.render_depth = render_depth

        self.seed = None
        self._lazy_initialization()

    def _lazy_initialization(self):
        # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
        if self.seed == None:
            # Re-seed numpy generator
            np.random.seed()
            self.seed = np.random.randint(2**32-1)
            sim_cfg = habitat_sim.SimulatorConfiguration()
            sim_cfg.scene_id = self.scene
            if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
                sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
            sim_cfg.random_seed = self.seed
            sim_cfg.load_semantic_mesh = False
            sim_cfg.gpu_device_id = self.gpu_id

            sensor_specifications = []

            # Add cubemaps
            if self.render_cubemap:
                for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC):
                    rgb_sensor_spec = habitat_sim.CameraSensorSpec()
                    rgb_sensor_spec.uuid = f"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}"
                    rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
                    rgb_sensor_spec.resolution = self.cubemap_resolution
                    rgb_sensor_spec.hfov = 90
                    rgb_sensor_spec.position = [0.0, 0.0, 0.0]
                    rgb_sensor_spec.orientation = orientation
                    sensor_specifications.append(rgb_sensor_spec)

                    if self.render_depth:
                        depth_sensor_spec = habitat_sim.CameraSensorSpec()
                        depth_sensor_spec.uuid = f"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}"
                        depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
                        depth_sensor_spec.resolution = self.cubemap_resolution
                        depth_sensor_spec.hfov = 90
                        depth_sensor_spec.position = [0.0, 0.0, 0.0]
                        depth_sensor_spec.orientation = orientation
                        sensor_specifications.append(depth_sensor_spec)

            # Add equirectangular map
            if self.render_equirectangular:
                rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec()
                rgb_sensor_spec.uuid = "color_equirectangular"
                rgb_sensor_spec.resolution = self.equirectangular_resolution
                rgb_sensor_spec.position = [0.0, 0.0, 0.0]
                sensor_specifications.append(rgb_sensor_spec)

                if self.render_depth:
                    depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec()
                    depth_sensor_spec.uuid = "depth_equirectangular"
                    depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
                    depth_sensor_spec.resolution = self.equirectangular_resolution
                    depth_sensor_spec.position = [0.0, 0.0, 0.0]
                    depth_sensor_spec.orientation
                    sensor_specifications.append(depth_sensor_spec)

            agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications)

            cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
            self.sim = habitat_sim.Simulator(cfg)
            if self.navmesh is not None and self.navmesh != "":
                # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof)
                # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes
                self.sim.pathfinder.load_nav_mesh(self.navmesh)

            # Check that the navmesh is not empty
            if not self.sim.pathfinder.is_loaded:
                # Try to compute a navmesh
                navmesh_settings = habitat_sim.NavMeshSettings()
                navmesh_settings.set_defaults()
                self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)

            # Check that the navmesh is not empty
            if not self.sim.pathfinder.is_loaded:
                raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")

            self.agent = self.sim.initialize_agent(agent_id=0)

    def close(self):
        if hasattr(self, 'sim'):
            self.sim.close()

    def __del__(self):
        self.close()

    def render_viewpoint(self, viewpoint_position):
        agent_state = habitat_sim.AgentState()
        agent_state.position = viewpoint_position
        # agent_state.rotation = viewpoint_orientation
        self.agent.set_state(agent_state)
        viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)

        try:
            # Depth map values have been obtained using cubemap rendering internally,
            # so they do not really correspond to distance to the viewpoint in practice
            # and they need some scaling
            viewpoint_observations["depth_equirectangular"] *= self.equirectangular_depth_scale_factors
        except KeyError:
            pass

        data = dict(observations=viewpoint_observations, position=viewpoint_position)
        return data

    def up_direction(self):
        return np.asarray(habitat_sim.geo.UP).tolist()
    
    def R_cam_to_world(self):
        return R_OPENCV2HABITAT.tolist()


================================================
FILE: datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Generate pairs of crops from a dataset of environment maps.
# --------------------------------------------------------
import os
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"  # noqa
import cv2
import collections
from habitat_renderer import projections, projections_conversions
from habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer

ViewpointData = collections.namedtuple("ViewpointData", ["colormap", "distancemap", "pointmap", "position"])

class HabitatMultiviewCrops:
    def __init__(self,
                 scene,
                 navmesh,
                 scene_dataset_config_file,
                 equirectangular_resolution=(400, 800),
                 crop_resolution=(240, 320),
                 pixel_jittering_iterations=5,
                 jittering_noise_level=1.0):
        self.crop_resolution = crop_resolution

        self.pixel_jittering_iterations = pixel_jittering_iterations
        self.jittering_noise_level = jittering_noise_level

        # Instanciate the low resolution habitat sim renderer
        self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene,
                                                                    navmesh=navmesh,
                                                                    scene_dataset_config_file=scene_dataset_config_file,
                                                                    equirectangular_resolution=equirectangular_resolution,
                                                                    render_depth=True,
                                                                    render_equirectangular=True)
        self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world())
        self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction())

        # Projection applied by each environment map
        self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution
        base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width)
        self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T)
        # 3D Rays map associated to each envmap
        self.envmap_rays = projections.get_projection_rays(self.envmap_projection)

    def compute_pointmap(self, distancemap, position):
        # Point cloud associated to each ray
        return self.envmap_rays * distancemap[:, :, None] + position

    def render_viewpoint_data(self, position):
        data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position))
        colormap = data['observations']['color_equirectangular'][..., :3]  # Ignore the alpha channel
        distancemap = data['observations']['depth_equirectangular']
        pointmap = self.compute_pointmap(distancemap, position)
        return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position)

    def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None):
        remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection,
                                                           pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level)
        cropped_color_image = remapper.convert(
            color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False)
        cropped_distancemap = remapper.convert(
            distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)
        cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST,
                                            borderMode=cv2.BORDER_WRAP, single_map=True)
        cropped_voxelmap = (None if voxelmap is None else
                            remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True))
        # Convert the distance map into a depth map
        cropped_depthmap = np.asarray(
            cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype)

        return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap

def perspective_projection_to_dict(persp_projection, position):
    """
    Serialization-like function."""
    camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(),
                         size=(persp_projection.base_projection.width, persp_projection.base_projection.height),
                         R_cam2world=persp_projection.R_to_base_projection.T.tolist(),
                         t_cam2world=position)
    return camera_params


def dict_to_perspective_projection(camera_params):
    K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params["camera_intrinsics"]))
    size = camera_params["size"]
    R_cam2world = np.asarray(camera_params["R_cam2world"])
    projection = projections.PerspectiveProjection(K, height=size[1], width=size[0])
    projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T)
    position = camera_params["t_cam2world"]
    return projection, position

================================================
FILE: datasets_preprocess/habitat/habitat_renderer/projections.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Various 3D/2D projection utils, useful to sample virtual cameras.
# --------------------------------------------------------
import numpy as np

class EquirectangularProjection:
    """
    Convention for the central pixel of the equirectangular map similar to OpenCV perspective model:
        +X from left to right
        +Y from top to bottom
        +Z going outside the camera
    EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5))
    """

    def __init__(self, height, width):
        self.height = height
        self.width = width
        self.u_scaling = (2 * np.pi) / self.width
        self.v_scaling = np.pi / self.height

    def unproject(self, u, v):
        """
        Args:
            u, v: 2D coordinates
        Returns:
            unnormalized 3D rays.
        """
        longitude = self.u_scaling * u - np.pi
        minus_latitude = self.v_scaling * v - np.pi/2

        cos_latitude = np.cos(minus_latitude)
        x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude
        y = np.sin(minus_latitude)

        rays = np.stack([x, y, z], axis=-1)
        return rays

    def project(self, rays):
        """
        Args:
            rays: Bx3 array of 3D rays.
        Returns:
            u, v: tuple of 2D coordinates.
        """
        rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True)
        x, y, z = [rays[..., i] for i in range(3)]

        longitude = np.arctan2(x, z)
        minus_latitude = np.arcsin(y)

        u = (longitude + np.pi) * (1.0 / self.u_scaling)
        v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling)
        return u, v


class PerspectiveProjection:
    """
    OpenCV convention:
    World space:
        +X from left to right
        +Y from top to bottom
        +Z going outside the camera
    Pixel space:
        +u from left to right
        +v from top to bottom
    EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)).
    """

    def __init__(self, K, height, width):
        self.height = height
        self.width = width
        self.K = K
        self.Kinv = np.linalg.inv(K)

    def project(self, rays):
        uv_homogeneous = np.einsum("ik, ...k -> ...i", self.K, rays)
        uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None]
        return uv[..., 0], uv[..., 1]

    def unproject(self, u, v):
        uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1)
        rays = np.einsum("ik, ...k -> ...i", self.Kinv, uv_homogeneous)
        return rays


class RotatedProjection:
    def __init__(self, base_projection, R_to_base_projection):
        self.base_projection = base_projection
        self.R_to_base_projection = R_to_base_projection

    @property
    def width(self):
        return self.base_projection.width

    @property
    def height(self):
        return self.base_projection.height

    def project(self, rays):
        if self.R_to_base_projection is not None:
            rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection, rays)
        return self.base_projection.project(rays)

    def unproject(self, u, v):
        rays = self.base_projection.unproject(u, v)
        if self.R_to_base_projection is not None:
            rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection.T, rays)
        return rays

def get_projection_rays(projection, noise_level=0):
    """
    Return a 2D map of 3D rays corresponding to the projection.
    If noise_level > 0, add some jittering noise to these rays.
    """
    grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height))
    if noise_level > 0:
        grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width)
        grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height)
    return projection.unproject(grid_u, grid_v)

def compute_camera_intrinsics(height, width, hfov):
    f = width/2 / np.tan(hfov/2 * np.pi/180)
    cu, cv = width/2, height/2
    return f, cu, cv

def colmap_to_opencv_intrinsics(K):
    """
    Modify camera intrinsics to follow a different convention.
    Coordinates of the center of the top-left pixels are by default:
    - (0.5, 0.5) in Colmap
    - (0,0) in OpenCV
    """
    K = K.copy()
    K[0, 2] -= 0.5
    K[1, 2] -= 0.5
    return K

def opencv_to_colmap_intrinsics(K):
    """
    Modify camera intrinsics to follow a different convention.
    Coordinates of the center of the top-left pixels are by default:
    - (0.5, 0.5) in Colmap
    - (0,0) in OpenCV
    """
    K = K.copy()
    K[0, 2] += 0.5
    K[1, 2] += 0.5
    return K

================================================
FILE: datasets_preprocess/habitat/habitat_renderer/projections_conversions.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Remap data from one projection to an other
# --------------------------------------------------------
import numpy as np
import cv2
from habitat_renderer import projections

class RemapProjection:
    def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0):
        """
        Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts.
        """
        assert jittering_noise_level >= 0
        assert pixel_jittering_iterations >= 0

        maps = []
        # Initial map
        self.output_rays = projections.get_projection_rays(output_projection)
        map_u, map_v = input_projection.project(self.output_rays)
        map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32)
        maps.append((map_u, map_v))

        for _ in range(pixel_jittering_iterations):
            # Define multiple mappings using some coordinates jittering to mitigate aliasing effects
            crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level)
            map_u, map_v = input_projection.project(crop_rays)
            map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32)
            maps.append((map_u, map_v))
        self.maps = maps

    def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False):
        remapped = []
        for map_u, map_v in self.maps:
            res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode)
            remapped.append(res)
            if single_map:
                break
        if len(remapped) == 1:
            res = remapped[0]
        else:
            res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype)
        return res


================================================
FILE: datasets_preprocess/habitat/preprocess_habitat.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# main executable for preprocessing habitat
# export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata"
# export SCENES_DIR="/path/to/habitat/data/scene_datasets/"
# export OUTPUT_DIR="data/habitat_processed"
# export PYTHONPATH=$(pwd)
# python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16
# --------------------------------------------------------
import os
import glob
import json
import os

import PIL.Image
import json
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"  # noqa
import cv2
from habitat_renderer import multiview_crop_generator
from tqdm import tqdm


def preprocess_metadata(metadata_filename,
                        scenes_dir,
                        output_dir,
                        crop_resolution=[512, 512],
                        equirectangular_resolution=None,
                        fix_existing_dataset=False):
    # Load data
    with open(metadata_filename, "r") as f:
        metadata = json.load(f)

    if metadata["scene_dataset_config_file"] == "":
        scene = os.path.join(scenes_dir, metadata["scene"])
        scene_dataset_config_file = ""
    else:
        scene = metadata["scene"]
        scene_dataset_config_file = os.path.join(scenes_dir, metadata["scene_dataset_config_file"])
    navmesh = None

    # Use 4 times the crop size as resolution for rendering the environment map.
    max_res = max(crop_resolution)

    if equirectangular_resolution == None:
        # Use 4 times the crop size as resolution for rendering the environment map.
        max_res = max(crop_resolution)
        equirectangular_resolution = (4*max_res, 8*max_res)

    print("equirectangular_resolution:", equirectangular_resolution)

    if os.path.exists(output_dir) and not fix_existing_dataset:
        raise FileExistsError(output_dir)

    # Lazy initialization
    highres_dataset = None

    for batch_label, batch in tqdm(metadata["view_batches"].items()):
        for view_label, view_params in batch.items():

            assert view_params["size"] == crop_resolution
            label = f"{batch_label}_{view_label}"

            output_camera_params_filename = os.path.join(output_dir, f"{label}_camera_params.json")
            if fix_existing_dataset and os.path.isfile(output_camera_params_filename):
                # Skip generation if we are fixing a dataset and the corresponding output file already exists
                continue

            # Lazy initialization
            if highres_dataset is None:
                highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene,
                                                                                 navmesh=navmesh,
                                                                                 scene_dataset_config_file=scene_dataset_config_file,
                                                                                 equirectangular_resolution=equirectangular_resolution,
                                                                                 crop_resolution=crop_resolution,)
                os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset))

            # Generate a higher resolution crop
            original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params)
            # Render an envmap at the given position
            viewpoint_data = highres_dataset.render_viewpoint_data(position)

            projection = original_projection
            colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera(
                projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap)

            camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position)

            # Color image
            PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f"{label}.jpeg"))
            # Depth image
            cv2.imwrite(os.path.join(output_dir, f"{label}_depth.exr"),
                        depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
            # Camera parameters
            with open(output_camera_params_filename, "w") as f:
                json.dump(camera_params, f)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--metadata_dir", required=True)
    parser.add_argument("--scenes_dir", required=True)
    parser.add_argument("--output_dir", required=True)
    parser.add_argument("--metadata_filename", default="")

    args = parser.parse_args()

    if args.metadata_filename == "":
        # Walk through the metadata dir to generate commandlines
        for filename in glob.iglob(os.path.join(args.metadata_dir, "**/metadata.json"), recursive=True):
            output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir))
            if not os.path.exists(output_dir):
                commandline = f"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}"
                print(commandline)
    else:
        preprocess_metadata(metadata_filename=args.metadata_filename,
                            scenes_dir=args.scenes_dir,
                            output_dir=args.output_dir)


================================================
FILE: datasets_preprocess/path_to_root.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# DUSt3R repo root import
# --------------------------------------------------------

import sys
import os.path as path
HERE_PATH = path.normpath(path.dirname(__file__))
DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
# workaround for sibling import
sys.path.insert(0, DUST3R_REPO_PATH)


================================================
FILE: datasets_preprocess/preprocess_arkitscenes.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Script to pre-process the arkitscenes dataset.
# Usage:
# python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs
# --------------------------------------------------------
import os
import json
import os.path as osp
import decimal
import argparse
import math
from bisect import bisect_left
from PIL import Image
import numpy as np
import quaternion
from scipy import interpolate
import cv2


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--arkitscenes_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/arkitscenes_processed')
    return parser


def value_to_decimal(value, decimal_places):
    decimal.getcontext().rounding = decimal.ROUND_HALF_UP  # define rounding method
    return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places)))


def closest(value, sorted_list):
    index = bisect_left(sorted_list, value)
    if index == 0:
        return sorted_list[0]
    elif index == len(sorted_list):
        return sorted_list[-1]
    else:
        value_before = sorted_list[index - 1]
        value_after = sorted_list[index]
        if value_after - value < value - value_before:
            return value_after
        else:
            return value_before


def get_up_vectors(pose_device_to_world):
    return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]]))


def get_right_vectors(pose_device_to_world):
    return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]]))


def read_traj(traj_path):
    quaternions = []
    poses = []
    timestamps = []
    poses_p_to_w = []
    with open(traj_path) as f:
        traj_lines = f.readlines()
        for line in traj_lines:
            tokens = line.split()
            assert len(tokens) == 7
            traj_timestamp = float(tokens[0])

            timestamps_decimal_value = value_to_decimal(traj_timestamp, 3)
            timestamps.append(float(timestamps_decimal_value))  # for spline interpolation

            angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])]
            r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis))
            t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])])

            pose_w_to_p = np.eye(4)
            pose_w_to_p[:3, :3] = r_w_to_p
            pose_w_to_p[:3, 3] = t_w_to_p

            pose_p_to_w = np.linalg.inv(pose_w_to_p)

            r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3])
            t_p_to_w = pose_p_to_w[:3, 3]
            poses_p_to_w.append(pose_p_to_w)
            poses.append(t_p_to_w)
            quaternions.append(r_p_to_w_as_quat)
    return timestamps, poses, quaternions, poses_p_to_w


def main(rootdir, pairsdir, outdir):
    os.makedirs(outdir, exist_ok=True)

    subdirs = ['Test', 'Training']
    for subdir in subdirs:
        if not osp.isdir(osp.join(rootdir, subdir)):
            continue
        # STEP 1: list all scenes
        outsubdir = osp.join(outdir, subdir)
        os.makedirs(outsubdir, exist_ok=True)
        listfile = osp.join(pairsdir, subdir, 'scene_list.json')
        with open(listfile, 'r') as f:
            scene_dirs = json.load(f)

        valid_scenes = []
        for scene_subdir in scene_dirs:
            out_scene_subdir = osp.join(outsubdir, scene_subdir)
            os.makedirs(out_scene_subdir, exist_ok=True)

            scene_dir = osp.join(rootdir, subdir, scene_subdir)
            depth_dir = osp.join(scene_dir, 'lowres_depth')
            rgb_dir = osp.join(scene_dir, 'vga_wide')
            intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics')
            traj_path = osp.join(scene_dir, 'lowres_wide.traj')

            # STEP 2: read selected_pairs.npz
            selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz')
            selected_npz = np.load(selected_pairs_path)
            selection, pairs = selected_npz['selection'], selected_npz['pairs']
            selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0])
            if len(selection) == 0 or len(pairs) == 0:
                # not a valid scene
                continue
            valid_scenes.append(scene_subdir)

            # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images
            scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz')
            if osp.isfile(scene_metadata_path):
                continue
            else:
                print(f'parsing {scene_subdir}')
                # loads traj
                timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path)

                poses = np.array(poses)
                quaternions = np.array(quaternions, dtype=np.quaternion)
                quaternions = quaternion.unflip_rotors(quaternions)
                timestamps = np.array(timestamps)

                selected_images = [(basename, basename.split(".png")[0].split("_")[1]) for basename in selection]
                timestamps_selected = [float(frame_id) for _, frame_id in selected_images]

                sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir,
                                                                                               intrinsics_dir,
                                                                                               timestamps,
                                                                                               quaternions,
                                                                                               poses,
                                                                                               poses_cam_to_world,
                                                                                               selected_images,
                                                                                               timestamps_selected)
                assert selected_sky_direction_scene == sky_direction_scene

                os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True)
                os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True)
                assert isinstance(sky_direction_scene, str)
                for basename in images:
                    img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg'))
                    depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename)
                    if osp.isfile(img_out) and osp.isfile(depth_out):
                        continue

                    vga_wide_path = osp.join(rgb_dir, basename)
                    depth_path = osp.join(depth_dir, basename)

                    img = Image.open(vga_wide_path)
                    depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)

                    # rotate the image
                    if sky_direction_scene == 'RIGHT':
                        try:
                            img = img.transpose(Image.Transpose.ROTATE_90)
                        except Exception:
                            img = img.transpose(Image.ROTATE_90)
                        depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE)
                    elif sky_direction_scene == 'LEFT':
                        try:
                            img = img.transpose(Image.Transpose.ROTATE_270)
                        except Exception:
                            img = img.transpose(Image.ROTATE_270)
                        depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE)
                    elif sky_direction_scene == 'DOWN':
                        try:
                            img = img.transpose(Image.Transpose.ROTATE_180)
                        except Exception:
                            img = img.transpose(Image.ROTATE_180)
                        depth = cv2.rotate(depth, cv2.ROTATE_180)

                    W, H = img.size
                    if not osp.isfile(img_out):
                        img.save(img_out)

                    depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT)
                    if not osp.isfile(depth_out):  # avoid destroying the base dataset when you mess up the paths
                        cv2.imwrite(depth_out, depth)

                # save at the end
                np.savez(scene_metadata_path,
                         trajectories=trajectories,
                         intrinsics=intrinsics,
                         images=images,
                         pairs=pairs)

        outlistfile = osp.join(outsubdir, 'scene_list.json')
        with open(outlistfile, 'w') as f:
            json.dump(valid_scenes, f)

        # STEP 5: concat all scene_metadata.npz into a single file
        scene_data = {}
        for scene_subdir in valid_scenes:
            scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz')
            with np.load(scene_metadata_path) as data:
                trajectories = data['trajectories']
                intrinsics = data['intrinsics']
                images = data['images']
                pairs = data['pairs']
            scene_data[scene_subdir] = {'trajectories': trajectories,
                                        'intrinsics': intrinsics,
                                        'images': images,
                                        'pairs': pairs}
        offset = 0
        counts = []
        scenes = []
        sceneids = []
        images = []
        intrinsics = []
        trajectories = []
        pairs = []
        for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()):
            num_imgs = data['images'].shape[0]
            img_pairs = data['pairs']

            scenes.append(scene_subdir)
            sceneids.extend([scene_idx] * num_imgs)

            images.append(data['images'])

            K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0)
            K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']]
            K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']]
            K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']]
            K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']]

            intrinsics.append(K)
            trajectories.append(data['trajectories'])

            # offset pairs
            img_pairs[:, 0:2] += offset
            pairs.append(img_pairs)
            counts.append(offset)

            offset += num_imgs

        images = np.concatenate(images, axis=0)
        intrinsics = np.concatenate(intrinsics, axis=0)
        trajectories = np.concatenate(trajectories, axis=0)
        pairs = np.concatenate(pairs, axis=0)
        np.savez(osp.join(outsubdir, 'all_metadata.npz'),
                 counts=counts,
                 scenes=scenes,
                 sceneids=sceneids,
                 images=images,
                 intrinsics=intrinsics,
                 trajectories=trajectories,
                 pairs=pairs)


def convert_scene_metadata(scene_subdir, intrinsics_dir,
                           timestamps, quaternions, poses, poses_cam_to_world,
                           selected_images, timestamps_selected):
    # find scene orientation
    sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world)

    # find/compute pose for selected timestamps
    # most images have a valid timestamp / exact pose associated
    timestamps_selected = np.array(timestamps_selected)
    spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0)
    interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected)
    interpolated_positions = spline(timestamps_selected)

    trajectories = []
    intrinsics = []
    images = []
    for i, (basename, frame_id) in enumerate(selected_images):
        intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{frame_id}.pincam")
        if not osp.exists(intrinsic_fn):
            intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam")
        if not osp.exists(intrinsic_fn):
            intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam")
        assert osp.exists(intrinsic_fn)
        w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn)  # PINHOLE

        pose = np.eye(4)
        pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i])
        pose[:3, 3] = interpolated_positions[i]

        images.append(basename)
        if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT':
            intrinsics.append([h, w, fy, fx, hh, hw])  # swapped intrinsics
        else:
            intrinsics.append([w, h, fx, fy, hw, hh])
        trajectories.append(pose  @ rotated_to_cam)  # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world

    return sky_direction_scene, trajectories, intrinsics, images


def find_scene_orientation(poses_cam_to_world):
    if len(poses_cam_to_world) > 0:
        up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world)
        right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world)
        up_world = np.array([[0.0], [0.0], [1.0], [0.0]])
    else:
        up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]])
        right_vector = np.array([[1.0], [0.0], [0.0], [0.0]])
        up_world = np.array([[0.0], [0.0], [1.0], [0.0]])

    # value between 0, 180
    device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world),
                                                           up_vector), -1.0, 1.0)).item() * 180.0 / np.pi
    device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world),
                                                              right_vector), -1.0, 1.0)).item() * 180.0 / np.pi

    up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0)
    if up_closest_to_90:
        assert abs(device_up_to_world_up_angle - 90.0) < 45.0
        # LEFT
        if device_right_to_world_up_angle > 90.0:
            sky_direction_scene = 'LEFT'
            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0])
        else:
            # note that in metadata.csv RIGHT does not exist, but again it's not accurate...
            # well, turns out there are scenes oriented like this
            # for example Training/41124801
            sky_direction_scene = 'RIGHT'
            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0])
    else:
        # right is close to 90
        assert abs(device_right_to_world_up_angle - 90.0) < 45.0
        if device_up_to_world_up_angle > 90.0:
            sky_direction_scene = 'DOWN'
            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi])
        else:
            sky_direction_scene = 'UP'
            cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0)
    cam_to_rotated = np.eye(4)
    cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q)
    rotated_to_cam = np.linalg.inv(cam_to_rotated)
    return sky_direction_scene, rotated_to_cam


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir)


================================================
FILE: datasets_preprocess/preprocess_blendedMVS.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Preprocessing code for the BlendedMVS dataset
# dataset at https://github.com/YoYo000/BlendedMVS
# 1) Download BlendedMVS.zip
# 2) Download BlendedMVS+.zip
# 3) Download BlendedMVS++.zip
# 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory
# 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/
# --------------------------------------------------------
import os
import os.path as osp
import re
from tqdm import tqdm
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2

import path_to_root  # noqa
from dust3r.utils.parallel import parallel_threads
from dust3r.datasets.utils import cropping  # noqa


def get_parser():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--blendedmvs_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/blendedmvs_processed')
    return parser


def main(db_root, pairs_path, output_dir):
    print('>> Listing all sequences')
    sequences = [f for f in os.listdir(db_root) if len(f) == 24]
    # should find 502 scenes
    assert sequences, f'did not found any sequences at {db_root}'
    print(f'   (found {len(sequences)} sequences)')

    for i, seq in enumerate(tqdm(sequences)):
        out_dir = osp.join(output_dir, seq)
        os.makedirs(out_dir, exist_ok=True)

        # generate the crops
        root = osp.join(db_root, seq)
        cam_dir = osp.join(root, 'cams')
        func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')]
        parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False)

    # verify that all pairs are there
    pairs = np.load(pairs_path)
    for seqh, seql, img1, img2, score in tqdm(pairs):
        for view_index in [img1, img2]:
            impath = osp.join(output_dir, f"{seqh:08x}{seql:016x}", f"{view_index:08n}.jpg")
            assert osp.isfile(impath), f'missing image at {impath=}'

    print(f'>> Done, saved everything in {output_dir}/')


def load_crop_and_save(root, img, out_dir):
    if osp.isfile(osp.join(out_dir, img + '.npz')):
        return  # already done

    # load everything
    intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt'))
    color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img +
                                  '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
    depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm'))

    # do the crop
    H, W = color_image_in.shape[:2]
    assert H * 4 == W * 3
    image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384))

    # write everything
    image.save(osp.join(out_dir, img + '.jpg'), quality=80)
    cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap)

    # New camera parameters
    R_camout2world = R_camin2world @ R_in2out.T
    t_camout2world = t_camin2world
    np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out,
             R_cam2world=R_camout2world, t_cam2world=t_camout2world)


def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)):
    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(
        color_image_in, depthmap_in, intrinsics_in, resolution_out)
    R_in2out = np.eye(3)
    return image, depthmap, intrinsics_out, R_in2out


def _load_pose(path, ret_44=False):
    f = open(path)
    RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32)
    assert RT.shape == (4, 4)
    RT = np.linalg.inv(RT)  # world2cam to cam2world

    K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32)
    assert K.shape == (3, 3)

    if ret_44:
        return K, RT
    return K, RT[:3, :3], RT[:3, 3]  # , depth_uint8_to_f32


def load_pfm_file(file_path):
    with open(file_path, 'rb') as file:
        header = file.readline().decode('UTF-8').strip()

        if header == 'PF':
            is_color = True
        elif header == 'Pf':
            is_color = False
        else:
            raise ValueError('The provided file is not a valid PFM file.')

        dimensions = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8'))
        if dimensions:
            img_width, img_height = map(int, dimensions.groups())
        else:
            raise ValueError('Invalid PFM header format.')

        endian_scale = float(file.readline().decode('UTF-8').strip())
        if endian_scale < 0:
            dtype = '<f'  # little-endian
        else:
            dtype = '>f'  # big-endian

        data_buffer = file.read()
        img_data = np.frombuffer(data_buffer, dtype=dtype)

        if is_color:
            img_data = np.reshape(img_data, (img_height, img_width, 3))
        else:
            img_data = np.reshape(img_data, (img_height, img_width))

        img_data = cv2.flip(img_data, 0)

    return img_data


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args.blendedmvs_dir, args.precomputed_pairs, args.output_dir)


================================================
FILE: datasets_preprocess/preprocess_co3d.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Script to pre-process the CO3D dataset.
# Usage:
# python3 datasets_preprocess/preprocess_co3d.py --co3d_dir /path/to/co3d
# --------------------------------------------------------

import argparse
import random
import gzip
import json
import os
import os.path as osp

import torch
import PIL.Image
import numpy as np
import cv2

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import path_to_root  # noqa
import dust3r.datasets.utils.cropping as cropping  # noqa


CATEGORIES = [
    "apple", "backpack", "ball", "banana", "baseballbat", "baseballglove",
    "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot",
    "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag",
    "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave",
    "motorcycle",
    "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich",
    "skateboard", "stopsign",
    "suitcase", "teddybear", "toaster", "toilet", "toybus",
    "toyplane", "toytrain", "toytruck", "tv",
    "umbrella", "vase", "wineglass",
]
CATEGORIES_IDX = {cat: i for i, cat in enumerate(CATEGORIES)}  # for seeding

SINGLE_SEQUENCE_CATEGORIES = sorted(set(CATEGORIES) - set(["microwave", "stopsign", "tv"]))


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--category", type=str, default=None)
    parser.add_argument('--single_sequence_subset', default=False, action='store_true',
                        help="prepare the single_sequence_subset instead.")
    parser.add_argument("--output_dir", type=str, default="data/co3d_processed")
    parser.add_argument("--co3d_dir", type=str, required=True)
    parser.add_argument("--num_sequences_per_object", type=int, default=50)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--min_quality", type=float, default=0.5, help="Minimum viewpoint quality score.")

    parser.add_argument("--img_size", type=int, default=512,
                        help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size"))
    return parser


def convert_ndc_to_pinhole(focal_length, principal_point, image_size):
    focal_length = np.array(focal_length)
    principal_point = np.array(principal_point)
    image_size_wh = np.array([image_size[1], image_size[0]])
    half_image_size = image_size_wh / 2
    rescale = half_image_size.min()
    principal_point_px = half_image_size - principal_point * rescale
    focal_length_px = focal_length * rescale
    fx, fy = focal_length_px[0], focal_length_px[1]
    cx, cy = principal_point_px[0], principal_point_px[1]
    K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)
    return K


def opencv_from_cameras_projection(R, T, focal, p0, image_size):
    R = torch.from_numpy(R)[None, :, :]
    T = torch.from_numpy(T)[None, :]
    focal = torch.from_numpy(focal)[None, :]
    p0 = torch.from_numpy(p0)[None, :]
    image_size = torch.from_numpy(image_size)[None, :]

    R_pytorch3d = R.clone()
    T_pytorch3d = T.clone()
    focal_pytorch3d = focal
    p0_pytorch3d = p0
    T_pytorch3d[:, :2] *= -1
    R_pytorch3d[:, :, :2] *= -1
    tvec = T_pytorch3d
    R = R_pytorch3d.permute(0, 2, 1)

    # Retype the image_size correctly and flip to width, height.
    image_size_wh = image_size.to(R).flip(dims=(1,))

    # NDC to screen conversion.
    scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
    scale = scale.expand(-1, 2)
    c0 = image_size_wh / 2.0

    principal_point = -p0_pytorch3d * scale + c0
    focal_length = focal_pytorch3d * scale

    camera_matrix = torch.zeros_like(R)
    camera_matrix[:, :2, 2] = principal_point
    camera_matrix[:, 2, 2] = 1.0
    camera_matrix[:, 0, 0] = focal_length[:, 0]
    camera_matrix[:, 1, 1] = focal_length[:, 1]
    return R[0], tvec[0], camera_matrix[0]


def get_set_list(category_dir, split, is_single_sequence_subset=False):
    listfiles = os.listdir(osp.join(category_dir, "set_lists"))
    if is_single_sequence_subset:
        # not all objects have manyview_dev
        subset_list_files = [f for f in listfiles if "manyview_dev" in f]
    else:
        subset_list_files = [f for f in listfiles if f"fewview_train" in f]

    sequences_all = []
    for subset_list_file in subset_list_files:
        with open(osp.join(category_dir, "set_lists", subset_list_file)) as f:
            subset_lists_data = json.load(f)
            sequences_all.extend(subset_lists_data[split])

    return sequences_all


def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object,
                      seed, is_single_sequence_subset=False):
    random.seed(seed)
    category_dir = osp.join(co3d_dir, category)
    category_output_dir = osp.join(output_dir, category)
    sequences_all = get_set_list(category_dir, split, is_single_sequence_subset)
    sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all))

    frame_file = osp.join(category_dir, "frame_annotations.jgz")
    sequence_file = osp.join(category_dir, "sequence_annotations.jgz")

    with gzip.open(frame_file, "r") as fin:
        frame_data = json.loads(fin.read())
    with gzip.open(sequence_file, "r") as fin:
        sequence_data = json.loads(fin.read())

    frame_data_processed = {}
    for f_data in frame_data:
        sequence_name = f_data["sequence_name"]
        frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data

    good_quality_sequences = set()
    for seq_data in sequence_data:
        if seq_data["viewpoint_quality_score"] > min_quality:
            good_quality_sequences.add(seq_data["sequence_name"])

    sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences]
    if len(sequences_numbers) < max_num_sequences_per_object:
        selected_sequences_numbers = sequences_numbers
    else:
        selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object)

    selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers}
    sequences_all = [(seq_name, frame_number, filepath)
                     for seq_name, frame_number, filepath in sequences_all
                     if seq_name in selected_sequences_numbers_dict]

    for seq_name, frame_number, filepath in tqdm(sequences_all):
        frame_idx = int(filepath.split('/')[-1][5:-4])
        selected_sequences_numbers_dict[seq_name].append(frame_idx)
        mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
        frame_data = frame_data_processed[seq_name][frame_number]
        focal_length = frame_data["viewpoint"]["focal_length"]
        principal_point = frame_data["viewpoint"]["principal_point"]
        image_size = frame_data["image"]["size"]
        K = convert_ndc_to_pinhole(focal_length, principal_point, image_size)
        R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]),
                                                                    np.array(frame_data["viewpoint"]["T"]),
                                                                    np.array(focal_length),
                                                                    np.array(principal_point),
                                                                    np.array(image_size))

        frame_data = frame_data_processed[seq_name][frame_number]
        depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"])
        assert frame_data["depth"]["scale_adjustment"] == 1.0
        image_path = os.path.join(co3d_dir, filepath)
        mask_path_full = os.path.join(co3d_dir, mask_path)

        input_rgb_image = PIL.Image.open(image_path).convert('RGB')
        input_mask = plt.imread(mask_path_full)

        with PIL.Image.open(depth_path) as depth_pil:
            # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
            # we cast it to uint16, then reinterpret as float16, then cast to float32
            input_depthmap = (
                np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
                .astype(np.float32)
                .reshape((depth_pil.size[1], depth_pil.size[0])))
        depth_mask = np.stack((input_depthmap, input_mask), axis=-1)
        H, W = input_depthmap.shape

        camera_intrinsics = camera_intrinsics.numpy()
        cx, cy = camera_intrinsics[:2, 2].round().astype(int)
        min_margin_x = min(cx, W - cx)
        min_margin_y = min(cy, H - cy)

        # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
        l, t = cx - min_margin_x, cy - min_margin_y
        r, b = cx + min_margin_x, cy + min_margin_y
        crop_bbox = (l, t, r, b)
        input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(
            input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)

        # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384
        scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8
        output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
        if max(output_resolution) < img_size:
            # let's put the max dimension to img_size
            scale_final = (img_size / max(H, W)) + 1e-8
            output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)

        input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(
            input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)
        input_depthmap = depth_mask[:, :, 0]
        input_mask = depth_mask[:, :, 1]

        # generate and adjust camera pose
        camera_pose = np.eye(4, dtype=np.float32)
        camera_pose[:3, :3] = R
        camera_pose[:3, 3] = tvec
        camera_pose = np.linalg.inv(camera_pose)

        # save crop images and depth, metadata
        save_img_path = os.path.join(output_dir, filepath)
        save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"])
        save_mask_path = os.path.join(output_dir, mask_path)
        os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)
        os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)
        os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)

        input_rgb_image.save(save_img_path)
        scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16)
        cv2.imwrite(save_depth_path, scaled_depth_map)
        cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))

        save_meta_path = save_img_path.replace('jpg', 'npz')
        np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,
                 camera_pose=camera_pose, maximum_depth=np.max(input_depthmap))

    return selected_sequences_numbers_dict


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    assert args.co3d_dir != args.output_dir
    if args.category is None:
        if args.single_sequence_subset:
            categories = SINGLE_SEQUENCE_CATEGORIES
        else:
            categories = CATEGORIES
    else:
        categories = [args.category]
    os.makedirs(args.output_dir, exist_ok=True)

    for split in ['train', 'test']:
        selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')
        if os.path.isfile(selected_sequences_path):
            continue

        all_selected_sequences = {}
        for category in categories:
            category_output_dir = osp.join(args.output_dir, category)
            os.makedirs(category_output_dir, exist_ok=True)
            category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')
            if os.path.isfile(category_selected_sequences_path):
                with open(category_selected_sequences_path, 'r') as fid:
                    category_selected_sequences = json.load(fid)
            else:
                print(f"Processing {split} - category = {category}")
                category_selected_sequences = prepare_sequences(
                    category=category,
                    co3d_dir=args.co3d_dir,
                    output_dir=args.output_dir,
                    img_size=args.img_size,
                    split=split,
                    min_quality=args.min_quality,
                    max_num_sequences_per_object=args.num_sequences_per_object,
                    seed=args.seed + CATEGORIES_IDX[category],
                    is_single_sequence_subset=args.single_sequence_subset
                )
                with open(category_selected_sequences_path, 'w') as file:
                    json.dump(category_selected_sequences, file)

            all_selected_sequences[category] = category_selected_sequences
        with open(selected_sequences_path, 'w') as file:
            json.dump(all_selected_sequences, file)


================================================
FILE: datasets_preprocess/preprocess_megadepth.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Preprocessing code for the MegaDepth dataset
# dataset at https://www.cs.cornell.edu/projects/megadepth/
# --------------------------------------------------------
import os
import os.path as osp
import collections
from tqdm import tqdm
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2
import h5py

import path_to_root  # noqa
from dust3r.utils.parallel import parallel_threads
from dust3r.datasets.utils import cropping  # noqa


def get_parser():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--megadepth_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/megadepth_processed')
    return parser


def main(db_root, pairs_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # load all pairs
    data = np.load(pairs_path, allow_pickle=True)
    scenes = data['scenes']
    images = data['images']
    pairs = data['pairs']

    # enumerate all unique images
    todo = collections.defaultdict(set)
    for scene, im1, im2, score in pairs:
        todo[scene].add(im1)
        todo[scene].add(im2)

    # for each scene, load intrinsics and then parallel crops
    for scene, im_idxs in tqdm(todo.items(), desc='Overall'):
        scene, subscene = scenes[scene].split()
        out_dir = osp.join(output_dir, scene, subscene)
        os.makedirs(out_dir, exist_ok=True)

        # load all camera params
        _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True)

        in_dir = osp.join(db_root, scene, 'dense' + subscene)
        args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir)
                for img in [images[im_id] for im_id in im_idxs]]
        parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}')

    # save pairs
    print('Done! prepared all pairs in', output_dir)


def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir):
    if osp.isfile(osp.join(out_dir, tag + '.npz')):
        return

    # load image
    img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
    H, W = img.shape[:2]

    # load depth
    with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5:
        depthmap = np.asarray(hd5['depth'])

    # rectify = undistort the intrinsics
    imsize_pre, K_pre, distortion = K_pre_rectif
    imsize_post = img.shape[1::-1]
    K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0,
                                           newImgSize=imsize_post, centerPrincipalPoint=True)[0]

    # downscale
    img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600))

    # write everything
    img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90)
    cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out)

    camout2world = np.linalg.inv(pose_w2cam)
    camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T
    np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world)


def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)):
    H, W = image.shape[:2]
    resolution_out = sorted(resolution_out)[::+1 if W < H else -1]

    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(
        image, depthmap, camera_intrinsics, resolution_out, force=False)
    R_in2out = np.eye(3)

    return image, depthmap, intrinsics_out, R_in2out


def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False):
    if intrinsics:
        with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f:
            raw = f.readlines()[3:]  # skip the header

        camera_intrinsics = {}
        for camera in raw:
            camera = camera.split(' ')
            width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]]
            K = np.eye(3)
            K[0, 0] = focal
            K[1, 1] = focal
            K[0, 2] = cx
            K[1, 2] = cy
            camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0))

    with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f:
        raw = f.read().splitlines()[4:]  # skip the header

    extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT

    poses = {}
    points3D_idxs = {}
    camera = []

    for image, points in zip(raw[:: 2], raw[1:: 2]):
        image = image.split(' ')
        points = points.split(' ')

        image_id = image[-1]
        camera.append(int(image[-2]))

        # find the principal axis
        raw_pose = [float(elem) for elem in image[1: -2]]
        poses[image_id] = extract_pose(raw_pose)

        current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'}
        assert -1 not in current_points3D_idxs, bb()
        points3D_idxs[image_id] = current_points3D_idxs

    if intrinsics:
        image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)}
        return points3D_idxs, poses, image_intrinsics
    else:
        return points3D_idxs, poses


def colmap_raw_pose_to_principal_axis(image_pose):
    qvec = image_pose[: 4]
    qvec = qvec / np.linalg.norm(qvec)
    w, x, y, z = qvec
    z_axis = np.float32([
        2 * x * z - 2 * y * w,
        2 * y * z + 2 * x * w,
        1 - 2 * x * x - 2 * y * y
    ])
    return z_axis


def colmap_raw_pose_to_RT(image_pose):
    qvec = image_pose[: 4]
    qvec = qvec / np.linalg.norm(qvec)
    w, x, y, z = qvec
    R = np.array([
        [
            1 - 2 * y * y - 2 * z * z,
            2 * x * y - 2 * z * w,
            2 * x * z + 2 * y * w
        ],
        [
            2 * x * y + 2 * z * w,
            1 - 2 * x * x - 2 * z * z,
            2 * y * z - 2 * x * w
        ],
        [
            2 * x * z - 2 * y * w,
            2 * y * z + 2 * x * w,
            1 - 2 * x * x - 2 * y * y
        ]
    ])
    # principal_axis.append(R[2, :])
    t = image_pose[4: 7]
    # World-to-Camera pose
    current_pose = np.eye(4)
    current_pose[: 3, : 3] = R
    current_pose[: 3, 3] = t
    return current_pose


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args.megadepth_dir, args.precomputed_pairs, args.output_dir)


================================================
FILE: datasets_preprocess/preprocess_scannetpp.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Script to pre-process the scannet++ dataset.
# Usage:
# python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl
# --------------------------------------------------------
import os
import argparse
import os.path as osp
import re
from tqdm import tqdm
import json
from scipy.spatial.transform import Rotation
import pyrender
import trimesh
import trimesh.exchange.ply
import numpy as np
import cv2
import PIL.Image as Image

from dust3r.datasets.utils.cropping import rescale_image_depthmap
import dust3r.utils.geometry as geometry

inv = np.linalg.inv
norm = np.linalg.norm
REGEXPR_DSLR = re.compile(r'^.*DSC(?P<frameid>\d+).JPG$')
REGEXPR_IPHONE = re.compile(r'.*frame_(?P<frameid>\d+).jpg$')

DEBUG_VIZ = None  # 'iou'
if DEBUG_VIZ is not None:
    import matplotlib.pyplot as plt  # noqa


OPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0],
                               [0, -1, 0, 0],
                               [0, 0, -1, 0],
                               [0, 0, 0, 1]])


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--scannetpp_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/scannetpp_processed')
    parser.add_argument('--target_resolution', default=920, type=int, help="images resolution")
    parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable')
    return parser


def pose_from_qwxyz_txyz(elems):
    qw, qx, qy, qz, tx, ty, tz = map(float, elems)
    pose = np.eye(4)
    pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix()
    pose[:3, 3] = (tx, ty, tz)
    return np.linalg.inv(pose)  # returns cam2world


def get_frame_number(name, cam_type='dslr'):
    if cam_type == 'dslr':
        regex_expr = REGEXPR_DSLR
    elif cam_type == 'iphone':
        regex_expr = REGEXPR_IPHONE
    else:
        raise NotImplementedError(f'wrong {cam_type=} for get_frame_number')
    try:
        matches = re.match(regex_expr, name)
        return matches['frameid']
    except Exception as e:
        print(f'Error when parsing {name}')
        raise ValueError(f'Invalid name {name}')


def load_sfm(sfm_dir, cam_type='dslr'):
    # load cameras
    with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f:
        raw = f.read().splitlines()[3:]  # skip header

    intrinsics = {}
    for camera in tqdm(raw, position=1, leave=False):
        camera = camera.split(' ')
        intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]]

    # load images
    with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f:
        raw = f.read().splitlines()
        raw = [line for line in raw if not line.startswith('#')]  # skip header

    img_idx = {}
    img_infos = {}
    for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False):
        image = image.split(' ')
        points = points.split(' ')

        idx = image[0]
        img_name = image[-1]
        prefixes = ['iphone/', 'video/']
        for prefix in prefixes:
            if img_name.startswith(prefix):
                img_name = img_name[len(prefix):]
        assert img_name not in img_idx, 'duplicate db image: ' + img_name
        img_idx[img_name] = idx  # register image name

        current_points2D = {int(i): (float(x), float(y))
                            for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'}
        img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])],
                              path=img_name,
                              frame_id=get_frame_number(img_name, cam_type),
                              cam_to_world=pose_from_qwxyz_txyz(image[1: -2]),
                              sparse_pts2d=current_points2D)

    # load 3D points
    with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f:
        raw = f.read().splitlines()
        raw = [line for line in raw if not line.startswith('#')]  # skip header

    points3D = {}
    observations = {idx: [] for idx in img_infos.keys()}
    for point in tqdm(raw, position=1, leave=False):
        point = point.split()
        point_3d_idx = int(point[0])
        points3D[point_3d_idx] = tuple(map(float, point[1:4]))
        if len(point) > 8:
            for idx, point_2d_idx in zip(point[8::2], point[9::2]):
                if idx not in observations:
                    continue
                observations[idx].append((point_3d_idx, int(point_2d_idx)))

    return img_idx, img_infos, points3D, observations


def subsample_img_infos(img_infos, num_images, allowed_name_subset=None):
    img_infos_val = [(idx, val) for idx, val in img_infos.items()]
    if allowed_name_subset is not None:
        img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset]

    if len(img_infos_val) > num_images:
        img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id'])
        kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist()
        img_infos_val = [img_infos_val[idx] for idx in kept_idx]
    return {idx: val for idx, val in img_infos_val}


def undistort_images(intrinsics, rgb, mask):
    camera_type = intrinsics[0]

    width = int(intrinsics[1])
    height = int(intrinsics[2])
    fx = intrinsics[3]
    fy = intrinsics[4]
    cx = intrinsics[5]
    cy = intrinsics[6]
    distortion = np.array(intrinsics[7:])

    K = np.zeros([3, 3])
    K[0, 0] = fx
    K[0, 2] = cx
    K[1, 1] = fy
    K[1, 2] = cy
    K[2, 2] = 1

    K = geometry.colmap_to_opencv_intrinsics(K)
    if camera_type == "OPENCV_FISHEYE":
        assert len(distortion) == 4

        new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
            K,
            distortion,
            (width, height),
            np.eye(3),
            balance=0.0,
        )
        # Make the cx and cy to be the center of the image
        new_K[0, 2] = width / 2.0
        new_K[1, 2] = height / 2.0

        map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1)
    else:
        new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True)
        map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1)

    undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
    undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT, borderValue=255)
    new_K = geometry.opencv_to_colmap_intrinsics(new_K)
    return width, height, new_K, undistorted_image, undistorted_mask


def process_scenes(root, pairsdir, output_dir, target_resolution):
    os.makedirs(output_dir, exist_ok=True)

    # default values from
    # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml
    znear = 0.05
    zfar = 20.0

    listfile = osp.join(pairsdir, 'scene_list.json')
    with open(listfile, 'r') as f:
        scenes = json.load(f)

    # for each of these, we will select some dslr images and some iphone images
    # we will undistort them and render their depth
    renderer = pyrender.OffscreenRenderer(0, 0)
    for scene in tqdm(scenes, position=0, leave=True):
        data_dir = os.path.join(root, 'data', scene)
        dir_dslr = os.path.join(data_dir, 'dslr')
        dir_iphone = os.path.join(data_dir, 'iphone')
        dir_scans = os.path.join(data_dir, 'scans')

        assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \
            and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans)

        output_dir_scene = os.path.join(output_dir, scene)
        scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz')
        if osp.isfile(scene_metadata_path):
            continue

        pairs_dir_scene = os.path.join(pairsdir, scene)
        pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz')
        assert osp.isfile(pairs_dir_scene_selected_pairs)
        selected_npz = np.load(pairs_dir_scene_selected_pairs)
        selection, pairs = selected_npz['selection'], selected_npz['pairs']

        # set up the output paths
        output_dir_scene_rgb = os.path.join(output_dir_scene, 'images')
        output_dir_scene_depth = os.path.join(output_dir_scene, 'depth')
        os.makedirs(output_dir_scene_rgb, exist_ok=True)
        os.makedirs(output_dir_scene_depth, exist_ok=True)

        ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply')

        sfm_dir_dslr = os.path.join(dir_dslr, 'colmap')
        rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images')
        mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks')

        sfm_dir_iphone = os.path.join(dir_iphone, 'colmap')
        rgb_dir_iphone = os.path.join(dir_iphone, 'rgb')
        mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks')

        # load the mesh
        with open(ply_path, 'rb') as f:
            mesh_kwargs = trimesh.exchange.ply.load_ply(f)
        mesh_scene = trimesh.Trimesh(**mesh_kwargs)

        # read colmap reconstruction, we will only use the intrinsics and pose here
        img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr')
        dslr_paths = {
            "in_colmap": sfm_dir_dslr,
            "in_rgb": rgb_dir_dslr,
            "in_mask": mask_dir_dslr,
        }

        img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm(
            sfm_dir_iphone, cam_type='iphone')
        iphone_paths = {
            "in_colmap": sfm_dir_iphone,
            "in_rgb": rgb_dir_iphone,
            "in_mask": mask_dir_iphone,
        }

        mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False)
        pyrender_scene = pyrender.Scene()
        pyrender_scene.add(mesh)

        selection_iphone = [imgname + '.jpg' for imgname in selection if 'frame_' in imgname]
        selection_dslr = [imgname + '.JPG' for imgname in selection if not 'frame_' in imgname]

        # resize the image to a more manageable size and render depth
        for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths),
                                                              (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]:
            rgb_dir = paths_data['in_rgb']
            mask_dir = paths_data['in_mask']
            for imgname in tqdm(selection_cam, position=1, leave=False):
                imgidx = img_idx[imgname]
                img_infos_idx = img_infos[imgidx]
                rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path'])))
                mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png')))

                _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask)

                # rescale_image_depthmap assumes opencv intrinsics
                intrinsics = geometry.colmap_to_opencv_intrinsics(K)
                image, mask, intrinsics = rescale_image_depthmap(
                    rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4))

                W, H = image.size
                intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics)

                # update inpace img_infos_idx
                img_infos_idx['intrinsics'] = intrinsics
                rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg')
                image.save(rgb_outpath)

                depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png')
                # render depth image
                renderer.viewport_width, renderer.viewport_height = W, H
                fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]
                camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar)
                camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV)

                _, depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.SKIP_CULL_FACES)
                pyrender_scene.remove_node(camera_node)  # dont forget to remove camera

                depth = (depth * 1000).astype('uint16')
                # invalidate depth from mask before saving
                depth_mask = (mask < 255)
                depth[depth_mask] = 0
                Image.fromarray(depth).save(depth_outpath)

        trajectories = []
        intrinsics = []
        for imgname in selection:
            if 'frame_' in imgname:
                imgidx = img_idx_iphone[imgname + '.jpg']
                img_infos_idx = img_infos_iphone[imgidx]
            elif 'DSC' in imgname:
                imgidx = img_idx_dslr[imgname + '.JPG']
                img_infos_idx = img_infos_dslr[imgidx]
            else:
                raise ValueError(f'invalid image name {imgname}')

            intrinsics.append(img_infos_idx['intrinsics'])
            trajectories.append(img_infos_idx['cam_to_world'])

        intrinsics = np.stack(intrinsics, axis=0)
        trajectories = np.stack(trajectories, axis=0)
        # save metadata for this scene
        np.savez(scene_metadata_path,
                 trajectories=trajectories,
                 intrinsics=intrinsics,
                 images=selection,
                 pairs=pairs)

        del img_infos
        del pyrender_scene

    # concat all scene_metadata.npz into a single file
    scene_data = {}
    for scene_subdir in scenes:
        scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz')
        with np.load(scene_metadata_path) as data:
            trajectories = data['trajectories']
            intrinsics = data['intrinsics']
            images = data['images']
            pairs = data['pairs']
        scene_data[scene_subdir] = {'trajectories': trajectories,
                                    'intrinsics': intrinsics,
                                    'images': images,
                                    'pairs': pairs}

    offset = 0
    counts = []
    scenes = []
    sceneids = []
    images = []
    intrinsics = []
    trajectories = []
    pairs = []
    for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()):
        num_imgs = data['images'].shape[0]
        img_pairs = data['pairs']

        scenes.append(scene_subdir)
        sceneids.extend([scene_idx] * num_imgs)

        images.append(data['images'])

        intrinsics.append(data['intrinsics'])
        trajectories.append(data['trajectories'])

        # offset pairs
        img_pairs[:, 0:2] += offset
        pairs.append(img_pairs)
        counts.append(offset)

        offset += num_imgs

    images = np.concatenate(images, axis=0)
    intrinsics = np.concatenate(intrinsics, axis=0)
    trajectories = np.concatenate(trajectories, axis=0)
    pairs = np.concatenate(pairs, axis=0)
    np.savez(osp.join(output_dir, 'all_metadata.npz'),
             counts=counts,
             scenes=scenes,
             sceneids=sceneids,
             images=images,
             intrinsics=intrinsics,
             trajectories=trajectories,
             pairs=pairs)
    print('all done')


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    if args.pyopengl_platform.strip():
        os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform
    process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution)


================================================
FILE: datasets_preprocess/preprocess_staticthings3d.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Preprocessing code for the StaticThings3D dataset
# dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d
# 1) Download StaticThings3D in /path/to/StaticThings3D/
#    with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh
#    --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2
# 2) unzip everything in the same /path/to/StaticThings3D/ directory
# 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/
# --------------------------------------------------------
import os
import os.path as osp
import re
from tqdm import tqdm
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2

import path_to_root  # noqa
from dust3r.utils.parallel import parallel_threads
from dust3r.datasets.utils import cropping  # noqa


def get_parser():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--StaticThings3D_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/staticthings3d_processed')
    return parser


def main(db_root, pairs_path, output_dir):
    all_scenes = _list_all_scenes(db_root)

    # crop images
    args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir)
            for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)]
    parallel_threads(load_crop_and_save, args, star_args=True, front_num=1)

    # verify that all images are there
    CAM = {b'l': 'left', b'r': 'right'}
    pairs = np.load(pairs_path)
    for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs):
        seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}')
        for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]:
            for ext in ['clean', 'final']:
                impath = osp.join(output_dir, seq_path, cam, f"{idx:04n}_{ext}.jpg")
                assert osp.isfile(impath), f'missing an image at {impath=}'

    print(f'>> Saved all data to {output_dir}!')


def load_crop_and_save(db_root, relpath_, camera, num, out_dir):
    relpath = osp.join(relpath_, camera, num)
    if osp.isfile(osp.join(out_dir, relpath + '.npz')):
        return
    os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True)

    # load everything
    intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3'))
    cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3')))
    depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3'))
    img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass',
                             relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
    img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass',
                             relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)

    # do the crop
    assert img_clean.shape[:2] == (540, 960)
    assert img_final.shape[:2] == (540, 960)
    (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image(
        intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384))

    # write everything
    clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80)
    final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80)
    cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap)

    # New camera parameters
    cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T
    np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world)


def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)):
    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(
        color_image_in, depthmap_in, intrinsics_in, resolution_out)
    R_in2out = np.eye(3)
    return image, depthmap, intrinsics_out, R_in2out


def _list_all_scenes(path):
    print('>> Listing all scenes')

    res = []
    for split in ['TRAIN']:
        for subsplit in 'ABC':
            for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)):
                res.append((split, subsplit, seq))
    print(f'   (found ({len(res)}) scenes)')
    assert res, f'Did not find anything at {path=}'
    return res


def readFloat(name):
    with open(name, 'rb') as f:
        if (f.readline().decode("utf-8")) != 'float\n':
            raise Exception('float file %s did not contain <float> keyword' % name)

        dim = int(f.readline())

        dims = []
        count = 1
        for i in range(0, dim):
            d = int(f.readline())
            dims.append(d)
            count *= d

        dims = list(reversed(dims))
        data = np.fromfile(f, np.float32, count).reshape(dims)
    return data  # Hxw or CxHxW NxCxHxW


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir)


================================================
FILE: datasets_preprocess/preprocess_waymo.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Preprocessing code for the WayMo Open dataset
# dataset at https://github.com/waymo-research/waymo-open-dataset
# 1) Accept the license
# 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2
# 3) put all .tfrecord files in '/path/to/waymo_dir'
# 4) install the waymo_open_dataset package with
#    `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4`
# 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir`
# --------------------------------------------------------
import sys
import os
import os.path as osp
import shutil
import json
from tqdm import tqdm
import PIL.Image
import numpy as np
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import cv2

import tensorflow.compat.v1 as tf
tf.enable_eager_execution()

import path_to_root  # noqa
from dust3r.utils.geometry import geotrf, inv
from dust3r.utils.image import imread_cv2
from dust3r.utils.parallel import parallel_processes as parallel_map
from dust3r.datasets.utils import cropping
from dust3r.viz import show_raw_pointcloud


def get_parser():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--waymo_dir', required=True)
    parser.add_argument('--precomputed_pairs', required=True)
    parser.add_argument('--output_dir', default='data/waymo_processed')
    parser.add_argument('--workers', type=int, default=1)
    return parser


def main(waymo_root, pairs_path, output_dir, workers=1):
    extract_frames(waymo_root, output_dir, workers=workers)
    make_crops(output_dir, workers=args.workers)

    # make sure all pairs are there
    with np.load(pairs_path) as data:
        scenes = data['scenes']
        frames = data['frames']
        pairs = data['pairs']  # (array of (scene_id, img1_id, img2_id)

    for scene_id, im1_id, im2_id in pairs:
        for im_id in (im1_id, im2_id):
            path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg')
            assert osp.isfile(path), f'Missing a file at {path=}\nDid you download all .tfrecord files?'

    shutil.rmtree(osp.join(output_dir, 'tmp'))
    print('Done! all data generated at', output_dir)


def _list_sequences(db_root):
    print('>> Looking for sequences in', db_root)
    res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord'))
    print(f'    found {len(res)} sequences')
    return res


def extract_frames(db_root, output_dir, workers=8):
    sequences = _list_sequences(db_root)
    output_dir = osp.join(output_dir, 'tmp')
    print('>> outputing result to', output_dir)
    args = [(db_root, output_dir, seq) for seq in sequences]
    parallel_map(process_one_seq, args, star_args=True, workers=workers)


def process_one_seq(db_root, output_dir, seq):
    out_dir = osp.join(output_dir, seq)
    os.makedirs(out_dir, exist_ok=True)
    calib_path = osp.join(out_dir, 'calib.json')
    if osp.isfile(calib_path):
        return

    try:
        with tf.device('/CPU:0'):
            calib, frames = extract_frames_one_seq(osp.join(db_root, seq))
    except RuntimeError:
        print(f'/!\\ Error with sequence {seq} /!\\', file=sys.stderr)
        return  # nothing is saved

    for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)):
        for cam_idx, view in views.items():
            img = PIL.Image.fromarray(view.pop('img'))
            img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg'))
            np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view)

    with open(calib_path, 'w') as f:
        json.dump(calib, f)


def extract_frames_one_seq(filename):
    from waymo_open_dataset import dataset_pb2 as open_dataset
    from waymo_open_dataset.utils import frame_utils

    print('>> Opening', filename)
    dataset = tf.data.TFRecordDataset(filename, compression_type='')

    calib = None
    frames = []

    for data in tqdm(dataset, leave=False):
        frame = open_dataset.Frame()
        frame.ParseFromString(bytearray(data.numpy()))

        content = frame_utils.parse_range_image_and_camera_projection(frame)
        range_images, camera_projections, _, range_image_top_pose = content

        views = {}
        frames.append((frame.context.name, views))

        # once in a sequence, read camera calibration info
        if calib is None:
            calib = []
            for cam in frame.context.camera_calibrations:
                calib.append((cam.name,
                              dict(width=cam.width,
                                   height=cam.height,
                                   intrinsics=list(cam.intrinsic),
                                   extrinsics=list(cam.extrinsic.transform))))

        # convert LIDAR to pointcloud
        points, cp_points = frame_utils.convert_range_image_to_point_cloud(
            frame,
            range_images,
            camera_projections,
            range_image_top_pose)

        # 3d points in vehicle frame.
        points_all = np.concatenate(points, axis=0)
        cp_points_all = np.concatenate(cp_points, axis=0)

        # The distance between lidar points and vehicle frame origin.
        cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32)

        for i, image in enumerate(frame.images):
            # select relevant 3D points for this view
            mask = tf.equal(cp_points_all_tensor[..., 0], image.name)
            cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32)

            pose = np.asarray(image.pose.transform).reshape(4, 4)
            timestamp = image.pose_timestamp

            rgb = tf.image.decode_jpeg(image.image).numpy()

            pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16)
            pts3d = points_all[mask.numpy()]

            views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp)

        if not 'show full point cloud':
            show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()])

    return calib, frames


def make_crops(output_dir, workers=16, **kw):
    tmp_dir = osp.join(output_dir, 'tmp')
    sequences = _list_sequences(tmp_dir)
    args = [(tmp_dir, output_dir, seq) for seq in sequences]
    parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0)


def crop_one_seq(input_dir, output_dir, seq, resolution=512):
    seq_dir = osp.join(input_dir, seq)
    out_dir = osp.join(output_dir, seq)
    if osp.isfile(osp.join(out_dir, '00100_1.jpg')):
        return
    os.makedirs(out_dir, exist_ok=True)

    # load calibration file
    try:
        with open(osp.join(seq_dir, 'calib.json')) as f:
            calib = json.load(f)
    except IOError:
        print(f'/!\\ Error: Missing calib.json in sequence {seq} /!\\', file=sys.stderr)
        return

    axes_transformation = np.array([
        [0, -1, 0, 0],
        [0, 0, -1, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1]])

    cam_K = {}
    cam_distortion = {}
    cam_res = {}
    cam_to_car = {}
    for cam_idx, cam_info in calib:
        cam_idx = str(cam_idx)
        cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height'])
        f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics']
        cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)])
        cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3])
        cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4)  # cam-to-vehicle

    frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg'))

    # from dust3r.viz import SceneViz
    # viz = SceneViz()

    for frame in tqdm(frames, leave=False):
        cam_idx = frame[-2]  # cam index
        assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}'
        data = np.load(osp.join(seq_dir, frame + 'npz'))
        car_to_world = data['pose']
        W, H = cam_res[cam_idx]

        # load depthmap
        pos2d = data['pixels'].round().astype(np.uint16)
        x, y = pos2d.T
        pts3d = data['pts3d']  # already in the car frame
        pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d)
        # X=LEFT_RIGHT y=ALTITUDE z=DEPTH

        # load image
        image = imread_cv2(osp.join(seq_dir, frame + 'jpg'))

        # downscale image
        output_resolution = (resolution, 1) if W > H else (1, resolution)
        image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution)
        image.save(osp.join(out_dir, frame + 'jpg'), quality=80)

        # save as an EXR file? yes it's smaller (and easier to load)
        W, H = image.size
        depthmap = np.zeros((H, W), dtype=np.float32)
        pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16)
        x, y = pos2d.T
        depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2]
        cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap)

        # save camera parametes
        cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation)
        np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2,
                 cam2world=cam2world, distortion=cam_distortion[cam_idx])

        # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world)
    # viz.show()


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers)


================================================
FILE: datasets_preprocess/preprocess_wildrgbd.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Script to pre-process the WildRGB-D dataset.
# Usage:
# python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd
# --------------------------------------------------------

import argparse
import random
import json
import os
import os.path as osp

import PIL.Image
import numpy as np
import cv2

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import path_to_root  # noqa
import dust3r.datasets.utils.cropping as cropping  # noqa
from dust3r.utils.image import imread_cv2


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="data/wildrgbd_processed")
    parser.add_argument("--wildrgbd_dir", type=str, required=True)
    parser.add_argument("--train_num_sequences_per_object", type=int, default=50)
    parser.add_argument("--test_num_sequences_per_object", type=int, default=10)
    parser.add_argument("--num_frames", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--img_size", type=int, default=512,
                        help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size"))
    return parser


def get_set_list(category_dir, split):
    listfiles = ["camera_eval_list.json", "nvs_list.json"]

    sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']}
    for listfile in listfiles:
        with open(osp.join(category_dir, listfile)) as f:
            subset_lists_data = json.load(f)
            for s in ['train', 'val']:
                sequences_all[s][listfile].update(subset_lists_data[s])
    train_intersection = set.intersection(*list(sequences_all['train'].values()))
    if split == "train":
        return train_intersection
    else:
        all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values()))
        return all_seqs.difference(train_intersection)


def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object,
                      output_num_frames, seed):
    random.seed(seed)
    category_dir = osp.join(wildrgbd_dir, category)
    category_output_dir = osp.join(output_dir, category)
    sequences_all = get_set_list(category_dir, split)
    sequences_all = sorted(sequences_all)

    sequences_all_tmp = []
    for seq_name in sequences_all:
        scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name)
        if not os.path.isdir(scene_dir):
            print(f'{scene_dir} does not exist, skipped')
            continue
        sequences_all_tmp.append(seq_name)
    sequences_all = sequences_all_tmp
    if len(sequences_all) <= max_num_sequences_per_object:
        selected_sequences = sequences_all
    else:
        selected_sequences = random.sample(sequences_all, max_num_sequences_per_object)

    selected_sequences_numbers_dict = {}
    for seq_name in tqdm(selected_sequences, leave=False):
        scene_dir = osp.join(category_dir, seq_name)
        scene_output_dir = osp.join(category_output_dir, seq_name)
        with open(osp.join(scene_dir, 'metadata'), 'r') as f:
            metadata = json.load(f)

        K = np.array(metadata["K"]).reshape(3, 3).T
        fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
        w, h = metadata["w"], metadata["h"]

        camera_intrinsics = np.array(
            [[fx, 0, cx],
             [0, fy, cy],
             [0, 0, 1]]
        )
        camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt')
        camera_to_world_content = np.genfromtxt(camera_to_world_path)
        camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4)

        frame_idx = camera_to_world_content[:, 0]
        num_frames = frame_idx.shape[0]
        assert num_frames >= output_num_frames
        assert np.all(frame_idx == np.arange(num_frames))

        # selected_sequences_numbers_dict[seq_name] = num_frames

        selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist()
        selected_sequences_numbers_dict[seq_name] = selected_frames

        for frame_id in tqdm(selected_frames):
            depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png')
            masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png')
            rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png')

            input_rgb_image = PIL.Image.open(rgb_path).convert('RGB')
            input_mask = plt.imread(masks_path)
            input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64)
            depth_mask = np.stack((input_depthmap, input_mask), axis=-1)
            H, W = input_depthmap.shape

            min_margin_x = min(cx, W - cx)
            min_margin_y = min(cy, H - cy)

            # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
            l, t = int(cx - min_margin_x), int(cy - min_margin_y)
            r, b = int(cx + min_margin_x), int(cy + min_margin_y)
            crop_bbox = (l, t, r, b)
            input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(
                input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)

            # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384
            scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8
            output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)
            if max(output_resolution) < img_size:
                # let's put the max dimension to img_size
                scale_final = (img_size / max(H, W)) + 1e-8
                output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)

            input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(
                input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)
            input_depthmap = depth_mask[:, :, 0]
            input_mask = depth_mask[:, :, 1]

            camera_pose = camera_to_world[frame_id]

            # save crop images and depth, metadata
            save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg')
            save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png')
            save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png')
            os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)
            os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)
            os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)

            input_rgb_image.save(save_img_path)
            cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16))
            cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))

            save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz')
            os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True)
            np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,
                     camera_pose=camera_pose)

    return selected_sequences_numbers_dict


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    assert args.wildrgbd_dir != args.output_dir

    categories = sorted([
        dirname for dirname in os.listdir(args.wildrgbd_dir)
        if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes'))
    ])

    os.makedirs(args.output_dir, exist_ok=True)

    splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object]
    for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object):
        selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')
        if os.path.isfile(selected_sequences_path):
            continue
        all_selected_sequences = {}
        for category in categories:
            category_output_dir = osp.join(args.output_dir, category)
            os.makedirs(category_output_dir, exist_ok=True)
            category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')
            if os.path.isfile(category_selected_sequences_path):
                with open(category_selected_sequences_path, 'r') as fid:
                    category_selected_sequences = json.load(fid)
            else:
                print(f"Processing {split} - category = {category}")
                category_selected_sequences = prepare_sequences(
                    category=category,
                    wildrgbd_dir=args.wildrgbd_dir,
                    output_dir=args.output_dir,
                    img_size=args.img_size,
                    split=split,
                    max_num_sequences_per_object=num_sequences_per_object,
                    output_num_frames=args.num_frames,
                    seed=args.seed + int("category".encode('ascii').hex(), 16),
                )
                with open(category_selected_sequences_path, 'w') as file:
                    json.dump(category_selected_sequences, file)

            all_selected_sequences[category] = category_selected_sequences
        with open(selected_sequences_path, 'w') as file:
            json.dump(all_selected_sequences, file)


================================================
FILE: demo.py
================================================
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# dust3r gradio demo executable
# --------------------------------------------------------
import os
import torch
import tempfile

from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.demo import get_args_parser, main_demo, set_print_with_timestamp

import matplotlib.pyplot as pl
pl.ion()

torch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12

if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    set_print_with_timestamp()

    if args.tmp_dir is not None:
        tmp_path = args.tmp_dir
        os.makedirs(tmp_path, exist_ok=True)
        tempfile.tempdir = tmp_path

    if args.server_name is not None:
        server_name = args.server_name
    else:
        server_name = '0.0.0.0' if args.local_network else '127.0.0.1'

    if args.weights is not None:
        weights_path = args.weights
    else:
        weights_path = "naver/" + args.model_name
    model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device)

    # dust3r will write the 3D model inside tmpdirname
    with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
        if not args.silent:
            print('Outputing stuff in', tmpdirname)
        main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent)


================================================
FILE: docker/docker-compose-cpu.yml
================================================
version: '3.8'
services:
  dust3r-demo:
    build:
      context: ./files
      dockerfile: cpu.Dockerfile 
    ports:
      - "7860:7860"
    volumes:
      - ./files/checkpoints:/dust3r/checkpoints
    environment:
      - DEVICE=cpu
      - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}
    cap_add:
      - IPC_LOCK
      - SYS_RESOURCE


================================================
FILE: docker/docker-compose-cuda.yml
================================================
version: '3.8'
services:
  dust3r-demo:
    build:
      context: ./files
      dockerfile: cuda.Dockerfile 
    ports:
      - "7860:7860"
    environment:
      - DEVICE=cuda
      - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}
    volumes:
      - ./files/checkpoints:/dust3r/checkpoints
    cap_add:
      - IPC_LOCK
      - SYS_RESOURCE
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: 1
              capabilities: [gpu]


================================================
FILE: docker/files/cpu.Dockerfile
================================================
FROM python:3.11-slim

LABEL description="Docker container for DUSt3R with dependencies installed. CPU VERSION"

ENV DEVICE="cpu"
ENV MODEL="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
ARG DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y \
    git \
    libgl1-mesa-glx \
    libegl1-mesa \
    libxrandr2 \
    libxrandr2 \
    libxss1 \
    libxcursor1 \
    libxcomposite1 \
    libasound2 \
    libxi6 \
    libxtst6 \
    libglib2.0-0 \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

RUN git clone --recursive https://github.com/naver/dust3r /dust3r
WORKDIR /dust3r

RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install -r requirements.txt
RUN pip install -r requirements_optional.txt
RUN pip install opencv-python==4.8.0.74

WORKDIR /dust3r

COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

ENTRYPOINT ["/entrypoint.sh"]


================================================
FILE: docker/files/cuda.Dockerfile
================================================
FROM nvcr.io/nvidia/pytorch:24.01-py3

LABEL description="Docker container for DUSt3R with dependencies installed. CUDA VERSION"
ENV DEVICE="cuda"
ENV MODEL="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
ARG DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y \
    git=1:2.34.1-1ubuntu1.10 \
    libglib2.0-0=2.72.4-0ubuntu2.2 \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

RUN git clone --recursive https://github.com/naver/dust3r /dust3r
WORKDIR /dust3r
RUN pip install -r requirements.txt
RUN pip install -r requirements_optional.txt
RUN pip install opencv-python==4.8.0.74

WORKDIR /dust3r/croco/models/curope/
RUN python setup.py build_ext --inplace

WORKDIR /dust3r
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

ENTRYPOINT ["/entrypoint.sh"]


================================================
FILE: docker/files/entrypoint.sh
================================================
#!/bin/bash

set -eux

DEVICE=${DEVICE:-cuda}
MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}

exec python3 demo.py --weights "checkpoints/$MODEL" --device "$DEVICE" --local_network "$@"


================================================
FILE: docker/run.sh
================================================
#!/bin/bash

set -eux

# Default model name
model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"

check_docker() {
    if ! command -v docker &>/dev/null; then
        echo "Docker could not be found. Please install Docker and try again."
        exit 1
    fi
}

download_model_checkpoint() { 
    if [ -f "./files/checkpoints/${model_name}" ]; then
        echo "Model checkpoint ${model_name} already exists. Skipping download."
        return
    fi
    echo "Downloading model checkpoint ${model_name}..."
    wget "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/${model_name}" -P ./files/checkpoints
}

set_dcomp() {
    if command -v docker-compose &>/dev/null; then
        dcomp="docker-compose"
    elif command -v docker &>/dev/null && docker compose version &>/dev/null; then
        dcomp="docker compose"
    else
        echo "Docker Compose could not be found. Please install Docker Compose and try again."
        exit 1
    fi
}

run_docker() {
    export MODEL=${model_name}
    if [ "$with_cuda" -eq 1 ]; then
        $dcomp -f docker-compose-cuda.yml up --build
    else
        $dcomp -f docker-compose-cpu.yml up --build
    fi
}

with_cuda=0
for arg in "$@"; do
    case $arg in
        --with-cuda)
            with_cuda=1
            ;;
        --model_name=*)
            model_name="${arg#*=}.pth"
            ;;
        *)
            echo "Unknown parameter passed: $arg"
            exit 1
            ;;
    esac
done


main() {
    check_docker
    download_model_checkpoint
    set_dcomp
    run_docker
}

main


================================================
FILE: dust3r/__init__.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).


================================================
FILE: dust3r/cloud_opt/__init__.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# global alignment optimization wrapper function
# --------------------------------------------------------
from enum import Enum

from .optimizer import PointCloudOptimizer
from .modular_optimizer import ModularPointCloudOptimizer
from .pair_viewer import PairViewer


class GlobalAlignerMode(Enum):
    PointCloudOptimizer = "PointCloudOptimizer"
    ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
    PairViewer = "PairViewer"


def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
    # extract all inputs
    view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
    # build the optimizer
    if mode == GlobalAlignerMode.PointCloudOptimizer:
        net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
    elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
        net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
    elif mode == GlobalAlignerMode.PairViewer:
        net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
    else:
        raise NotImplementedError(f'Unknown mode {mode}')

    return net


================================================
FILE: dust3r/cloud_opt/base_opt.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Base class for the global alignement procedure
# --------------------------------------------------------
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import roma
from copy import deepcopy
import tqdm

from dust3r.utils.geometry import inv, geotrf
from dust3r.utils.device import to_numpy
from dust3r.utils.image import rgb
from dust3r.viz import SceneViz, segment_sky, auto_cam_size
from dust3r.optim_factory import adjust_learning_rate_by_lr

from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
                                      cosine_schedule, linear_schedule, get_conf_trf)
import dust3r.cloud_opt.init_im_poses as init_fun


class BasePCOptimizer (nn.Module):
    """ Optimize a global scene, given a list of pairwise observations.
    Graph node: images
    Graph edges: observations = (pred1, pred2)
    """

    def __init__(self, *args, **kwargs):
        if len(args) == 1 and len(kwargs) == 0:
            other = deepcopy(args[0])
            attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes 
                        min_conf_thr conf_thr conf_i conf_j im_conf
                        base_scale norm_pw_scale POSE_DIM pw_poses 
                        pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
            self.__dict__.update({k: other[k] for k in attrs})
        else:
            self._init_from_views(*args, **kwargs)

    def _init_from_views(self, view1, view2, pred1, pred2,
                         dist='l1',
                         conf='log',
                         min_conf_thr=3,
                         base_scale=0.5,
                         allow_pw_adaptors=False,
                         pw_break=20,
                         rand_pose=torch.randn,
                         iterationsCount=None,
                         verbose=True):
        super().__init__()
        if not isinstance(view1['idx'], list):
            view1['idx'] = view1['idx'].tolist()
        if not isinstance(view2['idx'], list):
            view2['idx'] = view2['idx'].tolist()
        self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
        self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
        self.dist = ALL_DISTS[dist]
        self.verbose = verbose

        self.n_imgs = self._check_edges()

        # input data
        pred1_pts = pred1['pts3d']
        pred2_pts = pred2['pts3d_in_other_view']
        self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
        self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
        self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)

        # work in log-scale with conf
        pred1_conf = pred1['conf']
        pred2_conf = pred2['conf']
        self.min_conf_thr = min_conf_thr
        self.conf_trf = get_conf_trf(conf)

        self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
        self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
        self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
        for i in range(len(self.im_conf)):
            self.im_conf[i].requires_grad = False

        # pairwise pose parameters
        self.base_scale = base_scale
        self.norm_pw_scale = True
        self.pw_break = pw_break
        self.POSE_DIM = 7
        self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM)))  # pairwise poses
        self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2)))  # slight xy/z adaptation
        self.pw_adaptors.requires_grad_(allow_pw_adaptors)
        self.has_im_poses = False
        self.rand_pose = rand_pose

        # possibly store images for show_pointcloud
        self.imgs = None
        if 'img' in view1 and 'img' in view2:
            imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
            for v in range(len(self.edges)):
                idx = view1['idx'][v]
                imgs[idx] = view1['img'][v]
                idx = view2['idx'][v]
                imgs[idx] = view2['img'][v]
            self.imgs = rgb(imgs)

    @property
    def n_edges(self):
        return len(self.edges)

    @property
    def str_edges(self):
        return [edge_str(i, j) for i, j in self.edges]

    @property
    def imsizes(self):
        return [(w, h) for h, w in self.imshapes]

    @property
    def device(self):
        return next(iter(self.parameters())).device

    def state_dict(self, trainable=True):
        all_params = super().state_dict()
        return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}

    def load_state_dict(self, data):
        return super().load_state_dict(self.state_dict(trainable=False) | data)

    def _check_edges(self):
        indices = sorted({i for edge in self.edges for i in edge})
        assert indices == list(range(len(indices))), 'bad pair indices: missing values '
        return len(indices)

    @torch.no_grad()
    def _compute_img_conf(self, pred1_conf, pred2_conf):
        im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
        for e, (i, j) in enumerate(self.edges):
            im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
            im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
        return im_conf

    def get_adaptors(self):
        adapt = self.pw_adaptors
        adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1)  # (scale_xy, scale_xy, scale_z)
        if self.norm_pw_scale:  # normalize so that the product == 1
            adapt = adapt - adapt.mean(dim=1, keepdim=True)
        return (adapt / self.pw_break).exp()

    def _get_poses(self, poses):
        # normalize rotation
        Q = poses[:, :4]
        T = signed_expm1(poses[:, 4:7])
        RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
        return RT

    def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
        # all poses == cam-to-world
        pose = poses[idx]
        if not (pose.requires_grad or force):
            return pose

        if R.shape == (4, 4):
            assert T is None
            T = R[:3, 3]
            R = R[:3, :3]

        if R is not None:
            pose.data[0:4] = roma.rotmat_to_unitquat(R)
        if T is not None:
            pose.data[4:7] = signed_log1p(T / (scale or 1))  # translation is function of scale

        if scale is not None:
            assert poses.shape[-1] in (8, 13)
            pose.data[-1] = np.log(float(scale))
        return pose

    def get_pw_norm_scale_factor(self):
        if self.norm_pw_scale:
            # normalize scales so that things cannot go south
            # we want that exp(scale) ~= self.base_scale
            return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
        else:
            return 1  # don't norm scale for known poses

    def get_pw_scale(self):
        scale = self.pw_poses[:, -1].exp()  # (n_edges,)
        scale = scale * self.get_pw_norm_scale_factor()
        return scale

    def get_pw_poses(self):  # cam to world
        RT = self._get_poses(self.pw_poses)
        scaled_RT = RT.clone()
        scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1)  # scale the rotation AND translation
        return scaled_RT

    def get_masks(self):
        return [(conf > self.min_conf_thr) for conf in self.im_conf]

    def depth_to_pts3d(self):
        raise NotImplementedError()

    def get_pts3d(self, raw=False):
        res = self.depth_to_pts3d()
        if not raw:
            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
        return res

    def _set_focal(self, idx, focal, force=False):
        raise NotImplementedError()

    def get_focals(self):
        raise NotImplementedError()

    def get_known_focal_mask(self):
        raise NotImplementedError()

    def get_principal_points(self):
        raise NotImplementedError()

    def get_conf(self, mode=None):
        trf = self.conf_trf if mode is None else get_conf_trf(mode)
        return [trf(c) for c in self.im_conf]

    def get_im_poses(self):
        raise NotImplementedError()

    def _set_depthmap(self, idx, depth, force=False):
        raise NotImplementedError()

    def get_depthmaps(self, raw=False):
        raise NotImplementedError()

    def clean_pointcloud(self, **kw):
        cams = inv(self.get_im_poses())
        K = self.get_intrinsics()
        depthmaps = self.get_depthmaps()
        all_pts3d = self.get_pts3d()

        new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw)

        for i, new_conf in enumerate(new_im_confs):
            self.im_conf[i].data[:] = new_conf
        return self

    def forward(self, ret_details=False):
        pw_poses = self.get_pw_poses()  # cam-to-world
        pw_adapt = self.get_adaptors()
        proj_pts3d = self.get_pts3d()
        # pre-compute pixel weights
        weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
        weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}

        loss = 0
        if ret_details:
            details = -torch.ones((self.n_imgs, self.n_imgs))

        for e, (i, j) in enumerate(self.edges):
            i_j = edge_str(i, j)
            # distance in image i and j
            aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
            aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
            li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
            lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
            loss = loss + li + lj

            if ret_details:
                details[i, j] = li + lj
        loss /= self.n_edges  # average over all pairs

        if ret_details:
            return loss, details
        return loss

    @torch.cuda.amp.autocast(enabled=False)
    def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
        if init is None:
            pass
        elif init == 'msp' or init == 'mst':
            init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
        elif init == 'known_poses':
            init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
                                           niter_PnP=niter_PnP)
        else:
            raise ValueError(f'bad value for {init=}')

        return global_alignment_loop(self, **kw)

    @torch.no_grad()
    def mask_sky(self):
        res = deepcopy(self)
        for i in range(self.n_imgs):
            sky = segment_sky(self.imgs[i])
            res.im_conf[i][sky] = 0
        return res

    def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
        viz = SceneViz()
        if self.imgs is None:
            colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
            colors = list(map(tuple, colors.tolist()))
            for n in range(self.n_imgs):
                viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
        else:
            viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
            colors = np.random.randint(256, size=(self.n_imgs, 3))

        # camera poses
        im_poses = to_numpy(self.get_im_poses())
        if cam_size is None:
            cam_size = auto_cam_size(im_poses)
        viz.add_cameras(im_poses, self.get_focals(), colors=colors,
                        images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
        if show_pw_cams:
            pw_poses = self.get_pw_poses()
            viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)

            if show_pw_pts3d:
                pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
                viz.add_pointcloud(pts, (128, 0, 128))

        viz.show(**kw)
        return viz


def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
    params = [p for p in net.parameters() if p.requires_grad]
    if not params:
        return net

    verbose = net.verbose
    if verbose:
        print('Global alignement - optimizing for:')
        print([name for name, value in net.named_parameters() if value.requires_grad])

    lr_base = lr
    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))

    loss = float('inf')
    if verbose:
        with tqdm.tqdm(total=niter) as bar:
            while bar.n < bar.total:
                loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
                bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
                bar.update()
    else:
        for n in range(niter):
            loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
    return loss


def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
    t = cur_iter / niter
    if schedule == 'cosine':
        lr = cosine_schedule(t, lr_base, lr_min)
    elif schedule == 'linear':
        lr = linear_schedule(t, lr_base, lr_min)
    else:
        raise ValueError(f'bad lr {schedule=}')
    adjust_learning_rate_by_lr(optimizer, lr)
    optimizer.zero_grad()
    loss = net()
    loss.backward()
    optimizer.step()

    return float(loss), lr


@torch.no_grad()
def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, 
                      tol=0.001, bad_conf=0, dbg=()):
    """ Method: 
    1) express all 3d points in each camera coordinate frame
    2) if they're in front of a depthmap --> then lower their confidence
    """
    assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d)
    assert 0 <= tol < 1
    res = [c.clone() for c in im_confs]

    # reshape appropriately
    all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)]
    depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)]
    
    for i, pts3d in enumerate(all_pts3d):
        for j in range(len(all_pts3d)):
            if i == j: continue

            # project 3dpts in other view
            proj = geotrf(cams[j], pts3d)
            proj_depth = proj[:,:,2]
            u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)

            # check which points are actually in the visible cone
            H, W = im_confs[j].shape
            msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H)
            msk_j = v[msk_i], u[msk_i]

            # find bad points = those in front but less confident
            bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j])

            bad_msk_i = msk_i.clone()
            bad_msk_i[msk_i] = bad_points
            res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)

    return res


================================================
FILE: dust3r/cloud_opt/commons.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utility functions for global alignment
# --------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np


def edge_str(i, j):
    return f'{i}_{j}'


def i_j_ij(ij):
    return edge_str(*ij), ij


def edge_conf(conf_i, conf_j, edge):
    return float(conf_i[edge].mean() * conf_j[edge].mean())


def compute_edge_scores(edges, conf_i, conf_j):
    return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}


def NoGradParamDict(x):
    assert isinstance(x, dict)
    return nn.ParameterDict(x).requires_grad_(False)


def get_imshapes(edges, pred_i, pred_j):
    n_imgs = max(max(e) for e in edges) + 1
    imshapes = [None] * n_imgs
    for e, (i, j) in enumerate(edges):
        shape_i = tuple(pred_i[e].shape[0:2])
        shape_j = tuple(pred_j[e].shape[0:2])
        if imshapes[i]:
            assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
        if imshapes[j]:
            assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
        imshapes[i] = shape_i
        imshapes[j] = shape_j
    return imshapes


def get_conf_trf(mode):
    if mode == 'log':
        def conf_trf(x): return x.log()
    elif mode == 'sqrt':
        def conf_trf(x): return x.sqrt()
    elif mode == 'm1':
        def conf_trf(x): return x-1
    elif mode in ('id', 'none'):
        def conf_trf(x): return x
    else:
        raise ValueError(f'bad mode for {mode=}')
    return conf_trf


def l2_dist(a, b, weight):
    return ((a - b).square().sum(dim=-1) * weight)


def l1_dist(a, b, weight):
    return ((a - b).norm(dim=-1) * weight)


ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)


def signed_log1p(x):
    sign = torch.sign(x)
    return sign * torch.log1p(torch.abs(x))


def signed_expm1(x):
    sign = torch.sign(x)
    return sign * torch.expm1(torch.abs(x))


def cosine_schedule(t, lr_start, lr_end):
    assert 0 <= t <= 1
    return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2


def linear_schedule(t, lr_start, lr_end):
    assert 0 <= t <= 1
    return lr_start + (lr_end - lr_start) * t


================================================
FILE: dust3r/cloud_opt/init_im_poses.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Initialization functions for global alignment
# --------------------------------------------------------
from functools import cache

import numpy as np
import scipy.sparse as sp
import torch
import cv2
import roma
from tqdm import tqdm

from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.viz import to_numpy

from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores


@torch.no_grad()
def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
    device = self.device

    # indices of known poses
    nkp, known_poses_msk, known_poses = get_known_poses(self)
    assert nkp == self.n_imgs, 'not all poses are known'

    # get all focals
    nkf, _, im_focals = get_known_focals(self)
    assert nkf == self.n_imgs
    im_pp = self.get_principal_points()

    best_depthmaps = {}
    # init all pairwise poses
    for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
        i_j = edge_str(i, j)

        # find relative pose for this pair
        P1 = torch.eye(4, device=device)
        msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
        _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
                         pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)

        # align the two predicted camera with the two gt cameras
        s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
        # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
        # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
        self._set_pose(self.pw_poses, e, R, T, scale=s)

        # remember if this is a good depthmap
        score = float(self.conf_i[i_j].mean())
        if score > best_depthmaps.get(i, (0,))[0]:
            best_depthmaps[i] = score, i_j, s

    # init all image poses
    for n in range(self.n_imgs):
        assert known_poses_msk[n]
        _, i_j, scale = best_depthmaps[n]
        depth = self.pred_i[i_j][:, :, 2]
        self._set_depthmap(n, depth * scale)


@torch.no_grad()
def init_minimum_spanning_tree(self, **kw):
    """ Init all camera poses (image-wise and pairwise poses) given
        an initial set of pairwise estimations.
    """
    device = self.device
    pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
                                                          self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
                                                          device, has_im_poses=self.has_im_poses, verbose=self.verbose,
                                                          **kw)

    return init_from_pts3d(self, pts3d, im_focals, im_poses)


def init_from_pts3d(self, pts3d, im_focals, im_poses):
    # init poses
    nkp, known_poses_msk, known_poses = get_known_poses(self)
    if nkp == 1:
        raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
    elif nkp > 1:
        # global rigid SE3 alignment
        s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
        trf = sRT_to_4x4(s, R, T, device=known_poses.device)

        # rotate everything
        im_poses = trf @ im_poses
        im_poses[:, :3, :3] /= s  # undo scaling on the rotation part
        for img_pts3d in pts3d:
            img_pts3d[:] = geotrf(trf, img_pts3d)

    # set all pairwise poses
    for e, (i, j) in enumerate(self.edges):
        i_j = edge_str(i, j)
        # compute transform that goes from cam to world
        s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
        self._set_pose(self.pw_poses, e, R, T, scale=s)

    # take into account the scale normalization
    s_factor = self.get_pw_norm_scale_factor()
    im_poses[:, :3, 3] *= s_factor  # apply downscaling factor
    for img_pts3d in pts3d:
        img_pts3d *= s_factor

    # init all image poses
    if self.has_im_poses:
        for i in range(self.n_imgs):
            cam2world = im_poses[i]
            depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
            self._set_depthmap(i, depth)
            self._set_pose(self.im_poses, i, cam2world)
            if im_focals[i] is not None:
                self._set_focal(i, im_focals[i])

    if self.verbose:
        print(' init loss =', float(self()))


def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
                          device, has_im_poses=True, niter_PnP=10, verbose=True):
    n_imgs = len(imshapes)
    sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
    msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()

    # temp variable to store 3d points
    pts3d = [None] * len(imshapes)

    todo = sorted(zip(-msp.data, msp.row, msp.col))  # sorted edges
    im_poses = [None] * n_imgs
    im_focals = [None] * n_imgs

    # init with strongest edge
    score, i, j = todo.pop()
    if verbose:
        print(f' init edge ({i}*,{j}*) {score=}')
    i_j = edge_str(i, j)
    pts3d[i] = pred_i[i_j].clone()
    pts3d[j] = pred_j[i_j].clone()
    done = {i, j}
    if has_im_poses:
        im_poses[i] = torch.eye(4, device=device)
        im_focals[i] = estimate_focal(pred_i[i_j])

    # set initial pointcloud based on pairwise graph
    msp_edges = [(i, j)]
    while todo:
        # each time, predict the next one
        score, i, j = todo.pop()

        if im_focals[i] is None:
            im_focals[i] = estimate_focal(pred_i[i_j])

        if i in done:
            if verbose:
                print(f' init edge ({i},{j}*) {score=}')
            assert j not in done
            # align pred[i] with pts3d[i], and then set j accordingly
            i_j = edge_str(i, j)
            s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
            trf = sRT_to_4x4(s, R, T, device)
            pts3d[j] = geotrf(trf, pred_j[i_j])
            done.add(j)
            msp_edges.append((i, j))

            if has_im_poses and im_poses[i] is None:
                im_poses[i] = sRT_to_4x4(1, R, T, device)

        elif j in done:
            if verbose:
                print(f' init edge ({i}*,{j}) {score=}')
            assert i not in done
            i_j = edge_str(i, j)
            s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
            trf = sRT_to_4x4(s, R, T, device)
            pts3d[i] = geotrf(trf, pred_i[i_j])
            done.add(i)
            msp_edges.append((i, j))

            if has_im_poses and im_poses[i] is None:
                im_poses[i] = sRT_to_4x4(1, R, T, device)
        else:
            # let's try again later
            todo.insert(0, (score, i, j))

    if has_im_poses:
        # complete all missing informations
        pair_scores = list(sparse_graph.values())  # already negative scores: less is best
        edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
        for i, j in edges_from_best_to_worse.tolist():
            if im_focals[i] is None:
                im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])

        for i in range(n_imgs):
            if im_poses[i] is None:
                msk = im_conf[i] > min_conf_thr
                res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
                if res:
                    im_focals[i], im_poses[i] = res
            if im_poses[i] is None:
                im_poses[i] = torch.eye(4, device=device)
        im_poses = torch.stack(im_poses)
    else:
        im_poses = im_focals = None

    return pts3d, msp_edges, im_focals, im_poses


def dict_to_sparse_graph(dic):
    n_imgs = max(max(e) for e in dic) + 1
    res = sp.dok_array((n_imgs, n_imgs))
    for edge, value in dic.items():
        res[edge] = value
    return res


def rigid_points_registration(pts1, pts2, conf):
    R, T, s = roma.rigid_points_registration(
        pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
    return s, R, T  # return un-scaled (R, T)


def sRT_to_4x4(scale, R, T, device):
    trf = torch.eye(4, device=device)
    trf[:3, :3] = R * scale
    trf[:3, 3] = T.ravel()  # doesn't need scaling
    return trf


def estimate_focal(pts3d_i, pp=None):
    if pp is None:
        H, W, THREE = pts3d_i.shape
        assert THREE == 3
        pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
    focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
    return float(focal)


@cache
def pixel_grid(H, W):
    return np.mgrid[:W, :H].T.astype(np.float32)


def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
    # extract camera poses and focals with RANSAC-PnP
    if msk.sum() < 4:
        return None  # we need at least 4 points for PnP
    pts3d, msk = map(to_numpy, (pts3d, msk))

    H, W, THREE = pts3d.shape
    assert THREE == 3
    pixels = pixel_grid(H, W)

    if focal is None:
        S = max(W, H)
        tentative_focals = np.geomspace(S/2, S*3, 21)
    else:
        tentative_focals = [focal]

    if pp is None:
        pp = (W/2, H/2)
    else:
        pp = to_numpy(pp)

    best = 0,
    for focal in tentative_focals:
        K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])

        success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
                                                    iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
        if not success:
            continue

        score = len(inliers)
        if success and score > best[0]:
            best = score, R, T, focal

    if not best[0]:
        return None

    _, R, T, best_focal = best
    R = cv2.Rodrigues(R)[0]  # world to cam
    R, T = map(torch.from_numpy, (R, T))
    return best_focal, inv(sRT_to_4x4(1, R, T, device))  # cam to world


def get_known_poses(self):
    if self.has_im_poses:
        known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
        known_poses = self.get_im_poses()
        return known_poses_msk.sum(), known_poses_msk, known_poses
    else:
        return 0, None, None


def get_known_focals(self):
    if self.has_im_poses:
        known_focal_msk = self.get_known_focal_mask()
        known_focals = self.get_focals()
        return known_focal_msk.sum(), known_focal_msk, known_focals
    else:
        return 0, None, None


def align_multiple_poses(src_poses, target_poses):
    N = len(src_poses)
    assert src_poses.shape == target_poses.shape == (N, 4, 4)

    def center_and_z(poses):
        eps = get_med_dist_between_poses(poses) / 100
        return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
    R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
    return s, R, T


================================================
FILE: dust3r/cloud_opt/modular_optimizer.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn

from dust3r.cloud_opt.base_opt import BasePCOptimizer
from dust3r.utils.geometry import geotrf
from dust3r.utils.device import to_cpu, to_numpy
from dust3r.utils.geometry import depthmap_to_pts3d


class ModularPointCloudOptimizer (BasePCOptimizer):
    """ Optimize a global scene, given a list of pairwise observations.
    Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)
    Graph node: images
    Graph edges: observations = (pred1, pred2)
    """

    def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):
        super().__init__(*args, **kwargs)
        self.has_im_poses = True  # by definition of this class
        self.focal_brake = focal_brake

        # adding thing to optimize
        self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes)  # log(depth)
        self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs))  # camera poses
        default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]
        self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [
                                          f]) for f in default_focals)  # camera intrinsics
        self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs))  # camera intrinsics
        self.im_pp.requires_grad_(optimize_pp)

    def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world
        if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
            known_poses = [known_poses]
        for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
            if self.verbose:
                print(f' (setting pose #{idx} = {pose[:3,3]})')
            self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))

        # normalize scale if there's less than 1 known pose
        n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
        self.norm_pw_scale = (n_known_poses <= 1)

    def preset_intrinsics(self, known_intrinsics, msk=None):
        if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:
            known_intrinsics = [known_intrinsics]
        for K in known_intrinsics:
            assert K.shape == (3, 3)
        self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)
        self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)

    def preset_focal(self, known_focals, msk=None):
        for idx, focal in zip(self._get_msk_indices(msk), known_focals):
            if self.verbose:
                print(f' (setting focal #{idx} = {focal})')
            self._no_grad(self._set_focal(idx, focal, force=True))

    def preset_principal_point(self, known_pp, msk=None):
        for idx, pp in zip(self._get_msk_indices(msk), known_pp):
            if self.verbose:
                print(f' (setting principal point #{idx} = {pp})')
            self._no_grad(self._set_principal_point(idx, pp, force=True))

    def _no_grad(self, tensor):
        return tensor.requires_grad_(False)

    def _get_msk_indices(self, msk):
        if msk is None:
            return range(self.n_imgs)
        elif isinstance(msk, int):
            return [msk]
        elif isinstance(msk, (tuple, list)):
            return self._get_msk_indices(np.array(msk))
        elif msk.dtype in (bool, torch.bool, np.bool_):
            assert len(msk) == self.n_imgs
            return np.where(msk)[0]
        elif np.issubdtype(msk.dtype, np.integer):
            return msk
        else:
            raise ValueError(f'bad {msk=}')

    def _set_focal(self, idx, focal, force=False):
        param = self.im_focals[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = self.focal_brake * np.log(focal)
        return param

    def get_focals(self):
        log_focals = torch.stack(list(self.im_focals), dim=0)
        return (log_focals / self.focal_brake).exp()

    def _set_principal_point(self, idx, pp, force=False):
        param = self.im_pp[idx]
        H, W = self.imshapes[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
        return param

    def get_principal_points(self):
        return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])

    def get_intrinsics(self):
        K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
        focals = self.get_focals().view(self.n_imgs, -1)
        K[:, 0, 0] = focals[:, 0]
        K[:, 1, 1] = focals[:, -1]
        K[:, :2, 2] = self.get_principal_points()
        K[:, 2, 2] = 1
        return K

    def get_im_poses(self):  # cam to world
        cam2world = self._get_poses(torch.stack(list(self.im_poses)))
        return cam2world

    def _set_depthmap(self, idx, depth, force=False):
        param = self.im_depthmaps[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = depth.log().nan_to_num(neginf=0)
        return param

    def get_depthmaps(self):
        return [d.exp() for d in self.im_depthmaps]

    def depth_to_pts3d(self):
        # Get depths and  projection params if not provided
        focals = self.get_focals()
        pp = self.get_principal_points()
        im_poses = self.get_im_poses()
        depth = self.get_depthmaps()

        # convert focal to (1,2,H,W) constant field
        def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])
        # get pointmaps in camera frame
        rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]
        # project to world frame
        return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]

    def get_pts3d(self):
        return self.depth_to_pts3d()


================================================
FILE: dust3r/cloud_opt/optimizer.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Main class for the implementation of the global alignment
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn

from dust3r.cloud_opt.base_opt import BasePCOptimizer
from dust3r.utils.geometry import xy_grid, geotrf
from dust3r.utils.device import to_cpu, to_numpy


class PointCloudOptimizer(BasePCOptimizer):
    """ Optimize a global scene, given a list of pairwise observations.
    Graph node: images
    Graph edges: observations = (pred1, pred2)
    """

    def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
        super().__init__(*args, **kwargs)

        self.has_im_poses = True  # by definition of this class
        self.focal_break = focal_break

        # adding thing to optimize
        self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes)  # log(depth)
        self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs))  # camera poses
        self.im_focals = nn.ParameterList(torch.FloatTensor(
            [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes)  # camera intrinsics
        self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs))  # camera intrinsics
        self.im_pp.requires_grad_(optimize_pp)

        self.imshape = self.imshapes[0]
        im_areas = [h*w for h, w in self.imshapes]
        self.max_area = max(im_areas)

        # adding thing to optimize
        self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
        self.im_poses = ParameterStack(self.im_poses, is_param=True)
        self.im_focals = ParameterStack(self.im_focals, is_param=True)
        self.im_pp = ParameterStack(self.im_pp, is_param=True)
        self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
        self.register_buffer('_grid', ParameterStack(
            [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))

        # pre-compute pixel weights
        self.register_buffer('_weight_i', ParameterStack(
            [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
        self.register_buffer('_weight_j', ParameterStack(
            [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))

        # precompute aa
        self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
        self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
        self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
        self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
        self.total_area_i = sum([im_areas[i] for i, j in self.edges])
        self.total_area_j = sum([im_areas[j] for i, j in self.edges])

    def _check_all_imgs_are_selected(self, msk):
        assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'

    def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world
        self._check_all_imgs_are_selected(pose_msk)

        if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
            known_poses = [known_poses]
        for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
            if self.verbose:
                print(f' (setting pose #{idx} = {pose[:3,3]})')
            self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))

        # normalize scale if there's less than 1 known pose
        n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
        self.norm_pw_scale = (n_known_poses <= 1)

        self.im_poses.requires_grad_(False)
        self.norm_pw_scale = False

    def preset_focal(self, known_focals, msk=None):
        self._check_all_imgs_are_selected(msk)

        for idx, focal in zip(self._get_msk_indices(msk), known_focals):
            if self.verbose:
                print(f' (setting focal #{idx} = {focal})')
            self._no_grad(self._set_focal(idx, focal))

        self.im_focals.requires_grad_(False)

    def preset_principal_point(self, known_pp, msk=None):
        self._check_all_imgs_are_selected(msk)

        for idx, pp in zip(self._get_msk_indices(msk), known_pp):
            if self.verbose:
                print(f' (setting principal point #{idx} = {pp})')
            self._no_grad(self._set_principal_point(idx, pp))

        self.im_pp.requires_grad_(False)

    def _get_msk_indices(self, msk):
        if msk is None:
            return range(self.n_imgs)
        elif isinstance(msk, int):
            return [msk]
        elif isinstance(msk, (tuple, list)):
            return self._get_msk_indices(np.array(msk))
        elif msk.dtype in (bool, torch.bool, np.bool_):
            assert len(msk) == self.n_imgs
            return np.where(msk)[0]
        elif np.issubdtype(msk.dtype, np.integer):
            return msk
        else:
            raise ValueError(f'bad {msk=}')

    def _no_grad(self, tensor):
        assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'

    def _set_focal(self, idx, focal, force=False):
        param = self.im_focals[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = self.focal_break * np.log(focal)
        return param

    def get_focals(self):
        log_focals = torch.stack(list(self.im_focals), dim=0)
        return (log_focals / self.focal_break).exp()

    def get_known_focal_mask(self):
        return torch.tensor([not (p.requires_grad) for p in self.im_focals])

    def _set_principal_point(self, idx, pp, force=False):
        param = self.im_pp[idx]
        H, W = self.imshapes[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
        return param

    def get_principal_points(self):
        return self._pp + 10 * self.im_pp

    def get_intrinsics(self):
        K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
        focals = self.get_focals().flatten()
        K[:, 0, 0] = K[:, 1, 1] = focals
        K[:, :2, 2] = self.get_principal_points()
        K[:, 2, 2] = 1
        return K

    def get_im_poses(self):  # cam to world
        cam2world = self._get_poses(self.im_poses)
        return cam2world

    def _set_depthmap(self, idx, depth, force=False):
        depth = _ravel_hw(depth, self.max_area)

        param = self.im_depthmaps[idx]
        if param.requires_grad or force:  # can only init a parameter not already initialized
            param.data[:] = depth.log().nan_to_num(neginf=0)
        return param

    def get_depthmaps(self, raw=False):
        res = self.im_depthmaps.exp()
        if not raw:
            res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
        return res

    def depth_to_pts3d(self):
        # Get depths and  projection params if not provided
        focals = self.get_focals()
        pp = self.get_principal_points()
        im_poses = self.get_im_poses()
        depth = self.get_depthmaps(raw=True)

        # get pointmaps in camera frame
        rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
        # project to world frame
        return geotrf(im_poses, rel_ptmaps)

    def get_pts3d(self, raw=False):
        res = self.depth_to_pts3d()
        if not raw:
            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
        return res

    def forward(self):
        pw_poses = self.get_pw_poses()  # cam-to-world
        pw_adapt = self.get_adaptors().unsqueeze(1)
        proj_pts3d = self.get_pts3d(raw=True)

        # rotate pairwise prediction according to pw_poses
        aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
        aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)

        # compute the less
        li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
        lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j

        return li + lj


def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
    pp = pp.unsqueeze(1)
    focal = focal.unsqueeze(1)
    assert focal.shape == (len(depth), 1, 1)
    assert pp.shape == (len(depth), 1, 2)
    assert pixel_grid.shape == depth.shape + (2,)
    depth = depth.unsqueeze(-1)
    return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)


def ParameterStack(params, keys=None, is_param=None, fill=0):
    if keys is not None:
        params = [params[k] for k in keys]

    if fill > 0:
        params = [_ravel_hw(p, fill) for p in params]

    requires_grad = params[0].requires_grad
    assert all(p.requires_grad == requires_grad for p in params)

    params = torch.stack(list(params)).float().detach()
    if is_param or requires_grad:
        params = nn.Parameter(params)
        params.requires_grad_(requires_grad)
    return params


def _ravel_hw(tensor, fill=0):
    # ravel H,W
    tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])

    if len(tensor) < fill:
        tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
    return tensor


def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  # size / 1.1547005383792515
    return minf*focal_base, maxf*focal_base


def apply_mask(img, msk):
    img = img.copy()
    img[msk] = 0
    return img


================================================
FILE: dust3r/cloud_opt/pair_viewer.py
================================================
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dummy optimizer for visualizing pairs
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import cv2

from dust3r.cloud_opt.base_opt import BasePCOptimizer
from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
from dust3r.cloud_opt.commons import edge_str
from dust3r.post_process import estimate_focal_knowing_depth


class PairViewer (BasePCOptimizer):
    """
    This a Dummy Optimizer.
    To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert self.is_s
Download .txt
gitextract_0hi9z8ek/

├── .gitignore
├── .gitmodules
├── LICENSE
├── NOTICE
├── README.md
├── datasets_preprocess/
│   ├── habitat/
│   │   ├── README.md
│   │   ├── find_scenes.py
│   │   ├── habitat_renderer/
│   │   │   ├── __init__.py
│   │   │   ├── habitat_sim_envmaps_renderer.py
│   │   │   ├── multiview_crop_generator.py
│   │   │   ├── projections.py
│   │   │   └── projections_conversions.py
│   │   └── preprocess_habitat.py
│   ├── path_to_root.py
│   ├── preprocess_arkitscenes.py
│   ├── preprocess_blendedMVS.py
│   ├── preprocess_co3d.py
│   ├── preprocess_megadepth.py
│   ├── preprocess_scannetpp.py
│   ├── preprocess_staticthings3d.py
│   ├── preprocess_waymo.py
│   └── preprocess_wildrgbd.py
├── demo.py
├── docker/
│   ├── docker-compose-cpu.yml
│   ├── docker-compose-cuda.yml
│   ├── files/
│   │   ├── cpu.Dockerfile
│   │   ├── cuda.Dockerfile
│   │   └── entrypoint.sh
│   └── run.sh
├── dust3r/
│   ├── __init__.py
│   ├── cloud_opt/
│   │   ├── __init__.py
│   │   ├── base_opt.py
│   │   ├── commons.py
│   │   ├── init_im_poses.py
│   │   ├── modular_optimizer.py
│   │   ├── optimizer.py
│   │   └── pair_viewer.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── arkitscenes.py
│   │   ├── base/
│   │   │   ├── __init__.py
│   │   │   ├── base_stereo_view_dataset.py
│   │   │   ├── batched_sampler.py
│   │   │   └── easy_dataset.py
│   │   ├── blendedmvs.py
│   │   ├── co3d.py
│   │   ├── habitat.py
│   │   ├── megadepth.py
│   │   ├── scannetpp.py
│   │   ├── staticthings3d.py
│   │   ├── utils/
│   │   │   ├── __init__.py
│   │   │   ├── cropping.py
│   │   │   └── transforms.py
│   │   ├── waymo.py
│   │   └── wildrgbd.py
│   ├── demo.py
│   ├── heads/
│   │   ├── __init__.py
│   │   ├── dpt_head.py
│   │   ├── linear_head.py
│   │   └── postprocess.py
│   ├── image_pairs.py
│   ├── inference.py
│   ├── losses.py
│   ├── model.py
│   ├── optim_factory.py
│   ├── patch_embed.py
│   ├── post_process.py
│   ├── training.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── device.py
│   │   ├── geometry.py
│   │   ├── image.py
│   │   ├── misc.py
│   │   ├── parallel.py
│   │   └── path_to_croco.py
│   └── viz.py
├── dust3r_visloc/
│   ├── README.md
│   ├── __init__.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── aachen_day_night.py
│   │   ├── base_colmap.py
│   │   ├── base_dataset.py
│   │   ├── cambridge_landmarks.py
│   │   ├── inloc.py
│   │   ├── sevenscenes.py
│   │   └── utils.py
│   ├── evaluation.py
│   └── localization.py
├── requirements.txt
├── requirements_optional.txt
├── train.py
└── visloc.py
Download .txt
SYMBOL INDEX (512 symbols across 64 files)

FILE: datasets_preprocess/habitat/find_scenes.py
  function find_all_scenes (line 16) | def find_all_scenes(habitat_root, n_scenes=[100000]):

FILE: datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py
  class NoNaviguableSpaceError (line 25) | class NoNaviguableSpaceError(RuntimeError):
    method __init__ (line 26) | def __init__(self, *args):
  class HabitatEnvironmentMapRenderer (line 29) | class HabitatEnvironmentMapRenderer:
    method __init__ (line 30) | def __init__(self,
    method _lazy_initialization (line 64) | def _lazy_initialization(self):
    method close (line 141) | def close(self):
    method __del__ (line 145) | def __del__(self):
    method render_viewpoint (line 148) | def render_viewpoint(self, viewpoint_position):
    method up_direction (line 166) | def up_direction(self):
    method R_cam_to_world (line 169) | def R_cam_to_world(self):

FILE: datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py
  class HabitatMultiviewCrops (line 17) | class HabitatMultiviewCrops:
    method __init__ (line 18) | def __init__(self,
    method compute_pointmap (line 48) | def compute_pointmap(self, distancemap, position):
    method render_viewpoint_data (line 52) | def render_viewpoint_data(self, position):
    method extract_cropped_camera (line 59) | def extract_cropped_camera(self, projection, color_image, distancemap,...
  function perspective_projection_to_dict (line 76) | def perspective_projection_to_dict(persp_projection, position):
  function dict_to_perspective_projection (line 86) | def dict_to_perspective_projection(camera_params):

FILE: datasets_preprocess/habitat/habitat_renderer/projections.py
  class EquirectangularProjection (line 9) | class EquirectangularProjection:
    method __init__ (line 18) | def __init__(self, height, width):
    method unproject (line 24) | def unproject(self, u, v):
    method project (line 41) | def project(self, rays):
  class PerspectiveProjection (line 59) | class PerspectiveProjection:
    method __init__ (line 72) | def __init__(self, K, height, width):
    method project (line 78) | def project(self, rays):
    method unproject (line 83) | def unproject(self, u, v):
  class RotatedProjection (line 89) | class RotatedProjection:
    method __init__ (line 90) | def __init__(self, base_projection, R_to_base_projection):
    method width (line 95) | def width(self):
    method height (line 99) | def height(self):
    method project (line 102) | def project(self, rays):
    method unproject (line 107) | def unproject(self, u, v):
  function get_projection_rays (line 113) | def get_projection_rays(projection, noise_level=0):
  function compute_camera_intrinsics (line 124) | def compute_camera_intrinsics(height, width, hfov):
  function colmap_to_opencv_intrinsics (line 129) | def colmap_to_opencv_intrinsics(K):
  function opencv_to_colmap_intrinsics (line 141) | def opencv_to_colmap_intrinsics(K):

FILE: datasets_preprocess/habitat/habitat_renderer/projections_conversions.py
  class RemapProjection (line 11) | class RemapProjection:
    method __init__ (line 12) | def __init__(self, input_projection, output_projection, pixel_jitterin...
    method convert (line 34) | def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2....

FILE: datasets_preprocess/habitat/preprocess_habitat.py
  function preprocess_metadata (line 26) | def preprocess_metadata(metadata_filename,

FILE: datasets_preprocess/preprocess_arkitscenes.py
  function get_parser (line 24) | def get_parser():
  function value_to_decimal (line 32) | def value_to_decimal(value, decimal_places):
  function closest (line 37) | def closest(value, sorted_list):
  function get_up_vectors (line 52) | def get_up_vectors(pose_device_to_world):
  function get_right_vectors (line 56) | def get_right_vectors(pose_device_to_world):
  function read_traj (line 60) | def read_traj(traj_path):
  function main (line 93) | def main(rootdir, pairsdir, outdir):
  function convert_scene_metadata (line 269) | def convert_scene_metadata(scene_subdir, intrinsics_dir,
  function find_scene_orientation (line 308) | def find_scene_orientation(poses_cam_to_world):

FILE: datasets_preprocess/preprocess_blendedMVS.py
  function get_parser (line 27) | def get_parser():
  function main (line 36) | def main(db_root, pairs_path, output_dir):
  function load_crop_and_save (line 63) | def load_crop_and_save(root, img, out_dir):
  function _crop_image (line 89) | def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_o...
  function _load_pose (line 96) | def _load_pose(path, ret_44=False):
  function load_pfm_file (line 110) | def load_pfm_file(file_path):

FILE: datasets_preprocess/preprocess_co3d.py
  function get_parser (line 47) | def get_parser():
  function convert_ndc_to_pinhole (line 63) | def convert_ndc_to_pinhole(focal_length, principal_point, image_size):
  function opencv_from_cameras_projection (line 77) | def opencv_from_cameras_projection(R, T, focal, p0, image_size):
  function get_set_list (line 112) | def get_set_list(category_dir, split, is_single_sequence_subset=False):
  function prepare_sequences (line 129) | def prepare_sequences(category, co3d_dir, output_dir, img_size, split, m...

FILE: datasets_preprocess/preprocess_megadepth.py
  function get_parser (line 23) | def get_parser():
  function main (line 32) | def main(db_root, pairs_path, output_dir):
  function resize_one_image (line 65) | def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir):
  function _downscale_image (line 95) | def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=...
  function _load_kpts_and_poses (line 106) | def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrins...
  function colmap_raw_pose_to_principal_axis (line 153) | def colmap_raw_pose_to_principal_axis(image_pose):
  function colmap_raw_pose_to_RT (line 165) | def colmap_raw_pose_to_RT(image_pose):

FILE: datasets_preprocess/preprocess_scannetpp.py
  function get_parser (line 43) | def get_parser():
  function pose_from_qwxyz_txyz (line 53) | def pose_from_qwxyz_txyz(elems):
  function get_frame_number (line 61) | def get_frame_number(name, cam_type='dslr'):
  function load_sfm (line 76) | def load_sfm(sfm_dir, cam_type='dslr'):
  function subsample_img_infos (line 134) | def subsample_img_infos(img_infos, num_images, allowed_name_subset=None):
  function undistort_images (line 146) | def undistort_images(intrinsics, rgb, mask):
  function process_scenes (line 191) | def process_scenes(root, pairsdir, output_dir, target_resolution):

FILE: datasets_preprocess/preprocess_staticthings3d.py
  function get_parser (line 27) | def get_parser():
  function main (line 36) | def main(db_root, pairs_path, output_dir):
  function load_crop_and_save (line 57) | def load_crop_and_save(db_root, relpath_, camera, num, out_dir):
  function _crop_image (line 88) | def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_o...
  function _list_all_scenes (line 95) | def _list_all_scenes(path):
  function readFloat (line 108) | def readFloat(name):

FILE: datasets_preprocess/preprocess_waymo.py
  function get_parser (line 37) | def get_parser():
  function main (line 47) | def main(waymo_root, pairs_path, output_dir, workers=1):
  function _list_sequences (line 66) | def _list_sequences(db_root):
  function extract_frames (line 73) | def extract_frames(db_root, output_dir, workers=8):
  function process_one_seq (line 81) | def process_one_seq(db_root, output_dir, seq):
  function extract_frames_one_seq (line 105) | def extract_frames_one_seq(filename):
  function make_crops (line 170) | def make_crops(output_dir, workers=16, **kw):
  function crop_one_seq (line 177) | def crop_one_seq(input_dir, output_dir, seq, resolution=512):

FILE: datasets_preprocess/preprocess_wildrgbd.py
  function get_parser (line 29) | def get_parser():
  function get_set_list (line 43) | def get_set_list(category_dir, split):
  function prepare_sequences (line 60) | def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, spli...

FILE: dust3r/cloud_opt/__init__.py
  class GlobalAlignerMode (line 14) | class GlobalAlignerMode(Enum):
  function global_aligner (line 20) | def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCl...

FILE: dust3r/cloud_opt/base_opt.py
  class BasePCOptimizer (line 27) | class BasePCOptimizer (nn.Module):
    method __init__ (line 33) | def __init__(self, *args, **kwargs):
    method _init_from_views (line 44) | def _init_from_views(self, view1, view2, pred1, pred2,
    method n_edges (line 108) | def n_edges(self):
    method str_edges (line 112) | def str_edges(self):
    method imsizes (line 116) | def imsizes(self):
    method device (line 120) | def device(self):
    method state_dict (line 123) | def state_dict(self, trainable=True):
    method load_state_dict (line 127) | def load_state_dict(self, data):
    method _check_edges (line 130) | def _check_edges(self):
    method _compute_img_conf (line 136) | def _compute_img_conf(self, pred1_conf, pred2_conf):
    method get_adaptors (line 143) | def get_adaptors(self):
    method _get_poses (line 150) | def _get_poses(self, poses):
    method _set_pose (line 157) | def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
    method get_pw_norm_scale_factor (line 178) | def get_pw_norm_scale_factor(self):
    method get_pw_scale (line 186) | def get_pw_scale(self):
    method get_pw_poses (line 191) | def get_pw_poses(self):  # cam to world
    method get_masks (line 197) | def get_masks(self):
    method depth_to_pts3d (line 200) | def depth_to_pts3d(self):
    method get_pts3d (line 203) | def get_pts3d(self, raw=False):
    method _set_focal (line 209) | def _set_focal(self, idx, focal, force=False):
    method get_focals (line 212) | def get_focals(self):
    method get_known_focal_mask (line 215) | def get_known_focal_mask(self):
    method get_principal_points (line 218) | def get_principal_points(self):
    method get_conf (line 221) | def get_conf(self, mode=None):
    method get_im_poses (line 225) | def get_im_poses(self):
    method _set_depthmap (line 228) | def _set_depthmap(self, idx, depth, force=False):
    method get_depthmaps (line 231) | def get_depthmaps(self, raw=False):
    method clean_pointcloud (line 234) | def clean_pointcloud(self, **kw):
    method forward (line 246) | def forward(self, ret_details=False):
    method compute_global_alignment (line 276) | def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
    method mask_sky (line 290) | def mask_sky(self):
    method show (line 297) | def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None,...
  function global_alignment_loop (line 326) | def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr...
  function global_alignment_iter (line 352) | def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimiz...
  function clean_pointcloud (line 370) | def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d,

FILE: dust3r/cloud_opt/commons.py
  function edge_str (line 12) | def edge_str(i, j):
  function i_j_ij (line 16) | def i_j_ij(ij):
  function edge_conf (line 20) | def edge_conf(conf_i, conf_j, edge):
  function compute_edge_scores (line 24) | def compute_edge_scores(edges, conf_i, conf_j):
  function NoGradParamDict (line 28) | def NoGradParamDict(x):
  function get_imshapes (line 33) | def get_imshapes(edges, pred_i, pred_j):
  function get_conf_trf (line 48) | def get_conf_trf(mode):
  function l2_dist (line 62) | def l2_dist(a, b, weight):
  function l1_dist (line 66) | def l1_dist(a, b, weight):
  function signed_log1p (line 73) | def signed_log1p(x):
  function signed_expm1 (line 78) | def signed_expm1(x):
  function cosine_schedule (line 83) | def cosine_schedule(t, lr_start, lr_end):
  function linear_schedule (line 88) | def linear_schedule(t, lr_start, lr_end):

FILE: dust3r/cloud_opt/init_im_poses.py
  function init_from_known_poses (line 24) | def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
  function init_minimum_spanning_tree (line 67) | def init_minimum_spanning_tree(self, **kw):
  function init_from_pts3d (line 80) | def init_from_pts3d(self, pts3d, im_focals, im_poses):
  function minimum_spanning_tree (line 123) | def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_...
  function dict_to_sparse_graph (line 212) | def dict_to_sparse_graph(dic):
  function rigid_points_registration (line 220) | def rigid_points_registration(pts1, pts2, conf):
  function sRT_to_4x4 (line 226) | def sRT_to_4x4(scale, R, T, device):
  function estimate_focal (line 233) | def estimate_focal(pts3d_i, pp=None):
  function pixel_grid (line 243) | def pixel_grid(H, W):
  function fast_pnp (line 247) | def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
  function get_known_poses (line 290) | def get_known_poses(self):
  function get_known_focals (line 299) | def get_known_focals(self):
  function align_multiple_poses (line 308) | def align_multiple_poses(src_poses, target_poses):

FILE: dust3r/cloud_opt/modular_optimizer.py
  class ModularPointCloudOptimizer (line 17) | class ModularPointCloudOptimizer (BasePCOptimizer):
    method __init__ (line 24) | def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_br...
    method preset_pose (line 38) | def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world
    method preset_intrinsics (line 50) | def preset_intrinsics(self, known_intrinsics, msk=None):
    method preset_focal (line 58) | def preset_focal(self, known_focals, msk=None):
    method preset_principal_point (line 64) | def preset_principal_point(self, known_pp, msk=None):
    method _no_grad (line 70) | def _no_grad(self, tensor):
    method _get_msk_indices (line 73) | def _get_msk_indices(self, msk):
    method _set_focal (line 88) | def _set_focal(self, idx, focal, force=False):
    method get_focals (line 94) | def get_focals(self):
    method _set_principal_point (line 98) | def _set_principal_point(self, idx, pp, force=False):
    method get_principal_points (line 105) | def get_principal_points(self):
    method get_intrinsics (line 108) | def get_intrinsics(self):
    method get_im_poses (line 117) | def get_im_poses(self):  # cam to world
    method _set_depthmap (line 121) | def _set_depthmap(self, idx, depth, force=False):
    method get_depthmaps (line 127) | def get_depthmaps(self):
    method depth_to_pts3d (line 130) | def depth_to_pts3d(self):
    method get_pts3d (line 144) | def get_pts3d(self):

FILE: dust3r/cloud_opt/optimizer.py
  class PointCloudOptimizer (line 16) | class PointCloudOptimizer(BasePCOptimizer):
    method __init__ (line 22) | def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
    method _check_all_imgs_are_selected (line 63) | def _check_all_imgs_are_selected(self, msk):
    method preset_pose (line 66) | def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world
    method preset_focal (line 83) | def preset_focal(self, known_focals, msk=None):
    method preset_principal_point (line 93) | def preset_principal_point(self, known_pp, msk=None):
    method _get_msk_indices (line 103) | def _get_msk_indices(self, msk):
    method _no_grad (line 118) | def _no_grad(self, tensor):
    method _set_focal (line 121) | def _set_focal(self, idx, focal, force=False):
    method get_focals (line 127) | def get_focals(self):
    method get_known_focal_mask (line 131) | def get_known_focal_mask(self):
    method _set_principal_point (line 134) | def _set_principal_point(self, idx, pp, force=False):
    method get_principal_points (line 141) | def get_principal_points(self):
    method get_intrinsics (line 144) | def get_intrinsics(self):
    method get_im_poses (line 152) | def get_im_poses(self):  # cam to world
    method _set_depthmap (line 156) | def _set_depthmap(self, idx, depth, force=False):
    method get_depthmaps (line 164) | def get_depthmaps(self, raw=False):
    method depth_to_pts3d (line 170) | def depth_to_pts3d(self):
    method get_pts3d (line 182) | def get_pts3d(self, raw=False):
    method forward (line 188) | def forward(self):
  function _fast_depthmap_to_pts3d (line 204) | def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
  function ParameterStack (line 214) | def ParameterStack(params, keys=None, is_param=None, fill=0):
  function _ravel_hw (line 231) | def _ravel_hw(tensor, fill=0):
  function acceptable_focal_range (line 240) | def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
  function apply_mask (line 245) | def apply_mask(img, msk):

FILE: dust3r/cloud_opt/pair_viewer.py
  class PairViewer (line 18) | class PairViewer (BasePCOptimizer):
    method __init__ (line 24) | def __init__(self, *args, **kwargs):
    method _set_depthmap (line 83) | def _set_depthmap(self, idx, depth, force=False):
    method get_depthmaps (line 88) | def get_depthmaps(self, raw=False):
    method _set_focal (line 92) | def _set_focal(self, idx, focal, force=False):
    method get_focals (line 95) | def get_focals(self):
    method get_known_focal_mask (line 98) | def get_known_focal_mask(self):
    method get_principal_points (line 101) | def get_principal_points(self):
    method get_intrinsics (line 104) | def get_intrinsics(self):
    method get_im_poses (line 114) | def get_im_poses(self):
    method depth_to_pts3d (line 117) | def depth_to_pts3d(self):
    method forward (line 126) | def forward(self):

FILE: dust3r/datasets/__init__.py
  function get_data_loader (line 16) | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, dr...

FILE: dust3r/datasets/arkitscenes.py
  class ARKitScenes (line 17) | class ARKitScenes(BaseStereoViewDataset):
    method __init__ (line 18) | def __init__(self, *args, split, ROOT, **kwargs):
    method _load_data (line 30) | def _load_data(self, split):
    method __len__ (line 39) | def __len__(self):
    method _get_views (line 42) | def _get_views(self, idx, resolution, rng):

FILE: dust3r/datasets/base/base_stereo_view_dataset.py
  class BaseStereoViewDataset (line 17) | class BaseStereoViewDataset (EasyDataset):
    method __init__ (line 29) | def __init__(self, *,  # only keyword arguments
    method __len__ (line 46) | def __len__(self):
    method get_stats (line 49) | def get_stats(self):
    method __repr__ (line 52) | def __repr__(self):
    method _get_views (line 60) | def _get_views(self, idx, resolution, rng):
    method __getitem__ (line 63) | def __getitem__(self, idx):
    method _set_resolutions (line 120) | def _set_resolutions(self, resolutions):
    method _crop_resize_if_necessary (line 137) | def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resol...
  function is_good_type (line 184) | def is_good_type(key, v):
  function view_name (line 194) | def view_name(view, batch_index=None):
  function transpose_to_landscape (line 202) | def transpose_to_landscape(view):

FILE: dust3r/datasets/base/batched_sampler.py
  class BatchedRandomSampler (line 11) | class BatchedRandomSampler:
    method __init__ (line 21) | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=...
    method __len__ (line 34) | def __len__(self):
    method set_epoch (line 37) | def set_epoch(self, epoch):
    method __iter__ (line 40) | def __iter__(self):
  function round_by (line 71) | def round_by(total, multiple, up=False):

FILE: dust3r/datasets/base/easy_dataset.py
  class EasyDataset (line 11) | class EasyDataset:
    method __add__ (line 22) | def __add__(self, other):
    method __rmul__ (line 25) | def __rmul__(self, factor):
    method __rmatmul__ (line 28) | def __rmatmul__(self, factor):
    method set_epoch (line 31) | def set_epoch(self, epoch):
    method make_sampler (line 34) | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0,...
  class MulDataset (line 41) | class MulDataset (EasyDataset):
    method __init__ (line 46) | def __init__(self, multiplicator, dataset):
    method __len__ (line 51) | def __len__(self):
    method __repr__ (line 54) | def __repr__(self):
    method __getitem__ (line 57) | def __getitem__(self, idx):
    method _resolutions (line 65) | def _resolutions(self):
  class ResizedDataset (line 69) | class ResizedDataset (EasyDataset):
    method __init__ (line 74) | def __init__(self, new_size, dataset):
    method __len__ (line 79) | def __len__(self):
    method __repr__ (line 82) | def __repr__(self):
    method set_epoch (line 89) | def set_epoch(self, epoch):
    method __getitem__ (line 102) | def __getitem__(self, idx):
    method _resolutions (line 111) | def _resolutions(self):
  class CatDataset (line 115) | class CatDataset (EasyDataset):
    method __init__ (line 119) | def __init__(self, datasets):
    method __len__ (line 125) | def __len__(self):
    method __repr__ (line 128) | def __repr__(self):
    method set_epoch (line 132) | def set_epoch(self, epoch):
    method __getitem__ (line 136) | def __getitem__(self, idx):
    method _resolutions (line 153) | def _resolutions(self):

FILE: dust3r/datasets/blendedmvs.py
  class BlendedMVS (line 16) | class BlendedMVS (BaseStereoViewDataset):
    method __init__ (line 20) | def __init__(self, *args, ROOT, split=None, **kwargs):
    method _load_data (line 25) | def _load_data(self, split):
    method __len__ (line 40) | def __len__(self):
    method get_stats (line 43) | def get_stats(self):
    method _get_views (line 46) | def _get_views(self, pair_idx, resolution, rng):

FILE: dust3r/datasets/co3d.py
  class Co3d (line 21) | class Co3d(BaseStereoViewDataset):
    method __init__ (line 22) | def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
    method __len__ (line 45) | def __len__(self):
    method _get_metadatapath (line 48) | def _get_metadatapath(self, obj, instance, view_idx):
    method _get_impath (line 51) | def _get_impath(self, obj, instance, view_idx):
    method _get_depthpath (line 54) | def _get_depthpath(self, obj, instance, view_idx):
    method _get_maskpath (line 57) | def _get_maskpath(self, obj, instance, view_idx):
    method _read_depthmap (line 60) | def _read_depthmap(self, depthpath, input_metadata):
    method _get_views (line 65) | def _get_views(self, idx, resolution, rng):

FILE: dust3r/datasets/habitat.py
  class Habitat (line 20) | class Habitat(BaseStereoViewDataset):
    method __init__ (line 21) | def __init__(self, size, *args, ROOT, **kwargs):
    method filter_scene (line 30) | def filter_scene(self, label, instance=None):
    method _get_views (line 39) | def _get_views(self, idx, resolution, rng):
    method _load_one_view (line 60) | def _load_one_view(self, data_path, key, view_index, resolution, rng):

FILE: dust3r/datasets/megadepth.py
  class MegaDepth (line 16) | class MegaDepth(BaseStereoViewDataset):
    method __init__ (line 17) | def __init__(self, *args, split, ROOT, **kwargs):
    method _load_data (line 31) | def _load_data(self, split):
    method __len__ (line 37) | def __len__(self):
    method get_stats (line 40) | def get_stats(self):
    method select_scene (line 43) | def select_scene(self, scene, *instances, opposite=False):
    method _get_views (line 64) | def _get_views(self, pair_idx, resolution, rng):

FILE: dust3r/datasets/scannetpp.py
  class ScanNetpp (line 18) | class ScanNetpp(BaseStereoViewDataset):
    method __init__ (line 19) | def __init__(self, *args, ROOT, **kwargs):
    method _load_data (line 25) | def _load_data(self):
    method __len__ (line 34) | def __len__(self):
    method _get_views (line 37) | def _get_views(self, idx, resolution, rng):

FILE: dust3r/datasets/staticthings3d.py
  class StaticThings3D (line 16) | class StaticThings3D (BaseStereoViewDataset):
    method __init__ (line 19) | def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):
    method __len__ (line 30) | def __len__(self):
    method get_stats (line 33) | def get_stats(self):
    method _get_views (line 36) | def _get_views(self, pair_idx, resolution, rng):

FILE: dust3r/datasets/utils/cropping.py
  class ImageList (line 21) | class ImageList:
    method __init__ (line 25) | def __init__(self, images):
    method __len__ (line 34) | def __len__(self):
    method to_pil (line 37) | def to_pil(self):
    method size (line 41) | def size(self):
    method resize (line 46) | def resize(self, *args, **kwargs):
    method crop (line 49) | def crop(self, *args, **kwargs):
    method _dispatch (line 52) | def _dispatch(self, func, *args, **kwargs):
  function rescale_image_depthmap (line 56) | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_re...
  function camera_matrix_of_crop (line 87) | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_...
  function crop_image_depthmap (line 103) | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
  function bbox_from_intrinsics_in_out (line 120) | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matri...

FILE: dust3r/datasets/waymo.py
  class Waymo (line 16) | class Waymo (BaseStereoViewDataset):
    method __init__ (line 20) | def __init__(self, *args, ROOT, **kwargs):
    method _load_data (line 25) | def _load_data(self):
    method __len__ (line 33) | def __len__(self):
    method get_stats (line 36) | def get_stats(self):
    method _get_views (line 39) | def _get_views(self, pair_idx, resolution, rng):

FILE: dust3r/datasets/wildrgbd.py
  class WildRGBD (line 18) | class WildRGBD(Co3d):
    method __init__ (line 19) | def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
    method _get_metadatapath (line 23) | def _get_metadatapath(self, obj, instance, view_idx):
    method _get_impath (line 26) | def _get_impath(self, obj, instance, view_idx):
    method _get_depthpath (line 29) | def _get_depthpath(self, obj, instance, view_idx):
    method _get_maskpath (line 32) | def _get_maskpath(self, obj, instance, view_idx):
    method _read_depthmap (line 35) | def _read_depthmap(self, depthpath, input_metadata):

FILE: dust3r/demo.py
  function get_args_parser (line 30) | def get_args_parser():
  function set_print_with_timestamp (line 53) | def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"):
  function _convert_scene_output_to_glb (line 66) | def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams...
  function get_3D_model_from_scene (line 110) | def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_po...
  function get_reconstructed_scene (line 135) | def get_reconstructed_scene(outdir, model, device, silent, image_size, f...
  function set_scenegraph_options (line 189) | def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
  function main_demo (line 210) | def main_demo(tmpdirname, model, device, image_size, server_name, server...

FILE: dust3r/heads/__init__.py
  function head_factory (line 11) | def head_factory(head_type, output_mode, net, has_conf=False):

FILE: dust3r/heads/dpt_head.py
  class DPTOutputAdapter_fix (line 20) | class DPTOutputAdapter_fix(DPTOutputAdapter):
    method init (line 26) | def init(self, dim_tokens_enc=768):
    method forward (line 34) | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
  class PixelwiseTaskWithDPT (line 68) | class PixelwiseTaskWithDPT(nn.Module):
    method __init__ (line 71) | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
    method forward (line 89) | def forward(self, x, img_info):
  function create_dpt_head (line 96) | def create_dpt_head(net, has_conf=False):

FILE: dust3r/heads/linear_head.py
  class LinearPts3d (line 12) | class LinearPts3d (nn.Module):
    method __init__ (line 18) | def __init__(self, net, has_conf=False):
    method setup (line 27) | def setup(self, croconet):
    method forward (line 30) | def forward(self, decout, img_shape):

FILE: dust3r/heads/postprocess.py
  function postprocess (line 10) | def postprocess(out, depth_mode, conf_mode):
  function reg_dense_depth (line 22) | def reg_dense_depth(xyz, mode):
  function reg_dense_conf (line 49) | def reg_dense_conf(x, mode):

FILE: dust3r/image_pairs.py
  function make_pairs (line 11) | def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=...
  function sel (line 71) | def sel(x, kept):
  function _filter_edges_seq (line 80) | def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
  function filter_pairs_seq (line 94) | def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
  function filter_edges_seq (line 100) | def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=Fal...

FILE: dust3r/inference.py
  function _interleave_imgs (line 14) | def _interleave_imgs(img1, img2):
  function make_batch_symmetric (line 26) | def make_batch_symmetric(batch):
  function loss_of_one_batch (line 32) | def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=...
  function inference (line 56) | def inference(pairs, model, device, batch_size=8, verbose=True):
  function check_if_same_size (line 75) | def check_if_same_size(pairs):
  function get_pred_pts3d (line 81) | def get_pred_pts3d(gt, pred, use_pose=False):
  function find_opt_scaling (line 106) | def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='...

FILE: dust3r/losses.py
  function Sum (line 16) | def Sum(*losses_and_masks):
  class BaseCriterion (line 28) | class BaseCriterion(nn.Module):
    method __init__ (line 29) | def __init__(self, reduction='mean'):
  class LLoss (line 34) | class LLoss (BaseCriterion):
    method forward (line 38) | def forward(self, a, b):
    method distance (line 50) | def distance(self, a, b):
  class L21Loss (line 54) | class L21Loss (LLoss):
    method distance (line 57) | def distance(self, a, b):
  class Criterion (line 64) | class Criterion (nn.Module):
    method __init__ (line 65) | def __init__(self, criterion=None):
    method get_name (line 70) | def get_name(self):
    method with_reduction (line 73) | def with_reduction(self, mode='none'):
  class MultiLoss (line 82) | class MultiLoss (nn.Module):
    method __init__ (line 89) | def __init__(self):
    method compute_loss (line 94) | def compute_loss(self, *args, **kwargs):
    method get_name (line 97) | def get_name(self):
    method __mul__ (line 100) | def __mul__(self, alpha):
    method __add__ (line 107) | def __add__(self, loss2):
    method __repr__ (line 116) | def __repr__(self):
    method forward (line 124) | def forward(self, *args, **kwargs):
  class Regr3D (line 142) | class Regr3D (Criterion, MultiLoss):
    method __init__ (line 153) | def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):
    method get_all_pts3d (line 158) | def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
    method compute_loss (line 185) | def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
  class ConfLoss (line 197) | class ConfLoss (MultiLoss):
    method __init__ (line 208) | def __init__(self, pixel_loss, alpha=1):
    method get_name (line 214) | def get_name(self):
    method get_conf_log (line 217) | def get_conf_log(self, x):
    method compute_loss (line 220) | def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
  class Regr3D_ShiftInv (line 241) | class Regr3D_ShiftInv (Regr3D):
    method get_all_pts3d (line 245) | def get_all_pts3d(self, gt1, gt2, pred1, pred2):
  class Regr3D_ScaleInv (line 266) | class Regr3D_ScaleInv (Regr3D):
    method get_all_pts3d (line 271) | def get_all_pts3d(self, gt1, gt2, pred1, pred2):
  class Regr3D_ScaleShiftInv (line 297) | class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):

FILE: dust3r/model.py
  function load_model (line 27) | def load_model(model_path, device, verbose=True):
  class AsymmetricCroCo3DStereo (line 46) | class AsymmetricCroCo3DStereo (
    method __init__ (line 58) | def __init__(self,
    method from_pretrained (line 77) | def from_pretrained(cls, pretrained_model_name_or_path, **kw):
    method _set_patch_embed (line 87) | def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=...
    method load_state_dict (line 91) | def load_state_dict(self, ckpt, **kw):
    method set_freeze (line 100) | def set_freeze(self, freeze):  # this is for use by downstream models
    method _set_prediction_head (line 109) | def _set_prediction_head(self, *args, **kwargs):
    method set_downstream_head (line 113) | def set_downstream_head(self, output_mode, head_type, landscape_only, ...
    method _encode_image (line 128) | def _encode_image(self, image, true_shape):
    method _encode_image_pairs (line 142) | def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
    method _encode_symmetrized (line 153) | def _encode_symmetrized(self, view1, view2):
    method _decoder (line 172) | def _decoder(self, f1, pos1, f2, pos2):
    method _downstream_head (line 193) | def _downstream_head(self, head_num, decout, img_shape):
    method forward (line 199) | def forward(self, view1, view2):

FILE: dust3r/optim_factory.py
  function adjust_learning_rate_by_lr (line 9) | def adjust_learning_rate_by_lr(optimizer, lr):

FILE: dust3r/patch_embed.py
  function get_patch_embed (line 13) | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
  class PatchEmbedDust3R (line 19) | class PatchEmbedDust3R(PatchEmbed):
    method forward (line 20) | def forward(self, x, **kw):
  class ManyAR_PatchEmbed (line 32) | class ManyAR_PatchEmbed (PatchEmbed):
    method __init__ (line 38) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 42) | def forward(self, img, true_shape):

FILE: dust3r/post_process.py
  function estimate_focal_knowing_depth (line 12) | def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_foc...

FILE: dust3r/training.py
  function get_args_parser (line 39) | def get_args_parser():
  function train (line 92) | def train(args):
  function save_final_model (line 239) | def save_final_model(args, epoch, model_without_ddp, best_so_far=None):
  function build_dataset (line 253) | def build_dataset(dataset, batch_size, num_workers, test=False):
  function train_one_epoch (line 267) | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
  function test_one_epoch (line 342) | def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,

FILE: dust3r/utils/device.py
  function todevice (line 11) | def todevice(batch, device, callback=None, non_blocking=False):
  function to_numpy (line 42) | def to_numpy(x): return todevice(x, 'numpy')
  function to_cpu (line 43) | def to_cpu(x): return todevice(x, 'cpu')
  function to_cuda (line 44) | def to_cuda(x): return todevice(x, 'cuda')
  function collate_with_cat (line 47) | def collate_with_cat(whatever, lists=False):
  function listify (line 75) | def listify(elems):

FILE: dust3r/utils/geometry.py
  function xy_grid (line 15) | def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1...
  function geotrf (line 40) | def geotrf(Trf, pts, ncol=None, norm=False):
  function inv (line 104) | def inv(mat):
  function depthmap_to_pts3d (line 114) | def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
  function depthmap_to_camera_coordinates (line 165) | def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_f...
  function depthmap_to_absolute_camera_coordinates (line 200) | def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics,...
  function colmap_to_opencv_intrinsics (line 223) | def colmap_to_opencv_intrinsics(K):
  function opencv_to_colmap_intrinsics (line 236) | def opencv_to_colmap_intrinsics(K):
  function normalize_pointcloud (line 249) | def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, v...
  function get_joint_pointcloud_depth (line 313) | def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, qu...
  function get_joint_pointcloud_center_scale (line 328) | def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, vali...
  function find_reciprocal_matches (line 345) | def find_reciprocal_matches(P1, P2):
  function get_med_dist_between_poses (line 364) | def get_med_dist_between_poses(poses):

FILE: dust3r/utils/image.py
  function img_to_arr (line 26) | def img_to_arr(img):
  function imread_cv2 (line 32) | def imread_cv2(path, options=cv2.IMREAD_COLOR):
  function rgb (line 45) | def rgb(ftensor, true_shape=None):
  function _resize_pil_image (line 64) | def _resize_pil_image(img, long_edge_size):
  function load_images (line 74) | def load_images(folder_or_list, size, square_ok=False, verbose=True, pat...

FILE: dust3r/utils/misc.py
  function fill_default_args (line 10) | def fill_default_args(kwargs, func):
  function freeze_all_params (line 22) | def freeze_all_params(modules):
  function is_symmetrized (line 32) | def is_symmetrized(gt1, gt2):
  function flip (line 43) | def flip(tensor):
  function interleave (line 48) | def interleave(tensor1, tensor2):
  function transpose_to_landscape (line 54) | def transpose_to_landscape(head, activate=True):
  function transposed (line 99) | def transposed(dic):
  function invalid_to_nans (line 103) | def invalid_to_nans(arr, valid_mask, ndim=999):
  function invalid_to_zeros (line 112) | def invalid_to_zeros(arr, valid_mask, ndim=999):

FILE: dust3r/utils/parallel.py
  function parallel_threads (line 12) | def parallel_threads(function, args, workers=0, star_args=False, kw_args...
  function parallel_processes (line 62) | def parallel_processes(*args, **kwargs):
  function starcall (line 70) | def starcall(args):
  function starstarcall (line 76) | def starstarcall(args):

FILE: dust3r/viz.py
  function cat_3d (line 23) | def cat_3d(vecs):
  function show_raw_pointcloud (line 29) | def show_raw_pointcloud(pts3d, colors, point_size=2):
  function pts3d_to_trimesh (line 38) | def pts3d_to_trimesh(img, pts3d, valid=None):
  function cat_meshes (line 78) | def cat_meshes(meshes):
  function show_duster_pairs (line 90) | def show_duster_pairs(view1, view2, pred1, pred2):
  function auto_cam_size (line 115) | def auto_cam_size(im_poses):
  class SceneViz (line 119) | class SceneViz:
    method __init__ (line 120) | def __init__(self):
    method add_rgbd (line 123) | def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar...
    method add_pointcloud (line 137) | def add_pointcloud(self, pts3d, color=(0,0,0), mask=None, denoise=False):
    method add_rgbd (line 173) | def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar...
    method add_camera (line 190) | def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None...
    method add_cameras (line 202) | def add_cameras(self, poses, focals=None, images=None, imsizes=None, c...
    method show (line 208) | def show(self, point_size=2):
  function show_raw_pointcloud_with_cams (line 212) | def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
  function add_scene_cam (line 246) | def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, i...
  function cat (line 322) | def cat(a, b):
  function uint8 (line 336) | def uint8(colors):
  function segment_sky (line 345) | def segment_sky(image):

FILE: dust3r_visloc/datasets/aachen_day_night.py
  class VislocAachenDayNight (line 11) | class VislocAachenDayNight(BaseVislocColmapDataset):
    method __init__ (line 12) | def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False):

FILE: dust3r_visloc/datasets/base_colmap.py
  function kapture_to_opencv_intrinsics (line 29) | def kapture_to_opencv_intrinsics(sensor):
  function K_from_colmap (line 79) | def K_from_colmap(elems):
  function pose_from_qwxyz_txyz (line 88) | def pose_from_qwxyz_txyz(elems):
  class BaseVislocColmapDataset (line 96) | class BaseVislocColmapDataset(BaseVislocDataset):
    method __init__ (line 97) | def __init__(self, image_path, map_path, query_path, pairsfile_path, t...
    method _load_sfm (line 116) | def _load_sfm(self, sfm_dir):
    method __len__ (line 171) | def __len__(self):
    method _get_view_query (line 174) | def _get_view_query(self, imgname):
    method _get_view_map (line 220) | def _get_view_map(self, imgname, idx):
    method __getitem__ (line 274) | def __getitem__(self, idx):

FILE: dust3r_visloc/datasets/base_dataset.py
  class BaseVislocDataset (line 7) | class BaseVislocDataset:
    method __init__ (line 8) | def __init__(self):
    method set_resolution (line 11) | def set_resolution(self, model):
    method __len__ (line 15) | def __len__(self):
    method __getitem__ (line 18) | def __getitem__(self, idx):

FILE: dust3r_visloc/datasets/cambridge_landmarks.py
  class VislocCambridgeLandmarks (line 11) | class VislocCambridgeLandmarks (BaseVislocColmapDataset):
    method __init__ (line 12) | def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False):

FILE: dust3r_visloc/datasets/inloc.py
  function read_alignments (line 23) | def read_alignments(path_to_alignment):
  class VislocInLoc (line 47) | class VislocInLoc(BaseVislocDataset):
    method __init__ (line 48) | def __init__(self, root, pairsfile, topk=1):
    method __len__ (line 88) | def __len__(self):
    method __getitem__ (line 91) | def __getitem__(self, idx):

FILE: dust3r_visloc/datasets/sevenscenes.py
  class VislocSevenScenes (line 23) | class VislocSevenScenes(BaseVislocDataset):
    method __init__ (line 24) | def __init__(self, root, subscene, pairsfile, topk=1):
    method __len__ (line 54) | def __len__(self):
    method __getitem__ (line 57) | def __getitem__(self, idx):

FILE: dust3r_visloc/datasets/utils.py
  function cam_to_world_from_kapture (line 13) | def cam_to_world_from_kapture(kdata, timestamp, camera_id):
  function get_HW_resolution (line 27) | def get_HW_resolution(H, W, maxdim, patchsize=16):
  function get_resize_function (line 51) | def get_resize_function(maxdim, patch_size, H, W, is_mask=False):
  function rescale_points3d (line 93) | def rescale_points3d(pts2d, pts3d, to_resize, HR, WR):

FILE: dust3r_visloc/evaluation.py
  function aggregate_stats (line 15) | def aggregate_stats(info_str, pose_errors, angular_errors):
  function get_pose_error (line 31) | def get_pose_error(pr_camtoworld, gt_cam_to_world):
  function export_results (line 38) | def export_results(output_dir, xp_label, query_names, poses_pred):

FILE: dust3r_visloc/localization.py
  function run_pnp (line 30) | def run_pnp(pts2D, pts3D, K, distortion = None, mode='cv2', reprojection...

FILE: visloc.py
  function get_args_parser (line 23) | def get_args_parser():
Condensed preview — 91 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (451K chars).
[
  {
    "path": ".gitignore",
    "chars": 1819,
    "preview": "data/\ncheckpoints/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Di"
  },
  {
    "path": ".gitmodules",
    "chars": 72,
    "preview": "[submodule \"croco\"]\n\tpath = croco\n\turl = https://github.com/naver/croco\n"
  },
  {
    "path": "LICENSE",
    "chars": 361,
    "preview": "DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-S"
  },
  {
    "path": "NOTICE",
    "chars": 359,
    "preview": "DUSt3R\nCopyright 2024-present NAVER Corp.\n\nThis project contains subcomponents with separate copyright notices and licen"
  },
  {
    "path": "README.md",
    "chars": 24895,
    "preview": "![demo](assets/dust3r.jpg)\n\nOfficial implementation of `DUSt3R: Geometric 3D Vision Made Easy`  \n[[Project page](https:/"
  },
  {
    "path": "datasets_preprocess/habitat/README.md",
    "chars": 2406,
    "preview": "## Steps to reproduce synthetic training data using the Habitat-Sim simulator\n\n### Create a conda environment\n```bash\nco"
  },
  {
    "path": "datasets_preprocess/habitat/find_scenes.py",
    "chars": 2779,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py",
    "chars": 8015,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py",
    "chars": 5541,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/projections.py",
    "chars": 4954,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/projections_conversions.py",
    "chars": 2035,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/habitat/preprocess_habitat.py",
    "chars": 5602,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/path_to_root.py",
    "chars": 491,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "datasets_preprocess/preprocess_arkitscenes.py",
    "chars": 15818,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_blendedMVS.py",
    "chars": 5373,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_co3d.py",
    "chars": 13258,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_megadepth.py",
    "chars": 6715,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_scannetpp.py",
    "chars": 16037,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_staticthings3d.py",
    "chars": 5309,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_waymo.py",
    "chars": 9696,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "datasets_preprocess/preprocess_wildrgbd.py",
    "chars": 9572,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "demo.py",
    "chars": 1571,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "docker/docker-compose-cpu.yml",
    "chars": 351,
    "preview": "version: '3.8'\nservices:\n  dust3r-demo:\n    build:\n      context: ./files\n      dockerfile: cpu.Dockerfile \n    ports:\n "
  },
  {
    "path": "docker/docker-compose-cuda.yml",
    "chars": 509,
    "preview": "version: '3.8'\nservices:\n  dust3r-demo:\n    build:\n      context: ./files\n      dockerfile: cuda.Dockerfile \n    ports:\n"
  },
  {
    "path": "docker/files/cpu.Dockerfile",
    "chars": 937,
    "preview": "FROM python:3.11-slim\n\nLABEL description=\"Docker container for DUSt3R with dependencies installed. CPU VERSION\"\n\nENV DEV"
  },
  {
    "path": "docker/files/cuda.Dockerfile",
    "chars": 796,
    "preview": "FROM nvcr.io/nvidia/pytorch:24.01-py3\n\nLABEL description=\"Docker container for DUSt3R with dependencies installed. CUDA "
  },
  {
    "path": "docker/files/entrypoint.sh",
    "chars": 195,
    "preview": "#!/bin/bash\n\nset -eux\n\nDEVICE=${DEVICE:-cuda}\nMODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}\n\nexec python3 demo"
  },
  {
    "path": "docker/run.sh",
    "chars": 1558,
    "preview": "#!/bin/bash\n\nset -eux\n\n# Default model name\nmodel_name=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\"\n\ncheck_docker() {\n    i"
  },
  {
    "path": "dust3r/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/__init__.py",
    "chars": 1379,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/base_opt.py",
    "chars": 15241,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/commons.py",
    "chars": 2281,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/init_im_poses.py",
    "chars": 11269,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/modular_optimizer.py",
    "chars": 6481,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/optimizer.py",
    "chars": 9943,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/cloud_opt/pair_viewer.py",
    "chars": 4998,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/__init__.py",
    "chars": 1790,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/arkitscenes.py",
    "chars": 4216,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/base/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/base/base_stereo_view_dataset.py",
    "chars": 8853,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/base/batched_sampler.py",
    "chars": 2886,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/base/easy_dataset.py",
    "chars": 5088,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/blendedmvs.py",
    "chars": 3964,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/co3d.py",
    "chars": 7089,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/habitat.py",
    "chars": 4728,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/megadepth.py",
    "chars": 4868,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/scannetpp.py",
    "chars": 3967,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/staticthings3d.py",
    "chars": 3715,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/utils/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/utils/cropping.py",
    "chars": 4562,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/utils/transforms.py",
    "chars": 467,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/waymo.py",
    "chars": 3643,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/datasets/wildrgbd.py",
    "chars": 2780,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/demo.py",
    "chars": 14570,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/heads/__init__.py",
    "chars": 763,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/heads/dpt_head.py",
    "chars": 4868,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/heads/linear_head.py",
    "chars": 1349,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/heads/postprocess.py",
    "chars": 1621,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/image_pairs.py",
    "chars": 3863,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/inference.py",
    "chars": 5548,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/losses.py",
    "chars": 10630,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/model.py",
    "chars": 8895,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/optim_factory.py",
    "chars": 518,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/patch_embed.py",
    "chars": 3138,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/post_process.py",
    "chars": 2474,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/training.py",
    "chars": 16972,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/device.py",
    "chars": 2473,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/geometry.py",
    "chars": 13204,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/image.py",
    "chars": 4480,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/misc.py",
    "chars": 3755,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/parallel.py",
    "chars": 2482,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/utils/path_to_croco.py",
    "chars": 854,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r/viz.py",
    "chars": 13708,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/README.md",
    "chars": 3670,
    "preview": "# Visual Localization with DUSt3R\n\n## Dataset preparation\n\n### CambridgeLandmarks\n\nEach subscene should look like this:\n"
  },
  {
    "path": "dust3r_visloc/__init__.py",
    "chars": 129,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/__init__.py",
    "chars": 312,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/aachen_day_night.py",
    "chars": 1294,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/base_colmap.py",
    "chars": 10830,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/base_dataset.py",
    "chars": 603,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/cambridge_landmarks.py",
    "chars": 999,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/inloc.py",
    "chars": 7220,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/sevenscenes.py",
    "chars": 5867,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/datasets/utils.py",
    "chars": 4876,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/evaluation.py",
    "chars": 2909,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "dust3r_visloc/localization.py",
    "chars": 6368,
    "preview": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial us"
  },
  {
    "path": "requirements.txt",
    "chars": 130,
    "preview": "torch\ntorchvision\nroma\ngradio\nmatplotlib\ntqdm\nopencv-python\nscipy\neinops\ntrimesh\ntensorboard\npyglet<2\nhuggingface-hub[to"
  },
  {
    "path": "requirements_optional.txt",
    "chars": 201,
    "preview": "pillow-heif  # add heif/heic image support\npyrender  # for rendering depths in scannetpp\nkapture  # for visloc data load"
  },
  {
    "path": "train.py",
    "chars": 458,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  },
  {
    "path": "visloc.py",
    "chars": 9192,
    "preview": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA"
  }
]

About this extraction

This page contains the full source code of the naver/dust3r GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 91 files (422.9 KB), approximately 114.1k tokens, and a symbol index with 512 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!