Repository: naver/dust3r Branch: main Commit: 4c24a6ebf048 Files: 91 Total size: 422.9 KB Directory structure: gitextract_0hi9z8ek/ ├── .gitignore ├── .gitmodules ├── LICENSE ├── NOTICE ├── README.md ├── datasets_preprocess/ │ ├── habitat/ │ │ ├── README.md │ │ ├── find_scenes.py │ │ ├── habitat_renderer/ │ │ │ ├── __init__.py │ │ │ ├── habitat_sim_envmaps_renderer.py │ │ │ ├── multiview_crop_generator.py │ │ │ ├── projections.py │ │ │ └── projections_conversions.py │ │ └── preprocess_habitat.py │ ├── path_to_root.py │ ├── preprocess_arkitscenes.py │ ├── preprocess_blendedMVS.py │ ├── preprocess_co3d.py │ ├── preprocess_megadepth.py │ ├── preprocess_scannetpp.py │ ├── preprocess_staticthings3d.py │ ├── preprocess_waymo.py │ └── preprocess_wildrgbd.py ├── demo.py ├── docker/ │ ├── docker-compose-cpu.yml │ ├── docker-compose-cuda.yml │ ├── files/ │ │ ├── cpu.Dockerfile │ │ ├── cuda.Dockerfile │ │ └── entrypoint.sh │ └── run.sh ├── dust3r/ │ ├── __init__.py │ ├── cloud_opt/ │ │ ├── __init__.py │ │ ├── base_opt.py │ │ ├── commons.py │ │ ├── init_im_poses.py │ │ ├── modular_optimizer.py │ │ ├── optimizer.py │ │ └── pair_viewer.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── arkitscenes.py │ │ ├── base/ │ │ │ ├── __init__.py │ │ │ ├── base_stereo_view_dataset.py │ │ │ ├── batched_sampler.py │ │ │ └── easy_dataset.py │ │ ├── blendedmvs.py │ │ ├── co3d.py │ │ ├── habitat.py │ │ ├── megadepth.py │ │ ├── scannetpp.py │ │ ├── staticthings3d.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── cropping.py │ │ │ └── transforms.py │ │ ├── waymo.py │ │ └── wildrgbd.py │ ├── demo.py │ ├── heads/ │ │ ├── __init__.py │ │ ├── dpt_head.py │ │ ├── linear_head.py │ │ └── postprocess.py │ ├── image_pairs.py │ ├── inference.py │ ├── losses.py │ ├── model.py │ ├── optim_factory.py │ ├── patch_embed.py │ ├── post_process.py │ ├── training.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── device.py │ │ ├── geometry.py │ │ ├── image.py │ │ ├── misc.py │ │ ├── parallel.py │ │ └── path_to_croco.py │ └── viz.py ├── dust3r_visloc/ │ ├── README.md │ ├── __init__.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── aachen_day_night.py │ │ ├── base_colmap.py │ │ ├── base_dataset.py │ │ ├── cambridge_landmarks.py │ │ ├── inloc.py │ │ ├── sevenscenes.py │ │ └── utils.py │ ├── evaluation.py │ └── localization.py ├── requirements.txt ├── requirements_optional.txt ├── train.py └── visloc.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ data/ checkpoints/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ ================================================ FILE: .gitmodules ================================================ [submodule "croco"] path = croco url = https://github.com/naver/croco ================================================ FILE: LICENSE ================================================ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. A summary of the CC BY-NC-SA 4.0 license is located here: https://creativecommons.org/licenses/by-nc-sa/4.0/ The CC BY-NC-SA 4.0 license is located here: https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode ================================================ FILE: NOTICE ================================================ DUSt3R Copyright 2024-present NAVER Corp. This project contains subcomponents with separate copyright notices and license terms. Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. ==== naver/croco https://github.com/naver/croco/ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 ================================================ FILE: README.md ================================================ ![demo](assets/dust3r.jpg) Official implementation of `DUSt3R: Geometric 3D Vision Made Easy` [[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)] > Make sure to also check our other works: > [Grounding Image Matching in 3D with MASt3R](https://github.com/naver/mast3r): DUSt3R with a local feature head, metric pointmaps, and a more scalable global alignment! > [Pow3R: Empowering Unconstrained 3D Reconstruction with Camera and Scene Priors](https://github.com/naver/pow3r): DUSt3R with known depth / focal length / poses. > [MUSt3R: Multi-view Network for Stereo 3D Reconstruction](https://github.com/naver/must3r): Multi-view predictions (RGB SLAM/SfM) without any global alignment. ![Example of reconstruction from two images](assets/pipeline1.jpg) ![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg) ```bibtex @inproceedings{dust3r_cvpr24, title={DUSt3R: Geometric 3D Vision Made Easy}, author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud}, booktitle = {CVPR}, year = {2024} } @misc{dust3r_arxiv23, title={DUSt3R: Geometric 3D Vision Made Easy}, author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud}, year={2023}, eprint={2312.14132}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ## Table of Contents - [Table of Contents](#table-of-contents) - [License](#license) - [Get Started](#get-started) - [Installation](#installation) - [Checkpoints](#checkpoints) - [Interactive demo](#interactive-demo) - [Interactive demo with docker](#interactive-demo-with-docker) - [Usage](#usage) - [Training](#training) - [Datasets](#datasets) - [Demo](#demo) - [Our Hyperparameters](#our-hyperparameters) ## License The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information. ```python # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ``` ## Get Started ### Installation 1. Clone DUSt3R. ```bash git clone --recursive https://github.com/naver/dust3r cd dust3r # if you have already cloned dust3r: # git submodule update --init --recursive ``` 2. Create the environment, here we show an example using conda. ```bash conda create -n dust3r python=3.11 cmake=3.14.0 conda activate dust3r conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system pip install -r requirements.txt # Optional: you can also install additional packages to: # - add support for HEIC images # - add pyrender, used to render depthmap in some datasets preprocessing # - add required packages for visloc.py pip install -r requirements_optional.txt ``` 3. Optional, compile the cuda kernels for RoPE (as in CroCo v2). ```bash # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime. cd croco/models/curope/ python setup.py build_ext --inplace cd ../../../ ``` ### Checkpoints You can obtain the checkpoints by two ways: 1) You can use our huggingface_hub integration: the models will be downloaded automatically. 2) Otherwise, We provide several pre-trained models: | Modelname | Training resolutions | Head | Encoder | Decoder | |-------------|----------------------|------|---------|---------| | [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B | | [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B | | [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B | You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters) To download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`: ```bash mkdir -p checkpoints/ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/ ``` For the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details. ### Interactive demo In this demo, you should be able run DUSt3R on your machine to reconstruct a scene. First select images that depicts the same scene. You can adjust the global alignment schedule and its number of iterations. > [!NOTE] > If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer) Hit "Run" and wait. When the global alignment ends, the reconstruction appears. Use the slider "min_conf_thr" to show or remove low confidence areas. ```bash python3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt # Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth # Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224 # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually # Use --server_port to change the port, by default it will search for an available port starting at 7860 # Use --device to use a different device, by default it's "cuda" ``` ### Interactive demo with docker To run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions: 1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started). 2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). 3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands: ```bash cd docker bash run.sh --with-cuda --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt" ``` Or if you want to run the demo without CUDA support, run the following command: ```bash cd docker bash run.sh --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt" ``` By default, `demo.py` is lanched with the option `--local_network`. Visit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network). `run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh). ![demo](assets/demo.jpg) ## Usage ```python from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from dust3r.utils.image import load_images from dust3r.image_pairs import make_pairs from dust3r.cloud_opt import global_aligner, GlobalAlignerMode if __name__ == '__main__': device = 'cuda' batch_size = 1 schedule = 'cosine' lr = 0.01 niter = 300 model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" # you can put the path to a local checkpoint in model_name if needed model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device) # load_images can take a list of images or a directory images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512) pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True) output = inference(pairs, model, device, batch_size=batch_size) # at this stage, you have the raw dust3r predictions view1, pred1 = output['view1'], output['pred1'] view2, pred2 = output['view2'], output['pred2'] # here, view1, pred1, view2, pred2 are dicts of lists of len(2) # -> because we symmetrize we have (im1, im2) and (im2, im1) pairs # in each view you have: # an integer image identifier: view1['idx'] and view2['idx'] # the img: view1['img'] and view2['img'] # the image shape: view1['true_shape'] and view2['true_shape'] # an instance string output by the dataloader: view1['instance'] and view2['instance'] # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf'] # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d'] # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view'] # next we'll use the global_aligner to align the predictions # depending on your task, you may be fine with the raw output and not need it # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer) loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr) # retrieve useful values from scene: imgs = scene.imgs focals = scene.get_focals() poses = scene.get_im_poses() pts3d = scene.get_pts3d() confidence_masks = scene.get_masks() # visualize reconstruction scene.show() # find 2D-2D matches between the two images from dust3r.utils.geometry import find_reciprocal_matches, xy_grid pts2d_list, pts3d_list = [], [] for i in range(2): conf_i = confidence_masks[i].cpu().numpy() pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W) pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list) print(f'found {num_matches} matches') matches_im1 = pts2d_list[1][reciprocal_in_P2] matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] # visualize a few matches import numpy as np from matplotlib import pyplot as pl n_viz = 10 match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int) viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img = np.concatenate((img0, img1), axis=1) pl.figure() pl.imshow(img) cmap = pl.get_cmap('jet') for i in range(n_viz): (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) pl.show(block=True) ``` ![matching example on croco pair](assets/matching.jpg) ## Training In this section, we present a short demonstration to get started with training DUSt3R. ### Datasets At this moment, we have added the following training datasets: - [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) - [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license) - [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf) - [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/) - [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/) - [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md) - [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/) - [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d) - [WildRGB-D](https://github.com/wildrgbd/wildrgbd/) For each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed. You have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script. Links: [ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip) [ScanNet++ v1 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip) [ScanNet++ v2 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_v2_pairs.zip) [BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy) [WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz) [Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz) [MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz) [StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy) > [!NOTE] > They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough. ### Demo For this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it. The demo model will be trained for a few epochs on a very small dataset. It will not be very good. ```bash # download and prepare the co3d subset mkdir -p data/co3d_subset cd data/co3d_subset git clone https://github.com/facebookresearch/co3d cd co3d python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset rm ../*.zip cd ../../.. python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset # download the pretrained croco v2 checkpoint mkdir -p checkpoints/ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/ # the training of dust3r is done in 3 steps. # for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters" # step 1 - train dust3r for 224 resolution torchrun --nproc_per_node=4 train.py \ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --pretrained "checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \ --save_freq 1 --keep_freq 5 --eval_freq 1 \ --output_dir "checkpoints/dust3r_demo_224" # step 2 - train dust3r for 512 resolution torchrun --nproc_per_node=4 train.py \ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --pretrained "checkpoints/dust3r_demo_224/checkpoint-best.pth" \ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \ --save_freq 1 --keep_freq 5 --eval_freq 1 \ --output_dir "checkpoints/dust3r_demo_512" # step 3 - train dust3r for 512 resolution with dpt torchrun --nproc_per_node=4 train.py \ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --pretrained "checkpoints/dust3r_demo_512/checkpoint-best.pth" \ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \ --save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \ --output_dir "checkpoints/dust3r_demo_512dpt" ``` ### Our Hyperparameters Here are the commands we used for training the models: ```bash # NOTE: ROOT path omitted for datasets # 224 linear torchrun --nproc_per_node 8 train.py \ --train_dataset=" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) " \ --test_dataset=" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) " \ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \ --save_freq=5 --keep_freq=10 --eval_freq=1 \ --output_dir="checkpoints/dust3r_224" # 512 linear torchrun --nproc_per_node 8 train.py \ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --pretrained="checkpoints/dust3r_224/checkpoint-best.pth" \ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \ --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \ --output_dir="checkpoints/dust3r_512" # 512 dpt torchrun --nproc_per_node 8 train.py \ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \ --pretrained="checkpoints/dust3r_512/checkpoint-best.pth" \ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \ --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \ --output_dir="checkpoints/dust3r_512dpt" ``` ================================================ FILE: datasets_preprocess/habitat/README.md ================================================ ## Steps to reproduce synthetic training data using the Habitat-Sim simulator ### Create a conda environment ```bash conda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge conda active habitat conda install pytorch -c pytorch pip install opencv-python tqdm ``` or (if you get the error `For headless systems, compile with --headless for EGL support`) ``` git clone --branch stable https://github.com/facebookresearch/habitat-sim.git cd habitat-sim conda create -n habitat python=3.9 cmake=3.14.0 conda activate habitat pip install . -v conda install pytorch -c pytorch pip install opencv-python tqdm ``` ### Download Habitat-Sim scenes Download Habitat-Sim scenes: - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md - We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets. - Please put the scenes in a directory `$SCENES_DIR` following the structure below: (Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting. The correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`). ``` $SCENES_DIR/ ├──hm3d/ ├──gibson/ ├──habitat-test-scenes/ ├──ReplicaCAD_baked_lighting/ └──scannet/ ``` ### Download renderings metadata Download metadata corresponding to each scene and extract them into a directory `$METADATA_DIR` ```bash wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz tar -xvzf habitat_5views_v1_512x512_metadata.tar.gz ``` ### Render the scenes Render the scenes in an output directory `$OUTPUT_DIR` ```bash export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" export SCENES_DIR="/path/to/habitat/data/scene_datasets/" export OUTPUT_DIR="data/habitat_processed" cd datasets_preprocess/habitat/ export PYTHONPATH=$(pwd) # Print commandlines to generate images corresponding to each scene python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR # Launch these commandlines in parallel e.g. using GNU-Parallel as follows: python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 ``` ### Make a list of scenes ```bash python find_scenes.py --root $OUTPUT_DIR ``` ================================================ FILE: datasets_preprocess/habitat/find_scenes.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Script to export the list of scenes for habitat (after having rendered them). # Usage: # python3 datasets_preprocess/preprocess_co3d.py --root data/habitat_processed # -------------------------------------------------------- import numpy as np import os from collections import defaultdict from tqdm import tqdm def find_all_scenes(habitat_root, n_scenes=[100000]): np.random.seed(777) try: fpath = os.path.join(habitat_root, f'Habitat_all_scenes.txt') list_subscenes = open(fpath).read().splitlines() except IOError: if input('parsing sub-folders to find scenes? (y/n) ') != 'y': return list_subscenes = [] for root, dirs, files in tqdm(os.walk(habitat_root)): for f in files: if not f.endswith('_1_depth.exr'): continue scene = os.path.join(os.path.relpath(root, habitat_root), f.replace('_1_depth.exr', '')) if hash(scene) % 1000 == 0: print('... adding', scene) list_subscenes.append(scene) with open(fpath, 'w') as f: f.write('\n'.join(list_subscenes)) print(f'>> wrote {fpath}') print(f'Loaded {len(list_subscenes)} sub-scenes') # separate scenes list_scenes = defaultdict(list) for scene in list_subscenes: scene, id = os.path.split(scene) list_scenes[scene].append(id) list_scenes = list(list_scenes.items()) print(f'from {len(list_scenes)} scenes in total') np.random.shuffle(list_scenes) train_scenes = list_scenes[len(list_scenes) // 10:] val_scenes = list_scenes[:len(list_scenes) // 10] def write_scene_list(scenes, n, fpath): sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids] np.random.shuffle(sub_scenes) if len(sub_scenes) < n: return with open(fpath, 'w') as f: f.write('\n'.join(sub_scenes[:n])) print(f'>> wrote {fpath}') for n in n_scenes: write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt')) write_scene_list(val_scenes, n // 10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt')) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--root", required=True) parser.add_argument("--n_scenes", nargs='+', default=[1_000, 10_000, 100_000, 1_000_000], type=int) args = parser.parse_args() find_all_scenes(args.root, args.n_scenes) ================================================ FILE: datasets_preprocess/habitat/habitat_renderer/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Render environment maps from 3D meshes using the Habitat Sim simulator. # -------------------------------------------------------- import numpy as np import habitat_sim import math from habitat_renderer import projections # OpenCV to habitat camera convention transformation R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0) CUBEMAP_FACE_LABELS = ["left", "front", "right", "back", "up", "down"] # Expressed while considering Habitat coordinates systems CUBEMAP_FACE_ORIENTATIONS_ROTVEC = [ [0, math.pi / 2, 0], # Left [0, 0, 0], # Front [0, - math.pi / 2, 0], # Right [0, math.pi, 0], # Back [math.pi / 2, 0, 0], # Up [-math.pi / 2, 0, 0],] # Down class NoNaviguableSpaceError(RuntimeError): def __init__(self, *args): super().__init__(*args) class HabitatEnvironmentMapRenderer: def __init__(self, scene, navmesh, scene_dataset_config_file, render_equirectangular=False, equirectangular_resolution=(512, 1024), render_cubemap=False, cubemap_resolution=(512, 512), render_depth=False, gpu_id=0): self.scene = scene self.navmesh = navmesh self.scene_dataset_config_file = scene_dataset_config_file self.gpu_id = gpu_id self.render_equirectangular = render_equirectangular self.equirectangular_resolution = equirectangular_resolution self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution) # 3D unit ray associated to each pixel of the equirectangular map equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection) # Not needed, but just in case. equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True) # Depth map created by Habitat are produced by warping a cubemap, # so the values do not correspond to distance to the center and need some scaling. self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1) self.render_cubemap = render_cubemap self.cubemap_resolution = cubemap_resolution self.render_depth = render_depth self.seed = None self._lazy_initialization() def _lazy_initialization(self): # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly if self.seed == None: # Re-seed numpy generator np.random.seed() self.seed = np.random.randint(2**32-1) sim_cfg = habitat_sim.SimulatorConfiguration() sim_cfg.scene_id = self.scene if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "": sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file sim_cfg.random_seed = self.seed sim_cfg.load_semantic_mesh = False sim_cfg.gpu_device_id = self.gpu_id sensor_specifications = [] # Add cubemaps if self.render_cubemap: for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC): rgb_sensor_spec = habitat_sim.CameraSensorSpec() rgb_sensor_spec.uuid = f"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR rgb_sensor_spec.resolution = self.cubemap_resolution rgb_sensor_spec.hfov = 90 rgb_sensor_spec.position = [0.0, 0.0, 0.0] rgb_sensor_spec.orientation = orientation sensor_specifications.append(rgb_sensor_spec) if self.render_depth: depth_sensor_spec = habitat_sim.CameraSensorSpec() depth_sensor_spec.uuid = f"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}" depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH depth_sensor_spec.resolution = self.cubemap_resolution depth_sensor_spec.hfov = 90 depth_sensor_spec.position = [0.0, 0.0, 0.0] depth_sensor_spec.orientation = orientation sensor_specifications.append(depth_sensor_spec) # Add equirectangular map if self.render_equirectangular: rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() rgb_sensor_spec.uuid = "color_equirectangular" rgb_sensor_spec.resolution = self.equirectangular_resolution rgb_sensor_spec.position = [0.0, 0.0, 0.0] sensor_specifications.append(rgb_sensor_spec) if self.render_depth: depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec() depth_sensor_spec.uuid = "depth_equirectangular" depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH depth_sensor_spec.resolution = self.equirectangular_resolution depth_sensor_spec.position = [0.0, 0.0, 0.0] depth_sensor_spec.orientation sensor_specifications.append(depth_sensor_spec) agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications) cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg]) self.sim = habitat_sim.Simulator(cfg) if self.navmesh is not None and self.navmesh != "": # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof) # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes self.sim.pathfinder.load_nav_mesh(self.navmesh) # Check that the navmesh is not empty if not self.sim.pathfinder.is_loaded: # Try to compute a navmesh navmesh_settings = habitat_sim.NavMeshSettings() navmesh_settings.set_defaults() self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True) # Check that the navmesh is not empty if not self.sim.pathfinder.is_loaded: raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})") self.agent = self.sim.initialize_agent(agent_id=0) def close(self): if hasattr(self, 'sim'): self.sim.close() def __del__(self): self.close() def render_viewpoint(self, viewpoint_position): agent_state = habitat_sim.AgentState() agent_state.position = viewpoint_position # agent_state.rotation = viewpoint_orientation self.agent.set_state(agent_state) viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0) try: # Depth map values have been obtained using cubemap rendering internally, # so they do not really correspond to distance to the viewpoint in practice # and they need some scaling viewpoint_observations["depth_equirectangular"] *= self.equirectangular_depth_scale_factors except KeyError: pass data = dict(observations=viewpoint_observations, position=viewpoint_position) return data def up_direction(self): return np.asarray(habitat_sim.geo.UP).tolist() def R_cam_to_world(self): return R_OPENCV2HABITAT.tolist() ================================================ FILE: datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Generate pairs of crops from a dataset of environment maps. # -------------------------------------------------------- import os import numpy as np os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa import cv2 import collections from habitat_renderer import projections, projections_conversions from habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer ViewpointData = collections.namedtuple("ViewpointData", ["colormap", "distancemap", "pointmap", "position"]) class HabitatMultiviewCrops: def __init__(self, scene, navmesh, scene_dataset_config_file, equirectangular_resolution=(400, 800), crop_resolution=(240, 320), pixel_jittering_iterations=5, jittering_noise_level=1.0): self.crop_resolution = crop_resolution self.pixel_jittering_iterations = pixel_jittering_iterations self.jittering_noise_level = jittering_noise_level # Instanciate the low resolution habitat sim renderer self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene, navmesh=navmesh, scene_dataset_config_file=scene_dataset_config_file, equirectangular_resolution=equirectangular_resolution, render_depth=True, render_equirectangular=True) self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world()) self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction()) # Projection applied by each environment map self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width) self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T) # 3D Rays map associated to each envmap self.envmap_rays = projections.get_projection_rays(self.envmap_projection) def compute_pointmap(self, distancemap, position): # Point cloud associated to each ray return self.envmap_rays * distancemap[:, :, None] + position def render_viewpoint_data(self, position): data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position)) colormap = data['observations']['color_equirectangular'][..., :3] # Ignore the alpha channel distancemap = data['observations']['depth_equirectangular'] pointmap = self.compute_pointmap(distancemap, position) return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position) def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None): remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection, pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level) cropped_color_image = remapper.convert( color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False) cropped_distancemap = remapper.convert( distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True) cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True) cropped_voxelmap = (None if voxelmap is None else remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)) # Convert the distance map into a depth map cropped_depthmap = np.asarray( cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype) return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap def perspective_projection_to_dict(persp_projection, position): """ Serialization-like function.""" camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(), size=(persp_projection.base_projection.width, persp_projection.base_projection.height), R_cam2world=persp_projection.R_to_base_projection.T.tolist(), t_cam2world=position) return camera_params def dict_to_perspective_projection(camera_params): K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params["camera_intrinsics"])) size = camera_params["size"] R_cam2world = np.asarray(camera_params["R_cam2world"]) projection = projections.PerspectiveProjection(K, height=size[1], width=size[0]) projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T) position = camera_params["t_cam2world"] return projection, position ================================================ FILE: datasets_preprocess/habitat/habitat_renderer/projections.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Various 3D/2D projection utils, useful to sample virtual cameras. # -------------------------------------------------------- import numpy as np class EquirectangularProjection: """ Convention for the central pixel of the equirectangular map similar to OpenCV perspective model: +X from left to right +Y from top to bottom +Z going outside the camera EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)) """ def __init__(self, height, width): self.height = height self.width = width self.u_scaling = (2 * np.pi) / self.width self.v_scaling = np.pi / self.height def unproject(self, u, v): """ Args: u, v: 2D coordinates Returns: unnormalized 3D rays. """ longitude = self.u_scaling * u - np.pi minus_latitude = self.v_scaling * v - np.pi/2 cos_latitude = np.cos(minus_latitude) x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude y = np.sin(minus_latitude) rays = np.stack([x, y, z], axis=-1) return rays def project(self, rays): """ Args: rays: Bx3 array of 3D rays. Returns: u, v: tuple of 2D coordinates. """ rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True) x, y, z = [rays[..., i] for i in range(3)] longitude = np.arctan2(x, z) minus_latitude = np.arcsin(y) u = (longitude + np.pi) * (1.0 / self.u_scaling) v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling) return u, v class PerspectiveProjection: """ OpenCV convention: World space: +X from left to right +Y from top to bottom +Z going outside the camera Pixel space: +u from left to right +v from top to bottom EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)). """ def __init__(self, K, height, width): self.height = height self.width = width self.K = K self.Kinv = np.linalg.inv(K) def project(self, rays): uv_homogeneous = np.einsum("ik, ...k -> ...i", self.K, rays) uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None] return uv[..., 0], uv[..., 1] def unproject(self, u, v): uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1) rays = np.einsum("ik, ...k -> ...i", self.Kinv, uv_homogeneous) return rays class RotatedProjection: def __init__(self, base_projection, R_to_base_projection): self.base_projection = base_projection self.R_to_base_projection = R_to_base_projection @property def width(self): return self.base_projection.width @property def height(self): return self.base_projection.height def project(self, rays): if self.R_to_base_projection is not None: rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection, rays) return self.base_projection.project(rays) def unproject(self, u, v): rays = self.base_projection.unproject(u, v) if self.R_to_base_projection is not None: rays = np.einsum("ik, ...k -> ...i", self.R_to_base_projection.T, rays) return rays def get_projection_rays(projection, noise_level=0): """ Return a 2D map of 3D rays corresponding to the projection. If noise_level > 0, add some jittering noise to these rays. """ grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height)) if noise_level > 0: grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width) grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height) return projection.unproject(grid_u, grid_v) def compute_camera_intrinsics(height, width, hfov): f = width/2 / np.tan(hfov/2 * np.pi/180) cu, cv = width/2, height/2 return f, cu, cv def colmap_to_opencv_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] -= 0.5 K[1, 2] -= 0.5 return K def opencv_to_colmap_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] += 0.5 K[1, 2] += 0.5 return K ================================================ FILE: datasets_preprocess/habitat/habitat_renderer/projections_conversions.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Remap data from one projection to an other # -------------------------------------------------------- import numpy as np import cv2 from habitat_renderer import projections class RemapProjection: def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0): """ Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts. """ assert jittering_noise_level >= 0 assert pixel_jittering_iterations >= 0 maps = [] # Initial map self.output_rays = projections.get_projection_rays(output_projection) map_u, map_v = input_projection.project(self.output_rays) map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) maps.append((map_u, map_v)) for _ in range(pixel_jittering_iterations): # Define multiple mappings using some coordinates jittering to mitigate aliasing effects crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level) map_u, map_v = input_projection.project(crop_rays) map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32) maps.append((map_u, map_v)) self.maps = maps def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False): remapped = [] for map_u, map_v in self.maps: res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode) remapped.append(res) if single_map: break if len(remapped) == 1: res = remapped[0] else: res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype) return res ================================================ FILE: datasets_preprocess/habitat/preprocess_habitat.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # main executable for preprocessing habitat # export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata" # export SCENES_DIR="/path/to/habitat/data/scene_datasets/" # export OUTPUT_DIR="data/habitat_processed" # export PYTHONPATH=$(pwd) # python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16 # -------------------------------------------------------- import os import glob import json import os import PIL.Image import json os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa import cv2 from habitat_renderer import multiview_crop_generator from tqdm import tqdm def preprocess_metadata(metadata_filename, scenes_dir, output_dir, crop_resolution=[512, 512], equirectangular_resolution=None, fix_existing_dataset=False): # Load data with open(metadata_filename, "r") as f: metadata = json.load(f) if metadata["scene_dataset_config_file"] == "": scene = os.path.join(scenes_dir, metadata["scene"]) scene_dataset_config_file = "" else: scene = metadata["scene"] scene_dataset_config_file = os.path.join(scenes_dir, metadata["scene_dataset_config_file"]) navmesh = None # Use 4 times the crop size as resolution for rendering the environment map. max_res = max(crop_resolution) if equirectangular_resolution == None: # Use 4 times the crop size as resolution for rendering the environment map. max_res = max(crop_resolution) equirectangular_resolution = (4*max_res, 8*max_res) print("equirectangular_resolution:", equirectangular_resolution) if os.path.exists(output_dir) and not fix_existing_dataset: raise FileExistsError(output_dir) # Lazy initialization highres_dataset = None for batch_label, batch in tqdm(metadata["view_batches"].items()): for view_label, view_params in batch.items(): assert view_params["size"] == crop_resolution label = f"{batch_label}_{view_label}" output_camera_params_filename = os.path.join(output_dir, f"{label}_camera_params.json") if fix_existing_dataset and os.path.isfile(output_camera_params_filename): # Skip generation if we are fixing a dataset and the corresponding output file already exists continue # Lazy initialization if highres_dataset is None: highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene, navmesh=navmesh, scene_dataset_config_file=scene_dataset_config_file, equirectangular_resolution=equirectangular_resolution, crop_resolution=crop_resolution,) os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset)) # Generate a higher resolution crop original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params) # Render an envmap at the given position viewpoint_data = highres_dataset.render_viewpoint_data(position) projection = original_projection colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera( projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap) camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position) # Color image PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f"{label}.jpeg")) # Depth image cv2.imwrite(os.path.join(output_dir, f"{label}_depth.exr"), depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) # Camera parameters with open(output_camera_params_filename, "w") as f: json.dump(camera_params, f) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--metadata_dir", required=True) parser.add_argument("--scenes_dir", required=True) parser.add_argument("--output_dir", required=True) parser.add_argument("--metadata_filename", default="") args = parser.parse_args() if args.metadata_filename == "": # Walk through the metadata dir to generate commandlines for filename in glob.iglob(os.path.join(args.metadata_dir, "**/metadata.json"), recursive=True): output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir)) if not os.path.exists(output_dir): commandline = f"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}" print(commandline) else: preprocess_metadata(metadata_filename=args.metadata_filename, scenes_dir=args.scenes_dir, output_dir=args.output_dir) ================================================ FILE: datasets_preprocess/path_to_root.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # DUSt3R repo root import # -------------------------------------------------------- import sys import os.path as path HERE_PATH = path.normpath(path.dirname(__file__)) DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) # workaround for sibling import sys.path.insert(0, DUST3R_REPO_PATH) ================================================ FILE: datasets_preprocess/preprocess_arkitscenes.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Script to pre-process the arkitscenes dataset. # Usage: # python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs # -------------------------------------------------------- import os import json import os.path as osp import decimal import argparse import math from bisect import bisect_left from PIL import Image import numpy as np import quaternion from scipy import interpolate import cv2 def get_parser(): parser = argparse.ArgumentParser() parser.add_argument('--arkitscenes_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/arkitscenes_processed') return parser def value_to_decimal(value, decimal_places): decimal.getcontext().rounding = decimal.ROUND_HALF_UP # define rounding method return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places))) def closest(value, sorted_list): index = bisect_left(sorted_list, value) if index == 0: return sorted_list[0] elif index == len(sorted_list): return sorted_list[-1] else: value_before = sorted_list[index - 1] value_after = sorted_list[index] if value_after - value < value - value_before: return value_after else: return value_before def get_up_vectors(pose_device_to_world): return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]])) def get_right_vectors(pose_device_to_world): return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]])) def read_traj(traj_path): quaternions = [] poses = [] timestamps = [] poses_p_to_w = [] with open(traj_path) as f: traj_lines = f.readlines() for line in traj_lines: tokens = line.split() assert len(tokens) == 7 traj_timestamp = float(tokens[0]) timestamps_decimal_value = value_to_decimal(traj_timestamp, 3) timestamps.append(float(timestamps_decimal_value)) # for spline interpolation angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])] r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis)) t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])]) pose_w_to_p = np.eye(4) pose_w_to_p[:3, :3] = r_w_to_p pose_w_to_p[:3, 3] = t_w_to_p pose_p_to_w = np.linalg.inv(pose_w_to_p) r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3]) t_p_to_w = pose_p_to_w[:3, 3] poses_p_to_w.append(pose_p_to_w) poses.append(t_p_to_w) quaternions.append(r_p_to_w_as_quat) return timestamps, poses, quaternions, poses_p_to_w def main(rootdir, pairsdir, outdir): os.makedirs(outdir, exist_ok=True) subdirs = ['Test', 'Training'] for subdir in subdirs: if not osp.isdir(osp.join(rootdir, subdir)): continue # STEP 1: list all scenes outsubdir = osp.join(outdir, subdir) os.makedirs(outsubdir, exist_ok=True) listfile = osp.join(pairsdir, subdir, 'scene_list.json') with open(listfile, 'r') as f: scene_dirs = json.load(f) valid_scenes = [] for scene_subdir in scene_dirs: out_scene_subdir = osp.join(outsubdir, scene_subdir) os.makedirs(out_scene_subdir, exist_ok=True) scene_dir = osp.join(rootdir, subdir, scene_subdir) depth_dir = osp.join(scene_dir, 'lowres_depth') rgb_dir = osp.join(scene_dir, 'vga_wide') intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics') traj_path = osp.join(scene_dir, 'lowres_wide.traj') # STEP 2: read selected_pairs.npz selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz') selected_npz = np.load(selected_pairs_path) selection, pairs = selected_npz['selection'], selected_npz['pairs'] selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0]) if len(selection) == 0 or len(pairs) == 0: # not a valid scene continue valid_scenes.append(scene_subdir) # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz') if osp.isfile(scene_metadata_path): continue else: print(f'parsing {scene_subdir}') # loads traj timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path) poses = np.array(poses) quaternions = np.array(quaternions, dtype=np.quaternion) quaternions = quaternion.unflip_rotors(quaternions) timestamps = np.array(timestamps) selected_images = [(basename, basename.split(".png")[0].split("_")[1]) for basename in selection] timestamps_selected = [float(frame_id) for _, frame_id in selected_images] sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir, intrinsics_dir, timestamps, quaternions, poses, poses_cam_to_world, selected_images, timestamps_selected) assert selected_sky_direction_scene == sky_direction_scene os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True) os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True) assert isinstance(sky_direction_scene, str) for basename in images: img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg')) depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename) if osp.isfile(img_out) and osp.isfile(depth_out): continue vga_wide_path = osp.join(rgb_dir, basename) depth_path = osp.join(depth_dir, basename) img = Image.open(vga_wide_path) depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) # rotate the image if sky_direction_scene == 'RIGHT': try: img = img.transpose(Image.Transpose.ROTATE_90) except Exception: img = img.transpose(Image.ROTATE_90) depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE) elif sky_direction_scene == 'LEFT': try: img = img.transpose(Image.Transpose.ROTATE_270) except Exception: img = img.transpose(Image.ROTATE_270) depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE) elif sky_direction_scene == 'DOWN': try: img = img.transpose(Image.Transpose.ROTATE_180) except Exception: img = img.transpose(Image.ROTATE_180) depth = cv2.rotate(depth, cv2.ROTATE_180) W, H = img.size if not osp.isfile(img_out): img.save(img_out) depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT) if not osp.isfile(depth_out): # avoid destroying the base dataset when you mess up the paths cv2.imwrite(depth_out, depth) # save at the end np.savez(scene_metadata_path, trajectories=trajectories, intrinsics=intrinsics, images=images, pairs=pairs) outlistfile = osp.join(outsubdir, 'scene_list.json') with open(outlistfile, 'w') as f: json.dump(valid_scenes, f) # STEP 5: concat all scene_metadata.npz into a single file scene_data = {} for scene_subdir in valid_scenes: scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz') with np.load(scene_metadata_path) as data: trajectories = data['trajectories'] intrinsics = data['intrinsics'] images = data['images'] pairs = data['pairs'] scene_data[scene_subdir] = {'trajectories': trajectories, 'intrinsics': intrinsics, 'images': images, 'pairs': pairs} offset = 0 counts = [] scenes = [] sceneids = [] images = [] intrinsics = [] trajectories = [] pairs = [] for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): num_imgs = data['images'].shape[0] img_pairs = data['pairs'] scenes.append(scene_subdir) sceneids.extend([scene_idx] * num_imgs) images.append(data['images']) K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0) K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']] K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']] K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']] K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']] intrinsics.append(K) trajectories.append(data['trajectories']) # offset pairs img_pairs[:, 0:2] += offset pairs.append(img_pairs) counts.append(offset) offset += num_imgs images = np.concatenate(images, axis=0) intrinsics = np.concatenate(intrinsics, axis=0) trajectories = np.concatenate(trajectories, axis=0) pairs = np.concatenate(pairs, axis=0) np.savez(osp.join(outsubdir, 'all_metadata.npz'), counts=counts, scenes=scenes, sceneids=sceneids, images=images, intrinsics=intrinsics, trajectories=trajectories, pairs=pairs) def convert_scene_metadata(scene_subdir, intrinsics_dir, timestamps, quaternions, poses, poses_cam_to_world, selected_images, timestamps_selected): # find scene orientation sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world) # find/compute pose for selected timestamps # most images have a valid timestamp / exact pose associated timestamps_selected = np.array(timestamps_selected) spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0) interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected) interpolated_positions = spline(timestamps_selected) trajectories = [] intrinsics = [] images = [] for i, (basename, frame_id) in enumerate(selected_images): intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{frame_id}.pincam") if not osp.exists(intrinsic_fn): intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam") if not osp.exists(intrinsic_fn): intrinsic_fn = osp.join(intrinsics_dir, f"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam") assert osp.exists(intrinsic_fn) w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn) # PINHOLE pose = np.eye(4) pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i]) pose[:3, 3] = interpolated_positions[i] images.append(basename) if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT': intrinsics.append([h, w, fy, fx, hh, hw]) # swapped intrinsics else: intrinsics.append([w, h, fx, fy, hw, hh]) trajectories.append(pose @ rotated_to_cam) # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world return sky_direction_scene, trajectories, intrinsics, images def find_scene_orientation(poses_cam_to_world): if len(poses_cam_to_world) > 0: up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world) up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) else: up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]]) right_vector = np.array([[1.0], [0.0], [0.0], [0.0]]) up_world = np.array([[0.0], [0.0], [1.0], [0.0]]) # value between 0, 180 device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), up_vector), -1.0, 1.0)).item() * 180.0 / np.pi device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world), right_vector), -1.0, 1.0)).item() * 180.0 / np.pi up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0) if up_closest_to_90: assert abs(device_up_to_world_up_angle - 90.0) < 45.0 # LEFT if device_right_to_world_up_angle > 90.0: sky_direction_scene = 'LEFT' cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0]) else: # note that in metadata.csv RIGHT does not exist, but again it's not accurate... # well, turns out there are scenes oriented like this # for example Training/41124801 sky_direction_scene = 'RIGHT' cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0]) else: # right is close to 90 assert abs(device_right_to_world_up_angle - 90.0) < 45.0 if device_up_to_world_up_angle > 90.0: sky_direction_scene = 'DOWN' cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi]) else: sky_direction_scene = 'UP' cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0) cam_to_rotated = np.eye(4) cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q) rotated_to_cam = np.linalg.inv(cam_to_rotated) return sky_direction_scene, rotated_to_cam if __name__ == '__main__': parser = get_parser() args = parser.parse_args() main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir) ================================================ FILE: datasets_preprocess/preprocess_blendedMVS.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Preprocessing code for the BlendedMVS dataset # dataset at https://github.com/YoYo000/BlendedMVS # 1) Download BlendedMVS.zip # 2) Download BlendedMVS+.zip # 3) Download BlendedMVS++.zip # 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory # 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/ # -------------------------------------------------------- import os import os.path as osp import re from tqdm import tqdm import numpy as np os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 import path_to_root # noqa from dust3r.utils.parallel import parallel_threads from dust3r.datasets.utils import cropping # noqa def get_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--blendedmvs_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/blendedmvs_processed') return parser def main(db_root, pairs_path, output_dir): print('>> Listing all sequences') sequences = [f for f in os.listdir(db_root) if len(f) == 24] # should find 502 scenes assert sequences, f'did not found any sequences at {db_root}' print(f' (found {len(sequences)} sequences)') for i, seq in enumerate(tqdm(sequences)): out_dir = osp.join(output_dir, seq) os.makedirs(out_dir, exist_ok=True) # generate the crops root = osp.join(db_root, seq) cam_dir = osp.join(root, 'cams') func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')] parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False) # verify that all pairs are there pairs = np.load(pairs_path) for seqh, seql, img1, img2, score in tqdm(pairs): for view_index in [img1, img2]: impath = osp.join(output_dir, f"{seqh:08x}{seql:016x}", f"{view_index:08n}.jpg") assert osp.isfile(impath), f'missing image at {impath=}' print(f'>> Done, saved everything in {output_dir}/') def load_crop_and_save(root, img, out_dir): if osp.isfile(osp.join(out_dir, img + '.npz')): return # already done # load everything intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt')) color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img + '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm')) # do the crop H, W = color_image_in.shape[:2] assert H * 4 == W * 3 image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384)) # write everything image.save(osp.join(out_dir, img + '.jpg'), quality=80) cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap) # New camera parameters R_camout2world = R_camin2world @ R_in2out.T t_camout2world = t_camin2world np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out, R_cam2world=R_camout2world, t_cam2world=t_camout2world) def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)): image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( color_image_in, depthmap_in, intrinsics_in, resolution_out) R_in2out = np.eye(3) return image, depthmap, intrinsics_out, R_in2out def _load_pose(path, ret_44=False): f = open(path) RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32) assert RT.shape == (4, 4) RT = np.linalg.inv(RT) # world2cam to cam2world K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32) assert K.shape == (3, 3) if ret_44: return K, RT return K, RT[:3, :3], RT[:3, 3] # , depth_uint8_to_f32 def load_pfm_file(file_path): with open(file_path, 'rb') as file: header = file.readline().decode('UTF-8').strip() if header == 'PF': is_color = True elif header == 'Pf': is_color = False else: raise ValueError('The provided file is not a valid PFM file.') dimensions = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('UTF-8')) if dimensions: img_width, img_height = map(int, dimensions.groups()) else: raise ValueError('Invalid PFM header format.') endian_scale = float(file.readline().decode('UTF-8').strip()) if endian_scale < 0: dtype = '= img_size * 3/4, and max dimension will be >= img_size")) return parser def convert_ndc_to_pinhole(focal_length, principal_point, image_size): focal_length = np.array(focal_length) principal_point = np.array(principal_point) image_size_wh = np.array([image_size[1], image_size[0]]) half_image_size = image_size_wh / 2 rescale = half_image_size.min() principal_point_px = half_image_size - principal_point * rescale focal_length_px = focal_length * rescale fx, fy = focal_length_px[0], focal_length_px[1] cx, cy = principal_point_px[0], principal_point_px[1] K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32) return K def opencv_from_cameras_projection(R, T, focal, p0, image_size): R = torch.from_numpy(R)[None, :, :] T = torch.from_numpy(T)[None, :] focal = torch.from_numpy(focal)[None, :] p0 = torch.from_numpy(p0)[None, :] image_size = torch.from_numpy(image_size)[None, :] R_pytorch3d = R.clone() T_pytorch3d = T.clone() focal_pytorch3d = focal p0_pytorch3d = p0 T_pytorch3d[:, :2] *= -1 R_pytorch3d[:, :, :2] *= -1 tvec = T_pytorch3d R = R_pytorch3d.permute(0, 2, 1) # Retype the image_size correctly and flip to width, height. image_size_wh = image_size.to(R).flip(dims=(1,)) # NDC to screen conversion. scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 scale = scale.expand(-1, 2) c0 = image_size_wh / 2.0 principal_point = -p0_pytorch3d * scale + c0 focal_length = focal_pytorch3d * scale camera_matrix = torch.zeros_like(R) camera_matrix[:, :2, 2] = principal_point camera_matrix[:, 2, 2] = 1.0 camera_matrix[:, 0, 0] = focal_length[:, 0] camera_matrix[:, 1, 1] = focal_length[:, 1] return R[0], tvec[0], camera_matrix[0] def get_set_list(category_dir, split, is_single_sequence_subset=False): listfiles = os.listdir(osp.join(category_dir, "set_lists")) if is_single_sequence_subset: # not all objects have manyview_dev subset_list_files = [f for f in listfiles if "manyview_dev" in f] else: subset_list_files = [f for f in listfiles if f"fewview_train" in f] sequences_all = [] for subset_list_file in subset_list_files: with open(osp.join(category_dir, "set_lists", subset_list_file)) as f: subset_lists_data = json.load(f) sequences_all.extend(subset_lists_data[split]) return sequences_all def prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object, seed, is_single_sequence_subset=False): random.seed(seed) category_dir = osp.join(co3d_dir, category) category_output_dir = osp.join(output_dir, category) sequences_all = get_set_list(category_dir, split, is_single_sequence_subset) sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all)) frame_file = osp.join(category_dir, "frame_annotations.jgz") sequence_file = osp.join(category_dir, "sequence_annotations.jgz") with gzip.open(frame_file, "r") as fin: frame_data = json.loads(fin.read()) with gzip.open(sequence_file, "r") as fin: sequence_data = json.loads(fin.read()) frame_data_processed = {} for f_data in frame_data: sequence_name = f_data["sequence_name"] frame_data_processed.setdefault(sequence_name, {})[f_data["frame_number"]] = f_data good_quality_sequences = set() for seq_data in sequence_data: if seq_data["viewpoint_quality_score"] > min_quality: good_quality_sequences.add(seq_data["sequence_name"]) sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences] if len(sequences_numbers) < max_num_sequences_per_object: selected_sequences_numbers = sequences_numbers else: selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object) selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers} sequences_all = [(seq_name, frame_number, filepath) for seq_name, frame_number, filepath in sequences_all if seq_name in selected_sequences_numbers_dict] for seq_name, frame_number, filepath in tqdm(sequences_all): frame_idx = int(filepath.split('/')[-1][5:-4]) selected_sequences_numbers_dict[seq_name].append(frame_idx) mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") frame_data = frame_data_processed[seq_name][frame_number] focal_length = frame_data["viewpoint"]["focal_length"] principal_point = frame_data["viewpoint"]["principal_point"] image_size = frame_data["image"]["size"] K = convert_ndc_to_pinhole(focal_length, principal_point, image_size) R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data["viewpoint"]["R"]), np.array(frame_data["viewpoint"]["T"]), np.array(focal_length), np.array(principal_point), np.array(image_size)) frame_data = frame_data_processed[seq_name][frame_number] depth_path = os.path.join(co3d_dir, frame_data["depth"]["path"]) assert frame_data["depth"]["scale_adjustment"] == 1.0 image_path = os.path.join(co3d_dir, filepath) mask_path_full = os.path.join(co3d_dir, mask_path) input_rgb_image = PIL.Image.open(image_path).convert('RGB') input_mask = plt.imread(mask_path_full) with PIL.Image.open(depth_path) as depth_pil: # the image is stored with 16-bit depth but PIL reads it as I (32 bit). # we cast it to uint16, then reinterpret as float16, then cast to float32 input_depthmap = ( np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) .astype(np.float32) .reshape((depth_pil.size[1], depth_pil.size[0]))) depth_mask = np.stack((input_depthmap, input_mask), axis=-1) H, W = input_depthmap.shape camera_intrinsics = camera_intrinsics.numpy() cx, cy = camera_intrinsics[:2, 2].round().astype(int) min_margin_x = min(cx, W - cx) min_margin_y = min(cy, H - cy) # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) l, t = cx - min_margin_x, cy - min_margin_y r, b = cx + min_margin_x, cy + min_margin_y crop_bbox = (l, t, r, b) input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) if max(output_resolution) < img_size: # let's put the max dimension to img_size scale_final = (img_size / max(H, W)) + 1e-8 output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) input_depthmap = depth_mask[:, :, 0] input_mask = depth_mask[:, :, 1] # generate and adjust camera pose camera_pose = np.eye(4, dtype=np.float32) camera_pose[:3, :3] = R camera_pose[:3, 3] = tvec camera_pose = np.linalg.inv(camera_pose) # save crop images and depth, metadata save_img_path = os.path.join(output_dir, filepath) save_depth_path = os.path.join(output_dir, frame_data["depth"]["path"]) save_mask_path = os.path.join(output_dir, mask_path) os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) input_rgb_image.save(save_img_path) scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16) cv2.imwrite(save_depth_path, scaled_depth_map) cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) save_meta_path = save_img_path.replace('jpg', 'npz') np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, camera_pose=camera_pose, maximum_depth=np.max(input_depthmap)) return selected_sequences_numbers_dict if __name__ == "__main__": parser = get_parser() args = parser.parse_args() assert args.co3d_dir != args.output_dir if args.category is None: if args.single_sequence_subset: categories = SINGLE_SEQUENCE_CATEGORIES else: categories = CATEGORIES else: categories = [args.category] os.makedirs(args.output_dir, exist_ok=True) for split in ['train', 'test']: selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') if os.path.isfile(selected_sequences_path): continue all_selected_sequences = {} for category in categories: category_output_dir = osp.join(args.output_dir, category) os.makedirs(category_output_dir, exist_ok=True) category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') if os.path.isfile(category_selected_sequences_path): with open(category_selected_sequences_path, 'r') as fid: category_selected_sequences = json.load(fid) else: print(f"Processing {split} - category = {category}") category_selected_sequences = prepare_sequences( category=category, co3d_dir=args.co3d_dir, output_dir=args.output_dir, img_size=args.img_size, split=split, min_quality=args.min_quality, max_num_sequences_per_object=args.num_sequences_per_object, seed=args.seed + CATEGORIES_IDX[category], is_single_sequence_subset=args.single_sequence_subset ) with open(category_selected_sequences_path, 'w') as file: json.dump(category_selected_sequences, file) all_selected_sequences[category] = category_selected_sequences with open(selected_sequences_path, 'w') as file: json.dump(all_selected_sequences, file) ================================================ FILE: datasets_preprocess/preprocess_megadepth.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Preprocessing code for the MegaDepth dataset # dataset at https://www.cs.cornell.edu/projects/megadepth/ # -------------------------------------------------------- import os import os.path as osp import collections from tqdm import tqdm import numpy as np os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 import h5py import path_to_root # noqa from dust3r.utils.parallel import parallel_threads from dust3r.datasets.utils import cropping # noqa def get_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--megadepth_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/megadepth_processed') return parser def main(db_root, pairs_path, output_dir): os.makedirs(output_dir, exist_ok=True) # load all pairs data = np.load(pairs_path, allow_pickle=True) scenes = data['scenes'] images = data['images'] pairs = data['pairs'] # enumerate all unique images todo = collections.defaultdict(set) for scene, im1, im2, score in pairs: todo[scene].add(im1) todo[scene].add(im2) # for each scene, load intrinsics and then parallel crops for scene, im_idxs in tqdm(todo.items(), desc='Overall'): scene, subscene = scenes[scene].split() out_dir = osp.join(output_dir, scene, subscene) os.makedirs(out_dir, exist_ok=True) # load all camera params _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True) in_dir = osp.join(db_root, scene, 'dense' + subscene) args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir) for img in [images[im_id] for im_id in im_idxs]] parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}') # save pairs print('Done! prepared all pairs in', output_dir) def resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir): if osp.isfile(osp.join(out_dir, tag + '.npz')): return # load image img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) H, W = img.shape[:2] # load depth with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5: depthmap = np.asarray(hd5['depth']) # rectify = undistort the intrinsics imsize_pre, K_pre, distortion = K_pre_rectif imsize_post = img.shape[1::-1] K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0, newImgSize=imsize_post, centerPrincipalPoint=True)[0] # downscale img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600)) # write everything img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90) cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out) camout2world = np.linalg.inv(pose_w2cam) camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world) def _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)): H, W = image.shape[:2] resolution_out = sorted(resolution_out)[::+1 if W < H else -1] image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( image, depthmap, camera_intrinsics, resolution_out, force=False) R_in2out = np.eye(3) return image, depthmap, intrinsics_out, R_in2out def _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False): if intrinsics: with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f: raw = f.readlines()[3:] # skip the header camera_intrinsics = {} for camera in raw: camera = camera.split(' ') width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]] K = np.eye(3) K[0, 0] = focal K[1, 1] = focal K[0, 2] = cx K[1, 2] = cy camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0)) with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f: raw = f.read().splitlines()[4:] # skip the header extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT poses = {} points3D_idxs = {} camera = [] for image, points in zip(raw[:: 2], raw[1:: 2]): image = image.split(' ') points = points.split(' ') image_id = image[-1] camera.append(int(image[-2])) # find the principal axis raw_pose = [float(elem) for elem in image[1: -2]] poses[image_id] = extract_pose(raw_pose) current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'} assert -1 not in current_points3D_idxs, bb() points3D_idxs[image_id] = current_points3D_idxs if intrinsics: image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)} return points3D_idxs, poses, image_intrinsics else: return points3D_idxs, poses def colmap_raw_pose_to_principal_axis(image_pose): qvec = image_pose[: 4] qvec = qvec / np.linalg.norm(qvec) w, x, y, z = qvec z_axis = np.float32([ 2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y ]) return z_axis def colmap_raw_pose_to_RT(image_pose): qvec = image_pose[: 4] qvec = qvec / np.linalg.norm(qvec) w, x, y, z = qvec R = np.array([ [ 1 - 2 * y * y - 2 * z * z, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w ], [ 2 * x * y + 2 * z * w, 1 - 2 * x * x - 2 * z * z, 2 * y * z - 2 * x * w ], [ 2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x * x - 2 * y * y ] ]) # principal_axis.append(R[2, :]) t = image_pose[4: 7] # World-to-Camera pose current_pose = np.eye(4) current_pose[: 3, : 3] = R current_pose[: 3, 3] = t return current_pose if __name__ == '__main__': parser = get_parser() args = parser.parse_args() main(args.megadepth_dir, args.precomputed_pairs, args.output_dir) ================================================ FILE: datasets_preprocess/preprocess_scannetpp.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Script to pre-process the scannet++ dataset. # Usage: # python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl # -------------------------------------------------------- import os import argparse import os.path as osp import re from tqdm import tqdm import json from scipy.spatial.transform import Rotation import pyrender import trimesh import trimesh.exchange.ply import numpy as np import cv2 import PIL.Image as Image from dust3r.datasets.utils.cropping import rescale_image_depthmap import dust3r.utils.geometry as geometry inv = np.linalg.inv norm = np.linalg.norm REGEXPR_DSLR = re.compile(r'^.*DSC(?P\d+).JPG$') REGEXPR_IPHONE = re.compile(r'.*frame_(?P\d+).jpg$') DEBUG_VIZ = None # 'iou' if DEBUG_VIZ is not None: import matplotlib.pyplot as plt # noqa OPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument('--scannetpp_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/scannetpp_processed') parser.add_argument('--target_resolution', default=920, type=int, help="images resolution") parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable') return parser def pose_from_qwxyz_txyz(elems): qw, qx, qy, qz, tx, ty, tz = map(float, elems) pose = np.eye(4) pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() pose[:3, 3] = (tx, ty, tz) return np.linalg.inv(pose) # returns cam2world def get_frame_number(name, cam_type='dslr'): if cam_type == 'dslr': regex_expr = REGEXPR_DSLR elif cam_type == 'iphone': regex_expr = REGEXPR_IPHONE else: raise NotImplementedError(f'wrong {cam_type=} for get_frame_number') try: matches = re.match(regex_expr, name) return matches['frameid'] except Exception as e: print(f'Error when parsing {name}') raise ValueError(f'Invalid name {name}') def load_sfm(sfm_dir, cam_type='dslr'): # load cameras with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f: raw = f.read().splitlines()[3:] # skip header intrinsics = {} for camera in tqdm(raw, position=1, leave=False): camera = camera.split(' ') intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]] # load images with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith('#')] # skip header img_idx = {} img_infos = {} for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False): image = image.split(' ') points = points.split(' ') idx = image[0] img_name = image[-1] prefixes = ['iphone/', 'video/'] for prefix in prefixes: if img_name.startswith(prefix): img_name = img_name[len(prefix):] assert img_name not in img_idx, 'duplicate db image: ' + img_name img_idx[img_name] = idx # register image name current_points2D = {int(i): (float(x), float(y)) for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])], path=img_name, frame_id=get_frame_number(img_name, cam_type), cam_to_world=pose_from_qwxyz_txyz(image[1: -2]), sparse_pts2d=current_points2D) # load 3D points with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith('#')] # skip header points3D = {} observations = {idx: [] for idx in img_infos.keys()} for point in tqdm(raw, position=1, leave=False): point = point.split() point_3d_idx = int(point[0]) points3D[point_3d_idx] = tuple(map(float, point[1:4])) if len(point) > 8: for idx, point_2d_idx in zip(point[8::2], point[9::2]): if idx not in observations: continue observations[idx].append((point_3d_idx, int(point_2d_idx))) return img_idx, img_infos, points3D, observations def subsample_img_infos(img_infos, num_images, allowed_name_subset=None): img_infos_val = [(idx, val) for idx, val in img_infos.items()] if allowed_name_subset is not None: img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset] if len(img_infos_val) > num_images: img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id']) kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist() img_infos_val = [img_infos_val[idx] for idx in kept_idx] return {idx: val for idx, val in img_infos_val} def undistort_images(intrinsics, rgb, mask): camera_type = intrinsics[0] width = int(intrinsics[1]) height = int(intrinsics[2]) fx = intrinsics[3] fy = intrinsics[4] cx = intrinsics[5] cy = intrinsics[6] distortion = np.array(intrinsics[7:]) K = np.zeros([3, 3]) K[0, 0] = fx K[0, 2] = cx K[1, 1] = fy K[1, 2] = cy K[2, 2] = 1 K = geometry.colmap_to_opencv_intrinsics(K) if camera_type == "OPENCV_FISHEYE": assert len(distortion) == 4 new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify( K, distortion, (width, height), np.eye(3), balance=0.0, ) # Make the cx and cy to be the center of the image new_K[0, 2] = width / 2.0 new_K[1, 2] = height / 2.0 map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) else: new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True) map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1) undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=255) new_K = geometry.opencv_to_colmap_intrinsics(new_K) return width, height, new_K, undistorted_image, undistorted_mask def process_scenes(root, pairsdir, output_dir, target_resolution): os.makedirs(output_dir, exist_ok=True) # default values from # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml znear = 0.05 zfar = 20.0 listfile = osp.join(pairsdir, 'scene_list.json') with open(listfile, 'r') as f: scenes = json.load(f) # for each of these, we will select some dslr images and some iphone images # we will undistort them and render their depth renderer = pyrender.OffscreenRenderer(0, 0) for scene in tqdm(scenes, position=0, leave=True): data_dir = os.path.join(root, 'data', scene) dir_dslr = os.path.join(data_dir, 'dslr') dir_iphone = os.path.join(data_dir, 'iphone') dir_scans = os.path.join(data_dir, 'scans') assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \ and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans) output_dir_scene = os.path.join(output_dir, scene) scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz') if osp.isfile(scene_metadata_path): continue pairs_dir_scene = os.path.join(pairsdir, scene) pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz') assert osp.isfile(pairs_dir_scene_selected_pairs) selected_npz = np.load(pairs_dir_scene_selected_pairs) selection, pairs = selected_npz['selection'], selected_npz['pairs'] # set up the output paths output_dir_scene_rgb = os.path.join(output_dir_scene, 'images') output_dir_scene_depth = os.path.join(output_dir_scene, 'depth') os.makedirs(output_dir_scene_rgb, exist_ok=True) os.makedirs(output_dir_scene_depth, exist_ok=True) ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply') sfm_dir_dslr = os.path.join(dir_dslr, 'colmap') rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images') mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks') sfm_dir_iphone = os.path.join(dir_iphone, 'colmap') rgb_dir_iphone = os.path.join(dir_iphone, 'rgb') mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks') # load the mesh with open(ply_path, 'rb') as f: mesh_kwargs = trimesh.exchange.ply.load_ply(f) mesh_scene = trimesh.Trimesh(**mesh_kwargs) # read colmap reconstruction, we will only use the intrinsics and pose here img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr') dslr_paths = { "in_colmap": sfm_dir_dslr, "in_rgb": rgb_dir_dslr, "in_mask": mask_dir_dslr, } img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm( sfm_dir_iphone, cam_type='iphone') iphone_paths = { "in_colmap": sfm_dir_iphone, "in_rgb": rgb_dir_iphone, "in_mask": mask_dir_iphone, } mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False) pyrender_scene = pyrender.Scene() pyrender_scene.add(mesh) selection_iphone = [imgname + '.jpg' for imgname in selection if 'frame_' in imgname] selection_dslr = [imgname + '.JPG' for imgname in selection if not 'frame_' in imgname] # resize the image to a more manageable size and render depth for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths), (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]: rgb_dir = paths_data['in_rgb'] mask_dir = paths_data['in_mask'] for imgname in tqdm(selection_cam, position=1, leave=False): imgidx = img_idx[imgname] img_infos_idx = img_infos[imgidx] rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path']))) mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png'))) _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask) # rescale_image_depthmap assumes opencv intrinsics intrinsics = geometry.colmap_to_opencv_intrinsics(K) image, mask, intrinsics = rescale_image_depthmap( rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4)) W, H = image.size intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics) # update inpace img_infos_idx img_infos_idx['intrinsics'] = intrinsics rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg') image.save(rgb_outpath) depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png') # render depth image renderer.viewport_width, renderer.viewport_height = W, H fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2] camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar) camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV) _, depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.SKIP_CULL_FACES) pyrender_scene.remove_node(camera_node) # dont forget to remove camera depth = (depth * 1000).astype('uint16') # invalidate depth from mask before saving depth_mask = (mask < 255) depth[depth_mask] = 0 Image.fromarray(depth).save(depth_outpath) trajectories = [] intrinsics = [] for imgname in selection: if 'frame_' in imgname: imgidx = img_idx_iphone[imgname + '.jpg'] img_infos_idx = img_infos_iphone[imgidx] elif 'DSC' in imgname: imgidx = img_idx_dslr[imgname + '.JPG'] img_infos_idx = img_infos_dslr[imgidx] else: raise ValueError(f'invalid image name {imgname}') intrinsics.append(img_infos_idx['intrinsics']) trajectories.append(img_infos_idx['cam_to_world']) intrinsics = np.stack(intrinsics, axis=0) trajectories = np.stack(trajectories, axis=0) # save metadata for this scene np.savez(scene_metadata_path, trajectories=trajectories, intrinsics=intrinsics, images=selection, pairs=pairs) del img_infos del pyrender_scene # concat all scene_metadata.npz into a single file scene_data = {} for scene_subdir in scenes: scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz') with np.load(scene_metadata_path) as data: trajectories = data['trajectories'] intrinsics = data['intrinsics'] images = data['images'] pairs = data['pairs'] scene_data[scene_subdir] = {'trajectories': trajectories, 'intrinsics': intrinsics, 'images': images, 'pairs': pairs} offset = 0 counts = [] scenes = [] sceneids = [] images = [] intrinsics = [] trajectories = [] pairs = [] for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()): num_imgs = data['images'].shape[0] img_pairs = data['pairs'] scenes.append(scene_subdir) sceneids.extend([scene_idx] * num_imgs) images.append(data['images']) intrinsics.append(data['intrinsics']) trajectories.append(data['trajectories']) # offset pairs img_pairs[:, 0:2] += offset pairs.append(img_pairs) counts.append(offset) offset += num_imgs images = np.concatenate(images, axis=0) intrinsics = np.concatenate(intrinsics, axis=0) trajectories = np.concatenate(trajectories, axis=0) pairs = np.concatenate(pairs, axis=0) np.savez(osp.join(output_dir, 'all_metadata.npz'), counts=counts, scenes=scenes, sceneids=sceneids, images=images, intrinsics=intrinsics, trajectories=trajectories, pairs=pairs) print('all done') if __name__ == '__main__': parser = get_parser() args = parser.parse_args() if args.pyopengl_platform.strip(): os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution) ================================================ FILE: datasets_preprocess/preprocess_staticthings3d.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Preprocessing code for the StaticThings3D dataset # dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d # 1) Download StaticThings3D in /path/to/StaticThings3D/ # with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh # --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2 # 2) unzip everything in the same /path/to/StaticThings3D/ directory # 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/ # -------------------------------------------------------- import os import os.path as osp import re from tqdm import tqdm import numpy as np os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 import path_to_root # noqa from dust3r.utils.parallel import parallel_threads from dust3r.datasets.utils import cropping # noqa def get_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--StaticThings3D_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/staticthings3d_processed') return parser def main(db_root, pairs_path, output_dir): all_scenes = _list_all_scenes(db_root) # crop images args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir) for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)] parallel_threads(load_crop_and_save, args, star_args=True, front_num=1) # verify that all images are there CAM = {b'l': 'left', b'r': 'right'} pairs = np.load(pairs_path) for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs): seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: for ext in ['clean', 'final']: impath = osp.join(output_dir, seq_path, cam, f"{idx:04n}_{ext}.jpg") assert osp.isfile(impath), f'missing an image at {impath=}' print(f'>> Saved all data to {output_dir}!') def load_crop_and_save(db_root, relpath_, camera, num, out_dir): relpath = osp.join(relpath_, camera, num) if osp.isfile(osp.join(out_dir, relpath + '.npz')): return os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True) # load everything intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3')) cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3'))) depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3')) img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass', relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass', relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) # do the crop assert img_clean.shape[:2] == (540, 960) assert img_final.shape[:2] == (540, 960) (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image( intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384)) # write everything clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80) final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80) cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap) # New camera parameters cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world) def _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)): image, depthmap, intrinsics_out = cropping.rescale_image_depthmap( color_image_in, depthmap_in, intrinsics_in, resolution_out) R_in2out = np.eye(3) return image, depthmap, intrinsics_out, R_in2out def _list_all_scenes(path): print('>> Listing all scenes') res = [] for split in ['TRAIN']: for subsplit in 'ABC': for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)): res.append((split, subsplit, seq)) print(f' (found ({len(res)}) scenes)') assert res, f'Did not find anything at {path=}' return res def readFloat(name): with open(name, 'rb') as f: if (f.readline().decode("utf-8")) != 'float\n': raise Exception('float file %s did not contain keyword' % name) dim = int(f.readline()) dims = [] count = 1 for i in range(0, dim): d = int(f.readline()) dims.append(d) count *= d dims = list(reversed(dims)) data = np.fromfile(f, np.float32, count).reshape(dims) return data # Hxw or CxHxW NxCxHxW if __name__ == '__main__': parser = get_parser() args = parser.parse_args() main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir) ================================================ FILE: datasets_preprocess/preprocess_waymo.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Preprocessing code for the WayMo Open dataset # dataset at https://github.com/waymo-research/waymo-open-dataset # 1) Accept the license # 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2 # 3) put all .tfrecord files in '/path/to/waymo_dir' # 4) install the waymo_open_dataset package with # `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4` # 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir` # -------------------------------------------------------- import sys import os import os.path as osp import shutil import json from tqdm import tqdm import PIL.Image import numpy as np os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 import tensorflow.compat.v1 as tf tf.enable_eager_execution() import path_to_root # noqa from dust3r.utils.geometry import geotrf, inv from dust3r.utils.image import imread_cv2 from dust3r.utils.parallel import parallel_processes as parallel_map from dust3r.datasets.utils import cropping from dust3r.viz import show_raw_pointcloud def get_parser(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--waymo_dir', required=True) parser.add_argument('--precomputed_pairs', required=True) parser.add_argument('--output_dir', default='data/waymo_processed') parser.add_argument('--workers', type=int, default=1) return parser def main(waymo_root, pairs_path, output_dir, workers=1): extract_frames(waymo_root, output_dir, workers=workers) make_crops(output_dir, workers=args.workers) # make sure all pairs are there with np.load(pairs_path) as data: scenes = data['scenes'] frames = data['frames'] pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) for scene_id, im1_id, im2_id in pairs: for im_id in (im1_id, im2_id): path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg') assert osp.isfile(path), f'Missing a file at {path=}\nDid you download all .tfrecord files?' shutil.rmtree(osp.join(output_dir, 'tmp')) print('Done! all data generated at', output_dir) def _list_sequences(db_root): print('>> Looking for sequences in', db_root) res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord')) print(f' found {len(res)} sequences') return res def extract_frames(db_root, output_dir, workers=8): sequences = _list_sequences(db_root) output_dir = osp.join(output_dir, 'tmp') print('>> outputing result to', output_dir) args = [(db_root, output_dir, seq) for seq in sequences] parallel_map(process_one_seq, args, star_args=True, workers=workers) def process_one_seq(db_root, output_dir, seq): out_dir = osp.join(output_dir, seq) os.makedirs(out_dir, exist_ok=True) calib_path = osp.join(out_dir, 'calib.json') if osp.isfile(calib_path): return try: with tf.device('/CPU:0'): calib, frames = extract_frames_one_seq(osp.join(db_root, seq)) except RuntimeError: print(f'/!\\ Error with sequence {seq} /!\\', file=sys.stderr) return # nothing is saved for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)): for cam_idx, view in views.items(): img = PIL.Image.fromarray(view.pop('img')) img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg')) np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view) with open(calib_path, 'w') as f: json.dump(calib, f) def extract_frames_one_seq(filename): from waymo_open_dataset import dataset_pb2 as open_dataset from waymo_open_dataset.utils import frame_utils print('>> Opening', filename) dataset = tf.data.TFRecordDataset(filename, compression_type='') calib = None frames = [] for data in tqdm(dataset, leave=False): frame = open_dataset.Frame() frame.ParseFromString(bytearray(data.numpy())) content = frame_utils.parse_range_image_and_camera_projection(frame) range_images, camera_projections, _, range_image_top_pose = content views = {} frames.append((frame.context.name, views)) # once in a sequence, read camera calibration info if calib is None: calib = [] for cam in frame.context.camera_calibrations: calib.append((cam.name, dict(width=cam.width, height=cam.height, intrinsics=list(cam.intrinsic), extrinsics=list(cam.extrinsic.transform)))) # convert LIDAR to pointcloud points, cp_points = frame_utils.convert_range_image_to_point_cloud( frame, range_images, camera_projections, range_image_top_pose) # 3d points in vehicle frame. points_all = np.concatenate(points, axis=0) cp_points_all = np.concatenate(cp_points, axis=0) # The distance between lidar points and vehicle frame origin. cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32) for i, image in enumerate(frame.images): # select relevant 3D points for this view mask = tf.equal(cp_points_all_tensor[..., 0], image.name) cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32) pose = np.asarray(image.pose.transform).reshape(4, 4) timestamp = image.pose_timestamp rgb = tf.image.decode_jpeg(image.image).numpy() pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16) pts3d = points_all[mask.numpy()] views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp) if not 'show full point cloud': show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()]) return calib, frames def make_crops(output_dir, workers=16, **kw): tmp_dir = osp.join(output_dir, 'tmp') sequences = _list_sequences(tmp_dir) args = [(tmp_dir, output_dir, seq) for seq in sequences] parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0) def crop_one_seq(input_dir, output_dir, seq, resolution=512): seq_dir = osp.join(input_dir, seq) out_dir = osp.join(output_dir, seq) if osp.isfile(osp.join(out_dir, '00100_1.jpg')): return os.makedirs(out_dir, exist_ok=True) # load calibration file try: with open(osp.join(seq_dir, 'calib.json')) as f: calib = json.load(f) except IOError: print(f'/!\\ Error: Missing calib.json in sequence {seq} /!\\', file=sys.stderr) return axes_transformation = np.array([ [0, -1, 0, 0], [0, 0, -1, 0], [1, 0, 0, 0], [0, 0, 0, 1]]) cam_K = {} cam_distortion = {} cam_res = {} cam_to_car = {} for cam_idx, cam_info in calib: cam_idx = str(cam_idx) cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height']) f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics'] cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)]) cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3]) cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4) # cam-to-vehicle frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg')) # from dust3r.viz import SceneViz # viz = SceneViz() for frame in tqdm(frames, leave=False): cam_idx = frame[-2] # cam index assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}' data = np.load(osp.join(seq_dir, frame + 'npz')) car_to_world = data['pose'] W, H = cam_res[cam_idx] # load depthmap pos2d = data['pixels'].round().astype(np.uint16) x, y = pos2d.T pts3d = data['pts3d'] # already in the car frame pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d) # X=LEFT_RIGHT y=ALTITUDE z=DEPTH # load image image = imread_cv2(osp.join(seq_dir, frame + 'jpg')) # downscale image output_resolution = (resolution, 1) if W > H else (1, resolution) image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution) image.save(osp.join(out_dir, frame + 'jpg'), quality=80) # save as an EXR file? yes it's smaller (and easier to load) W, H = image.size depthmap = np.zeros((H, W), dtype=np.float32) pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16) x, y = pos2d.T depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2] cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap) # save camera parametes cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation) np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2, cam2world=cam2world, distortion=cam_distortion[cam_idx]) # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world) # viz.show() if __name__ == '__main__': parser = get_parser() args = parser.parse_args() main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers) ================================================ FILE: datasets_preprocess/preprocess_wildrgbd.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Script to pre-process the WildRGB-D dataset. # Usage: # python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd # -------------------------------------------------------- import argparse import random import json import os import os.path as osp import PIL.Image import numpy as np import cv2 from tqdm.auto import tqdm import matplotlib.pyplot as plt import path_to_root # noqa import dust3r.datasets.utils.cropping as cropping # noqa from dust3r.utils.image import imread_cv2 def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--output_dir", type=str, default="data/wildrgbd_processed") parser.add_argument("--wildrgbd_dir", type=str, required=True) parser.add_argument("--train_num_sequences_per_object", type=int, default=50) parser.add_argument("--test_num_sequences_per_object", type=int, default=10) parser.add_argument("--num_frames", type=int, default=100) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--img_size", type=int, default=512, help=("lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size")) return parser def get_set_list(category_dir, split): listfiles = ["camera_eval_list.json", "nvs_list.json"] sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']} for listfile in listfiles: with open(osp.join(category_dir, listfile)) as f: subset_lists_data = json.load(f) for s in ['train', 'val']: sequences_all[s][listfile].update(subset_lists_data[s]) train_intersection = set.intersection(*list(sequences_all['train'].values())) if split == "train": return train_intersection else: all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values())) return all_seqs.difference(train_intersection) def prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object, output_num_frames, seed): random.seed(seed) category_dir = osp.join(wildrgbd_dir, category) category_output_dir = osp.join(output_dir, category) sequences_all = get_set_list(category_dir, split) sequences_all = sorted(sequences_all) sequences_all_tmp = [] for seq_name in sequences_all: scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name) if not os.path.isdir(scene_dir): print(f'{scene_dir} does not exist, skipped') continue sequences_all_tmp.append(seq_name) sequences_all = sequences_all_tmp if len(sequences_all) <= max_num_sequences_per_object: selected_sequences = sequences_all else: selected_sequences = random.sample(sequences_all, max_num_sequences_per_object) selected_sequences_numbers_dict = {} for seq_name in tqdm(selected_sequences, leave=False): scene_dir = osp.join(category_dir, seq_name) scene_output_dir = osp.join(category_output_dir, seq_name) with open(osp.join(scene_dir, 'metadata'), 'r') as f: metadata = json.load(f) K = np.array(metadata["K"]).reshape(3, 3).T fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] w, h = metadata["w"], metadata["h"] camera_intrinsics = np.array( [[fx, 0, cx], [0, fy, cy], [0, 0, 1]] ) camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt') camera_to_world_content = np.genfromtxt(camera_to_world_path) camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4) frame_idx = camera_to_world_content[:, 0] num_frames = frame_idx.shape[0] assert num_frames >= output_num_frames assert np.all(frame_idx == np.arange(num_frames)) # selected_sequences_numbers_dict[seq_name] = num_frames selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist() selected_sequences_numbers_dict[seq_name] = selected_frames for frame_id in tqdm(selected_frames): depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png') masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png') rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png') input_rgb_image = PIL.Image.open(rgb_path).convert('RGB') input_mask = plt.imread(masks_path) input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64) depth_mask = np.stack((input_depthmap, input_mask), axis=-1) H, W = input_depthmap.shape min_margin_x = min(cx, W - cx) min_margin_y = min(cy, H - cy) # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) l, t = int(cx - min_margin_x), int(cy - min_margin_y) r, b = int(cx + min_margin_x), int(cy + min_margin_y) crop_bbox = (l, t, r, b) input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap( input_rgb_image, depth_mask, camera_intrinsics, crop_bbox) # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384 scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8 output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) if max(output_resolution) < img_size: # let's put the max dimension to img_size scale_final = (img_size / max(H, W)) + 1e-8 output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int) input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap( input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution) input_depthmap = depth_mask[:, :, 0] input_mask = depth_mask[:, :, 1] camera_pose = camera_to_world[frame_id] # save crop images and depth, metadata save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg') save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png') save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png') os.makedirs(os.path.split(save_img_path)[0], exist_ok=True) os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True) os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True) input_rgb_image.save(save_img_path) cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16)) cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8)) save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz') os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True) np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics, camera_pose=camera_pose) return selected_sequences_numbers_dict if __name__ == "__main__": parser = get_parser() args = parser.parse_args() assert args.wildrgbd_dir != args.output_dir categories = sorted([ dirname for dirname in os.listdir(args.wildrgbd_dir) if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes')) ]) os.makedirs(args.output_dir, exist_ok=True) splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object] for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object): selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json') if os.path.isfile(selected_sequences_path): continue all_selected_sequences = {} for category in categories: category_output_dir = osp.join(args.output_dir, category) os.makedirs(category_output_dir, exist_ok=True) category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json') if os.path.isfile(category_selected_sequences_path): with open(category_selected_sequences_path, 'r') as fid: category_selected_sequences = json.load(fid) else: print(f"Processing {split} - category = {category}") category_selected_sequences = prepare_sequences( category=category, wildrgbd_dir=args.wildrgbd_dir, output_dir=args.output_dir, img_size=args.img_size, split=split, max_num_sequences_per_object=num_sequences_per_object, output_num_frames=args.num_frames, seed=args.seed + int("category".encode('ascii').hex(), 16), ) with open(category_selected_sequences_path, 'w') as file: json.dump(category_selected_sequences, file) all_selected_sequences[category] = category_selected_sequences with open(selected_sequences_path, 'w') as file: json.dump(all_selected_sequences, file) ================================================ FILE: demo.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # dust3r gradio demo executable # -------------------------------------------------------- import os import torch import tempfile from dust3r.model import AsymmetricCroCo3DStereo from dust3r.demo import get_args_parser, main_demo, set_print_with_timestamp import matplotlib.pyplot as pl pl.ion() torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() set_print_with_timestamp() if args.tmp_dir is not None: tmp_path = args.tmp_dir os.makedirs(tmp_path, exist_ok=True) tempfile.tempdir = tmp_path if args.server_name is not None: server_name = args.server_name else: server_name = '0.0.0.0' if args.local_network else '127.0.0.1' if args.weights is not None: weights_path = args.weights else: weights_path = "naver/" + args.model_name model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) # dust3r will write the 3D model inside tmpdirname with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: if not args.silent: print('Outputing stuff in', tmpdirname) main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent) ================================================ FILE: docker/docker-compose-cpu.yml ================================================ version: '3.8' services: dust3r-demo: build: context: ./files dockerfile: cpu.Dockerfile ports: - "7860:7860" volumes: - ./files/checkpoints:/dust3r/checkpoints environment: - DEVICE=cpu - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} cap_add: - IPC_LOCK - SYS_RESOURCE ================================================ FILE: docker/docker-compose-cuda.yml ================================================ version: '3.8' services: dust3r-demo: build: context: ./files dockerfile: cuda.Dockerfile ports: - "7860:7860" environment: - DEVICE=cuda - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} volumes: - ./files/checkpoints:/dust3r/checkpoints cap_add: - IPC_LOCK - SYS_RESOURCE deploy: resources: reservations: devices: - driver: nvidia count: 1 capabilities: [gpu] ================================================ FILE: docker/files/cpu.Dockerfile ================================================ FROM python:3.11-slim LABEL description="Docker container for DUSt3R with dependencies installed. CPU VERSION" ENV DEVICE="cpu" ENV MODEL="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y \ git \ libgl1-mesa-glx \ libegl1-mesa \ libxrandr2 \ libxrandr2 \ libxss1 \ libxcursor1 \ libxcomposite1 \ libasound2 \ libxi6 \ libxtst6 \ libglib2.0-0 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* RUN git clone --recursive https://github.com/naver/dust3r /dust3r WORKDIR /dust3r RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu RUN pip install -r requirements.txt RUN pip install -r requirements_optional.txt RUN pip install opencv-python==4.8.0.74 WORKDIR /dust3r COPY entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] ================================================ FILE: docker/files/cuda.Dockerfile ================================================ FROM nvcr.io/nvidia/pytorch:24.01-py3 LABEL description="Docker container for DUSt3R with dependencies installed. CUDA VERSION" ENV DEVICE="cuda" ENV MODEL="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y \ git=1:2.34.1-1ubuntu1.10 \ libglib2.0-0=2.72.4-0ubuntu2.2 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* RUN git clone --recursive https://github.com/naver/dust3r /dust3r WORKDIR /dust3r RUN pip install -r requirements.txt RUN pip install -r requirements_optional.txt RUN pip install opencv-python==4.8.0.74 WORKDIR /dust3r/croco/models/curope/ RUN python setup.py build_ext --inplace WORKDIR /dust3r COPY entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] ================================================ FILE: docker/files/entrypoint.sh ================================================ #!/bin/bash set -eux DEVICE=${DEVICE:-cuda} MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth} exec python3 demo.py --weights "checkpoints/$MODEL" --device "$DEVICE" --local_network "$@" ================================================ FILE: docker/run.sh ================================================ #!/bin/bash set -eux # Default model name model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth" check_docker() { if ! command -v docker &>/dev/null; then echo "Docker could not be found. Please install Docker and try again." exit 1 fi } download_model_checkpoint() { if [ -f "./files/checkpoints/${model_name}" ]; then echo "Model checkpoint ${model_name} already exists. Skipping download." return fi echo "Downloading model checkpoint ${model_name}..." wget "https://download.europe.naverlabs.com/ComputerVision/DUSt3R/${model_name}" -P ./files/checkpoints } set_dcomp() { if command -v docker-compose &>/dev/null; then dcomp="docker-compose" elif command -v docker &>/dev/null && docker compose version &>/dev/null; then dcomp="docker compose" else echo "Docker Compose could not be found. Please install Docker Compose and try again." exit 1 fi } run_docker() { export MODEL=${model_name} if [ "$with_cuda" -eq 1 ]; then $dcomp -f docker-compose-cuda.yml up --build else $dcomp -f docker-compose-cpu.yml up --build fi } with_cuda=0 for arg in "$@"; do case $arg in --with-cuda) with_cuda=1 ;; --model_name=*) model_name="${arg#*=}.pth" ;; *) echo "Unknown parameter passed: $arg" exit 1 ;; esac done main() { check_docker download_model_checkpoint set_dcomp run_docker } main ================================================ FILE: dust3r/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: dust3r/cloud_opt/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # global alignment optimization wrapper function # -------------------------------------------------------- from enum import Enum from .optimizer import PointCloudOptimizer from .modular_optimizer import ModularPointCloudOptimizer from .pair_viewer import PairViewer class GlobalAlignerMode(Enum): PointCloudOptimizer = "PointCloudOptimizer" ModularPointCloudOptimizer = "ModularPointCloudOptimizer" PairViewer = "PairViewer" def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): # extract all inputs view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] # build the optimizer if mode == GlobalAlignerMode.PointCloudOptimizer: net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) elif mode == GlobalAlignerMode.ModularPointCloudOptimizer: net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) elif mode == GlobalAlignerMode.PairViewer: net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) else: raise NotImplementedError(f'Unknown mode {mode}') return net ================================================ FILE: dust3r/cloud_opt/base_opt.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Base class for the global alignement procedure # -------------------------------------------------------- from copy import deepcopy import numpy as np import torch import torch.nn as nn import roma from copy import deepcopy import tqdm from dust3r.utils.geometry import inv, geotrf from dust3r.utils.device import to_numpy from dust3r.utils.image import rgb from dust3r.viz import SceneViz, segment_sky, auto_cam_size from dust3r.optim_factory import adjust_learning_rate_by_lr from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p, cosine_schedule, linear_schedule, get_conf_trf) import dust3r.cloud_opt.init_im_poses as init_fun class BasePCOptimizer (nn.Module): """ Optimize a global scene, given a list of pairwise observations. Graph node: images Graph edges: observations = (pred1, pred2) """ def __init__(self, *args, **kwargs): if len(args) == 1 and len(kwargs) == 0: other = deepcopy(args[0]) attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes min_conf_thr conf_thr conf_i conf_j im_conf base_scale norm_pw_scale POSE_DIM pw_poses pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split() self.__dict__.update({k: other[k] for k in attrs}) else: self._init_from_views(*args, **kwargs) def _init_from_views(self, view1, view2, pred1, pred2, dist='l1', conf='log', min_conf_thr=3, base_scale=0.5, allow_pw_adaptors=False, pw_break=20, rand_pose=torch.randn, iterationsCount=None, verbose=True): super().__init__() if not isinstance(view1['idx'], list): view1['idx'] = view1['idx'].tolist() if not isinstance(view2['idx'], list): view2['idx'] = view2['idx'].tolist() self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges} self.dist = ALL_DISTS[dist] self.verbose = verbose self.n_imgs = self._check_edges() # input data pred1_pts = pred1['pts3d'] pred2_pts = pred2['pts3d_in_other_view'] self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)}) self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)}) self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts) # work in log-scale with conf pred1_conf = pred1['conf'] pred2_conf = pred2['conf'] self.min_conf_thr = min_conf_thr self.conf_trf = get_conf_trf(conf) self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)}) self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)}) self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf) for i in range(len(self.im_conf)): self.im_conf[i].requires_grad = False # pairwise pose parameters self.base_scale = base_scale self.norm_pw_scale = True self.pw_break = pw_break self.POSE_DIM = 7 self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation self.pw_adaptors.requires_grad_(allow_pw_adaptors) self.has_im_poses = False self.rand_pose = rand_pose # possibly store images for show_pointcloud self.imgs = None if 'img' in view1 and 'img' in view2: imgs = [torch.zeros((3,)+hw) for hw in self.imshapes] for v in range(len(self.edges)): idx = view1['idx'][v] imgs[idx] = view1['img'][v] idx = view2['idx'][v] imgs[idx] = view2['img'][v] self.imgs = rgb(imgs) @property def n_edges(self): return len(self.edges) @property def str_edges(self): return [edge_str(i, j) for i, j in self.edges] @property def imsizes(self): return [(w, h) for h, w in self.imshapes] @property def device(self): return next(iter(self.parameters())).device def state_dict(self, trainable=True): all_params = super().state_dict() return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable} def load_state_dict(self, data): return super().load_state_dict(self.state_dict(trainable=False) | data) def _check_edges(self): indices = sorted({i for edge in self.edges for i in edge}) assert indices == list(range(len(indices))), 'bad pair indices: missing values ' return len(indices) @torch.no_grad() def _compute_img_conf(self, pred1_conf, pred2_conf): im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes]) for e, (i, j) in enumerate(self.edges): im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e]) im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e]) return im_conf def get_adaptors(self): adapt = self.pw_adaptors adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z) if self.norm_pw_scale: # normalize so that the product == 1 adapt = adapt - adapt.mean(dim=1, keepdim=True) return (adapt / self.pw_break).exp() def _get_poses(self, poses): # normalize rotation Q = poses[:, :4] T = signed_expm1(poses[:, 4:7]) RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous() return RT def _set_pose(self, poses, idx, R, T=None, scale=None, force=False): # all poses == cam-to-world pose = poses[idx] if not (pose.requires_grad or force): return pose if R.shape == (4, 4): assert T is None T = R[:3, 3] R = R[:3, :3] if R is not None: pose.data[0:4] = roma.rotmat_to_unitquat(R) if T is not None: pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale if scale is not None: assert poses.shape[-1] in (8, 13) pose.data[-1] = np.log(float(scale)) return pose def get_pw_norm_scale_factor(self): if self.norm_pw_scale: # normalize scales so that things cannot go south # we want that exp(scale) ~= self.base_scale return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp() else: return 1 # don't norm scale for known poses def get_pw_scale(self): scale = self.pw_poses[:, -1].exp() # (n_edges,) scale = scale * self.get_pw_norm_scale_factor() return scale def get_pw_poses(self): # cam to world RT = self._get_poses(self.pw_poses) scaled_RT = RT.clone() scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation return scaled_RT def get_masks(self): return [(conf > self.min_conf_thr) for conf in self.im_conf] def depth_to_pts3d(self): raise NotImplementedError() def get_pts3d(self, raw=False): res = self.depth_to_pts3d() if not raw: res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] return res def _set_focal(self, idx, focal, force=False): raise NotImplementedError() def get_focals(self): raise NotImplementedError() def get_known_focal_mask(self): raise NotImplementedError() def get_principal_points(self): raise NotImplementedError() def get_conf(self, mode=None): trf = self.conf_trf if mode is None else get_conf_trf(mode) return [trf(c) for c in self.im_conf] def get_im_poses(self): raise NotImplementedError() def _set_depthmap(self, idx, depth, force=False): raise NotImplementedError() def get_depthmaps(self, raw=False): raise NotImplementedError() def clean_pointcloud(self, **kw): cams = inv(self.get_im_poses()) K = self.get_intrinsics() depthmaps = self.get_depthmaps() all_pts3d = self.get_pts3d() new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw) for i, new_conf in enumerate(new_im_confs): self.im_conf[i].data[:] = new_conf return self def forward(self, ret_details=False): pw_poses = self.get_pw_poses() # cam-to-world pw_adapt = self.get_adaptors() proj_pts3d = self.get_pts3d() # pre-compute pixel weights weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()} weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()} loss = 0 if ret_details: details = -torch.ones((self.n_imgs, self.n_imgs)) for e, (i, j) in enumerate(self.edges): i_j = edge_str(i, j) # distance in image i and j aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j]) aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j]) li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean() lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean() loss = loss + li + lj if ret_details: details[i, j] = li + lj loss /= self.n_edges # average over all pairs if ret_details: return loss, details return loss @torch.cuda.amp.autocast(enabled=False) def compute_global_alignment(self, init=None, niter_PnP=10, **kw): if init is None: pass elif init == 'msp' or init == 'mst': init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP) elif init == 'known_poses': init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP) else: raise ValueError(f'bad value for {init=}') return global_alignment_loop(self, **kw) @torch.no_grad() def mask_sky(self): res = deepcopy(self) for i in range(self.n_imgs): sky = segment_sky(self.imgs[i]) res.im_conf[i][sky] = 0 return res def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw): viz = SceneViz() if self.imgs is None: colors = np.random.randint(0, 256, size=(self.n_imgs, 3)) colors = list(map(tuple, colors.tolist())) for n in range(self.n_imgs): viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n]) else: viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks()) colors = np.random.randint(256, size=(self.n_imgs, 3)) # camera poses im_poses = to_numpy(self.get_im_poses()) if cam_size is None: cam_size = auto_cam_size(im_poses) viz.add_cameras(im_poses, self.get_focals(), colors=colors, images=self.imgs, imsizes=self.imsizes, cam_size=cam_size) if show_pw_cams: pw_poses = self.get_pw_poses() viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size) if show_pw_pts3d: pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)] viz.add_pointcloud(pts, (128, 0, 128)) viz.show(**kw) return viz def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6): params = [p for p in net.parameters() if p.requires_grad] if not params: return net verbose = net.verbose if verbose: print('Global alignement - optimizing for:') print([name for name, value in net.named_parameters() if value.requires_grad]) lr_base = lr optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9)) loss = float('inf') if verbose: with tqdm.tqdm(total=niter) as bar: while bar.n < bar.total: loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule) bar.set_postfix_str(f'{lr=:g} loss={loss:g}') bar.update() else: for n in range(niter): loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule) return loss def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule): t = cur_iter / niter if schedule == 'cosine': lr = cosine_schedule(t, lr_base, lr_min) elif schedule == 'linear': lr = linear_schedule(t, lr_base, lr_min) else: raise ValueError(f'bad lr {schedule=}') adjust_learning_rate_by_lr(optimizer, lr) optimizer.zero_grad() loss = net() loss.backward() optimizer.step() return float(loss), lr @torch.no_grad() def clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, tol=0.001, bad_conf=0, dbg=()): """ Method: 1) express all 3d points in each camera coordinate frame 2) if they're in front of a depthmap --> then lower their confidence """ assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d) assert 0 <= tol < 1 res = [c.clone() for c in im_confs] # reshape appropriately all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)] depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)] for i, pts3d in enumerate(all_pts3d): for j in range(len(all_pts3d)): if i == j: continue # project 3dpts in other view proj = geotrf(cams[j], pts3d) proj_depth = proj[:,:,2] u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1) # check which points are actually in the visible cone H, W = im_confs[j].shape msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H) msk_j = v[msk_i], u[msk_i] # find bad points = those in front but less confident bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j]) bad_msk_i = msk_i.clone() bad_msk_i[msk_i] = bad_points res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf) return res ================================================ FILE: dust3r/cloud_opt/commons.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utility functions for global alignment # -------------------------------------------------------- import torch import torch.nn as nn import numpy as np def edge_str(i, j): return f'{i}_{j}' def i_j_ij(ij): return edge_str(*ij), ij def edge_conf(conf_i, conf_j, edge): return float(conf_i[edge].mean() * conf_j[edge].mean()) def compute_edge_scores(edges, conf_i, conf_j): return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} def NoGradParamDict(x): assert isinstance(x, dict) return nn.ParameterDict(x).requires_grad_(False) def get_imshapes(edges, pred_i, pred_j): n_imgs = max(max(e) for e in edges) + 1 imshapes = [None] * n_imgs for e, (i, j) in enumerate(edges): shape_i = tuple(pred_i[e].shape[0:2]) shape_j = tuple(pred_j[e].shape[0:2]) if imshapes[i]: assert imshapes[i] == shape_i, f'incorrect shape for image {i}' if imshapes[j]: assert imshapes[j] == shape_j, f'incorrect shape for image {j}' imshapes[i] = shape_i imshapes[j] = shape_j return imshapes def get_conf_trf(mode): if mode == 'log': def conf_trf(x): return x.log() elif mode == 'sqrt': def conf_trf(x): return x.sqrt() elif mode == 'm1': def conf_trf(x): return x-1 elif mode in ('id', 'none'): def conf_trf(x): return x else: raise ValueError(f'bad mode for {mode=}') return conf_trf def l2_dist(a, b, weight): return ((a - b).square().sum(dim=-1) * weight) def l1_dist(a, b, weight): return ((a - b).norm(dim=-1) * weight) ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) def signed_log1p(x): sign = torch.sign(x) return sign * torch.log1p(torch.abs(x)) def signed_expm1(x): sign = torch.sign(x) return sign * torch.expm1(torch.abs(x)) def cosine_schedule(t, lr_start, lr_end): assert 0 <= t <= 1 return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 def linear_schedule(t, lr_start, lr_end): assert 0 <= t <= 1 return lr_start + (lr_end - lr_start) * t ================================================ FILE: dust3r/cloud_opt/init_im_poses.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Initialization functions for global alignment # -------------------------------------------------------- from functools import cache import numpy as np import scipy.sparse as sp import torch import cv2 import roma from tqdm import tqdm from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses from dust3r.post_process import estimate_focal_knowing_depth from dust3r.viz import to_numpy from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores @torch.no_grad() def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3): device = self.device # indices of known poses nkp, known_poses_msk, known_poses = get_known_poses(self) assert nkp == self.n_imgs, 'not all poses are known' # get all focals nkf, _, im_focals = get_known_focals(self) assert nkf == self.n_imgs im_pp = self.get_principal_points() best_depthmaps = {} # init all pairwise poses for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)): i_j = edge_str(i, j) # find relative pose for this pair P1 = torch.eye(4, device=device) msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1) _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()), pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP) # align the two predicted camera with the two gt cameras s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]]) # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1 # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3]) self._set_pose(self.pw_poses, e, R, T, scale=s) # remember if this is a good depthmap score = float(self.conf_i[i_j].mean()) if score > best_depthmaps.get(i, (0,))[0]: best_depthmaps[i] = score, i_j, s # init all image poses for n in range(self.n_imgs): assert known_poses_msk[n] _, i_j, scale = best_depthmaps[n] depth = self.pred_i[i_j][:, :, 2] self._set_depthmap(n, depth * scale) @torch.no_grad() def init_minimum_spanning_tree(self, **kw): """ Init all camera poses (image-wise and pairwise poses) given an initial set of pairwise estimations. """ device = self.device pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges, self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr, device, has_im_poses=self.has_im_poses, verbose=self.verbose, **kw) return init_from_pts3d(self, pts3d, im_focals, im_poses) def init_from_pts3d(self, pts3d, im_focals, im_poses): # init poses nkp, known_poses_msk, known_poses = get_known_poses(self) if nkp == 1: raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose") elif nkp > 1: # global rigid SE3 alignment s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk]) trf = sRT_to_4x4(s, R, T, device=known_poses.device) # rotate everything im_poses = trf @ im_poses im_poses[:, :3, :3] /= s # undo scaling on the rotation part for img_pts3d in pts3d: img_pts3d[:] = geotrf(trf, img_pts3d) # set all pairwise poses for e, (i, j) in enumerate(self.edges): i_j = edge_str(i, j) # compute transform that goes from cam to world s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) self._set_pose(self.pw_poses, e, R, T, scale=s) # take into account the scale normalization s_factor = self.get_pw_norm_scale_factor() im_poses[:, :3, 3] *= s_factor # apply downscaling factor for img_pts3d in pts3d: img_pts3d *= s_factor # init all image poses if self.has_im_poses: for i in range(self.n_imgs): cam2world = im_poses[i] depth = geotrf(inv(cam2world), pts3d[i])[..., 2] self._set_depthmap(i, depth) self._set_pose(self.im_poses, i, cam2world) if im_focals[i] is not None: self._set_focal(i, im_focals[i]) if self.verbose: print(' init loss =', float(self())) def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr, device, has_im_poses=True, niter_PnP=10, verbose=True): n_imgs = len(imshapes) sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() # temp variable to store 3d points pts3d = [None] * len(imshapes) todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges im_poses = [None] * n_imgs im_focals = [None] * n_imgs # init with strongest edge score, i, j = todo.pop() if verbose: print(f' init edge ({i}*,{j}*) {score=}') i_j = edge_str(i, j) pts3d[i] = pred_i[i_j].clone() pts3d[j] = pred_j[i_j].clone() done = {i, j} if has_im_poses: im_poses[i] = torch.eye(4, device=device) im_focals[i] = estimate_focal(pred_i[i_j]) # set initial pointcloud based on pairwise graph msp_edges = [(i, j)] while todo: # each time, predict the next one score, i, j = todo.pop() if im_focals[i] is None: im_focals[i] = estimate_focal(pred_i[i_j]) if i in done: if verbose: print(f' init edge ({i},{j}*) {score=}') assert j not in done # align pred[i] with pts3d[i], and then set j accordingly i_j = edge_str(i, j) s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) trf = sRT_to_4x4(s, R, T, device) pts3d[j] = geotrf(trf, pred_j[i_j]) done.add(j) msp_edges.append((i, j)) if has_im_poses and im_poses[i] is None: im_poses[i] = sRT_to_4x4(1, R, T, device) elif j in done: if verbose: print(f' init edge ({i}*,{j}) {score=}') assert i not in done i_j = edge_str(i, j) s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) trf = sRT_to_4x4(s, R, T, device) pts3d[i] = geotrf(trf, pred_i[i_j]) done.add(i) msp_edges.append((i, j)) if has_im_poses and im_poses[i] is None: im_poses[i] = sRT_to_4x4(1, R, T, device) else: # let's try again later todo.insert(0, (score, i, j)) if has_im_poses: # complete all missing informations pair_scores = list(sparse_graph.values()) # already negative scores: less is best edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)] for i, j in edges_from_best_to_worse.tolist(): if im_focals[i] is None: im_focals[i] = estimate_focal(pred_i[edge_str(i, j)]) for i in range(n_imgs): if im_poses[i] is None: msk = im_conf[i] > min_conf_thr res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP) if res: im_focals[i], im_poses[i] = res if im_poses[i] is None: im_poses[i] = torch.eye(4, device=device) im_poses = torch.stack(im_poses) else: im_poses = im_focals = None return pts3d, msp_edges, im_focals, im_poses def dict_to_sparse_graph(dic): n_imgs = max(max(e) for e in dic) + 1 res = sp.dok_array((n_imgs, n_imgs)) for edge, value in dic.items(): res[edge] = value return res def rigid_points_registration(pts1, pts2, conf): R, T, s = roma.rigid_points_registration( pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True) return s, R, T # return un-scaled (R, T) def sRT_to_4x4(scale, R, T, device): trf = torch.eye(4, device=device) trf[:3, :3] = R * scale trf[:3, 3] = T.ravel() # doesn't need scaling return trf def estimate_focal(pts3d_i, pp=None): if pp is None: H, W, THREE = pts3d_i.shape assert THREE == 3 pp = torch.tensor((W/2, H/2), device=pts3d_i.device) focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() return float(focal) @cache def pixel_grid(H, W): return np.mgrid[:W, :H].T.astype(np.float32) def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): # extract camera poses and focals with RANSAC-PnP if msk.sum() < 4: return None # we need at least 4 points for PnP pts3d, msk = map(to_numpy, (pts3d, msk)) H, W, THREE = pts3d.shape assert THREE == 3 pixels = pixel_grid(H, W) if focal is None: S = max(W, H) tentative_focals = np.geomspace(S/2, S*3, 21) else: tentative_focals = [focal] if pp is None: pp = (W/2, H/2) else: pp = to_numpy(pp) best = 0, for focal in tentative_focals: K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) if not success: continue score = len(inliers) if success and score > best[0]: best = score, R, T, focal if not best[0]: return None _, R, T, best_focal = best R = cv2.Rodrigues(R)[0] # world to cam R, T = map(torch.from_numpy, (R, T)) return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world def get_known_poses(self): if self.has_im_poses: known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses]) known_poses = self.get_im_poses() return known_poses_msk.sum(), known_poses_msk, known_poses else: return 0, None, None def get_known_focals(self): if self.has_im_poses: known_focal_msk = self.get_known_focal_mask() known_focals = self.get_focals() return known_focal_msk.sum(), known_focal_msk, known_focals else: return 0, None, None def align_multiple_poses(src_poses, target_poses): N = len(src_poses) assert src_poses.shape == target_poses.shape == (N, 4, 4) def center_and_z(poses): eps = get_med_dist_between_poses(poses) / 100 return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2])) R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True) return s, R, T ================================================ FILE: dust3r/cloud_opt/modular_optimizer.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Slower implementation of the global alignment that allows to freeze partial poses/intrinsics # -------------------------------------------------------- import numpy as np import torch import torch.nn as nn from dust3r.cloud_opt.base_opt import BasePCOptimizer from dust3r.utils.geometry import geotrf from dust3r.utils.device import to_cpu, to_numpy from dust3r.utils.geometry import depthmap_to_pts3d class ModularPointCloudOptimizer (BasePCOptimizer): """ Optimize a global scene, given a list of pairwise observations. Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics) Graph node: images Graph edges: observations = (pred1, pred2) """ def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs): super().__init__(*args, **kwargs) self.has_im_poses = True # by definition of this class self.focal_brake = focal_brake # adding thing to optimize self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes] self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [ f]) for f in default_focals) # camera intrinsics self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics self.im_pp.requires_grad_(optimize_pp) def preset_pose(self, known_poses, pose_msk=None): # cam-to-world if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: known_poses = [known_poses] for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): if self.verbose: print(f' (setting pose #{idx} = {pose[:3,3]})') self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True)) # normalize scale if there's less than 1 known pose n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) self.norm_pw_scale = (n_known_poses <= 1) def preset_intrinsics(self, known_intrinsics, msk=None): if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2: known_intrinsics = [known_intrinsics] for K in known_intrinsics: assert K.shape == (3, 3) self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk) self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk) def preset_focal(self, known_focals, msk=None): for idx, focal in zip(self._get_msk_indices(msk), known_focals): if self.verbose: print(f' (setting focal #{idx} = {focal})') self._no_grad(self._set_focal(idx, focal, force=True)) def preset_principal_point(self, known_pp, msk=None): for idx, pp in zip(self._get_msk_indices(msk), known_pp): if self.verbose: print(f' (setting principal point #{idx} = {pp})') self._no_grad(self._set_principal_point(idx, pp, force=True)) def _no_grad(self, tensor): return tensor.requires_grad_(False) def _get_msk_indices(self, msk): if msk is None: return range(self.n_imgs) elif isinstance(msk, int): return [msk] elif isinstance(msk, (tuple, list)): return self._get_msk_indices(np.array(msk)) elif msk.dtype in (bool, torch.bool, np.bool_): assert len(msk) == self.n_imgs return np.where(msk)[0] elif np.issubdtype(msk.dtype, np.integer): return msk else: raise ValueError(f'bad {msk=}') def _set_focal(self, idx, focal, force=False): param = self.im_focals[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = self.focal_brake * np.log(focal) return param def get_focals(self): log_focals = torch.stack(list(self.im_focals), dim=0) return (log_focals / self.focal_brake).exp() def _set_principal_point(self, idx, pp, force=False): param = self.im_pp[idx] H, W = self.imshapes[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 return param def get_principal_points(self): return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)]) def get_intrinsics(self): K = torch.zeros((self.n_imgs, 3, 3), device=self.device) focals = self.get_focals().view(self.n_imgs, -1) K[:, 0, 0] = focals[:, 0] K[:, 1, 1] = focals[:, -1] K[:, :2, 2] = self.get_principal_points() K[:, 2, 2] = 1 return K def get_im_poses(self): # cam to world cam2world = self._get_poses(torch.stack(list(self.im_poses))) return cam2world def _set_depthmap(self, idx, depth, force=False): param = self.im_depthmaps[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = depth.log().nan_to_num(neginf=0) return param def get_depthmaps(self): return [d.exp() for d in self.im_depthmaps] def depth_to_pts3d(self): # Get depths and projection params if not provided focals = self.get_focals() pp = self.get_principal_points() im_poses = self.get_im_poses() depth = self.get_depthmaps() # convert focal to (1,2,H,W) constant field def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i]) # get pointmaps in camera frame rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])] # project to world frame return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)] def get_pts3d(self): return self.depth_to_pts3d() ================================================ FILE: dust3r/cloud_opt/optimizer.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Main class for the implementation of the global alignment # -------------------------------------------------------- import numpy as np import torch import torch.nn as nn from dust3r.cloud_opt.base_opt import BasePCOptimizer from dust3r.utils.geometry import xy_grid, geotrf from dust3r.utils.device import to_cpu, to_numpy class PointCloudOptimizer(BasePCOptimizer): """ Optimize a global scene, given a list of pairwise observations. Graph node: images Graph edges: observations = (pred1, pred2) """ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs): super().__init__(*args, **kwargs) self.has_im_poses = True # by definition of this class self.focal_break = focal_break # adding thing to optimize self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth) self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses self.im_focals = nn.ParameterList(torch.FloatTensor( [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics self.im_pp.requires_grad_(optimize_pp) self.imshape = self.imshapes[0] im_areas = [h*w for h, w in self.imshapes] self.max_area = max(im_areas) # adding thing to optimize self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area) self.im_poses = ParameterStack(self.im_poses, is_param=True) self.im_focals = ParameterStack(self.im_focals, is_param=True) self.im_pp = ParameterStack(self.im_pp, is_param=True) self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes])) self.register_buffer('_grid', ParameterStack( [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area)) # pre-compute pixel weights self.register_buffer('_weight_i', ParameterStack( [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area)) self.register_buffer('_weight_j', ParameterStack( [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area)) # precompute aa self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area)) self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area)) self.register_buffer('_ei', torch.tensor([i for i, j in self.edges])) self.register_buffer('_ej', torch.tensor([j for i, j in self.edges])) self.total_area_i = sum([im_areas[i] for i, j in self.edges]) self.total_area_j = sum([im_areas[j] for i, j in self.edges]) def _check_all_imgs_are_selected(self, msk): assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!' def preset_pose(self, known_poses, pose_msk=None): # cam-to-world self._check_all_imgs_are_selected(pose_msk) if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2: known_poses = [known_poses] for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses): if self.verbose: print(f' (setting pose #{idx} = {pose[:3,3]})') self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose))) # normalize scale if there's less than 1 known pose n_known_poses = sum((p.requires_grad is False) for p in self.im_poses) self.norm_pw_scale = (n_known_poses <= 1) self.im_poses.requires_grad_(False) self.norm_pw_scale = False def preset_focal(self, known_focals, msk=None): self._check_all_imgs_are_selected(msk) for idx, focal in zip(self._get_msk_indices(msk), known_focals): if self.verbose: print(f' (setting focal #{idx} = {focal})') self._no_grad(self._set_focal(idx, focal)) self.im_focals.requires_grad_(False) def preset_principal_point(self, known_pp, msk=None): self._check_all_imgs_are_selected(msk) for idx, pp in zip(self._get_msk_indices(msk), known_pp): if self.verbose: print(f' (setting principal point #{idx} = {pp})') self._no_grad(self._set_principal_point(idx, pp)) self.im_pp.requires_grad_(False) def _get_msk_indices(self, msk): if msk is None: return range(self.n_imgs) elif isinstance(msk, int): return [msk] elif isinstance(msk, (tuple, list)): return self._get_msk_indices(np.array(msk)) elif msk.dtype in (bool, torch.bool, np.bool_): assert len(msk) == self.n_imgs return np.where(msk)[0] elif np.issubdtype(msk.dtype, np.integer): return msk else: raise ValueError(f'bad {msk=}') def _no_grad(self, tensor): assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs' def _set_focal(self, idx, focal, force=False): param = self.im_focals[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = self.focal_break * np.log(focal) return param def get_focals(self): log_focals = torch.stack(list(self.im_focals), dim=0) return (log_focals / self.focal_break).exp() def get_known_focal_mask(self): return torch.tensor([not (p.requires_grad) for p in self.im_focals]) def _set_principal_point(self, idx, pp, force=False): param = self.im_pp[idx] H, W = self.imshapes[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10 return param def get_principal_points(self): return self._pp + 10 * self.im_pp def get_intrinsics(self): K = torch.zeros((self.n_imgs, 3, 3), device=self.device) focals = self.get_focals().flatten() K[:, 0, 0] = K[:, 1, 1] = focals K[:, :2, 2] = self.get_principal_points() K[:, 2, 2] = 1 return K def get_im_poses(self): # cam to world cam2world = self._get_poses(self.im_poses) return cam2world def _set_depthmap(self, idx, depth, force=False): depth = _ravel_hw(depth, self.max_area) param = self.im_depthmaps[idx] if param.requires_grad or force: # can only init a parameter not already initialized param.data[:] = depth.log().nan_to_num(neginf=0) return param def get_depthmaps(self, raw=False): res = self.im_depthmaps.exp() if not raw: res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)] return res def depth_to_pts3d(self): # Get depths and projection params if not provided focals = self.get_focals() pp = self.get_principal_points() im_poses = self.get_im_poses() depth = self.get_depthmaps(raw=True) # get pointmaps in camera frame rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) # project to world frame return geotrf(im_poses, rel_ptmaps) def get_pts3d(self, raw=False): res = self.depth_to_pts3d() if not raw: res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)] return res def forward(self): pw_poses = self.get_pw_poses() # cam-to-world pw_adapt = self.get_adaptors().unsqueeze(1) proj_pts3d = self.get_pts3d(raw=True) # rotate pairwise prediction according to pw_poses aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j) # compute the less li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j return li + lj def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp): pp = pp.unsqueeze(1) focal = focal.unsqueeze(1) assert focal.shape == (len(depth), 1, 1) assert pp.shape == (len(depth), 1, 2) assert pixel_grid.shape == depth.shape + (2,) depth = depth.unsqueeze(-1) return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) def ParameterStack(params, keys=None, is_param=None, fill=0): if keys is not None: params = [params[k] for k in keys] if fill > 0: params = [_ravel_hw(p, fill) for p in params] requires_grad = params[0].requires_grad assert all(p.requires_grad == requires_grad for p in params) params = torch.stack(list(params)).float().detach() if is_param or requires_grad: params = nn.Parameter(params) params.requires_grad_(requires_grad) return params def _ravel_hw(tensor, fill=0): # ravel H,W tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) if len(tensor) < fill: tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:]))) return tensor def acceptable_focal_range(H, W, minf=0.5, maxf=3.5): focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 return minf*focal_base, maxf*focal_base def apply_mask(img, msk): img = img.copy() img[msk] = 0 return img ================================================ FILE: dust3r/cloud_opt/pair_viewer.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dummy optimizer for visualizing pairs # -------------------------------------------------------- import numpy as np import torch import torch.nn as nn import cv2 from dust3r.cloud_opt.base_opt import BasePCOptimizer from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates from dust3r.cloud_opt.commons import edge_str from dust3r.post_process import estimate_focal_knowing_depth class PairViewer (BasePCOptimizer): """ This a Dummy Optimizer. To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.is_symmetrized and self.n_edges == 2 self.has_im_poses = True # compute all parameters directly from raw input self.focals = [] self.pp = [] rel_poses = [] confs = [] for i in range(self.n_imgs): conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) if self.verbose: print(f' - {conf=:.3} for edge {i}-{1-i}') confs.append(conf) H, W = self.imshapes[i] pts3d = self.pred_i[edge_str(i, 1-i)] pp = torch.tensor((W/2, H/2)) focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) self.focals.append(focal) self.pp.append(pp) # estimate the pose of pts1 in image 2 pixels = np.mgrid[:W, :H].T.astype(np.float32) pts3d = self.pred_j[edge_str(1-i, i)].numpy() assert pts3d.shape[:2] == (H, W) msk = self.get_masks()[i].numpy() K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) try: res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) success, R, T, inliers = res assert success R = cv2.Rodrigues(R)[0] # world to cam pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world except: pose = np.eye(4) rel_poses.append(torch.from_numpy(pose.astype(np.float32))) # let's use the pair with the most confidence if confs[0] > confs[1]: # ptcloud is expressed in camera1 self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] else: # ptcloud is expressed in camera2 self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) self.depth = nn.ParameterList(self.depth) for p in self.parameters(): p.requires_grad = False def _set_depthmap(self, idx, depth, force=False): if self.verbose: print('_set_depthmap is ignored in PairViewer') return def get_depthmaps(self, raw=False): depth = [d.to(self.device) for d in self.depth] return depth def _set_focal(self, idx, focal, force=False): self.focals[idx] = focal def get_focals(self): return self.focals def get_known_focal_mask(self): return torch.tensor([not (p.requires_grad) for p in self.focals]) def get_principal_points(self): return self.pp def get_intrinsics(self): focals = self.get_focals() pps = self.get_principal_points() K = torch.zeros((len(focals), 3, 3), device=self.device) for i in range(len(focals)): K[i, 0, 0] = K[i, 1, 1] = focals[i] K[i, :2, 2] = pps[i] K[i, 2, 2] = 1 return K def get_im_poses(self): return self.im_poses def depth_to_pts3d(self): pts3d = [] for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), intrinsics.cpu().numpy(), im_pose.cpu().numpy()) pts3d.append(torch.from_numpy(pts).to(device=self.device)) return pts3d def forward(self): return float('nan') ================================================ FILE: dust3r/datasets/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). from .utils.transforms import * from .base.batched_sampler import BatchedRandomSampler # noqa from .arkitscenes import ARKitScenes # noqa from .blendedmvs import BlendedMVS # noqa from .co3d import Co3d # noqa from .habitat import Habitat # noqa from .megadepth import MegaDepth # noqa from .scannetpp import ScanNetpp # noqa from .staticthings3d import StaticThings3D # noqa from .waymo import Waymo # noqa from .wildrgbd import WildRGBD # noqa def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): import torch from croco.utils.misc import get_world_size, get_rank # pytorch dataset if isinstance(dataset, str): dataset = eval(dataset) world_size = get_world_size() rank = get_rank() try: sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, rank=rank, drop_last=drop_last) except (AttributeError, NotImplementedError): # not avail for this dataset if torch.distributed.is_initialized(): sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last ) elif shuffle: sampler = torch.utils.data.RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_mem, drop_last=drop_last, ) return data_loader ================================================ FILE: dust3r/datasets/arkitscenes.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed arkitscenes # dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license # See datasets_preprocess/preprocess_arkitscenes.py # -------------------------------------------------------- import os.path as osp import cv2 import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class ARKitScenes(BaseStereoViewDataset): def __init__(self, *args, split, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) if split == "train": self.split = "Training" elif split == "test": self.split = "Test" else: raise ValueError("") self.loaded_data = self._load_data(self.split) def _load_data(self, split): with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data: self.scenes = data['scenes'] self.sceneids = data['sceneids'] self.images = data['images'] self.intrinsics = data['intrinsics'].astype(np.float32) self.trajectories = data['trajectories'].astype(np.float32) self.pairs = data['pairs'][:, :2].astype(int) def __len__(self): return len(self.pairs) def _get_views(self, idx, resolution, rng): image_idx1, image_idx2 = self.pairs[idx] views = [] for view_idx in [image_idx1, image_idx2]: scene_id = self.sceneids[view_idx] scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id]) intrinsics = self.intrinsics[view_idx] camera_pose = self.trajectories[view_idx] basename = self.images[view_idx] # Load RGB image rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg'))) # Load depthmap depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED) depthmap = depthmap.astype(np.float32) / 1000 depthmap[~np.isfinite(depthmap)] = 0 # invalid rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) views.append(dict( img=rgb_image, depthmap=depthmap.astype(np.float32), camera_pose=camera_pose.astype(np.float32), camera_intrinsics=intrinsics.astype(np.float32), dataset='arkitscenes', label=self.scenes[scene_id] + '_' + basename, instance=f'{str(idx)}_{str(view_idx)}', )) return views if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = ARKitScenes(split='train', ROOT="data/arkitscenes_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/base/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: dust3r/datasets/base/base_stereo_view_dataset.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # base class for implementing datasets # -------------------------------------------------------- import PIL import numpy as np import torch from dust3r.datasets.base.easy_dataset import EasyDataset from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates import dust3r.datasets.utils.cropping as cropping class BaseStereoViewDataset (EasyDataset): """ Define all basic options. Usage: class MyDataset (BaseStereoViewDataset): def _get_views(self, idx, rng): # overload here views = [] views.append(dict(img=, ...)) return views """ def __init__(self, *, # only keyword arguments split=None, resolution=None, # square_size or (width, height) or list of [(width,height), ...] transform=ImgNorm, aug_crop=False, seed=None): self.num_views = 2 self.split = split self._set_resolutions(resolution) if isinstance(transform, str): transform = eval(transform) self.transform = transform self.aug_crop = aug_crop self.seed = seed def __len__(self): return len(self.scenes) def get_stats(self): return f"{len(self)} pairs" def __repr__(self): resolutions_str = '[' + ';'.join(f'{w}x{h}' for w, h in self._resolutions) + ']' return f"""{type(self).__name__}({self.get_stats()}, {self.split=}, {self.seed=}, resolutions={resolutions_str}, {self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '') def _get_views(self, idx, resolution, rng): raise NotImplementedError() def __getitem__(self, idx): if isinstance(idx, tuple): # the idx is specifying the aspect-ratio idx, ar_idx = idx else: assert len(self._resolutions) == 1 ar_idx = 0 # set-up the rng if self.seed: # reseed for each __getitem__ self._rng = np.random.default_rng(seed=self.seed + idx) elif not hasattr(self, '_rng'): seed = torch.initial_seed() # this is different for each dataloader process self._rng = np.random.default_rng(seed=seed) # over-loaded code resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) views = self._get_views(idx, resolution, self._rng) assert len(views) == self.num_views # check data-types for v, view in enumerate(views): assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" view['idx'] = (idx, ar_idx, v) # encode the image width, height = view['img'].size view['true_shape'] = np.int32((height, width)) view['img'] = self.transform(view['img']) assert 'camera_intrinsics' in view if 'camera_pose' not in view: view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32) else: assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}' assert 'pts3d' not in view assert 'valid_mask' not in view assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}' pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) view['pts3d'] = pts3d view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) # check all datatypes for key, val in view.items(): res, err_msg = is_good_type(key, val) assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" K = view['camera_intrinsics'] # last thing done! for view in views: # transpose to make sure all views are the same size transpose_to_landscape(view) # this allows to check whether the RNG is is the same state each time view['rng'] = int.from_bytes(self._rng.bytes(4), 'big') return views def _set_resolutions(self, resolutions): assert resolutions is not None, 'undefined resolution' if not isinstance(resolutions, list): resolutions = [resolutions] self._resolutions = [] for resolution in resolutions: if isinstance(resolution, int): width = height = resolution else: width, height = resolution assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int' assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int' assert width >= height self._resolutions.append((width, height)) def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None): """ This function: - first downsizes the image with LANCZOS inteprolation, which is better than bilinear interpolation in """ if not isinstance(image, PIL.Image.Image): image = PIL.Image.fromarray(image) # downscale with lanczos interpolation so that image.size == resolution # cropping centered on the principal point W, H = image.size cx, cy = intrinsics[:2, 2].round().astype(int) min_margin_x = min(cx, W - cx) min_margin_y = min(cy, H - cy) # assert min_margin_x > W/5, f'Bad principal point in view={info}' # assert min_margin_y > H/5, f'Bad principal point in view={info}' # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) l, t = cx - min_margin_x, cy - min_margin_y r, b = cx + min_margin_x, cy + min_margin_y crop_bbox = (l, t, r, b) image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) # transpose the resolution if necessary W, H = image.size # new size assert resolution[0] >= resolution[1] if H > 1.1 * W: # image is portrait mode resolution = resolution[::-1] elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: # image is square, so we chose (portrait, landscape) randomly if rng.integers(2): resolution = resolution[::-1] # high-quality Lanczos down-scaling target_resolution = np.array(resolution) if self.aug_crop > 1: target_resolution += rng.integers(0, self.aug_crop) image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) # actual cropping (if necessary) with bilinear interpolation intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) return image, depthmap, intrinsics2 def is_good_type(key, v): """ returns (is_good, err_msg) """ if isinstance(v, (str, int, tuple)): return True, None if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): return False, f"bad {v.dtype=}" return True, None def view_name(view, batch_index=None): def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x db = sel(view['dataset']) label = sel(view['label']) instance = sel(view['instance']) return f"{db}/{label}/{instance}" def transpose_to_landscape(view): height, width = view['true_shape'] if width < height: # rectify portrait to landscape assert view['img'].shape == (3, height, width) view['img'] = view['img'].swapaxes(1, 2) assert view['valid_mask'].shape == (height, width) view['valid_mask'] = view['valid_mask'].swapaxes(0, 1) assert view['depthmap'].shape == (height, width) view['depthmap'] = view['depthmap'].swapaxes(0, 1) assert view['pts3d'].shape == (height, width, 3) view['pts3d'] = view['pts3d'].swapaxes(0, 1) # transpose x and y pixels view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]] ================================================ FILE: dust3r/datasets/base/batched_sampler.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Random sampling under a constraint # -------------------------------------------------------- import numpy as np import torch class BatchedRandomSampler: """ Random sampling under a constraint: each sample in the batch has the same feature, which is chosen randomly from a known pool of 'features' for each batch. For instance, the 'feature' could be the image aspect-ratio. The index returned is a tuple (sample_idx, feat_idx). This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. """ def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): self.batch_size = batch_size self.pool_size = pool_size self.len_dataset = N = len(dataset) self.total_size = round_by(N, batch_size*world_size) if drop_last else N assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' # distributed sampler self.world_size = world_size self.rank = rank self.epoch = None def __len__(self): return self.total_size // self.world_size def set_epoch(self, epoch): self.epoch = epoch def __iter__(self): # prepare RNG if self.epoch is None: assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' seed = int(torch.empty((), dtype=torch.int64).random_().item()) else: seed = self.epoch + 777 rng = np.random.default_rng(seed=seed) # random indices (will restart from 0 if not drop_last) sample_idxs = np.arange(self.total_size) rng.shuffle(sample_idxs) # random feat_idxs (same across each batch) n_batches = (self.total_size+self.batch_size-1) // self.batch_size feat_idxs = rng.integers(self.pool_size, size=n_batches) feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) feat_idxs = feat_idxs.ravel()[:self.total_size] # put them together idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) # Distributed sampler: we select a subset of batches # make sure the slice for each node is aligned with batch_size size_per_proc = self.batch_size * ((self.total_size + self.world_size * self.batch_size-1) // (self.world_size * self.batch_size)) idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] yield from (tuple(idx) for idx in idxs) def round_by(total, multiple, up=False): if up: total = total + multiple-1 return (total//multiple) * multiple ================================================ FILE: dust3r/datasets/base/easy_dataset.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # A dataset base class that you can easily resize and combine. # -------------------------------------------------------- import numpy as np from dust3r.datasets.base.batched_sampler import BatchedRandomSampler class EasyDataset: """ a dataset that you can easily resize and combine. Examples: --------- 2 * dataset ==> duplicate each element 2x 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) dataset1 + dataset2 ==> concatenate datasets """ def __add__(self, other): return CatDataset([self, other]) def __rmul__(self, factor): return MulDataset(factor, self) def __rmatmul__(self, factor): return ResizedDataset(factor, self) def set_epoch(self, epoch): pass # nothing to do by default def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): if not (shuffle): raise NotImplementedError() # cannot deal yet num_of_aspect_ratios = len(self._resolutions) return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) class MulDataset (EasyDataset): """ Artifically augmenting the size of a dataset. """ multiplicator: int def __init__(self, multiplicator, dataset): assert isinstance(multiplicator, int) and multiplicator > 0 self.multiplicator = multiplicator self.dataset = dataset def __len__(self): return self.multiplicator * len(self.dataset) def __repr__(self): return f'{self.multiplicator}*{repr(self.dataset)}' def __getitem__(self, idx): if isinstance(idx, tuple): idx, other = idx return self.dataset[idx // self.multiplicator, other] else: return self.dataset[idx // self.multiplicator] @property def _resolutions(self): return self.dataset._resolutions class ResizedDataset (EasyDataset): """ Artifically changing the size of a dataset. """ new_size: int def __init__(self, new_size, dataset): assert isinstance(new_size, int) and new_size > 0 self.new_size = new_size self.dataset = dataset def __len__(self): return self.new_size def __repr__(self): size_str = str(self.new_size) for i in range((len(size_str)-1) // 3): sep = -4*i-3 size_str = size_str[:sep] + '_' + size_str[sep:] return f'{size_str} @ {repr(self.dataset)}' def set_epoch(self, epoch): # this random shuffle only depends on the epoch rng = np.random.default_rng(seed=epoch+777) # shuffle all indices perm = rng.permutation(len(self.dataset)) # rotary extension until target size is met shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) self._idxs_mapping = shuffled_idxs[:self.new_size] assert len(self._idxs_mapping) == self.new_size def __getitem__(self, idx): assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' if isinstance(idx, tuple): idx, other = idx return self.dataset[self._idxs_mapping[idx], other] else: return self.dataset[self._idxs_mapping[idx]] @property def _resolutions(self): return self.dataset._resolutions class CatDataset (EasyDataset): """ Concatenation of several datasets """ def __init__(self, datasets): for dataset in datasets: assert isinstance(dataset, EasyDataset) self.datasets = datasets self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) def __len__(self): return self._cum_sizes[-1] def __repr__(self): # remove uselessly long transform return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) def set_epoch(self, epoch): for dataset in self.datasets: dataset.set_epoch(epoch) def __getitem__(self, idx): other = None if isinstance(idx, tuple): idx, other = idx if not (0 <= idx < len(self)): raise IndexError() db_idx = np.searchsorted(self._cum_sizes, idx, 'right') dataset = self.datasets[db_idx] new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) if other is not None: new_idx = (new_idx, other) return dataset[new_idx] @property def _resolutions(self): resolutions = self.datasets[0]._resolutions for dataset in self.datasets[1:]: assert tuple(dataset._resolutions) == tuple(resolutions) return resolutions ================================================ FILE: dust3r/datasets/blendedmvs.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed BlendedMVS # dataset at https://github.com/YoYo000/BlendedMVS # See datasets_preprocess/preprocess_blendedmvs.py # -------------------------------------------------------- import os.path as osp import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class BlendedMVS (BaseStereoViewDataset): """ Dataset of outdoor street scenes, 5 images each time """ def __init__(self, *args, ROOT, split=None, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) self._load_data(split) def _load_data(self, split): pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy')) if split is None: selection = slice(None) if split == 'train': # select 90% of all scenes selection = (pairs['seq_low'] % 10) > 0 if split == 'val': # select 10% of all scenes selection = (pairs['seq_low'] % 10) == 0 self.pairs = pairs[selection] # list of all scenes self.scenes = np.unique(self.pairs['seq_low']) # low is unique enough def __len__(self): return len(self.pairs) def get_stats(self): return f'{len(self)} pairs from {len(self.scenes)} scenes' def _get_views(self, pair_idx, resolution, rng): seqh, seql, img1, img2, score = self.pairs[pair_idx] seq = f"{seqh:08x}{seql:016x}" seq_path = osp.join(self.ROOT, seq) views = [] for view_index in [img1, img2]: impath = f"{view_index:08n}" image = imread_cv2(osp.join(seq_path, impath + ".jpg")) depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) camera_params = np.load(osp.join(seq_path, impath + ".npz")) intrinsics = np.float32(camera_params['intrinsics']) camera_pose = np.eye(4, dtype=np.float32) camera_pose[:3, :3] = camera_params['R_cam2world'] camera_pose[:3, 3] = camera_params['t_cam2world'] image, depthmap, intrinsics = self._crop_resize_if_necessary( image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) views.append(dict( img=image, depthmap=depthmap, camera_pose=camera_pose, # cam2world camera_intrinsics=intrinsics, dataset='BlendedMVS', label=osp.relpath(seq_path, self.ROOT), instance=impath)) return views if __name__ == '__main__': from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = BlendedMVS(split='train', ROOT="data/blendedmvs_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(idx, view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/co3d.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed Co3d_v2 # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International # See datasets_preprocess/preprocess_co3d.py # -------------------------------------------------------- import os.path as osp import json import itertools from collections import deque import cv2 import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class Co3d(BaseStereoViewDataset): def __init__(self, mask_bg=True, *args, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) assert mask_bg in (True, False, 'rand') self.mask_bg = mask_bg self.dataset_label = 'Co3d_v2' # load all scenes with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: self.scenes = json.load(f) self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} self.scenes = {(k, k2): v2 for k, v in self.scenes.items() for k2, v2 in v.items()} self.scene_list = list(self.scenes.keys()) # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees self.combinations = [(i, j) for i, j in itertools.combinations(range(100), 2) if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0] self.invalidate = {scene: {} for scene in self.scene_list} def __len__(self): return len(self.scene_list) * len(self.combinations) def _get_metadatapath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') def _get_impath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') def _get_depthpath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') def _get_maskpath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') def _read_depthmap(self, depthpath, input_metadata): depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) return depthmap def _get_views(self, idx, resolution, rng): # choose a scene obj, instance = self.scene_list[idx // len(self.combinations)] image_pool = self.scenes[obj, instance] im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] # add a bit of randomness last = len(image_pool) - 1 if resolution not in self.invalidate[obj, instance]: # flag invalid images self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] # decide now if we mask the bg mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) views = [] imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] imgs_idxs = deque(imgs_idxs) while len(imgs_idxs) > 0: # some images (few) have zero depth im_idx = imgs_idxs.pop() if self.invalidate[obj, instance][resolution][im_idx]: # search for a valid image random_direction = 2 * rng.choice(2) - 1 for offset in range(1, len(image_pool)): tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) if not self.invalidate[obj, instance][resolution][tentative_im_idx]: im_idx = tentative_im_idx break view_idx = image_pool[im_idx] impath = self._get_impath(obj, instance, view_idx) depthpath = self._get_depthpath(obj, instance, view_idx) # load camera params metadata_path = self._get_metadatapath(obj, instance, view_idx) input_metadata = np.load(metadata_path) camera_pose = input_metadata['camera_pose'].astype(np.float32) intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) # load image and depth rgb_image = imread_cv2(impath) depthmap = self._read_depthmap(depthpath, input_metadata) if mask_bg: # load object mask maskpath = self._get_maskpath(obj, instance, view_idx) maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) maskmap = (maskmap / 255.0) > 0.1 # update the depthmap with mask depthmap *= maskmap rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) num_valid = (depthmap > 0.0).sum() if num_valid == 0: # problem, invalidate image and retry self.invalidate[obj, instance][resolution][im_idx] = True imgs_idxs.append(im_idx) continue views.append(dict( img=rgb_image, depthmap=depthmap, camera_pose=camera_pose, camera_intrinsics=intrinsics, dataset=self.dataset_label, label=osp.join(obj, instance), instance=osp.split(impath)[1], )) return views if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/habitat.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed habitat # dataset at https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md # See datasets_preprocess/habitat for more details # -------------------------------------------------------- import os.path as osp import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # noqa import cv2 # noqa import numpy as np from PIL import Image import json from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset class Habitat(BaseStereoViewDataset): def __init__(self, size, *args, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) assert self.split is not None # loading list of scenes with open(osp.join(self.ROOT, f'Habitat_{size}_scenes_{self.split}.txt')) as f: self.scenes = f.read().splitlines() self.instances = list(range(1, 5)) def filter_scene(self, label, instance=None): if instance: subscene, instance = instance.split('_') label += '/' + subscene self.instances = [int(instance) - 1] valid = np.bool_([scene.startswith(label) for scene in self.scenes]) assert sum(valid), 'no scene was selected for {label=} {instance=}' self.scenes = [scene for i, scene in enumerate(self.scenes) if valid[i]] def _get_views(self, idx, resolution, rng): scene = self.scenes[idx] data_path, key = osp.split(osp.join(self.ROOT, scene)) views = [] two_random_views = [0, rng.choice(self.instances)] # view 0 is connected with all other views for view_index in two_random_views: # load the view (and use the next one if this one's broken) for ii in range(view_index, view_index + 5): image, depthmap, intrinsics, camera_pose = self._load_one_view(data_path, key, ii % 5, resolution, rng) if np.isfinite(camera_pose).all(): break views.append(dict( img=image, depthmap=depthmap, camera_pose=camera_pose, # cam2world camera_intrinsics=intrinsics, dataset='Habitat', label=osp.relpath(data_path, self.ROOT), instance=f"{key}_{view_index}")) return views def _load_one_view(self, data_path, key, view_index, resolution, rng): view_index += 1 # file indices starts at 1 impath = osp.join(data_path, f"{key}_{view_index}.jpeg") image = Image.open(impath) depthmap_filename = osp.join(data_path, f"{key}_{view_index}_depth.exr") depthmap = cv2.imread(depthmap_filename, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH) camera_params_filename = osp.join(data_path, f"{key}_{view_index}_camera_params.json") with open(camera_params_filename, 'r') as f: camera_params = json.load(f) intrinsics = np.float32(camera_params['camera_intrinsics']) camera_pose = np.eye(4, dtype=np.float32) camera_pose[:3, :3] = camera_params['R_cam2world'] camera_pose[:3, 3] = camera_params['t_cam2world'] image, depthmap, intrinsics = self._crop_resize_if_necessary( image, depthmap, intrinsics, resolution, rng, info=impath) return image, depthmap, intrinsics, camera_pose if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = Habitat(1_000_000, split='train', ROOT="data/habitat_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/megadepth.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed MegaDepth # dataset at https://www.cs.cornell.edu/projects/megadepth/ # See datasets_preprocess/preprocess_megadepth.py # -------------------------------------------------------- import os.path as osp import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class MegaDepth(BaseStereoViewDataset): def __init__(self, *args, split, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) self.loaded_data = self._load_data(self.split) if self.split is None: pass elif self.split == 'train': self.select_scene(('0015', '0022'), opposite=True) elif self.split == 'val': self.select_scene(('0015', '0022')) else: raise ValueError(f'bad {self.split=}') def _load_data(self, split): with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: self.all_scenes = data['scenes'] self.all_images = data['images'] self.pairs = data['pairs'] def __len__(self): return len(self.pairs) def get_stats(self): return f'{len(self)} pairs from {len(self.all_scenes)} scenes' def select_scene(self, scene, *instances, opposite=False): scenes = (scene,) if isinstance(scene, str) else tuple(scene) scene_id = [s.startswith(scenes) for s in self.all_scenes] assert any(scene_id), 'no scene found' valid = np.in1d(self.pairs['scene_id'], np.nonzero(scene_id)[0]) if instances: image_id = [i.startswith(instances) for i in self.all_images] image_id = np.nonzero(image_id)[0] assert len(image_id), 'no instance found' # both together? if len(instances) == 2: valid &= np.in1d(self.pairs['im1_id'], image_id) & np.in1d(self.pairs['im2_id'], image_id) else: valid &= np.in1d(self.pairs['im1_id'], image_id) | np.in1d(self.pairs['im2_id'], image_id) if opposite: valid = ~valid assert valid.any() self.pairs = self.pairs[valid] def _get_views(self, pair_idx, resolution, rng): scene_id, im1_id, im2_id, score = self.pairs[pair_idx] scene, subscene = self.all_scenes[scene_id].split() seq_path = osp.join(self.ROOT, scene, subscene) views = [] for im_id in [im1_id, im2_id]: img = self.all_images[im_id] try: image = imread_cv2(osp.join(seq_path, img + '.jpg')) depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) camera_params = np.load(osp.join(seq_path, img + ".npz")) except Exception as e: raise OSError(f'cannot load {img}, got exception {e}') intrinsics = np.float32(camera_params['intrinsics']) camera_pose = np.float32(camera_params['cam2world']) image, depthmap, intrinsics = self._crop_resize_if_necessary( image, depthmap, intrinsics, resolution, rng, info=(seq_path, img)) views.append(dict( img=image, depthmap=depthmap, camera_pose=camera_pose, # cam2world camera_intrinsics=intrinsics, dataset='MegaDepth', label=osp.relpath(seq_path, self.ROOT), instance=img)) return views if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = MegaDepth(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(idx, view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/scannetpp.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed scannet++ # dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes # https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf # See datasets_preprocess/preprocess_scannetpp.py # -------------------------------------------------------- import os.path as osp import cv2 import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class ScanNetpp(BaseStereoViewDataset): def __init__(self, *args, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) assert self.split == 'train' self.loaded_data = self._load_data() def _load_data(self): with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data: self.scenes = data['scenes'] self.sceneids = data['sceneids'] self.images = data['images'] self.intrinsics = data['intrinsics'].astype(np.float32) self.trajectories = data['trajectories'].astype(np.float32) self.pairs = data['pairs'][:, :2].astype(int) def __len__(self): return len(self.pairs) def _get_views(self, idx, resolution, rng): image_idx1, image_idx2 = self.pairs[idx] views = [] for view_idx in [image_idx1, image_idx2]: scene_id = self.sceneids[view_idx] scene_dir = osp.join(self.ROOT, self.scenes[scene_id]) intrinsics = self.intrinsics[view_idx] camera_pose = self.trajectories[view_idx] basename = self.images[view_idx] # Load RGB image rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg')) # Load depthmap depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED) depthmap = depthmap.astype(np.float32) / 1000 depthmap[~np.isfinite(depthmap)] = 0 # invalid rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx) views.append(dict( img=rgb_image, depthmap=depthmap.astype(np.float32), camera_pose=camera_pose.astype(np.float32), camera_intrinsics=intrinsics.astype(np.float32), dataset='ScanNet++', label=self.scenes[scene_id] + '_' + basename, instance=f'{str(idx)}_{str(view_idx)}', )) return views if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = ScanNetpp(split='train', ROOT="data/scannetpp_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx*255, (1 - idx)*255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/staticthings3d.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed StaticThings3D # dataset at https://github.com/lmb-freiburg/robustmvd/ # See datasets_preprocess/preprocess_staticthings3d.py # -------------------------------------------------------- import os.path as osp import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class StaticThings3D (BaseStereoViewDataset): """ Dataset of indoor scenes, 5 images each time """ def __init__(self, ROOT, *args, mask_bg='rand', **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) assert mask_bg in (True, False, 'rand') self.mask_bg = mask_bg # loading all pairs assert self.split is None self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy')) def __len__(self): return len(self.pairs) def get_stats(self): return f'{len(self)} pairs' def _get_views(self, pair_idx, resolution, rng): scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx] seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}') views = [] mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) CAM = {b'l':'left', b'r':'right'} for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]: num = f"{idx:04n}" img = num+"_clean.jpg" if rng.choice(2) else num+"_final.jpg" image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img)) depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+".exr")) camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+".npz")) intrinsics = camera_params['intrinsics'] camera_pose = camera_params['cam2world'] if mask_bg: depthmap[depthmap > 200] = 0 image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img)) views.append(dict( img = image, depthmap = depthmap, camera_pose = camera_pose, # cam2world camera_intrinsics = intrinsics, dataset = 'StaticThings3D', label = seq_path, instance = cam+'_'+img)) return views if __name__ == '__main__': from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = StaticThings3D(ROOT="data/staticthings3d_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(idx, view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx*255, (1 - idx)*255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/utils/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: dust3r/datasets/utils/cropping.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # croppping utilities # -------------------------------------------------------- import PIL.Image import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 # noqa import numpy as np # noqa from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa try: lanczos = PIL.Image.Resampling.LANCZOS bicubic = PIL.Image.Resampling.BICUBIC except AttributeError: lanczos = PIL.Image.LANCZOS bicubic = PIL.Image.BICUBIC class ImageList: """ Convenience class to aply the same operation to a whole set of images. """ def __init__(self, images): if not isinstance(images, (tuple, list, set)): images = [images] self.images = [] for image in images: if not isinstance(image, PIL.Image.Image): image = PIL.Image.fromarray(image) self.images.append(image) def __len__(self): return len(self.images) def to_pil(self): return tuple(self.images) if len(self.images) > 1 else self.images[0] @property def size(self): sizes = [im.size for im in self.images] assert all(sizes[0] == s for s in sizes) return sizes[0] def resize(self, *args, **kwargs): return ImageList(self._dispatch('resize', *args, **kwargs)) def crop(self, *args, **kwargs): return ImageList(self._dispatch('crop', *args, **kwargs)) def _dispatch(self, func, *args, **kwargs): return [getattr(im, func)(*args, **kwargs) for im in self.images] def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True): """ Jointly rescale a (image, depthmap) so that (out_width, out_height) >= output_res """ image = ImageList(image) input_resolution = np.array(image.size) # (W,H) output_resolution = np.array(output_resolution) if depthmap is not None: # can also use this with masks instead of depthmaps assert tuple(depthmap.shape[:2]) == image.size[::-1] # define output resolution assert output_resolution.shape == (2,) scale_final = max(output_resolution / image.size) + 1e-8 if scale_final >= 1 and not force: # image is already smaller than what is asked return (image.to_pil(), depthmap, camera_intrinsics) output_resolution = np.floor(input_resolution * scale_final).astype(int) # first rescale the image so that it contains the crop image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic) if depthmap is not None: depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, fy=scale_final, interpolation=cv2.INTER_NEAREST) # no offset here; simple rescaling camera_intrinsics = camera_matrix_of_crop( camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) return image.to_pil(), depthmap, camera_intrinsics def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): # Margins to offset the origin margins = np.asarray(input_resolution) * scaling - output_resolution assert np.all(margins >= 0.0) if offset is None: offset = offset_factor * margins # Generate new camera parameters output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) output_camera_matrix_colmap[:2, :] *= scaling output_camera_matrix_colmap[:2, 2] -= offset output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) return output_camera_matrix def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): """ Return a crop of the input view. """ image = ImageList(image) l, t, r, b = crop_bbox image = image.crop((l, t, r, b)) depthmap = depthmap[t:b, l:r] camera_intrinsics = camera_intrinsics.copy() camera_intrinsics[0, 2] -= l camera_intrinsics[1, 2] -= t return image.to_pil(), depthmap, camera_intrinsics def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): out_width, out_height = output_resolution l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) crop_bbox = (l, t, l + out_width, t + out_height) return crop_bbox ================================================ FILE: dust3r/datasets/utils/transforms.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # DUST3R default transforms # -------------------------------------------------------- import torchvision.transforms as tvf from dust3r.utils.image import ImgNorm # define the standard image transforms ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) ================================================ FILE: dust3r/datasets/waymo.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed WayMo # dataset at https://github.com/waymo-research/waymo-open-dataset # See datasets_preprocess/preprocess_waymo.py # -------------------------------------------------------- import os.path as osp import numpy as np from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset from dust3r.utils.image import imread_cv2 class Waymo (BaseStereoViewDataset): """ Dataset of outdoor street scenes, 5 images each time """ def __init__(self, *args, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) self._load_data() def _load_data(self): with np.load(osp.join(self.ROOT, 'waymo_pairs.npz')) as data: self.scenes = data['scenes'] self.frames = data['frames'] self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])} self.pairs = data['pairs'] # (array of (scene_id, img1_id, img2_id) assert self.pairs[:, 0].max() == len(self.scenes) - 1 def __len__(self): return len(self.pairs) def get_stats(self): return f'{len(self)} pairs from {len(self.scenes)} scenes' def _get_views(self, pair_idx, resolution, rng): seq, img1, img2 = self.pairs[pair_idx] seq_path = osp.join(self.ROOT, self.scenes[seq]) views = [] for view_index in [img1, img2]: impath = self.frames[view_index] image = imread_cv2(osp.join(seq_path, impath + ".jpg")) depthmap = imread_cv2(osp.join(seq_path, impath + ".exr")) camera_params = np.load(osp.join(seq_path, impath + ".npz")) intrinsics = np.float32(camera_params['intrinsics']) camera_pose = np.float32(camera_params['cam2world']) image, depthmap, intrinsics = self._crop_resize_if_necessary( image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath)) views.append(dict( img=image, depthmap=depthmap, camera_pose=camera_pose, # cam2world camera_intrinsics=intrinsics, dataset='Waymo', label=osp.relpath(seq_path, self.ROOT), instance=impath)) return views if __name__ == '__main__': from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = Waymo(split='train', ROOT="data/megadepth_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(idx, view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/datasets/wildrgbd.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Dataloader for preprocessed WildRGB-D # dataset at https://github.com/wildrgbd/wildrgbd/ # See datasets_preprocess/preprocess_wildrgbd.py # -------------------------------------------------------- import os.path as osp import cv2 import numpy as np from dust3r.datasets.co3d import Co3d from dust3r.utils.image import imread_cv2 class WildRGBD(Co3d): def __init__(self, mask_bg=True, *args, ROOT, **kwargs): super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs) self.dataset_label = 'WildRGBD' def _get_metadatapath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz') def _get_impath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg') def _get_depthpath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png') def _get_maskpath(self, obj, instance, view_idx): return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png') def _read_depthmap(self, depthpath, input_metadata): # We store depths in the depth scale of 1000. # That is, when we load depth image and divide by 1000, we could get depth in meters. depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) depthmap = depthmap.astype(np.float32) / 1000.0 return depthmap if __name__ == "__main__": from dust3r.datasets.base.base_stereo_view_dataset import view_name from dust3r.viz import SceneViz, auto_cam_size from dust3r.utils.image import rgb dataset = WildRGBD(split='train', ROOT="data/wildrgbd_processed", resolution=224, aug_crop=16) for idx in np.random.permutation(len(dataset)): views = dataset[idx] assert len(views) == 2 print(view_name(views[0]), view_name(views[1])) viz = SceneViz() poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] cam_size = max(auto_cam_size(poses), 0.001) for view_idx in [0, 1]: pts3d = views[view_idx]['pts3d'] valid_mask = views[view_idx]['valid_mask'] colors = rgb(views[view_idx]['img']) viz.add_pointcloud(pts3d, colors, valid_mask) viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], focal=views[view_idx]['camera_intrinsics'][0, 0], color=(idx * 255, (1 - idx) * 255, 0), image=colors, cam_size=cam_size) viz.show() ================================================ FILE: dust3r/demo.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # gradio demo # -------------------------------------------------------- import argparse import math import builtins import datetime import gradio import os import torch import numpy as np import functools import trimesh import copy from scipy.spatial.transform import Rotation from dust3r.inference import inference from dust3r.image_pairs import make_pairs from dust3r.utils.image import load_images, rgb from dust3r.utils.device import to_numpy from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes from dust3r.cloud_opt import global_aligner, GlobalAlignerMode import matplotlib.pyplot as pl def get_args_parser(): parser = argparse.ArgumentParser() parser_url = parser.add_mutually_exclusive_group() parser_url.add_argument("--local_network", action='store_true', default=False, help="make app accessible on local network: address will be set to 0.0.0.0") parser_url.add_argument("--server_name", type=str, default=None, help="server url, default is 127.0.0.1") parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size") parser.add_argument("--server_port", type=int, help=("will start gradio app on this port (if available). " "If None, will search for an available port starting at 7860."), default=None) parser_weights = parser.add_mutually_exclusive_group(required=True) parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) parser_weights.add_argument("--model_name", type=str, help="name of the model weights", choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", "DUSt3R_ViTLarge_BaseDecoder_512_linear", "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) parser.add_argument("--device", type=str, default='cuda', help="pytorch device") parser.add_argument("--tmp_dir", type=str, default=None, help="value for tempfile.tempdir") parser.add_argument("--silent", action='store_true', default=False, help="silence logs") return parser def set_print_with_timestamp(time_format="%Y-%m-%d %H:%M:%S"): builtin_print = builtins.print def print_with_timestamp(*args, **kwargs): now = datetime.datetime.now() formatted_date_time = now.strftime(time_format) builtin_print(f'[{formatted_date_time}] ', end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print_with_timestamp def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, cam_color=None, as_pointcloud=False, transparent_cams=False, silent=False): assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals) pts3d = to_numpy(pts3d) imgs = to_numpy(imgs) focals = to_numpy(focals) cams2world = to_numpy(cams2world) scene = trimesh.Scene() # full pointcloud if as_pointcloud: pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) scene.add_geometry(pct) else: meshes = [] for i in range(len(imgs)): meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) mesh = trimesh.Trimesh(**cat_meshes(meshes)) scene.add_geometry(mesh) # add each camera for i, pose_c2w in enumerate(cams2world): if isinstance(cam_color, list): camera_edge_color = cam_color[i] else: camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] add_scene_cam(scene, pose_c2w, camera_edge_color, None if transparent_cams else imgs[i], focals[i], imsize=imgs[i].shape[1::-1], screen_width=cam_size) rot = np.eye(4) rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) outfile = os.path.join(outdir, 'scene.glb') if not silent: print('(exporting 3D scene to', outfile, ')') scene.export(file_obj=outfile) return outfile def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, clean_depth=False, transparent_cams=False, cam_size=0.05): """ extract 3D_model (glb file) from a reconstructed scene """ if scene is None: return None # post processes if clean_depth: scene = scene.clean_pointcloud() if mask_sky: scene = scene.mask_sky() # get optimized values from scene rgbimg = scene.imgs focals = scene.get_focals().cpu() cams2world = scene.get_im_poses().cpu() # 3D pointcloud from depthmap, poses and intrinsics pts3d = to_numpy(scene.get_pts3d()) scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) msk = to_numpy(scene.get_masks()) return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, transparent_cams=transparent_cams, cam_size=cam_size, silent=silent) def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize, refid): """ from a list of images, run dust3r inference, global aligner. then run get_3D_model_from_scene """ try: square_ok = model.square_ok except Exception as e: square_ok = False imgs = load_images(filelist, size=image_size, verbose=not silent, patch_size=model.patch_size, square_ok=square_ok) if len(imgs) == 1: imgs = [imgs[0], copy.deepcopy(imgs[0])] imgs[1]['idx'] = 1 if scenegraph_type == "swin": scenegraph_type = scenegraph_type + "-" + str(winsize) elif scenegraph_type == "oneref": scenegraph_type = scenegraph_type + "-" + str(refid) pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) output = inference(pairs, model, device, batch_size=1, verbose=not silent) mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer scene = global_aligner(output, device=device, mode=mode, verbose=not silent) lr = 0.01 if mode == GlobalAlignerMode.PointCloudOptimizer: loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size) # also return rgb, depth and confidence imgs # depth is normalized with the max value for all images # we apply the jet colormap on the confidence maps rgbimg = scene.imgs depths = to_numpy(scene.get_depthmaps()) confs = to_numpy([c for c in scene.im_conf]) cmap = pl.get_cmap('jet') depths_max = max([d.max() for d in depths]) depths = [d / depths_max for d in depths] confs_max = max([d.max() for d in confs]) confs = [cmap(d / confs_max) for d in confs] imgs = [] for i in range(len(rgbimg)): imgs.append(rgbimg[i]) imgs.append(rgb(depths[i])) imgs.append(rgb(confs[i])) return scene, outfile, imgs def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): num_files = len(inputfiles) if inputfiles is not None else 1 max_winsize = max(1, math.ceil((num_files - 1) / 2)) if scenegraph_type == "swin": winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, minimum=1, maximum=max_winsize, step=1, visible=True) refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=num_files - 1, step=1, visible=False) elif scenegraph_type == "oneref": winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, minimum=1, maximum=max_winsize, step=1, visible=False) refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=num_files - 1, step=1, visible=True) else: winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize, minimum=1, maximum=max_winsize, step=1, visible=False) refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=num_files - 1, step=1, visible=False) return winsize, refid def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False): recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size) model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo: # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference scene = gradio.State(None) gradio.HTML('

DUSt3R Demo

') with gradio.Column(): inputfiles = gradio.File(file_count="multiple") with gradio.Row(): schedule = gradio.Dropdown(["linear", "cosine"], value='linear', label="schedule", info="For global alignment!") niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, label="num_iterations", info="For global alignment!") scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"), ("swin: sliding window", "swin"), ("oneref: match one image with all", "oneref")], value='complete', label="Scenegraph", info="Define how to make pairs", interactive=True) winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, minimum=1, maximum=1, step=1, visible=False) refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) run_btn = gradio.Button("Run") with gradio.Row(): # adjust the confidence threshold min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1) # adjust the camera size in the output pointcloud cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) with gradio.Row(): as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud") # two post process implemented mask_sky = gradio.Checkbox(value=False, label="Mask sky") clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") outmodel = gradio.Model3D() outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%") # events scenegraph_type.change(set_scenegraph_options, inputs=[inputfiles, winsize, refid, scenegraph_type], outputs=[winsize, refid]) inputfiles.change(set_scenegraph_options, inputs=[inputfiles, winsize, refid, scenegraph_type], outputs=[winsize, refid]) run_btn.click(fn=recon_fun, inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize, refid], outputs=[scene, outmodel, outgallery]) min_conf_thr.release(fn=model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) cam_size.change(fn=model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) as_pointcloud.change(fn=model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) mask_sky.change(fn=model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) clean_depth.change(fn=model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) transparent_cams.change(model_from_scene_fun, inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size], outputs=outmodel) demo.launch(share=False, server_name=server_name, server_port=server_port) ================================================ FILE: dust3r/heads/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # head factory # -------------------------------------------------------- from .linear_head import LinearPts3d from .dpt_head import create_dpt_head def head_factory(head_type, output_mode, net, has_conf=False): """" build a prediction head for the decoder """ if head_type == 'linear' and output_mode == 'pts3d': return LinearPts3d(net, has_conf) elif head_type == 'dpt' and output_mode == 'pts3d': return create_dpt_head(net, has_conf=has_conf) else: raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") ================================================ FILE: dust3r/heads/dpt_head.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # dpt head implementation for DUST3R # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True # the forward function also takes as input a dictionnary img_info with key "height" and "width" # for PixelwiseTask, the output will be of dimension B x num_channels x H x W # -------------------------------------------------------- from einops import rearrange from typing import List import torch import torch.nn as nn from dust3r.heads.postprocess import postprocess import dust3r.utils.path_to_croco # noqa: F401 from models.dpt_block import DPTOutputAdapter # noqa class DPTOutputAdapter_fix(DPTOutputAdapter): """ Adapt croco's DPTOutputAdapter implementation for dust3r: remove duplicated weigths, and fix forward for dust3r """ def init(self, dim_tokens_enc=768): super().init(dim_tokens_enc) # these are duplicated weights del self.act_1_postprocess del self.act_2_postprocess del self.act_3_postprocess del self.act_4_postprocess def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' # H, W = input_info['image_size'] image_size = self.image_size if image_size is None else image_size H, W = image_size # Number of patches in height and width N_H = H // (self.stride_level * self.P_H) N_W = W // (self.stride_level * self.P_W) # Hook decoder onto 4 layers from specified ViT layers layers = [encoder_tokens[hook] for hook in self.hooks] # Extract only task-relevant tokens and ignore global tokens. layers = [self.adapt_tokens(l) for l in layers] # Reshape tokens to spatial representation layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] # Project layers to chosen feature dim layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] # Fuse layers using refinement stages path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] path_3 = self.scratch.refinenet3(path_4, layers[2]) path_2 = self.scratch.refinenet2(path_3, layers[1]) path_1 = self.scratch.refinenet1(path_2, layers[0]) # Output head out = self.head(path_1) return out class PixelwiseTaskWithDPT(nn.Module): """ DPT module for dust3r, can return 3D points + confidence for all pixels""" def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): super(PixelwiseTaskWithDPT, self).__init__() self.return_all_layers = True # backbone needs to return all layers self.postprocess = postprocess self.depth_mode = depth_mode self.conf_mode = conf_mode assert n_cls_token == 0, "Not implemented" dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs) if hooks_idx is not None: dpt_args.update(hooks=hooks_idx) self.dpt = DPTOutputAdapter_fix(**dpt_args) dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} self.dpt.init(**dpt_init_args) def forward(self, x, img_info): out = self.dpt(x, image_size=(img_info[0], img_info[1])) if self.postprocess: out = self.postprocess(out, self.depth_mode, self.conf_mode) return out def create_dpt_head(net, has_conf=False): """ return PixelwiseTaskWithDPT for given net params """ assert net.dec_depth > 9 l2 = net.dec_depth feature_dim = 256 last_dim = feature_dim//2 out_nchan = 3 ed = net.enc_embed_dim dd = net.dec_embed_dim return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=[0, l2*2//4, l2*3//4, l2], dim_tokens=[ed, dd, dd, dd], postprocess=postprocess, depth_mode=net.depth_mode, conf_mode=net.conf_mode, head_type='regression') ================================================ FILE: dust3r/heads/linear_head.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # linear head implementation for DUST3R # -------------------------------------------------------- import torch.nn as nn import torch.nn.functional as F from dust3r.heads.postprocess import postprocess class LinearPts3d (nn.Module): """ Linear head for dust3r Each token outputs: - 16x16 3D points (+ confidence) """ def __init__(self, net, has_conf=False): super().__init__() self.patch_size = net.patch_embed.patch_size[0] self.depth_mode = net.depth_mode self.conf_mode = net.conf_mode self.has_conf = has_conf self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) def setup(self, croconet): pass def forward(self, decout, img_shape): H, W = img_shape tokens = decout[-1] B, S, D = tokens.shape # extract 3D points feat = self.proj(tokens) # B,S,D feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W # permute + norm depth return postprocess(feat, self.depth_mode, self.conf_mode) ================================================ FILE: dust3r/heads/postprocess.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # post process function for all heads: extract 3D points/confidence from output # -------------------------------------------------------- import torch def postprocess(out, depth_mode, conf_mode): """ extract 3D points/confidence from prediction head output """ fmap = out.permute(0, 2, 3, 1) # B,H,W,3 res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) if conf_mode is not None: res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) return res def reg_dense_depth(xyz, mode): """ extract 3D points from prediction head output """ mode, vmin, vmax = mode no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) assert no_bounds if mode == 'linear': if no_bounds: return xyz # [-inf, +inf] return xyz.clip(min=vmin, max=vmax) # distance to origin d = xyz.norm(dim=-1, keepdim=True) xyz = xyz / d.clip(min=1e-8) if mode == 'square': return xyz * d.square() if mode == 'exp': return xyz * torch.expm1(d) raise ValueError(f'bad {mode=}') def reg_dense_conf(x, mode): """ extract confidence from prediction head output """ mode, vmin, vmax = mode if mode == 'exp': return vmin + x.exp().clip(max=vmax-vmin) if mode == 'sigmoid': return (vmax - vmin) * torch.sigmoid(x) + vmin raise ValueError(f'bad {mode=}') ================================================ FILE: dust3r/image_pairs.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilities needed to load image pairs # -------------------------------------------------------- import numpy as np import torch def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): pairs = [] if scene_graph == 'complete': # complete graph for i in range(len(imgs)): for j in range(i): pairs.append((imgs[i], imgs[j])) elif scene_graph.startswith('swin'): iscyclic = not scene_graph.endswith('noncyclic') try: winsize = int(scene_graph.split('-')[1]) except Exception as e: winsize = 3 pairsid = set() for i in range(len(imgs)): for j in range(1, winsize + 1): idx = (i + j) if iscyclic: idx = idx % len(imgs) # explicit loop closure if idx >= len(imgs): continue pairsid.add((i, idx) if i < idx else (idx, i)) for i, j in pairsid: pairs.append((imgs[i], imgs[j])) elif scene_graph.startswith('logwin'): iscyclic = not scene_graph.endswith('noncyclic') try: winsize = int(scene_graph.split('-')[1]) except Exception as e: winsize = 3 offsets = [2**i for i in range(winsize)] pairsid = set() for i in range(len(imgs)): ixs_l = [i - off for off in offsets] ixs_r = [i + off for off in offsets] for j in ixs_l + ixs_r: if iscyclic: j = j % len(imgs) # Explicit loop closure if j < 0 or j >= len(imgs) or j == i: continue pairsid.add((i, j) if i < j else (j, i)) for i, j in pairsid: pairs.append((imgs[i], imgs[j])) elif scene_graph.startswith('oneref'): refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 for j in range(len(imgs)): if j != refid: pairs.append((imgs[refid], imgs[j])) if symmetrize: pairs += [(img2, img1) for img1, img2 in pairs] # now, remove edges if isinstance(prefilter, str) and prefilter.startswith('seq'): pairs = filter_pairs_seq(pairs, int(prefilter[3:])) if isinstance(prefilter, str) and prefilter.startswith('cyc'): pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) return pairs def sel(x, kept): if isinstance(x, dict): return {k: sel(v, kept) for k, v in x.items()} if isinstance(x, (torch.Tensor, np.ndarray)): return x[kept] if isinstance(x, (tuple, list)): return type(x)([x[k] for k in kept]) def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): # number of images n = max(max(e) for e in edges) + 1 kept = [] for e, (i, j) in enumerate(edges): dis = abs(i - j) if cyclic: dis = min(dis, abs(i + n - j), abs(i - n - j)) if dis <= seq_dis_thr: kept.append(e) return kept def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) return [pairs[i] for i in kept] def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) ================================================ FILE: dust3r/inference.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilities needed for the inference # -------------------------------------------------------- import tqdm import torch from dust3r.utils.device import to_cpu, collate_with_cat from dust3r.utils.misc import invalid_to_nans from dust3r.utils.geometry import depthmap_to_pts3d, geotrf def _interleave_imgs(img1, img2): res = {} for key, value1 in img1.items(): value2 = img2[key] if isinstance(value1, torch.Tensor): value = torch.stack((value1, value2), dim=1).flatten(0, 1) else: value = [x for pair in zip(value1, value2) for x in pair] res[key] = value return res def make_batch_symmetric(batch): view1, view2 = batch view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) return view1, view2 def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): view1, view2 = batch ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng']) for view in batch: for name in view.keys(): # pseudo_focal if name in ignore_keys: continue view[name] = view[name].to(device, non_blocking=True) if symmetrize_batch: view1, view2 = make_batch_symmetric(batch) with torch.cuda.amp.autocast(enabled=bool(use_amp)): pred1, pred2 = model(view1, view2) # loss is supposed to be symmetric with torch.cuda.amp.autocast(enabled=False): loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) return result[ret] if ret else result @torch.no_grad() def inference(pairs, model, device, batch_size=8, verbose=True): if verbose: print(f'>> Inference with model on {len(pairs)} image pairs') result = [] # first, check if all images have the same size multiple_shapes = not (check_if_same_size(pairs)) if multiple_shapes: # force bs=1 batch_size = 1 for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose): res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device) result.append(to_cpu(res)) result = collate_with_cat(result, lists=multiple_shapes) return result def check_if_same_size(pairs): shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) def get_pred_pts3d(gt, pred, use_pose=False): if 'depth' in pred and 'pseudo_focal' in pred: try: pp = gt['camera_intrinsics'][..., :2, 2] except KeyError: pp = None pts3d = depthmap_to_pts3d(**pred, pp=pp) elif 'pts3d' in pred: # pts3d from my camera pts3d = pred['pts3d'] elif 'pts3d_in_other_view' in pred: # pts3d from the other camera, already transformed assert use_pose is True return pred['pts3d_in_other_view'] # return! if use_pose: camera_pose = pred.get('camera_pose') assert camera_pose is not None pts3d = geotrf(camera_pose, pts3d) return pts3d def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): assert gt_pts1.ndim == pr_pts1.ndim == 4 assert gt_pts1.shape == pr_pts1.shape if gt_pts2 is not None: assert gt_pts2.ndim == pr_pts2.ndim == 4 assert gt_pts2.shape == pr_pts2.shape # concat the pointcloud nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 dot_gt_pr = (all_pr * all_gt).sum(dim=-1) dot_gt_gt = all_gt.square().sum(dim=-1) if fit_mode.startswith('avg'): # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) elif fit_mode.startswith('median'): scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values elif fit_mode.startswith('weiszfeld'): # init scaling with l2 closed form scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) # iterative re-weighted least-squares for iter in range(10): # re-weighting by inverse of distance dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) # print(dis.nanmean(-1)) w = dis.clip_(min=1e-8).reciprocal() # update the scaling with the new weights scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) else: raise ValueError(f'bad {fit_mode=}') if fit_mode.endswith('stop_grad'): scaling = scaling.detach() scaling = scaling.clip(min=1e-3) # assert scaling.isfinite().all(), bb() return scaling ================================================ FILE: dust3r/losses.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Implementation of DUSt3R training losses # -------------------------------------------------------- from copy import copy, deepcopy import torch import torch.nn as nn from dust3r.inference import get_pred_pts3d, find_opt_scaling from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale def Sum(*losses_and_masks): loss, mask = losses_and_masks[0] if loss.ndim > 0: # we are actually returning the loss for every pixels return losses_and_masks else: # we are returning the global loss for loss2, mask2 in losses_and_masks[1:]: loss = loss + loss2 return loss class BaseCriterion(nn.Module): def __init__(self, reduction='mean'): super().__init__() self.reduction = reduction class LLoss (BaseCriterion): """ L-norm loss """ def forward(self, a, b): assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}' dist = self.distance(a, b) assert dist.ndim == a.ndim - 1 # one dimension less if self.reduction == 'none': return dist if self.reduction == 'sum': return dist.sum() if self.reduction == 'mean': return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) raise ValueError(f'bad {self.reduction=} mode') def distance(self, a, b): raise NotImplementedError() class L21Loss (LLoss): """ Euclidean distance between 3d points """ def distance(self, a, b): return torch.norm(a - b, dim=-1) # normalized L2 distance L21 = L21Loss() class Criterion (nn.Module): def __init__(self, criterion=None): super().__init__() assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!' self.criterion = copy(criterion) def get_name(self): return f'{type(self).__name__}({self.criterion})' def with_reduction(self, mode='none'): res = loss = deepcopy(self) while loss is not None: assert isinstance(loss, Criterion) loss.criterion.reduction = mode # make it return the loss for each sample loss = loss._loss2 # we assume loss is a Multiloss return res class MultiLoss (nn.Module): """ Easily combinable losses (also keep track of individual loss values): loss = MyLoss1() + 0.1*MyLoss2() Usage: Inherit from this class and override get_name() and compute_loss() """ def __init__(self): super().__init__() self._alpha = 1 self._loss2 = None def compute_loss(self, *args, **kwargs): raise NotImplementedError() def get_name(self): raise NotImplementedError() def __mul__(self, alpha): assert isinstance(alpha, (int, float)) res = copy(self) res._alpha = alpha return res __rmul__ = __mul__ # same def __add__(self, loss2): assert isinstance(loss2, MultiLoss) res = cur = copy(self) # find the end of the chain while cur._loss2 is not None: cur = cur._loss2 cur._loss2 = loss2 return res def __repr__(self): name = self.get_name() if self._alpha != 1: name = f'{self._alpha:g}*{name}' if self._loss2: name = f'{name} + {self._loss2}' return name def forward(self, *args, **kwargs): loss = self.compute_loss(*args, **kwargs) if isinstance(loss, tuple): loss, details = loss elif loss.ndim == 0: details = {self.get_name(): float(loss)} else: details = {} loss = loss * self._alpha if self._loss2: loss2, details2 = self._loss2(*args, **kwargs) loss = loss + loss2 details |= details2 return loss, details class Regr3D (Criterion, MultiLoss): """ Ensure that all 3D points are correct. Asymmetric loss: view1 is supposed to be the anchor. P1 = RT1 @ D1 P2 = RT2 @ D2 loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1) loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2) = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2) """ def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False): super().__init__(criterion) self.norm_mode = norm_mode self.gt_scale = gt_scale def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None): # everything is normalized w.r.t. camera of view1 in_camera1 = inv(gt1['camera_pose']) gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3 gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3 valid1 = gt1['valid_mask'].clone() valid2 = gt2['valid_mask'].clone() if dist_clip is not None: # points that are too far-away == invalid dis1 = gt_pts1.norm(dim=-1) # (B, H, W) dis2 = gt_pts2.norm(dim=-1) # (B, H, W) valid1 = valid1 & (dis1 <= dist_clip) valid2 = valid2 & (dis2 <= dist_clip) pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False) pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True) # normalize 3d points if self.norm_mode: pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2) if self.norm_mode and not self.gt_scale: gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2) return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {} def compute_loss(self, gt1, gt2, pred1, pred2, **kw): gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw) # loss on img1 side l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1]) # loss on gt2 side l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2]) self_name = type(self).__name__ details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())} return Sum((l1, mask1), (l2, mask2)), (details | monitoring) class ConfLoss (MultiLoss): """ Weighted regression by learned confidence. Assuming the input pixel_loss is a pixel-level regression loss. Principle: high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) alpha: hyperparameter """ def __init__(self, pixel_loss, alpha=1): super().__init__() assert alpha > 0 self.alpha = alpha self.pixel_loss = pixel_loss.with_reduction('none') def get_name(self): return f'ConfLoss({self.pixel_loss})' def get_conf_log(self, x): return x, torch.log(x) def compute_loss(self, gt1, gt2, pred1, pred2, **kw): # compute per-pixel loss ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw) if loss1.numel() == 0: print('NO VALID POINTS in img1', force=True) if loss2.numel() == 0: print('NO VALID POINTS in img2', force=True) # weight by confidence conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1]) conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2]) conf_loss1 = loss1 * conf1 - self.alpha * log_conf1 conf_loss2 = loss2 * conf2 - self.alpha * log_conf2 # average + nan protection (in case of no valid pixels at all) conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0 conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0 return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details) class Regr3D_ShiftInv (Regr3D): """ Same than Regr3D but invariant to depth shift. """ def get_all_pts3d(self, gt1, gt2, pred1, pred2): # compute unnormalized points gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \ super().get_all_pts3d(gt1, gt2, pred1, pred2) # compute median depth gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2] pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2] gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None] pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None] # subtract the median depth gt_z1 -= gt_shift_z gt_z2 -= gt_shift_z pred_z1 -= pred_shift_z pred_z2 -= pred_shift_z # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach()) return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring class Regr3D_ScaleInv (Regr3D): """ Same than Regr3D but invariant to depth shift. if gt_scale == True: enforce the prediction to take the same scale than GT """ def get_all_pts3d(self, gt1, gt2, pred1, pred2): # compute depth-normalized points gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2) # measure scene scale _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2) _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2) # prevent predictions to be in a ridiculous range pred_scale = pred_scale.clip(min=1e-3, max=1e3) # subtract the median depth if self.gt_scale: pred_pts1 *= gt_scale / pred_scale pred_pts2 *= gt_scale / pred_scale # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean()) else: gt_pts1 /= gt_scale gt_pts2 /= gt_scale pred_pts1 /= pred_scale pred_pts2 /= pred_scale # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()) return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv): # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv pass ================================================ FILE: dust3r/model.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # DUSt3R model class # -------------------------------------------------------- from copy import deepcopy import torch import os from packaging import version import huggingface_hub from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape from .heads import head_factory from dust3r.patch_embed import get_patch_embed import dust3r.utils.path_to_croco # noqa: F401 from models.croco import CroCoNet # noqa inf = float('inf') hf_version_number = huggingface_hub.__version__ assert version.parse(hf_version_number) >= version.parse("0.22.0"), ("Outdated huggingface_hub version, " "please reinstall requirements.txt") def load_model(model_path, device, verbose=True): if verbose: print('... loading model from', model_path) ckpt = torch.load(model_path, map_location='cpu') args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") if 'landscape_only' not in args: args = args[:-1] + ', landscape_only=False)' else: args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') assert "landscape_only=False" in args if verbose: print(f"instantiating : {args}") net = eval(args) s = net.load_state_dict(ckpt['model'], strict=False) if verbose: print(s) return net.to(device) class AsymmetricCroCo3DStereo ( CroCoNet, huggingface_hub.PyTorchModelHubMixin, library_name="dust3r", repo_url="https://github.com/naver/dust3r", tags=["image-to-3d"], ): """ Two siamese encoders, followed by two decoders. The goal is to output 3d points directly, both images in view1's frame (hence the asymmetry). """ def __init__(self, output_mode='pts3d', head_type='linear', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), freeze='none', landscape_only=True, patch_embed_cls='PatchEmbedDust3R', # PatchEmbedDust3R or ManyAR_PatchEmbed **croco_kwargs): self.patch_embed_cls = patch_embed_cls self.croco_args = fill_default_args(croco_kwargs, super().__init__) super().__init__(**croco_kwargs) # dust3r specific initialization self.dec_blocks2 = deepcopy(self.dec_blocks) self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs) self.set_freeze(freeze) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kw): if os.path.isfile(pretrained_model_name_or_path): return load_model(pretrained_model_name_or_path, device='cpu') else: try: model = super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw) except TypeError as e: raise Exception(f'tried to load {pretrained_model_name_or_path} from huggingface, but failed') return model def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): self.patch_size = patch_size self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) def load_state_dict(self, ckpt, **kw): # duplicate all weights for the second decoder if not present new_ckpt = dict(ckpt) if not any(k.startswith('dec_blocks2') for k in ckpt): for key, value in ckpt.items(): if key.startswith('dec_blocks'): new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value return super().load_state_dict(new_ckpt, **kw) def set_freeze(self, freeze): # this is for use by downstream models self.freeze = freeze to_be_frozen = { 'none': [], 'mask': [self.mask_token], 'encoder': [self.mask_token, self.patch_embed, self.enc_blocks], } freeze_all_params(to_be_frozen[freeze]) def _set_prediction_head(self, *args, **kwargs): """ No prediction head """ return def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \ f'{img_size=} must be multiple of {patch_size=}' self.output_mode = output_mode self.head_type = head_type self.depth_mode = depth_mode self.conf_mode = conf_mode # allocate heads self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) # magic wrapper self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) def _encode_image(self, image, true_shape): # embed the image into patches (x has size B x Npatches x C) x, pos = self.patch_embed(image, true_shape=true_shape) # add positional embedding without cls token assert self.enc_pos_embed is None # now apply the transformer encoder and normalization for blk in self.enc_blocks: x = blk(x, pos) x = self.enc_norm(x) return x, pos, None def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2): if img1.shape[-2:] == img2.shape[-2:]: out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), torch.cat((true_shape1, true_shape2), dim=0)) out, out2 = out.chunk(2, dim=0) pos, pos2 = pos.chunk(2, dim=0) else: out, pos, _ = self._encode_image(img1, true_shape1) out2, pos2, _ = self._encode_image(img2, true_shape2) return out, out2, pos, pos2 def _encode_symmetrized(self, view1, view2): img1 = view1['img'] img2 = view2['img'] B = img1.shape[0] # Recover true_shape when available, otherwise assume that the img shape is the true one shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1)) shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1)) # warning! maybe the images have different portrait/landscape orientations if is_symmetrized(view1, view2): # computing half of forward pass!' feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2]) feat1, feat2 = interleave(feat1, feat2) pos1, pos2 = interleave(pos1, pos2) else: feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2) return (shape1, shape2), (feat1, feat2), (pos1, pos2) def _decoder(self, f1, pos1, f2, pos2): final_output = [(f1, f2)] # before projection # project to decoder dim f1 = self.decoder_embed(f1) f2 = self.decoder_embed(f2) final_output.append((f1, f2)) for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): # img1 side f1, _ = blk1(*final_output[-1][::+1], pos1, pos2) # img2 side f2, _ = blk2(*final_output[-1][::-1], pos2, pos1) # store the result final_output.append((f1, f2)) # normalize last output del final_output[1] # duplicate with final_output[0] final_output[-1] = tuple(map(self.dec_norm, final_output[-1])) return zip(*final_output) def _downstream_head(self, head_num, decout, img_shape): B, S, D = decout[-1].shape # img_shape = tuple(map(int, img_shape)) head = getattr(self, f'head{head_num}') return head(decout, img_shape) def forward(self, view1, view2): # encode the two images --> B,S,D (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2) # combine all ref images into object-centric representation dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2) with torch.cuda.amp.autocast(enabled=False): res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2) res2['pts3d_in_other_view'] = res2.pop('pts3d') # predict view2's pts3d in view1's frame return res1, res2 ================================================ FILE: dust3r/optim_factory.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # optimization functions # -------------------------------------------------------- def adjust_learning_rate_by_lr(optimizer, lr): for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] else: param_group["lr"] = lr ================================================ FILE: dust3r/patch_embed.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # PatchEmbed implementation for DUST3R, # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio # -------------------------------------------------------- import torch import dust3r.utils.path_to_croco # noqa: F401 from models.blocks import PatchEmbed # noqa def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) return patch_embed class PatchEmbedDust3R(PatchEmbed): def forward(self, x, **kw): B, C, H, W = x.shape assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." x = self.proj(x) pos = self.position_getter(B, x.size(2), x.size(3), x.device) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x, pos class ManyAR_PatchEmbed (PatchEmbed): """ Handle images with non-square aspect ratio. All images in the same batch have the same aspect ratio. true_shape = [(height, width) ...] indicates the actual shape of each image. """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): self.embed_dim = embed_dim super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) def forward(self, img, true_shape): B, C, H, W = img.shape assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" # size expressed in tokens W //= self.patch_size[0] H //= self.patch_size[1] n_tokens = H * W height, width = true_shape.T is_landscape = (width >= height) is_portrait = ~is_landscape # allocate result x = img.new_zeros((B, n_tokens, self.embed_dim)) pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) # linear projection, transposed if necessary x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() pos[is_landscape] = self.position_getter(1, H, W, pos.device) pos[is_portrait] = self.position_getter(1, W, H, pos.device) x = self.norm(x) return x, pos ================================================ FILE: dust3r/post_process.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilities for interpreting the DUST3R output # -------------------------------------------------------- import numpy as np import torch from dust3r.utils.geometry import xy_grid def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf): """ Reprojection method, for when the absolute depth is known: 1) estimate the camera focal using a robust estimator 2) reproject points onto true rays, minimizing a certain error """ B, H, W, THREE = pts3d.shape assert THREE == 3 # centered pixel grid pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 pts3d = pts3d.flatten(1, 2) # (B, HW, 3) if focal_mode == 'median': with torch.no_grad(): # direct estimation of focal u, v = pixels.unbind(dim=-1) x, y, z = pts3d.unbind(dim=-1) fx_votes = (u * z) / x fy_votes = (v * z) / y # assume square pixels, hence same focal for X and Y f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) focal = torch.nanmedian(f_votes, dim=-1).values elif focal_mode == 'weiszfeld': # init focal with l2 closed form # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) dot_xy_px = (xy_over_z * pixels).sum(dim=-1) dot_xy_xy = xy_over_z.square().sum(dim=-1) focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) # iterative re-weighted least-squares for iter in range(10): # re-weighting by inverse of distance dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) # print(dis.nanmean(-1)) w = dis.clip(min=1e-8).reciprocal() # update the scaling with the new weights focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) else: raise ValueError(f'bad {focal_mode=}') focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) # print(focal) return focal ================================================ FILE: dust3r/training.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # training code for DUSt3R # -------------------------------------------------------- # References: # MAE: https://github.com/facebookresearch/mae # DeiT: https://github.com/facebookresearch/deit # BEiT: https://github.com/microsoft/unilm/tree/master/beit # -------------------------------------------------------- import argparse import datetime import json import numpy as np import os import sys import time import math from collections import defaultdict from pathlib import Path from typing import Sized import torch import torch.backends.cudnn as cudnn from torch.utils.tensorboard import SummaryWriter torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model from dust3r.datasets import get_data_loader # noqa from dust3r.losses import * # noqa: F401, needed when loading the model from dust3r.inference import loss_of_one_batch # noqa import dust3r.utils.path_to_croco # noqa: F401 import croco.utils.misc as misc # noqa from croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler # noqa def get_args_parser(): parser = argparse.ArgumentParser('DUST3R training', add_help=False) # model and criterion parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')", type=str, help="string containing the model to build") parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint') parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)", type=str, help="train criterion") parser.add_argument('--test_criterion', default=None, type=str, help="test criterion") # dataset parser.add_argument('--train_dataset', required=True, type=str, help="training set") parser.add_argument('--test_dataset', default='[None]', type=str, help="testing set") # training parser.add_argument('--seed', default=0, type=int, help="Random seed") parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus") parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)") parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler") parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)") parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR') parser.add_argument('--amp', type=int, default=0, choices=[0, 1], help="Use Automatic Mixed Precision for pretraining") parser.add_argument("--disable_cudnn_benchmark", action='store_true', default=False, help="set cudnn.benchmark = False") # others parser.add_argument('--num_workers', default=8, type=int) parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency') parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth') parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth') parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training') # output dir parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output") return parser def train(args): misc.init_distributed_mode(args) global_rank = misc.get_rank() world_size = misc.get_world_size() print("output_dir: " + args.output_dir) if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) # auto resume last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth') args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # fix the seed seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = not args.disable_cudnn_benchmark # training dataset and loader print('Building train dataset {:s}'.format(args.train_dataset)) # dataset and loader data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False) print('Building test dataset {:s}'.format(args.train_dataset)) data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True) for dataset in args.test_dataset.split('+')} # model print('Loading model: {:s}'.format(args.model)) model = eval(args.model) print(f'>> Creating train criterion = {args.train_criterion}') train_criterion = eval(args.train_criterion).to(device) print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}') test_criterion = eval(args.test_criterion or args.criterion).to(device) model.to(device) model_without_ddp = model print("Model = %s" % str(model_without_ddp)) if args.pretrained and not args.resume: print('Loading pretrained: ', args.pretrained) ckpt = torch.load(args.pretrained, map_location=device) print(model.load_state_dict(ckpt['model'], strict=False)) del ckpt # in case it occupies memory eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 256 print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True) model_without_ddp = model.module # following timm: set wd as 0 for bias and norm layers param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() def write_log_stats(epoch, train_stats, test_stats): if misc.is_main_process(): if log_writer is not None: log_writer.flush() log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()}) for test_name in data_loader_test: if test_name not in test_stats: continue log_stats.update({test_name + '_' + k: v for k, v in test_stats[test_name].items()}) with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") def save_model(epoch, fname, best_so_far): misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far) best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) if best_so_far is None: best_so_far = float('inf') if global_rank == 0 and args.output_dir is not None: log_writer = SummaryWriter(log_dir=args.output_dir) else: log_writer = None print(f"Start training for {args.epochs} epochs") start_time = time.time() train_stats = test_stats = {} for epoch in range(args.start_epoch, args.epochs + 1): # Save immediately the last checkpoint if epoch > args.start_epoch: if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs: save_model(epoch - 1, 'last', best_so_far) # Test on multiple datasets new_best = False if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0): test_stats = {} for test_name, testset in data_loader_test.items(): stats = test_one_epoch(model, test_criterion, testset, device, epoch, log_writer=log_writer, args=args, prefix=test_name) test_stats[test_name] = stats # Save best of all if stats['loss_med'] < best_so_far: best_so_far = stats['loss_med'] new_best = True # Save more stuff write_log_stats(epoch, train_stats, test_stats) if epoch > args.start_epoch: if args.keep_freq and epoch % args.keep_freq == 0: save_model(epoch - 1, str(epoch), best_so_far) if new_best: save_model(epoch - 1, 'best', best_so_far) if epoch >= args.epochs: break # exit after writing last test to disk # Train train_stats = train_one_epoch( model, train_criterion, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far) def save_final_model(args, epoch, model_without_ddp, best_so_far=None): output_dir = Path(args.output_dir) checkpoint_path = output_dir / 'checkpoint-final.pth' to_save = { 'args': args, 'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(), 'epoch': epoch } if best_so_far is not None: to_save['best_so_far'] = best_so_far print(f'>> Saving model to {checkpoint_path} ...') misc.save_on_master(to_save, checkpoint_path) def build_dataset(dataset, batch_size, num_workers, test=False): split = ['Train', 'Test'][test] print(f'Building {split} Data loader for dataset: ', dataset) loader = get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers, pin_mem=True, shuffle=not (test), drop_last=not (test)) print(f"{split} dataset length: ", len(loader)) return loader def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Sized, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, args, log_writer=None): assert torch.backends.cuda.matmul.allow_tf32 == True model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) accum_iter = args.accum_iter if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): data_loader.dataset.set_epoch(epoch) if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): data_loader.sampler.set_epoch(epoch) optimizer.zero_grad() for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): epoch_f = epoch + data_iter_step / len(data_loader) # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: misc.adjust_learning_rate(optimizer, epoch_f, args) loss_tuple = loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=True, use_amp=bool(args.amp), ret='loss') loss, loss_details = loss_tuple # criterion returns two values loss_value = float(loss) if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value), force=True) sys.exit(1) loss /= accum_iter loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() del loss del batch lr = optimizer.param_groups[0]["lr"] metric_logger.update(epoch=epoch_f) metric_logger.update(lr=lr) metric_logger.update(loss=loss_value, **loss_details) if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0: loss_value_reduce = misc.all_reduce_mean(loss_value) # MUST BE EXECUTED BY ALL NODES if log_writer is None: continue """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int(epoch_f * 1000) log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('train_lr', lr, epoch_1000x) log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x) for name, val in loss_details.items(): log_writer.add_scalar('train_' + name, val, epoch_1000x) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Sized, device: torch.device, epoch: int, args, log_writer=None, prefix='test'): model.eval() metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9)) header = 'Test Epoch: [{}]'.format(epoch) if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'): data_loader.dataset.set_epoch(epoch) if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'): data_loader.sampler.set_epoch(epoch) for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): loss_tuple = loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=True, use_amp=bool(args.amp), ret='loss') loss_value, loss_details = loss_tuple # criterion returns two values metric_logger.update(loss=float(loss_value), **loss_details) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) aggs = [('avg', 'global_avg'), ('med', 'median')] results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs} if log_writer is not None: for name, val in results.items(): log_writer.add_scalar(prefix + '_' + name, val, 1000 * epoch) return results ================================================ FILE: dust3r/utils/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: dust3r/utils/device.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilitary functions for DUSt3R # -------------------------------------------------------- import numpy as np import torch def todevice(batch, device, callback=None, non_blocking=False): ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). batch: list, tuple, dict of tensors or other things device: pytorch device or 'numpy' callback: function that would be called on every sub-elements. ''' if callback: batch = callback(batch) if isinstance(batch, dict): return {k: todevice(v, device) for k, v in batch.items()} if isinstance(batch, (tuple, list)): return type(batch)(todevice(x, device) for x in batch) x = batch if device == 'numpy': if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() elif x is not None: if isinstance(x, np.ndarray): x = torch.from_numpy(x) if torch.is_tensor(x): x = x.to(device, non_blocking=non_blocking) return x to_device = todevice # alias def to_numpy(x): return todevice(x, 'numpy') def to_cpu(x): return todevice(x, 'cpu') def to_cuda(x): return todevice(x, 'cuda') def collate_with_cat(whatever, lists=False): if isinstance(whatever, dict): return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} elif isinstance(whatever, (tuple, list)): if len(whatever) == 0: return whatever elem = whatever[0] T = type(whatever) if elem is None: return None if isinstance(elem, (bool, float, int, str)): return whatever if isinstance(elem, tuple): return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) if isinstance(elem, dict): return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} if isinstance(elem, torch.Tensor): return listify(whatever) if lists else torch.cat(whatever) if isinstance(elem, np.ndarray): return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) # otherwise, we just chain lists return sum(whatever, T()) def listify(elems): return [x for e in elems for x in e] ================================================ FILE: dust3r/utils/geometry.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # geometry utilitary functions # -------------------------------------------------------- import torch import numpy as np from scipy.spatial import cKDTree as KDTree from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans from dust3r.utils.device import to_numpy def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): """ Output a (H,W,2) array of int32 with output[j,i,0] = i + origin[0] output[j,i,1] = j + origin[1] """ if device is None: # numpy arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones else: # torch arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) meshgrid, stack = torch.meshgrid, torch.stack ones = lambda *a: torch.ones(*a, device=device) tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] grid = meshgrid(tw, th, indexing='xy') if homogeneous: grid = grid + (ones((H, W)),) if unsqueeze is not None: grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) if cat_dim is not None: grid = stack(grid, cat_dim) return grid def geotrf(Trf, pts, ncol=None, norm=False): """ Apply a geometric transformation to a list of 3-D points. H: 3x3 or 4x4 projection matrix (typically a Homography) p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) ncol: int. number of columns of the result (2 or 3) norm: float. if != 0, the resut is projected on the z=norm plane. Returns an array of projected 2d points. """ assert Trf.ndim >= 2 if isinstance(Trf, np.ndarray): pts = np.asarray(pts) elif isinstance(Trf, torch.Tensor): pts = torch.as_tensor(pts, dtype=Trf.dtype) # adapt shape if necessary output_reshape = pts.shape[:-1] ncol = ncol or pts.shape[-1] # optimized code if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and Trf.ndim == 3 and pts.ndim == 4): d = pts.shape[3] if Trf.shape[-1] == d: pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) elif Trf.shape[-1] == d + 1: pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] else: raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') else: if Trf.ndim >= 3: n = Trf.ndim - 2 assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) if pts.ndim > Trf.ndim: # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) elif pts.ndim == 2: # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) pts = pts[:, None, :] if pts.shape[-1] + 1 == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] elif pts.shape[-1] == Trf.shape[-1]: Trf = Trf.swapaxes(-1, -2) # transpose Trf pts = pts @ Trf else: pts = Trf @ pts.T if pts.ndim >= 2: pts = pts.swapaxes(-1, -2) if norm: pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG if norm != 1: pts *= norm res = pts[..., :ncol].reshape(*output_reshape, ncol) return res def inv(mat): """ Invert a torch or numpy matrix """ if isinstance(mat, torch.Tensor): return torch.linalg.inv(mat) if isinstance(mat, np.ndarray): return np.linalg.inv(mat) raise ValueError(f'bad matrix type = {type(mat)}') def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): """ Args: - depthmap (BxHxW array): - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] Returns: pointmap of absolute coordinates (BxHxWx3 array) """ if len(depth.shape) == 4: B, H, W, n = depth.shape else: B, H, W = depth.shape n = None if len(pseudo_focal.shape) == 3: # [B,H,W] pseudo_focalx = pseudo_focaly = pseudo_focal elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] pseudo_focalx = pseudo_focal[:, 0] if pseudo_focal.shape[1] == 2: pseudo_focaly = pseudo_focal[:, 1] else: pseudo_focaly = pseudo_focalx else: raise NotImplementedError("Error, unknown input focal shape format.") assert pseudo_focalx.shape == depth.shape[:3] assert pseudo_focaly.shape == depth.shape[:3] grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] # set principal point if pp is None: grid_x = grid_x - (W - 1) / 2 grid_y = grid_y - (H - 1) / 2 else: grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] if n is None: pts3d = torch.empty((B, H, W, 3), device=depth.device) pts3d[..., 0] = depth * grid_x / pseudo_focalx pts3d[..., 1] = depth * grid_y / pseudo_focaly pts3d[..., 2] = depth else: pts3d = torch.empty((B, H, W, 3, n), device=depth.device) pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] pts3d[..., 2, :] = depth return pts3d def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): """ Args: - depthmap (HxW array): - camera_intrinsics: a 3x3 matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. """ camera_intrinsics = np.float32(camera_intrinsics) H, W = depthmap.shape # Compute 3D ray associated with each pixel # Strong assumption: there are no skew terms assert camera_intrinsics[0, 1] == 0.0 assert camera_intrinsics[1, 0] == 0.0 if pseudo_focal is None: fu = camera_intrinsics[0, 0] fv = camera_intrinsics[1, 1] else: assert pseudo_focal.shape == (H, W) fu = fv = pseudo_focal cu = camera_intrinsics[0, 2] cv = camera_intrinsics[1, 2] u, v = np.meshgrid(np.arange(W), np.arange(H)) z_cam = depthmap x_cam = (u - cu) * z_cam / fu y_cam = (v - cv) * z_cam / fv X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) # Mask for valid coordinates valid_mask = (depthmap > 0.0) return X_cam, valid_mask def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): """ Args: - depthmap (HxW array): - camera_intrinsics: a 3x3 matrix - camera_pose: a 4x3 or 4x4 cam2world matrix Returns: pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) X_world = X_cam # default if camera_pose is not None: # R_cam2world = np.float32(camera_params["R_cam2world"]) # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() R_cam2world = camera_pose[:3, :3] t_cam2world = camera_pose[:3, 3] # Express in absolute coordinates (invalid depth values) X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] return X_world, valid_mask def colmap_to_opencv_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] -= 0.5 K[1, 2] -= 0.5 return K def opencv_to_colmap_intrinsics(K): """ Modify camera intrinsics to follow a different convention. Coordinates of the center of the top-left pixels are by default: - (0.5, 0.5) in Colmap - (0,0) in OpenCV """ K = K.copy() K[0, 2] += 0.5 K[1, 2] += 0.5 return K def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False): """ renorm pointmaps pts1, pts2 with norm_mode """ assert pts1.ndim >= 3 and pts1.shape[-1] == 3 assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3) norm_mode, dis_mode = norm_mode.split('_') if norm_mode == 'avg': # gather all points together (joint normalization) nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3) nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0) all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 # compute distance to origin all_dis = all_pts.norm(dim=-1) if dis_mode == 'dis': pass # do nothing elif dis_mode == 'log1p': all_dis = torch.log1p(all_dis) elif dis_mode == 'warp-log1p': # actually warp input points before normalizing them log_dis = torch.log1p(all_dis) warp_factor = log_dis / all_dis.clip(min=1e-8) H1, W1 = pts1.shape[1:-1] pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1) if pts2 is not None: H2, W2 = pts2.shape[1:-1] pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1) all_dis = log_dis # this is their true distance afterwards else: raise ValueError(f'bad {dis_mode=}') norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8) else: # gather all points together (joint normalization) nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3) nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1 # compute distance to origin all_dis = all_pts.norm(dim=-1) if norm_mode == 'avg': norm_factor = all_dis.nanmean(dim=1) elif norm_mode == 'median': norm_factor = all_dis.nanmedian(dim=1).values.detach() elif norm_mode == 'sqrt': norm_factor = all_dis.sqrt().nanmean(dim=1)**2 else: raise ValueError(f'bad {norm_mode=}') norm_factor = norm_factor.clip(min=1e-8) while norm_factor.ndim < pts1.ndim: norm_factor.unsqueeze_(-1) res = pts1 / norm_factor if pts2 is not None: res = (res, pts2 / norm_factor) if ret_factor: res = res + (norm_factor,) return res @torch.no_grad() def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): # set invalid points to NaN _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 # compute median depth overall (ignoring nans) if quantile == 0.5: shift_z = torch.nanmedian(_z, dim=-1).values else: shift_z = torch.nanquantile(_z, quantile, dim=-1) return shift_z # (B,) @torch.no_grad() def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): # set invalid points to NaN _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 # compute median center _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) if z_only: _center[..., :2] = 0 # do not center X and Y # compute median norm _norm = ((_pts - _center) if center else _pts).norm(dim=-1) scale = torch.nanmedian(_norm, dim=1).values return _center[:, None, :, :], scale[:, None, None, None] def find_reciprocal_matches(P1, P2): """ returns 3 values: 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 3 - reciprocal_in_P2.sum(): the number of matches """ tree1 = KDTree(P1) tree2 = KDTree(P2) _, nn1_in_P2 = tree2.query(P1, workers=8) _, nn2_in_P1 = tree1.query(P2, workers=8) reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))) reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))) assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() def get_med_dist_between_poses(poses): from scipy.spatial.distance import pdist return np.median(pdist([to_numpy(p[:3, 3]) for p in poses])) ================================================ FILE: dust3r/utils/image.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilitary functions about images (loading/converting...) # -------------------------------------------------------- import os import torch import numpy as np import PIL.Image from PIL.ImageOps import exif_transpose import torchvision.transforms as tvf os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" import cv2 # noqa try: from pillow_heif import register_heif_opener # noqa register_heif_opener() heif_support_enabled = True except ImportError: heif_support_enabled = False ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) def img_to_arr(img): if isinstance(img, str): img = imread_cv2(img) return img def imread_cv2(path, options=cv2.IMREAD_COLOR): """ Open an image or a depthmap with opencv-python. """ if path.endswith(('.exr', 'EXR')): options = cv2.IMREAD_ANYDEPTH img = cv2.imread(path, options) if img is None: raise IOError(f'Could not load image={path} with {options=}') if img.ndim == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img def rgb(ftensor, true_shape=None): if isinstance(ftensor, list): return [rgb(x, true_shape=true_shape) for x in ftensor] if isinstance(ftensor, torch.Tensor): ftensor = ftensor.detach().cpu().numpy() # H,W,3 if ftensor.ndim == 3 and ftensor.shape[0] == 3: ftensor = ftensor.transpose(1, 2, 0) elif ftensor.ndim == 4 and ftensor.shape[1] == 3: ftensor = ftensor.transpose(0, 2, 3, 1) if true_shape is not None: H, W = true_shape ftensor = ftensor[:H, :W] if ftensor.dtype == np.uint8: img = np.float32(ftensor) / 255 else: img = (ftensor * 0.5) + 0.5 return img.clip(min=0, max=1) def _resize_pil_image(img, long_edge_size): S = max(img.size) if S > long_edge_size: interp = PIL.Image.LANCZOS elif S <= long_edge_size: interp = PIL.Image.BICUBIC new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) return img.resize(new_size, interp) def load_images(folder_or_list, size, square_ok=False, verbose=True, patch_size=16): """ open and convert all images in a list or folder to proper input format for DUSt3R """ if isinstance(folder_or_list, str): if verbose: print(f'>> Loading images from {folder_or_list}') root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) elif isinstance(folder_or_list, list): if verbose: print(f'>> Loading a list of {len(folder_or_list)} images') root, folder_content = '', folder_or_list else: raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})') supported_images_extensions = ['.jpg', '.jpeg', '.png'] if heif_support_enabled: supported_images_extensions += ['.heic', '.heif'] supported_images_extensions = tuple(supported_images_extensions) imgs = [] for path in folder_content: if not path.lower().endswith(supported_images_extensions): continue img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB') W1, H1 = img.size if size == 224: # resize short side to 224 (then crop) img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) else: # resize long side to 512 img = _resize_pil_image(img, size) W, H = img.size cx, cy = W//2, H//2 if size == 224: half = min(cx, cy) img = img.crop((cx-half, cy-half, cx+half, cy+half)) else: halfw = ((2 * cx) // patch_size) * patch_size / 2 halfh = ((2 * cy) // patch_size) * patch_size / 2 if not (square_ok) and W == H: halfh = 3*halfw/4 img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) W2, H2 = img.size if verbose: print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) assert imgs, 'no images foud at '+root if verbose: print(f' (Found {len(imgs)} images)') return imgs ================================================ FILE: dust3r/utils/misc.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilitary functions for DUSt3R # -------------------------------------------------------- import torch def fill_default_args(kwargs, func): import inspect # a bit hacky but it works reliably signature = inspect.signature(func) for k, v in signature.parameters.items(): if v.default is inspect.Parameter.empty: continue kwargs.setdefault(k, v.default) return kwargs def freeze_all_params(modules): for module in modules: try: for n, param in module.named_parameters(): param.requires_grad = False except AttributeError: # module is directly a parameter module.requires_grad = False def is_symmetrized(gt1, gt2): x = gt1['instance'] y = gt2['instance'] if len(x) == len(y) and len(x) == 1: return False # special case of batchsize 1 ok = True for i in range(0, len(x), 2): ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) return ok def flip(tensor): """ flip so that tensor[0::2] <=> tensor[1::2] """ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) def interleave(tensor1, tensor2): res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) return res1, res2 def transpose_to_landscape(head, activate=True): """ Predict in the correct aspect-ratio, then transpose the result in landscape and stack everything back together. """ def wrapper_no(decout, true_shape): B = len(true_shape) assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' H, W = true_shape[0].cpu().tolist() res = head(decout, (H, W)) return res def wrapper_yes(decout, true_shape): B = len(true_shape) # by definition, the batch is in landscape mode so W >= H H, W = int(true_shape.min()), int(true_shape.max()) height, width = true_shape.T is_landscape = (width >= height) is_portrait = ~is_landscape # true_shape = true_shape.cpu() if is_landscape.all(): return head(decout, (H, W)) if is_portrait.all(): return transposed(head(decout, (W, H))) # batch is a mix of both portraint & landscape def selout(ar): return [d[ar] for d in decout] l_result = head(selout(is_landscape), (H, W)) p_result = transposed(head(selout(is_portrait), (W, H))) # allocate full result result = {} for k in l_result | p_result: x = l_result[k].new(B, *l_result[k].shape[1:]) x[is_landscape] = l_result[k] x[is_portrait] = p_result[k] result[k] = x return result return wrapper_yes if activate else wrapper_no def transposed(dic): return {k: v.swapaxes(1, 2) for k, v in dic.items()} def invalid_to_nans(arr, valid_mask, ndim=999): if valid_mask is not None: arr = arr.clone() arr[~valid_mask] = float('nan') if arr.ndim > ndim: arr = arr.flatten(-2 - (arr.ndim - ndim), -2) return arr def invalid_to_zeros(arr, valid_mask, ndim=999): if valid_mask is not None: arr = arr.clone() arr[~valid_mask] = 0 nnz = valid_mask.view(len(valid_mask), -1).sum(1) else: nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image if arr.ndim > ndim: arr = arr.flatten(-2 - (arr.ndim - ndim), -2) return arr, nnz ================================================ FILE: dust3r/utils/parallel.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilitary functions for multiprocessing # -------------------------------------------------------- from tqdm import tqdm from multiprocessing.dummy import Pool as ThreadPool from multiprocessing import cpu_count def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw): """ tqdm but with parallel execution. Will essentially return res = [ function(arg) # default function(*arg) # if star_args is True function(**arg) # if kw_args is True for arg in args] Note: the first elements of args will not be parallelized. This can be useful for debugging. """ while workers <= 0: workers += cpu_count() if workers == 1: front_num = float('inf') # convert into an iterable try: n_args_parallel = len(args) - front_num except TypeError: n_args_parallel = None args = iter(args) # sequential execution first front = [] while len(front) < front_num: try: a = next(args) except StopIteration: return front # end of the iterable front.append(function(*a) if star_args else function(**a) if kw_args else function(a)) # then parallel execution out = [] with Pool(workers) as pool: # Pass the elements of args into function if star_args: futures = pool.imap(starcall, [(function, a) for a in args]) elif kw_args: futures = pool.imap(starstarcall, [(function, a) for a in args]) else: futures = pool.imap(function, args) # Print out the progress as tasks complete for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): out.append(f) return front + out def parallel_processes(*args, **kwargs): """ Same as parallel_threads, with processes """ import multiprocessing as mp kwargs['Pool'] = mp.Pool return parallel_threads(*args, **kwargs) def starcall(args): """ convenient wrapper for Process.Pool """ function, args = args return function(*args) def starstarcall(args): """ convenient wrapper for Process.Pool """ function, args = args return function(**args) ================================================ FILE: dust3r/utils/path_to_croco.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # CroCo submodule import # -------------------------------------------------------- import sys import os.path as path HERE_PATH = path.normpath(path.dirname(__file__)) CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') # check the presence of models directory in repo to be sure its cloned if path.isdir(CROCO_MODELS_PATH): # workaround for sibling import sys.path.insert(0, CROCO_REPO_PATH) else: raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " "Did you forget to run 'git submodule update --init --recursive' ?") ================================================ FILE: dust3r/viz.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Visualization utilities using trimesh # -------------------------------------------------------- import PIL.Image import numpy as np from scipy.spatial.transform import Rotation import torch from dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates from dust3r.utils.device import to_numpy from dust3r.utils.image import rgb, img_to_arr try: import trimesh except ImportError: print('/!\\ module trimesh is not installed, cannot visualize results /!\\') def cat_3d(vecs): if isinstance(vecs, (np.ndarray, torch.Tensor)): vecs = [vecs] return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)]) def show_raw_pointcloud(pts3d, colors, point_size=2): scene = trimesh.Scene() pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors)) scene.add_geometry(pct) scene.show(line_settings={'point_size': point_size}) def pts3d_to_trimesh(img, pts3d, valid=None): H, W, THREE = img.shape assert THREE == 3 assert img.shape == pts3d.shape vertices = pts3d.reshape(-1, 3) # make squares: each pixel == 2 triangles idx = np.arange(len(vertices)).reshape(H, W) idx1 = idx[:-1, :-1].ravel() # top-left corner idx2 = idx[:-1, +1:].ravel() # right-left corner idx3 = idx[+1:, :-1].ravel() # bottom-left corner idx4 = idx[+1:, +1:].ravel() # bottom-right corner faces = np.concatenate(( np.c_[idx1, idx2, idx3], np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) np.c_[idx2, idx3, idx4], np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) ), axis=0) # prepare triangle colors face_colors = np.concatenate(( img[:-1, :-1].reshape(-1, 3), img[:-1, :-1].reshape(-1, 3), img[+1:, +1:].reshape(-1, 3), img[+1:, +1:].reshape(-1, 3) ), axis=0) # remove invalid faces if valid is not None: assert valid.shape == (H, W) valid_idxs = valid.ravel() valid_faces = valid_idxs[faces].all(axis=-1) faces = faces[valid_faces] face_colors = face_colors[valid_faces] assert len(faces) == len(face_colors) return dict(vertices=vertices, face_colors=face_colors, faces=faces) def cat_meshes(meshes): vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) n_vertices = np.cumsum([0]+[len(v) for v in vertices]) for i in range(len(faces)): faces[i][:] += n_vertices[i] vertices = np.concatenate(vertices) colors = np.concatenate(colors) faces = np.concatenate(faces) return dict(vertices=vertices, face_colors=colors, faces=faces) def show_duster_pairs(view1, view2, pred1, pred2): import matplotlib.pyplot as pl pl.ion() for e in range(len(view1['instance'])): i = view1['idx'][e] j = view2['idx'][e] img1 = rgb(view1['img'][e]) img2 = rgb(view2['img'][e]) conf1 = pred1['conf'][e].squeeze() conf2 = pred2['conf'][e].squeeze() score = conf1.mean()*conf2.mean() print(f">> Showing pair #{e} {i}-{j} {score=:g}") pl.clf() pl.subplot(221).imshow(img1) pl.subplot(223).imshow(img2) pl.subplot(222).imshow(conf1, vmin=1, vmax=30) pl.subplot(224).imshow(conf2, vmin=1, vmax=30) pts1 = pred1['pts3d'][e] pts2 = pred2['pts3d_in_other_view'][e] pl.subplots_adjust(0, 0, 1, 1, 0, 0) if input('show pointcloud? (y/n) ') == 'y': show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5) def auto_cam_size(im_poses): return 0.1 * get_med_dist_between_poses(im_poses) class SceneViz: def __init__(self): self.scene = trimesh.Scene() def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None): image = img_to_arr(image) # make up some intrinsics if intrinsics is None: H, W, THREE = image.shape focal = max(H, W) intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]]) # compute 3d points pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world) return self.add_pointcloud(pts3d, image, mask=(depth 150) mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180) mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220) # Morphological operations kernel = np.ones((5, 5), np.uint8) mask2 = ndimage.binary_opening(mask, structure=kernel) # keep only largest CC _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8) cc_sizes = stats[1:, cv2.CC_STAT_AREA] order = cc_sizes.argsort()[::-1] # bigger first i = 0 selection = [] while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2: selection.append(1 + order[i]) i += 1 mask3 = np.in1d(labels, selection).reshape(labels.shape) # Apply mask return torch.from_numpy(mask3) ================================================ FILE: dust3r_visloc/README.md ================================================ # Visual Localization with DUSt3R ## Dataset preparation ### CambridgeLandmarks Each subscene should look like this: ``` Cambridge_Landmarks ├─ mapping │ ├─ GreatCourt │ │ └─ colmap/reconstruction │ │ ├─ cameras.txt │ │ ├─ images.txt │ │ └─ points3D.txt ├─ kapture │ ├─ GreatCourt │ │ └─ query # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#cambridge-landmarks │ ... ├─ GreatCourt │ ├─ pairsfile/query │ │ └─ AP-GeM-LM18_top50.txt # https://github.com/naver/deep-image-retrieval/blob/master/dirtorch/extract_kapture.py followed by https://github.com/naver/kapture-localization/blob/main/tools/kapture_compute_image_pairs.py │ ├─ seq1 │ ... ... ``` ### 7Scenes Each subscene should look like this: ``` 7-scenes ├─ chess │ ├─ mapping/ # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#1-7-scenes │ ├─ query/ # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#1-7-scenes │ └─ pairsfile/query/ │ └─ APGeM-LM18_top20.txt # https://github.com/naver/deep-image-retrieval/blob/master/dirtorch/extract_kapture.py followed by https://github.com/naver/kapture-localization/blob/main/tools/kapture_compute_image_pairs.py ... ``` ### Aachen-Day-Night ``` Aachen-Day-Night-v1.1 ├─ mapping │ ├─ colmap/reconstruction │ │ ├─ cameras.txt │ │ ├─ images.txt │ │ └─ points3D.txt ├─ kapture │ └─ query # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#2-aachen-day-night-v11 ├─ images │ ├─ db │ ├─ query │ └─ sequences └─ pairsfile/query └─ fire_top50.txt # https://github.com/naver/fire/blob/main/kapture_compute_pairs.py ``` ### InLoc ``` InLoc ├─ mapping # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#6-inloc ├─ query # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#6-inloc └─ pairsfile/query └─ pairs-query-netvlad40-temporal.txt # https://github.com/cvg/Hierarchical-Localization/blob/master/pairs/inloc/pairs-query-netvlad40-temporal.txt ``` ## Example Commands With `visloc.py` you can run our visual localization experiments on Aachen-Day-Night, InLoc, Cambridge Landmarks and 7 Scenes. ```bash # Aachen-Day-Night-v1.1: # scene in 'day' 'night' # scene can also be 'all' python3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset "VislocAachenDayNight('/path/to/prepared/Aachen-Day-Night-v1.1/', subscene='${scene}', pairsfile='fire_top50', topk=20)" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/Aachen-Day-Night-v1.1/${scene}/loc # InLoc python3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset "VislocInLoc('/path/to/prepared/InLoc/', pairsfile='pairs-query-netvlad40-temporal', topk=20)" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/InLoc/loc # 7-scenes: # scene in 'chess' 'fire' 'heads' 'office' 'pumpkin' 'redkitchen' 'stairs' python3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset "VislocSevenScenes('/path/to/prepared/7-scenes/', subscene='${scene}', pairsfile='APGeM-LM18_top20', topk=1)" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/7-scenes/${scene}/loc # Cambridge Landmarks: # scene in 'ShopFacade' 'GreatCourt' 'KingsCollege' 'OldHospital' 'StMarysChurch' python3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset "VislocCambridgeLandmarks('/path/to/prepared/Cambridge_Landmarks/', subscene='${scene}', pairsfile='APGeM-LM18_top50', topk=20)" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/Cambridge_Landmarks/${scene}/loc ``` ================================================ FILE: dust3r_visloc/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). ================================================ FILE: dust3r_visloc/datasets/__init__.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). from .sevenscenes import VislocSevenScenes from .cambridge_landmarks import VislocCambridgeLandmarks from .aachen_day_night import VislocAachenDayNight from .inloc import VislocInLoc ================================================ FILE: dust3r_visloc/datasets/aachen_day_night.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # AachenDayNight dataloader # -------------------------------------------------------- import os from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset class VislocAachenDayNight(BaseVislocColmapDataset): def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): assert subscene in [None, '', 'day', 'night', 'all'] self.subscene = subscene image_path = os.path.join(root, 'images') map_path = os.path.join(root, 'mapping/colmap/reconstruction') query_path = os.path.join(root, 'kapture', 'query') pairsfile_path = os.path.join(root, 'pairsfile/query', pairsfile + '.txt') super().__init__(image_path=image_path, map_path=map_path, query_path=query_path, pairsfile_path=pairsfile_path, topk=topk, cache_sfm=cache_sfm) self.scenes = [filename for filename in self.scenes if filename in self.pairs] if self.subscene == 'day' or self.subscene == 'night': self.scenes = [filename for filename in self.scenes if self.subscene in filename] ================================================ FILE: dust3r_visloc/datasets/base_colmap.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Base class for colmap / kapture # -------------------------------------------------------- import os import numpy as np from tqdm import tqdm import collections import pickle import PIL.Image import torch from scipy.spatial.transform import Rotation import torchvision.transforms as tvf from kapture.core import CameraType from kapture.io.csv import kapture_from_dir from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d from dust3r_visloc.datasets.base_dataset import BaseVislocDataset from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import colmap_to_opencv_intrinsics KaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params') def kapture_to_opencv_intrinsics(sensor): """ Convert from Kapture to OpenCV parameters. Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. Args: sensor: Kapture sensor """ sensor_type = sensor.sensor_params[0] if sensor_type == "SIMPLE_PINHOLE": # Simple pinhole model. # We still call OpenCV undistorsion however for code simplicity. w, h, f, cx, cy = sensor.camera_params k1 = 0 k2 = 0 p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "PINHOLE": w, h, fx, fy, cx, cy = sensor.camera_params k1 = 0 k2 = 0 p1 = 0 p2 = 0 elif sensor_type == "SIMPLE_RADIAL": w, h, f, cx, cy, k1 = sensor.camera_params k2 = 0 p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "RADIAL": w, h, f, cx, cy, k1, k2 = sensor.camera_params p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "OPENCV": w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params else: raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") cameraMatrix = np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) # We assume that Kapture data comes from Colmap: the origin is different. cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) return cameraMatrix, distCoeffs, (w, h) def K_from_colmap(elems): sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) res = dict(resolution=(w, h), intrinsics=cameraMatrix, distortion=distCoeffs) return res def pose_from_qwxyz_txyz(elems): qw, qx, qy, qz, tx, ty, tz = map(float, elems) pose = np.eye(4) pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() pose[:3, 3] = (tx, ty, tz) return np.linalg.inv(pose) # returns cam2world class BaseVislocColmapDataset(BaseVislocDataset): def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False): super().__init__() self.topk = topk self.num_views = self.topk + 1 self.image_path = image_path self.cache_sfm = cache_sfm self._load_sfm(map_path) kdata_query = kapture_from_dir(query_path) assert kdata_query.records_camera is not None and kdata_query.trajectories is not None kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex} self.pairs = get_ordered_pairs_from_file(pairsfile_path) self.scenes = kdata_query.records_camera.data_list() def _load_sfm(self, sfm_dir): sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl') if os.path.isfile(sfm_cache_path) and self.cache_sfm: with open(sfm_cache_path, "rb") as f: data = pickle.load(f) self.img_infos = data['img_infos'] self.points3D = data['points3D'] return # load cameras with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f: raw = f.read().splitlines()[3:] # skip header intrinsics = {} for camera in tqdm(raw): camera = camera.split(' ') intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) # load images with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith('#')] # skip header self.img_infos = {} for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): image = image.split(' ') points = points.split(' ') img_name = image[-1] current_points2D = {int(i): (float(x), float(y)) for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'} self.img_infos[img_name] = dict(intrinsics[int(image[-2])], path=img_name, camera_pose=pose_from_qwxyz_txyz(image[1: -2]), sparse_pts2d=current_points2D) # load 3D points with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith('#')] # skip header self.points3D = {} for point in tqdm(raw): point = point.split() self.points3D[int(point[0])] = tuple(map(float, point[1:4])) if self.cache_sfm: to_save = \ { 'img_infos': self.img_infos, 'points3D': self.points3D } with open(sfm_cache_path, "wb") as f: pickle.dump(to_save, f) def __len__(self): return len(self.scenes) def _get_view_query(self, imgname): kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex']) timestamp, camera_id = searchindex[imgname] camera_params = kdata.sensors[camera_id].camera_params if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: W, H, f, cx, cy = camera_params k1 = 0 fx = fy = f elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: W, H, f, cx, cy, k1 = camera_params fx = fy = f else: raise NotImplementedError('not implemented') W, H = int(W), int(H) intrinsics = np.float32([(fx, 0, cx), (0, fy, cy), (0, 0, 1)]) intrinsics = colmap_to_opencv_intrinsics(intrinsics) distortion = [k1, 0, 0, 0] if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) else: cam_to_world = np.eye(4, dtype=np.float32) # Load RGB image rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB') rgb_image.load() resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) rgb_tensor = resize_func(ImgNorm(rgb_image)) view = { 'intrinsics': intrinsics, 'distortion': distortion, 'cam_to_world': cam_to_world, 'rgb': rgb_image, 'rgb_rescaled': rgb_tensor, 'to_orig': to_orig, 'idx': 0, 'image_name': imgname } return view def _get_view_map(self, imgname, idx): infos = self.img_infos[imgname] rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB') rgb_image.load() W, H = rgb_image.size intrinsics = infos['intrinsics'] intrinsics = colmap_to_opencv_intrinsics(intrinsics) distortion_coefs = infos['distortion'] pts2d = infos['sparse_pts2d'] sparse_pos2d = np.float32(list(pts2d.values())).reshape((-1, 2)) # pts2d from colmap sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3)) # store full resolution 2D->3D sparse_pos2d_cv2 = sparse_pos2d.copy() sparse_pos2d_cv2[:, 0] -= 0.5 sparse_pos2d_cv2[:, 1] -= 0.5 sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & ( sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) sparse_pos2d_int = sparse_pos2d_int[valid] # nan => invalid pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] pts3d = torch.from_numpy(pts3d) cam_to_world = infos['camera_pose'] # cam2world # also store resized resolution 2D->3D resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) rgb_tensor = resize_func(ImgNorm(rgb_image)) HR, WR = rgb_tensor.shape[1:] _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR) pts3d_rescaled = torch.from_numpy(pts3d_rescaled) valid_rescaled = torch.from_numpy(valid_rescaled) view = { 'intrinsics': intrinsics, 'distortion': distortion_coefs, 'cam_to_world': cam_to_world, 'rgb': rgb_image, "pts3d": pts3d, "valid": pts3d.sum(dim=-1).isfinite(), 'rgb_rescaled': rgb_tensor, "pts3d_rescaled": pts3d_rescaled, "valid_rescaled": valid_rescaled, 'to_orig': to_orig, 'idx': idx, 'image_name': imgname } return view def __getitem__(self, idx): assert self.maxdim is not None and self.patch_size is not None query_image = self.scenes[idx] map_images = [p[0] for p in self.pairs[query_image][:self.topk]] views = [] views.append(self._get_view_query(query_image)) for idx, map_image in enumerate(map_images): views.append(self._get_view_map(map_image, idx + 1)) return views ================================================ FILE: dust3r_visloc/datasets/base_dataset.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Base class # -------------------------------------------------------- class BaseVislocDataset: def __init__(self): pass def set_resolution(self, model): self.maxdim = max(model.patch_embed.img_size) self.patch_size = model.patch_embed.patch_size def __len__(self): raise NotImplementedError() def __getitem__(self, idx): raise NotImplementedError() ================================================ FILE: dust3r_visloc/datasets/cambridge_landmarks.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Cambridge Landmarks dataloader # -------------------------------------------------------- import os from dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset class VislocCambridgeLandmarks (BaseVislocColmapDataset): def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False): image_path = os.path.join(root, subscene) map_path = os.path.join(root, 'mapping', subscene, 'colmap/reconstruction') query_path = os.path.join(root, 'kapture', subscene, 'query') pairsfile_path = os.path.join(root, subscene, 'pairsfile/query', pairsfile + '.txt') super().__init__(image_path=image_path, map_path=map_path, query_path=query_path, pairsfile_path=pairsfile_path, topk=topk, cache_sfm=cache_sfm) ================================================ FILE: dust3r_visloc/datasets/inloc.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # InLoc dataloader # -------------------------------------------------------- import os import numpy as np import torch import PIL.Image import scipy.io import kapture from kapture.io.csv import kapture_from_dir from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d from dust3r_visloc.datasets.base_dataset import BaseVislocDataset from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import xy_grid, geotrf def read_alignments(path_to_alignment): aligns = {} with open(path_to_alignment, "r") as fid: while True: line = fid.readline() if not line: break if len(line) == 4: trans_nr = line[:-1] while line != 'After general icp:\n': line = fid.readline() line = fid.readline() p = [] for i in range(4): elems = line.split(' ') line = fid.readline() for e in elems: if len(e) != 0: p.append(float(e)) P = np.array(p).reshape(4, 4) aligns[trans_nr] = P return aligns class VislocInLoc(BaseVislocDataset): def __init__(self, root, pairsfile, topk=1): super().__init__() self.root = root self.topk = topk self.num_views = self.topk + 1 self.maxdim = None self.patch_size = None query_path = os.path.join(self.root, 'query') kdata_query = kapture_from_dir(query_path) assert kdata_query.records_camera is not None kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} map_path = os.path.join(self.root, 'mapping') kdata_map = kapture_from_dir(map_path) assert kdata_map.records_camera is not None and kdata_map.trajectories is not None kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} try: self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt')) except Exception as e: # if using pairs from hloc self.pairs = {} with open(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'), 'r') as fid: lines = fid.readlines() for line in lines: splits = line.rstrip("\n\r").split(" ") self.pairs.setdefault(splits[0].replace('query/', ''), []).append( (splits[1].replace('database/cutouts/', ''), 1.0) ) self.scenes = kdata_query.records_camera.data_list() self.aligns_DUC1 = read_alignments(os.path.join(self.root, 'mapping/DUC1_alignment/all_transformations.txt')) self.aligns_DUC2 = read_alignments(os.path.join(self.root, 'mapping/DUC2_alignment/all_transformations.txt')) def __len__(self): return len(self.scenes) def __getitem__(self, idx): assert self.maxdim is not None and self.patch_size is not None query_image = self.scenes[idx] map_images = [p[0] for p in self.pairs[query_image][:self.topk]] views = [] dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) for map_image in map_images] for idx, (imgname, data, should_load_depth) in enumerate(dataarray): imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) timestamp, camera_id = searchindex[imgname] # for InLoc, SIMPLE_PINHOLE camera_params = kdata.sensors[camera_id].camera_params W, H, f, cx, cy = camera_params distortion = [0, 0, 0, 0] intrinsics = np.float32([(f, 0, cx), (0, f, cy), (0, 0, 1)]) if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories: cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) else: cam_to_world = np.eye(4, dtype=np.float32) # Load RGB image rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') rgb_image.load() W, H = rgb_image.size resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) rgb_tensor = resize_func(ImgNorm(rgb_image)) view = { 'intrinsics': intrinsics, 'distortion': distortion, 'cam_to_world': cam_to_world, 'rgb': rgb_image, 'rgb_rescaled': rgb_tensor, 'to_orig': to_orig, 'idx': idx, 'image_name': imgname } # Load depthmap if should_load_depth: depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname + '.mat') depthmap = scipy.io.loadmat(depthmap_filename) pt3d_cut = depthmap['XYZcut'] scene_id = imgname.replace('\\', '/').split('/')[1] if imgname.startswith('DUC1'): pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut) else: pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut) pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1)) pts3d = pts3d_full[pts3d_valid] pts2d_int = xy_grid(W, H)[pts3d_valid] pts2d = pts2d_int.astype(np.float64) # nan => invalid pts3d_full[~pts3d_valid] = np.nan pts3d_full = torch.from_numpy(pts3d_full) view['pts3d'] = pts3d_full view["valid"] = pts3d_full.sum(dim=-1).isfinite() HR, WR = rgb_tensor.shape[1:] _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) pts3d_rescaled = torch.from_numpy(pts3d_rescaled) valid_rescaled = torch.from_numpy(valid_rescaled) view['pts3d_rescaled'] = pts3d_rescaled view["valid_rescaled"] = valid_rescaled views.append(view) return views ================================================ FILE: dust3r_visloc/datasets/sevenscenes.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # 7 Scenes dataloader # -------------------------------------------------------- import os import numpy as np import torch import PIL.Image import kapture from kapture.io.csv import kapture_from_dir from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file from kapture.io.records import depth_map_from_file from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d from dust3r_visloc.datasets.base_dataset import BaseVislocDataset from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, xy_grid, geotrf class VislocSevenScenes(BaseVislocDataset): def __init__(self, root, subscene, pairsfile, topk=1): super().__init__() self.root = root self.subscene = subscene self.topk = topk self.num_views = self.topk + 1 self.maxdim = None self.patch_size = None query_path = os.path.join(self.root, subscene, 'query') kdata_query = kapture_from_dir(query_path) assert kdata_query.records_camera is not None and kdata_query.trajectories is not None and kdata_query.rigs is not None kapture.rigs_remove_inplace(kdata_query.trajectories, kdata_query.rigs) kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_query.records_camera.key_pairs()} self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex} map_path = os.path.join(self.root, subscene, 'mapping') kdata_map = kapture_from_dir(map_path) assert kdata_map.records_camera is not None and kdata_map.trajectories is not None and kdata_map.rigs is not None kapture.rigs_remove_inplace(kdata_map.trajectories, kdata_map.rigs) kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_map.records_camera.key_pairs()} self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex} self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, subscene, 'pairfiles/query', pairsfile + '.txt')) self.scenes = kdata_query.records_camera.data_list() def __len__(self): return len(self.scenes) def __getitem__(self, idx): assert self.maxdim is not None and self.patch_size is not None query_image = self.scenes[idx] map_images = [p[0] for p in self.pairs[query_image][:self.topk]] views = [] dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True) for map_image in map_images] for idx, (imgname, data, should_load_depth) in enumerate(dataarray): imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex']) timestamp, camera_id = searchindex[imgname] # for 7scenes, SIMPLE_PINHOLE camera_params = kdata.sensors[camera_id].camera_params W, H, f, cx, cy = camera_params distortion = [0, 0, 0, 0] intrinsics = np.float32([(f, 0, cx), (0, f, cy), (0, 0, 1)]) cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) # Load RGB image rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB') rgb_image.load() W, H = rgb_image.size resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W) rgb_tensor = resize_func(ImgNorm(rgb_image)) view = { 'intrinsics': intrinsics, 'distortion': distortion, 'cam_to_world': cam_to_world, 'rgb': rgb_image, 'rgb_rescaled': rgb_tensor, 'to_orig': to_orig, 'idx': idx, 'image_name': imgname } # Load depthmap if should_load_depth: depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname.replace('color.png', 'depth.reg')) depthmap = depth_map_from_file(depthmap_filename, (int(W), int(H))).astype(np.float32) pts3d_full, pts3d_valid = depthmap_to_absolute_camera_coordinates(depthmap, intrinsics, cam_to_world) pts3d = pts3d_full[pts3d_valid] pts2d_int = xy_grid(W, H)[pts3d_valid] pts2d = pts2d_int.astype(np.float64) # nan => invalid pts3d_full[~pts3d_valid] = np.nan pts3d_full = torch.from_numpy(pts3d_full) view['pts3d'] = pts3d_full view["valid"] = pts3d_full.sum(dim=-1).isfinite() HR, WR = rgb_tensor.shape[1:] _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR) pts3d_rescaled = torch.from_numpy(pts3d_rescaled) valid_rescaled = torch.from_numpy(valid_rescaled) view['pts3d_rescaled'] = pts3d_rescaled view["valid_rescaled"] = valid_rescaled views.append(view) return views ================================================ FILE: dust3r_visloc/datasets/utils.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # dataset utilities # -------------------------------------------------------- import numpy as np import quaternion import torchvision.transforms as tvf from dust3r.utils.geometry import geotrf def cam_to_world_from_kapture(kdata, timestamp, camera_id): camera_to_world = kdata.trajectories[timestamp, camera_id].inverse() camera_pose = np.eye(4, dtype=np.float32) camera_pose[:3, :3] = quaternion.as_rotation_matrix(camera_to_world.r) camera_pose[:3, 3] = camera_to_world.t_raw return camera_pose ratios_resolutions = { 224: {1.0: [224, 224]}, 512: {4 / 3: [512, 384], 32 / 21: [512, 336], 16 / 9: [512, 288], 2 / 1: [512, 256], 16 / 5: [512, 160]} } def get_HW_resolution(H, W, maxdim, patchsize=16): assert maxdim in ratios_resolutions, "Error, maxdim can only be 224 or 512 for now. Other maxdims not implemented yet." ratios_resolutions_maxdim = ratios_resolutions[maxdim] mindims = set([min(res) for res in ratios_resolutions_maxdim.values()]) ratio = W / H ref_ratios = np.array([*(ratios_resolutions_maxdim.keys())]) islandscape = (W >= H) if islandscape: diff = np.abs(ratio - ref_ratios) else: diff = np.abs(ratio - (1 / ref_ratios)) selkey = ref_ratios[np.argmin(diff)] res = ratios_resolutions_maxdim[selkey] # check patchsize and make sure output resolution is a multiple of patchsize if isinstance(patchsize, tuple): assert len(patchsize) == 2 and isinstance(patchsize[0], int) and isinstance( patchsize[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints." assert patchsize[0] == patchsize[1], "Error, non square patches not managed" patchsize = patchsize[0] assert max(res) == maxdim assert min(res) in mindims return res[::-1] if islandscape else res # return HW def get_resize_function(maxdim, patch_size, H, W, is_mask=False): if [max(H, W), min(H, W)] in ratios_resolutions[maxdim].values(): return lambda x: x, np.eye(3), np.eye(3) else: target_HW = get_HW_resolution(H, W, maxdim=maxdim, patchsize=patch_size) ratio = W / H target_ratio = target_HW[1] / target_HW[0] to_orig_crop = np.eye(3) to_rescaled_crop = np.eye(3) if abs(ratio - target_ratio) < np.finfo(np.float32).eps: crop_W = W crop_H = H elif ratio - target_ratio < 0: crop_W = W crop_H = int(W / target_ratio) to_orig_crop[1, 2] = (H - crop_H) / 2.0 to_rescaled_crop[1, 2] = -(H - crop_H) / 2.0 else: crop_W = int(H * target_ratio) crop_H = H to_orig_crop[0, 2] = (W - crop_W) / 2.0 to_rescaled_crop[0, 2] = - (W - crop_W) / 2.0 crop_op = tvf.CenterCrop([crop_H, crop_W]) if is_mask: resize_op = tvf.Resize(size=target_HW, interpolation=tvf.InterpolationMode.NEAREST_EXACT) else: resize_op = tvf.Resize(size=target_HW) to_orig_resize = np.array([[crop_W / target_HW[1], 0, 0], [0, crop_H / target_HW[0], 0], [0, 0, 1]]) to_rescaled_resize = np.array([[target_HW[1] / crop_W, 0, 0], [0, target_HW[0] / crop_H, 0], [0, 0, 1]]) op = tvf.Compose([crop_op, resize_op]) return op, to_rescaled_resize @ to_rescaled_crop, to_orig_crop @ to_orig_resize def rescale_points3d(pts2d, pts3d, to_resize, HR, WR): # rescale pts2d as floats # to colmap, so that the image is in [0, D] -> [0, NewD] pts2d = pts2d.copy() pts2d[:, 0] += 0.5 pts2d[:, 1] += 0.5 pts2d_rescaled = geotrf(to_resize, pts2d, norm=True) pts2d_rescaled_int = pts2d_rescaled.copy() # convert back to cv2 before round [-0.5, 0.5] -> pixel 0 pts2d_rescaled_int[:, 0] -= 0.5 pts2d_rescaled_int[:, 1] -= 0.5 pts2d_rescaled_int = pts2d_rescaled_int.round().astype(np.int64) # update valid (remove cropped regions) valid_rescaled = (pts2d_rescaled_int[:, 0] >= 0) & (pts2d_rescaled_int[:, 0] < WR) & ( pts2d_rescaled_int[:, 1] >= 0) & (pts2d_rescaled_int[:, 1] < HR) pts2d_rescaled_int = pts2d_rescaled_int[valid_rescaled] # rebuild pts3d from rescaled ps2d poses pts3d_rescaled = np.full((HR, WR, 3), np.nan, dtype=np.float32) # pts3d in 512 x something pts3d_rescaled[pts2d_rescaled_int[:, 1], pts2d_rescaled_int[:, 0]] = pts3d[valid_rescaled] return pts2d_rescaled, pts2d_rescaled_int, pts3d_rescaled, np.isfinite(pts3d_rescaled.sum(axis=-1)) ================================================ FILE: dust3r_visloc/evaluation.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # evaluation utilities # -------------------------------------------------------- import numpy as np import quaternion import torch import roma import collections import os def aggregate_stats(info_str, pose_errors, angular_errors): stats = collections.Counter() median_pos_error = np.median(pose_errors) median_angular_error = np.median(angular_errors) out_str = f'{info_str}: {len(pose_errors)} images - {median_pos_error=}, {median_angular_error=}' for trl_thr, ang_thr in [(0.1, 1), (0.25, 2), (0.5, 5), (5, 10)]: for pose_error, angular_error in zip(pose_errors, angular_errors): correct_for_this_threshold = (pose_error < trl_thr) and (angular_error < ang_thr) stats[trl_thr, ang_thr] += correct_for_this_threshold stats = {f'acc@{key[0]:g}m,{key[1]}deg': 100 * val / len(pose_errors) for key, val in stats.items()} for metric, perf in stats.items(): out_str += f' - {metric:12s}={float(perf):.3f}' return out_str def get_pose_error(pr_camtoworld, gt_cam_to_world): abs_transl_error = torch.linalg.norm(torch.tensor(pr_camtoworld[:3, 3]) - torch.tensor(gt_cam_to_world[:3, 3])) abs_angular_error = roma.rotmat_geodesic_distance(torch.tensor(pr_camtoworld[:3, :3]), torch.tensor(gt_cam_to_world[:3, :3])) * 180 / np.pi return abs_transl_error, abs_angular_error def export_results(output_dir, xp_label, query_names, poses_pred): if output_dir is not None: os.makedirs(output_dir, exist_ok=True) lines = "" lines_ltvl = "" for query_name, pr_querycam_to_world in zip(query_names, poses_pred): if pr_querycam_to_world is None: pr_world_to_querycam = np.eye(4) else: pr_world_to_querycam = np.linalg.inv(pr_querycam_to_world) query_shortname = os.path.basename(query_name) pr_world_to_querycam_q = quaternion.from_rotation_matrix(pr_world_to_querycam[:3, :3]) pr_world_to_querycam_t = pr_world_to_querycam[:3, 3] line_pose = quaternion.as_float_array(pr_world_to_querycam_q).tolist() + \ pr_world_to_querycam_t.flatten().tolist() line_content = [query_name] + line_pose lines += ' '.join(str(v) for v in line_content) + '\n' line_content_ltvl = [query_shortname] + line_pose lines_ltvl += ' '.join(str(v) for v in line_content_ltvl) + '\n' with open(os.path.join(output_dir, xp_label + '_results.txt'), 'wt') as f: f.write(lines) with open(os.path.join(output_dir, xp_label + '_ltvl.txt'), 'wt') as f: f.write(lines_ltvl) ================================================ FILE: dust3r_visloc/localization.py ================================================ # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # main pnp code # -------------------------------------------------------- import numpy as np import quaternion import cv2 from packaging import version from dust3r.utils.geometry import opencv_to_colmap_intrinsics try: import poselib # noqa HAS_POSELIB = True except Exception as e: HAS_POSELIB = False try: import pycolmap # noqa version_number = pycolmap.__version__ if version.parse(version_number) < version.parse("0.5.0"): HAS_PYCOLMAP = False else: HAS_PYCOLMAP = True except Exception as e: HAS_PYCOLMAP = False def run_pnp(pts2D, pts3D, K, distortion = None, mode='cv2', reprojectionError=5, img_size = None): """ use OPENCV model for distortion (4 values) """ assert mode in ['cv2', 'poselib', 'pycolmap'] try: if len(pts2D) > 4 and mode == "cv2": confidence = 0.9999 iterationsCount = 10_000 if distortion is not None: cv2_pts2ds = np.copy(pts2D) cv2_pts2ds = cv2.undistortPoints(cv2_pts2ds, K, np.array(distortion), R=None, P=K) pts2D = cv2_pts2ds.reshape((-1, 2)) success, r_pose, t_pose, _ = cv2.solvePnPRansac(pts3D, pts2D, K, None, flags=cv2.SOLVEPNP_SQPNP, iterationsCount=iterationsCount, reprojectionError=reprojectionError, confidence=confidence) if not success: return False, None r_pose = cv2.Rodrigues(r_pose)[0] # world2cam == world2cam2 RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] # world2cam2 return True, np.linalg.inv(RT) # cam2toworld elif len(pts2D) > 4 and mode == "poselib": assert HAS_POSELIB confidence = 0.9999 iterationsCount = 10_000 # NOTE: `Camera` struct currently contains `width`/`height` fields, # however these are not used anywhere in the code-base and are provided simply to be consistent with COLMAP. # so we put garbage in there colmap_intrinsics = opencv_to_colmap_intrinsics(K) fx = colmap_intrinsics[0, 0] fy = colmap_intrinsics[1, 1] cx = colmap_intrinsics[0, 2] cy = colmap_intrinsics[1, 2] width = img_size[0] if img_size is not None else int(cx*2) height = img_size[1] if img_size is not None else int(cy*2) if distortion is None: camera = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} else: camera = {'model': 'OPENCV', 'width': width, 'height': height, 'params': [fx, fy, cx, cy] + distortion} pts2D = np.copy(pts2D) pts2D[:, 0] += 0.5 pts2D[:, 1] += 0.5 pose, _ = poselib.estimate_absolute_pose(pts2D, pts3D, camera, {'max_reproj_error': reprojectionError, 'max_iterations': iterationsCount, 'success_prob': confidence}, {}) if pose is None: return False, None RT = pose.Rt # (3x4) RT = np.r_[RT, [(0,0,0,1)]] # world2cam return True, np.linalg.inv(RT) # cam2toworld elif len(pts2D) > 4 and mode == "pycolmap": assert HAS_PYCOLMAP assert img_size is not None pts2D = np.copy(pts2D) pts2D[:, 0] += 0.5 pts2D[:, 1] += 0.5 colmap_intrinsics = opencv_to_colmap_intrinsics(K) fx = colmap_intrinsics[0, 0] fy = colmap_intrinsics[1, 1] cx = colmap_intrinsics[0, 2] cy = colmap_intrinsics[1, 2] width = img_size[0] height = img_size[1] if distortion is None: camera_dict = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]} else: camera_dict = {'model': 'OPENCV', 'width': width, 'height': height, 'params': [fx, fy, cx, cy] + distortion} pycolmap_camera = pycolmap.Camera( model=camera_dict['model'], width=camera_dict['width'], height=camera_dict['height'], params=camera_dict['params']) pycolmap_estimation_options = dict(ransac=dict(max_error=reprojectionError, min_inlier_ratio=0.01, min_num_trials=1000, max_num_trials=100000, confidence=0.9999)) pycolmap_refinement_options=dict(refine_focal_length=False, refine_extra_params=False) ret = pycolmap.absolute_pose_estimation(pts2D, pts3D, pycolmap_camera, estimation_options=pycolmap_estimation_options, refinement_options=pycolmap_refinement_options) if ret is None: ret = {'success': False} else: ret['success'] = True if callable(ret['cam_from_world'].matrix): retmat = ret['cam_from_world'].matrix() else: retmat = ret['cam_from_world'].matrix ret['qvec'] = quaternion.from_rotation_matrix(retmat[:3, :3]) ret['tvec'] = retmat[:3, 3] if not (ret['success'] and ret['num_inliers'] > 0): success = False pose = None else: success = True pr_world_to_querycam = np.r_[ret['cam_from_world'].matrix(), [(0,0,0,1)]] pose = np.linalg.inv(pr_world_to_querycam) return success, pose else: return False, None except Exception as e: print(f'error during pnp: {e}') return False, None ================================================ FILE: requirements.txt ================================================ torch torchvision roma gradio matplotlib tqdm opencv-python scipy einops trimesh tensorboard pyglet<2 huggingface-hub[torch]>=0.22 ================================================ FILE: requirements_optional.txt ================================================ pillow-heif # add heif/heic image support pyrender # for rendering depths in scannetpp kapture # for visloc data loading kapture-localization numpy-quaternion pycolmap # for pnp poselib # for pnp ================================================ FILE: train.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # training executable for DUSt3R # -------------------------------------------------------- from dust3r.training import get_args_parser, train if __name__ == '__main__': args = get_args_parser() args = args.parse_args() train(args) ================================================ FILE: visloc.py ================================================ #!/usr/bin/env python3 # Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Simple visloc script # -------------------------------------------------------- import numpy as np import random import argparse from tqdm import tqdm import math from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf from dust3r_visloc.datasets import * from dust3r_visloc.localization import run_pnp from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval") parser_weights = parser.add_mutually_exclusive_group(required=True) parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None) parser_weights.add_argument("--model_name", type=str, help="name of the model weights", choices=["DUSt3R_ViTLarge_BaseDecoder_512_dpt", "DUSt3R_ViTLarge_BaseDecoder_512_linear", "DUSt3R_ViTLarge_BaseDecoder_224_linear"]) parser.add_argument("--confidence_threshold", type=float, default=3.0, help="confidence values higher than threshold are invalid") parser.add_argument("--device", type=str, default='cuda', help="pytorch device") parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'], help="pnp lib to use") parser_reproj = parser.add_mutually_exclusive_group() parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error") parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None, help="pnp reprojection error as a ratio of the diagonal of the image") parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept") parser.add_argument("--viz_matches", type=int, default=0, help="debug matches") parser.add_argument("--output_dir", type=str, default=None, help="output path") parser.add_argument("--output_label", type=str, default='', help="prefix for results files") return parser if __name__ == '__main__': parser = get_args_parser() args = parser.parse_args() conf_thr = args.confidence_threshold device = args.device pnp_mode = args.pnp_mode reprojection_error = args.reprojection_error reprojection_error_diag_ratio = args.reprojection_error_diag_ratio pnp_max_points = args.pnp_max_points viz_matches = args.viz_matches if args.weights is not None: weights_path = args.weights else: weights_path = "naver/" + args.model_name model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) dataset = eval(args.dataset) dataset.set_resolution(model) query_names = [] poses_pred = [] pose_errors = [] angular_errors = [] for idx in tqdm(range(len(dataset))): views = dataset[(idx)] # 0 is the query query_view = views[0] map_views = views[1:] query_names.append(query_view['image_name']) query_pts2d = [] query_pts3d = [] for map_view in map_views: # prepare batch imgs = [] for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]): imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]), idx=idx, instance=str(idx))) output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False) pred1, pred2 = output['pred1'], output['pred2'] confidence_masks = [pred1['conf'].squeeze(0) >= conf_thr, (pred2['conf'].squeeze(0) >= conf_thr) & map_view['valid_rescaled']] pts3d = [pred1['pts3d'].squeeze(0), pred2['pts3d_in_other_view'].squeeze(0)] # find 2D-2D matches between the two images pts2d_list, pts3d_list = [], [] for i in range(2): conf_i = confidence_masks[i].cpu().numpy() true_shape_i = imgs[i]['true_shape'][0] pts2d_list.append(xy_grid(true_shape_i[1], true_shape_i[0])[conf_i]) pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) PQ, PM = pts3d_list[0], pts3d_list[1] if len(PQ) == 0 or len(PM) == 0: continue reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(PQ, PM) if viz_matches > 0: print(f'found {num_matches} matches') matches_im1 = pts2d_list[1][reciprocal_in_PM] matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] valid_pts3d = map_view['pts3d_rescaled'][matches_im1[:, 1], matches_im1[:, 0]] # from cv2 to colmap matches_im0 = matches_im0.astype(np.float64) matches_im1 = matches_im1.astype(np.float64) matches_im0[:, 0] += 0.5 matches_im0[:, 1] += 0.5 matches_im1[:, 0] += 0.5 matches_im1[:, 1] += 0.5 # rescale coordinates matches_im0 = geotrf(query_view['to_orig'], matches_im0, norm=True) matches_im1 = geotrf(query_view['to_orig'], matches_im1, norm=True) # from colmap back to cv2 matches_im0[:, 0] -= 0.5 matches_im0[:, 1] -= 0.5 matches_im1[:, 0] -= 0.5 matches_im1[:, 1] -= 0.5 # visualize a few matches if viz_matches > 0: viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])] from matplotlib import pyplot as pl n_viz = viz_matches match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2] img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) img = np.concatenate((img0, img1), axis=1) pl.figure() pl.imshow(img) cmap = pl.get_cmap('jet') for i in range(n_viz): (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) pl.show(block=True) if len(valid_pts3d) == 0: pass else: query_pts3d.append(valid_pts3d.cpu().numpy()) query_pts2d.append(matches_im0) if len(query_pts2d) == 0: success = False pr_querycam_to_world = None else: query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32) query_pts3d = np.concatenate(query_pts3d, axis=0) if len(query_pts2d) > pnp_max_points: idxs = random.sample(range(len(query_pts2d)), pnp_max_points) query_pts3d = query_pts3d[idxs] query_pts2d = query_pts2d[idxs] W, H = query_view['rgb'].size if reprojection_error_diag_ratio is not None: reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2) else: reprojection_error_img = reprojection_error success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d, query_view['intrinsics'], query_view['distortion'], pnp_mode, reprojection_error_img, img_size=[W, H]) if not success: abs_transl_error = float('inf') abs_angular_error = float('inf') else: abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world']) pose_errors.append(abs_transl_error) angular_errors.append(abs_angular_error) poses_pred.append(pr_querycam_to_world) xp_label = f'tol_conf_{conf_thr}' if args.output_label: xp_label = args.output_label + '_' + xp_label if reprojection_error_diag_ratio is not None: xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}' else: xp_label = xp_label + f'_reproj_err_{reprojection_error}' export_results(args.output_dir, xp_label, query_names, poses_pred) out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors) print(out_string)