[
  {
    "path": ".gitignore",
    "content": "data/\ncheckpoints/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"croco\"]\n\tpath = croco\n\turl = https://github.com/naver/croco\n"
  },
  {
    "path": "LICENSE",
    "content": "DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.\n\nA summary of the CC BY-NC-SA 4.0 license is located here:\n\thttps://creativecommons.org/licenses/by-nc-sa/4.0/\n\nThe CC BY-NC-SA 4.0 license is located here:\n\thttps://creativecommons.org/licenses/by-nc-sa/4.0/legalcode\n"
  },
  {
    "path": "NOTICE",
    "content": "DUSt3R\nCopyright 2024-present NAVER Corp.\n\nThis project contains subcomponents with separate copyright notices and license terms. \nYour use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.\n\n====\n\nnaver/croco\nhttps://github.com/naver/croco/\n\nCreative Commons Attribution-NonCommercial-ShareAlike 4.0\n"
  },
  {
    "path": "README.md",
    "content": "![demo](assets/dust3r.jpg)\n\nOfficial implementation of `DUSt3R: Geometric 3D Vision Made Easy`  \n[[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]  \n\n> Make sure to also check our other works:  \n> [Grounding Image Matching in 3D with MASt3R](https://github.com/naver/mast3r): DUSt3R with a local feature head, metric pointmaps, and a more scalable global alignment!  \n> [Pow3R: Empowering Unconstrained 3D Reconstruction with Camera and Scene Priors](https://github.com/naver/pow3r): DUSt3R with known depth / focal length / poses.  \n> [MUSt3R: Multi-view Network for Stereo 3D Reconstruction](https://github.com/naver/must3r): Multi-view predictions (RGB SLAM/SfM) without any global alignment.    \n\n![Example of reconstruction from two images](assets/pipeline1.jpg)\n\n![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg)\n\n```bibtex\n@inproceedings{dust3r_cvpr24,\n      title={DUSt3R: Geometric 3D Vision Made Easy}, \n      author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},\n      booktitle = {CVPR},\n      year = {2024}\n}\n\n@misc{dust3r_arxiv23,\n      title={DUSt3R: Geometric 3D Vision Made Easy}, \n      author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},\n      year={2023},\n      eprint={2312.14132},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n## Table of Contents\n\n- [Table of Contents](#table-of-contents)\n- [License](#license)\n- [Get Started](#get-started)\n  - [Installation](#installation)\n  - [Checkpoints](#checkpoints)\n  - [Interactive demo](#interactive-demo)\n  - [Interactive demo with docker](#interactive-demo-with-docker)\n- [Usage](#usage)\n- [Training](#training)\n  - [Datasets](#datasets)\n  - [Demo](#demo)\n  - [Our Hyperparameters](#our-hyperparameters)\n\n## License\n\nThe code is distributed under the CC BY-NC-SA 4.0 License.\nSee [LICENSE](LICENSE) for more information.\n\n```python\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n```\n\n## Get Started\n\n### Installation\n\n1. Clone DUSt3R.\n```bash\ngit clone --recursive https://github.com/naver/dust3r\ncd dust3r\n# if you have already cloned dust3r:\n# git submodule update --init --recursive\n```\n\n2. Create the environment, here we show an example using conda.\n```bash\nconda create -n dust3r python=3.11 cmake=3.14.0\nconda activate dust3r \nconda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia  # use the correct version of cuda for your system\npip install -r requirements.txt\n# Optional: you can also install additional packages to:\n# - add support for HEIC images\n# - add pyrender, used to render depthmap in some datasets preprocessing\n# - add required packages for visloc.py\npip install -r requirements_optional.txt\n```\n\n3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).\n```bash\n# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.\ncd croco/models/curope/\npython setup.py build_ext --inplace\ncd ../../../\n```\n\n### Checkpoints\n\nYou can obtain the checkpoints by two ways:\n\n1) You can use our huggingface_hub integration: the models will be downloaded automatically.\n\n2) Otherwise, We provide several pre-trained models:\n\n| Modelname   | Training resolutions | Head | Encoder | Decoder |\n|-------------|----------------------|------|---------|---------|\n| [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |\n| [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth)   | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |\n| [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |\n\nYou can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)\n\nTo download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`:\n```bash\nmkdir -p checkpoints/\nwget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/\n```\n\nFor the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details.\n\n### Interactive demo\n\nIn this demo, you should be able run DUSt3R on your machine to reconstruct a scene.\nFirst select images that depicts the same scene.\n\nYou can adjust the global alignment schedule and its number of iterations.\n\n> [!NOTE]\n> If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)\n\nHit \"Run\" and wait.\nWhen the global alignment ends, the reconstruction appears.\nUse the slider \"min_conf_thr\" to show or remove low confidence areas.\n\n```bash\npython3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt\n\n# Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\n# Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224\n# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually\n# Use --server_port to change the port, by default it will search for an available port starting at 7860\n# Use --device to use a different device, by default it's \"cuda\"\n```\n\n### Interactive demo with docker\n\nTo run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions:\n\n1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).\n\n2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).\n\n3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands: \n\n```bash\ncd docker\nbash run.sh --with-cuda --model_name=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt\"\n```\n\nOr if you want to run the demo without CUDA support, run the following command:\n\n```bash \ncd docker\nbash run.sh --model_name=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt\"\n```\n\nBy default, `demo.py` is lanched with the option `--local_network`.  \nVisit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network).  \n\n`run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh).\n\n\n![demo](assets/demo.jpg)\n\n## Usage\n\n```python\nfrom dust3r.inference import inference\nfrom dust3r.model import AsymmetricCroCo3DStereo\nfrom dust3r.utils.image import load_images\nfrom dust3r.image_pairs import make_pairs\nfrom dust3r.cloud_opt import global_aligner, GlobalAlignerMode\n\nif __name__ == '__main__':\n    device = 'cuda'\n    batch_size = 1\n    schedule = 'cosine'\n    lr = 0.01\n    niter = 300\n\n    model_name = \"naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt\"\n    # you can put the path to a local checkpoint in model_name if needed\n    model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)\n    # load_images can take a list of images or a directory\n    images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)\n    pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)\n    output = inference(pairs, model, device, batch_size=batch_size)\n\n    # at this stage, you have the raw dust3r predictions\n    view1, pred1 = output['view1'], output['pred1']\n    view2, pred2 = output['view2'], output['pred2']\n    # here, view1, pred1, view2, pred2 are dicts of lists of len(2)\n    #  -> because we symmetrize we have (im1, im2) and (im2, im1) pairs\n    # in each view you have:\n    # an integer image identifier: view1['idx'] and view2['idx']\n    # the img: view1['img'] and view2['img']\n    # the image shape: view1['true_shape'] and view2['true_shape']\n    # an instance string output by the dataloader: view1['instance'] and view2['instance']\n    # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']\n    # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']\n    # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']\n\n    # next we'll use the global_aligner to align the predictions\n    # depending on your task, you may be fine with the raw output and not need it\n    # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output\n    # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment\n    scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)\n    loss = scene.compute_global_alignment(init=\"mst\", niter=niter, schedule=schedule, lr=lr)\n\n    # retrieve useful values from scene:\n    imgs = scene.imgs\n    focals = scene.get_focals()\n    poses = scene.get_im_poses()\n    pts3d = scene.get_pts3d()\n    confidence_masks = scene.get_masks()\n\n    # visualize reconstruction\n    scene.show()\n\n    # find 2D-2D matches between the two images\n    from dust3r.utils.geometry import find_reciprocal_matches, xy_grid\n    pts2d_list, pts3d_list = [], []\n    for i in range(2):\n        conf_i = confidence_masks[i].cpu().numpy()\n        pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)\n        pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])\n    reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)\n    print(f'found {num_matches} matches')\n    matches_im1 = pts2d_list[1][reciprocal_in_P2]\n    matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]\n\n    # visualize a few matches\n    import numpy as np\n    from matplotlib import pyplot as pl\n    n_viz = 10\n    match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)\n    viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]\n\n    H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]\n    img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n    img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n    img = np.concatenate((img0, img1), axis=1)\n    pl.figure()\n    pl.imshow(img)\n    cmap = pl.get_cmap('jet')\n    for i in range(n_viz):\n        (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T\n        pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)\n    pl.show(block=True)\n\n```\n![matching example on croco pair](assets/matching.jpg)\n\n## Training\n\nIn this section, we present a short demonstration to get started with training DUSt3R.\n\n### Datasets\nAt this moment, we have added the following training datasets:\n  - [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE)\n  - [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license)\n  - [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf)\n  - [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/)\n  - [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/)\n  - [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md)\n  - [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/)\n  - [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d)\n  - [WildRGB-D](https://github.com/wildrgbd/wildrgbd/)\n\nFor each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed.\nYou have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script.\n\nLinks:  \n  \n[ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip)  \n[ScanNet++ v1 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip)  \n[ScanNet++ v2 pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_v2_pairs.zip)  \n[BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy)  \n[WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz)  \n[Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz)  \n[MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz)  \n[StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy)  \n\n> [!NOTE]\n> They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough.\n\n### Demo\nFor this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.\nThe demo model will be trained for a few epochs on a very small dataset.\nIt will not be very good.\n\n```bash\n# download and prepare the co3d subset\nmkdir -p data/co3d_subset\ncd data/co3d_subset\ngit clone https://github.com/facebookresearch/co3d\ncd co3d\npython3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset\nrm ../*.zip\ncd ../../..\n\npython3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed  --single_sequence_subset\n\n# download the pretrained croco v2 checkpoint\nmkdir -p checkpoints/\nwget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/\n\n# the training of dust3r is done in 3 steps.\n# for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: \"Our Hyperparameters\"\n# step 1 - train dust3r for 224 resolution\ntorchrun --nproc_per_node=4 train.py \\\n    --train_dataset \"1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)\" \\\n    --test_dataset \"100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)\" \\\n    --model \"AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --train_criterion \"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion \"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --pretrained \"checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth\" \\\n    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \\\n    --save_freq 1 --keep_freq 5 --eval_freq 1 \\\n    --output_dir \"checkpoints/dust3r_demo_224\"\t  \n\n# step 2 - train dust3r for 512 resolution\ntorchrun --nproc_per_node=4 train.py \\\n    --train_dataset \"1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)\" \\\n    --test_dataset \"100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)\" \\\n    --model \"AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --train_criterion \"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion \"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --pretrained \"checkpoints/dust3r_demo_224/checkpoint-best.pth\" \\\n    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \\\n    --save_freq 1 --keep_freq 5 --eval_freq 1 \\\n    --output_dir \"checkpoints/dust3r_demo_512\"\n\n# step 3 - train dust3r for 512 resolution with dpt\ntorchrun --nproc_per_node=4 train.py \\\n    --train_dataset \"1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)\" \\\n    --test_dataset \"100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)\" \\\n    --model \"AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --train_criterion \"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion \"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --pretrained \"checkpoints/dust3r_demo_512/checkpoint-best.pth\" \\\n    --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \\\n    --save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \\\n    --output_dir \"checkpoints/dust3r_demo_512dpt\"\n\n```\n\n### Our Hyperparameters\n\nHere are the commands we used for training the models:\n\n```bash\n# NOTE: ROOT path omitted for datasets\n# 224 linear\ntorchrun --nproc_per_node 8 train.py \\\n    --train_dataset=\" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) \" \\\n    --test_dataset=\" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) \" \\\n    --train_criterion=\"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion=\"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --model=\"AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --pretrained=\"checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth\" \\\n    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \\\n    --save_freq=5 --keep_freq=10 --eval_freq=1 \\\n    --output_dir=\"checkpoints/dust3r_224\"\n\n# 512 linear\ntorchrun --nproc_per_node 8 train.py \\\n    --train_dataset=\" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) \" \\\n    --test_dataset=\" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) \" \\\n    --train_criterion=\"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion=\"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --model=\"AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --pretrained=\"checkpoints/dust3r_224/checkpoint-best.pth\" \\\n    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \\\n    --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \\\n    --output_dir=\"checkpoints/dust3r_512\"\n\n# 512 dpt\ntorchrun --nproc_per_node 8 train.py \\\n    --train_dataset=\" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) \" \\\n    --test_dataset=\" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) \" \\\n    --train_criterion=\"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\" \\\n    --test_criterion=\"Regr3D_ScaleShiftInv(L21, gt_scale=True)\" \\\n    --model=\"AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)\" \\\n    --pretrained=\"checkpoints/dust3r_512/checkpoint-best.pth\" \\\n    --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \\\n    --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \\\n    --output_dir=\"checkpoints/dust3r_512dpt\"\n\n```\n"
  },
  {
    "path": "datasets_preprocess/habitat/README.md",
    "content": "## Steps to reproduce synthetic training data using the Habitat-Sim simulator\n\n### Create a conda environment\n```bash\nconda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge\nconda active habitat\nconda install pytorch -c pytorch\npip install opencv-python tqdm\n```\n\nor (if you get the error `For headless systems, compile with --headless for EGL support`)\n```\ngit clone --branch stable https://github.com/facebookresearch/habitat-sim.git\ncd habitat-sim\n\nconda create -n habitat python=3.9 cmake=3.14.0\nconda activate habitat\npip install . -v\nconda install pytorch -c pytorch\npip install opencv-python tqdm\n```\n\n### Download Habitat-Sim scenes\nDownload Habitat-Sim scenes:\n- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md\n- We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets.\n- Please put the scenes in a directory `$SCENES_DIR` following the structure below:\n(Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting.\nThe correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`).\n```\n$SCENES_DIR/\n├──hm3d/\n├──gibson/\n├──habitat-test-scenes/\n├──ReplicaCAD_baked_lighting/\n└──scannet/\n```\n\n### Download renderings metadata \n\nDownload metadata corresponding to each scene and extract them into a directory `$METADATA_DIR`\n```bash\nwget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz\ntar -xvzf habitat_5views_v1_512x512_metadata.tar.gz\n```\n\n### Render the scenes\n\nRender the scenes in an output directory `$OUTPUT_DIR`\n```bash\nexport METADATA_DIR=\"/path/to/habitat/5views_v1_512x512_metadata\"\nexport SCENES_DIR=\"/path/to/habitat/data/scene_datasets/\"\nexport OUTPUT_DIR=\"data/habitat_processed\"\ncd datasets_preprocess/habitat/\nexport PYTHONPATH=$(pwd)\n# Print commandlines to generate images corresponding to each scene\npython preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR\n# Launch these commandlines in parallel e.g. using GNU-Parallel as follows:\npython preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16\n```\n\n### Make a list of scenes\n\n```bash\npython find_scenes.py --root $OUTPUT_DIR\n```"
  },
  {
    "path": "datasets_preprocess/habitat/find_scenes.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Script to export the list of scenes for habitat (after having rendered them).\n# Usage:\n# python3 datasets_preprocess/preprocess_co3d.py --root data/habitat_processed\n# --------------------------------------------------------\nimport numpy as np\nimport os\nfrom collections import defaultdict\nfrom tqdm import tqdm\n\n\ndef find_all_scenes(habitat_root, n_scenes=[100000]):\n    np.random.seed(777)\n\n    try:\n        fpath = os.path.join(habitat_root, f'Habitat_all_scenes.txt')\n        list_subscenes = open(fpath).read().splitlines()\n\n    except IOError:\n        if input('parsing sub-folders to find scenes? (y/n) ') != 'y':\n            return\n        list_subscenes = []\n        for root, dirs, files in tqdm(os.walk(habitat_root)):\n            for f in files:\n                if not f.endswith('_1_depth.exr'):\n                    continue\n                scene = os.path.join(os.path.relpath(root, habitat_root), f.replace('_1_depth.exr', ''))\n                if hash(scene) % 1000 == 0:\n                    print('... adding', scene)\n                list_subscenes.append(scene)\n\n        with open(fpath, 'w') as f:\n            f.write('\\n'.join(list_subscenes))\n        print(f'>> wrote {fpath}')\n\n    print(f'Loaded {len(list_subscenes)} sub-scenes')\n\n    # separate scenes\n    list_scenes = defaultdict(list)\n    for scene in list_subscenes:\n        scene, id = os.path.split(scene)\n        list_scenes[scene].append(id)\n\n    list_scenes = list(list_scenes.items())\n    print(f'from {len(list_scenes)} scenes in total')\n\n    np.random.shuffle(list_scenes)\n    train_scenes = list_scenes[len(list_scenes) // 10:]\n    val_scenes = list_scenes[:len(list_scenes) // 10]\n\n    def write_scene_list(scenes, n, fpath):\n        sub_scenes = [os.path.join(scene, id) for scene, ids in scenes for id in ids]\n        np.random.shuffle(sub_scenes)\n\n        if len(sub_scenes) < n:\n            return\n\n        with open(fpath, 'w') as f:\n            f.write('\\n'.join(sub_scenes[:n]))\n        print(f'>> wrote {fpath}')\n\n    for n in n_scenes:\n        write_scene_list(train_scenes, n, os.path.join(habitat_root, f'Habitat_{n}_scenes_train.txt'))\n        write_scene_list(val_scenes, n // 10, os.path.join(habitat_root, f'Habitat_{n//10}_scenes_val.txt'))\n\n\nif __name__ == \"__main__\":\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--root\", required=True)\n    parser.add_argument(\"--n_scenes\", nargs='+', default=[1_000, 10_000, 100_000, 1_000_000], type=int)\n\n    args = parser.parse_args()\n    find_all_scenes(args.root, args.n_scenes)\n"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/habitat_sim_envmaps_renderer.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Render environment maps from 3D meshes using the Habitat Sim simulator.\n# --------------------------------------------------------\nimport numpy as np\nimport habitat_sim\nimport math\nfrom habitat_renderer import projections\n\n# OpenCV to habitat camera convention transformation\nR_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)\n\nCUBEMAP_FACE_LABELS = [\"left\", \"front\", \"right\", \"back\", \"up\", \"down\"]\n# Expressed while considering Habitat coordinates systems\nCUBEMAP_FACE_ORIENTATIONS_ROTVEC = [\n    [0, math.pi / 2, 0],  # Left\n    [0, 0, 0],  # Front\n                [0, - math.pi / 2, 0],  # Right\n                [0, math.pi, 0],  # Back\n                [math.pi / 2, 0, 0],  # Up\n                [-math.pi / 2, 0, 0],]  # Down\n\nclass NoNaviguableSpaceError(RuntimeError):\n    def __init__(self, *args):\n        super().__init__(*args)\n\nclass HabitatEnvironmentMapRenderer:\n    def __init__(self,\n                 scene,\n                 navmesh,\n                 scene_dataset_config_file,\n                 render_equirectangular=False,\n                 equirectangular_resolution=(512, 1024),\n                 render_cubemap=False,\n                 cubemap_resolution=(512, 512),\n                 render_depth=False,\n                 gpu_id=0):\n        self.scene = scene\n        self.navmesh = navmesh\n        self.scene_dataset_config_file = scene_dataset_config_file\n        self.gpu_id = gpu_id\n\n        self.render_equirectangular = render_equirectangular\n        self.equirectangular_resolution = equirectangular_resolution\n        self.equirectangular_projection = projections.EquirectangularProjection(*equirectangular_resolution)\n        # 3D unit ray associated to each pixel of the equirectangular map\n        equirectangular_rays = projections.get_projection_rays(self.equirectangular_projection)\n        # Not needed, but just in case.\n        equirectangular_rays /= np.linalg.norm(equirectangular_rays, axis=-1, keepdims=True)\n        # Depth map created by Habitat are produced by warping a cubemap,\n        # so the values do not correspond to distance to the center and need some scaling.\n        self.equirectangular_depth_scale_factors = 1.0 / np.max(np.abs(equirectangular_rays), axis=-1)\n\n        self.render_cubemap = render_cubemap\n        self.cubemap_resolution = cubemap_resolution\n\n        self.render_depth = render_depth\n\n        self.seed = None\n        self._lazy_initialization()\n\n    def _lazy_initialization(self):\n        # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly\n        if self.seed == None:\n            # Re-seed numpy generator\n            np.random.seed()\n            self.seed = np.random.randint(2**32-1)\n            sim_cfg = habitat_sim.SimulatorConfiguration()\n            sim_cfg.scene_id = self.scene\n            if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != \"\":\n                sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file\n            sim_cfg.random_seed = self.seed\n            sim_cfg.load_semantic_mesh = False\n            sim_cfg.gpu_device_id = self.gpu_id\n\n            sensor_specifications = []\n\n            # Add cubemaps\n            if self.render_cubemap:\n                for face_id, orientation in enumerate(CUBEMAP_FACE_ORIENTATIONS_ROTVEC):\n                    rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n                    rgb_sensor_spec.uuid = f\"color_cubemap_{CUBEMAP_FACE_LABELS[face_id]}\"\n                    rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n                    rgb_sensor_spec.resolution = self.cubemap_resolution\n                    rgb_sensor_spec.hfov = 90\n                    rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n                    rgb_sensor_spec.orientation = orientation\n                    sensor_specifications.append(rgb_sensor_spec)\n\n                    if self.render_depth:\n                        depth_sensor_spec = habitat_sim.CameraSensorSpec()\n                        depth_sensor_spec.uuid = f\"depth_cubemap_{CUBEMAP_FACE_LABELS[face_id]}\"\n                        depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH\n                        depth_sensor_spec.resolution = self.cubemap_resolution\n                        depth_sensor_spec.hfov = 90\n                        depth_sensor_spec.position = [0.0, 0.0, 0.0]\n                        depth_sensor_spec.orientation = orientation\n                        sensor_specifications.append(depth_sensor_spec)\n\n            # Add equirectangular map\n            if self.render_equirectangular:\n                rgb_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec()\n                rgb_sensor_spec.uuid = \"color_equirectangular\"\n                rgb_sensor_spec.resolution = self.equirectangular_resolution\n                rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n                sensor_specifications.append(rgb_sensor_spec)\n\n                if self.render_depth:\n                    depth_sensor_spec = habitat_sim.bindings.EquirectangularSensorSpec()\n                    depth_sensor_spec.uuid = \"depth_equirectangular\"\n                    depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH\n                    depth_sensor_spec.resolution = self.equirectangular_resolution\n                    depth_sensor_spec.position = [0.0, 0.0, 0.0]\n                    depth_sensor_spec.orientation\n                    sensor_specifications.append(depth_sensor_spec)\n\n            agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=sensor_specifications)\n\n            cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n            self.sim = habitat_sim.Simulator(cfg)\n            if self.navmesh is not None and self.navmesh != \"\":\n                # Use pre-computed navmesh (the one generated automatically does some weird stuffs like going on top of the roof)\n                # See https://youtu.be/kunFMRJAu2U?t=1522 regarding navmeshes\n                self.sim.pathfinder.load_nav_mesh(self.navmesh)\n\n            # Check that the navmesh is not empty\n            if not self.sim.pathfinder.is_loaded:\n                # Try to compute a navmesh\n                navmesh_settings = habitat_sim.NavMeshSettings()\n                navmesh_settings.set_defaults()\n                self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)\n\n            # Check that the navmesh is not empty\n            if not self.sim.pathfinder.is_loaded:\n                raise NoNaviguableSpaceError(f\"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})\")\n\n            self.agent = self.sim.initialize_agent(agent_id=0)\n\n    def close(self):\n        if hasattr(self, 'sim'):\n            self.sim.close()\n\n    def __del__(self):\n        self.close()\n\n    def render_viewpoint(self, viewpoint_position):\n        agent_state = habitat_sim.AgentState()\n        agent_state.position = viewpoint_position\n        # agent_state.rotation = viewpoint_orientation\n        self.agent.set_state(agent_state)\n        viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)\n\n        try:\n            # Depth map values have been obtained using cubemap rendering internally,\n            # so they do not really correspond to distance to the viewpoint in practice\n            # and they need some scaling\n            viewpoint_observations[\"depth_equirectangular\"] *= self.equirectangular_depth_scale_factors\n        except KeyError:\n            pass\n\n        data = dict(observations=viewpoint_observations, position=viewpoint_position)\n        return data\n\n    def up_direction(self):\n        return np.asarray(habitat_sim.geo.UP).tolist()\n    \n    def R_cam_to_world(self):\n        return R_OPENCV2HABITAT.tolist()\n"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/multiview_crop_generator.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Generate pairs of crops from a dataset of environment maps.\n# --------------------------------------------------------\nimport os\nimport numpy as np\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"  # noqa\nimport cv2\nimport collections\nfrom habitat_renderer import projections, projections_conversions\nfrom habitat_renderer.habitat_sim_envmaps_renderer import HabitatEnvironmentMapRenderer\n\nViewpointData = collections.namedtuple(\"ViewpointData\", [\"colormap\", \"distancemap\", \"pointmap\", \"position\"])\n\nclass HabitatMultiviewCrops:\n    def __init__(self,\n                 scene,\n                 navmesh,\n                 scene_dataset_config_file,\n                 equirectangular_resolution=(400, 800),\n                 crop_resolution=(240, 320),\n                 pixel_jittering_iterations=5,\n                 jittering_noise_level=1.0):\n        self.crop_resolution = crop_resolution\n\n        self.pixel_jittering_iterations = pixel_jittering_iterations\n        self.jittering_noise_level = jittering_noise_level\n\n        # Instanciate the low resolution habitat sim renderer\n        self.lowres_envmap_renderer = HabitatEnvironmentMapRenderer(scene=scene,\n                                                                    navmesh=navmesh,\n                                                                    scene_dataset_config_file=scene_dataset_config_file,\n                                                                    equirectangular_resolution=equirectangular_resolution,\n                                                                    render_depth=True,\n                                                                    render_equirectangular=True)\n        self.R_cam_to_world = np.asarray(self.lowres_envmap_renderer.R_cam_to_world())\n        self.up_direction = np.asarray(self.lowres_envmap_renderer.up_direction())\n\n        # Projection applied by each environment map\n        self.envmap_height, self.envmap_width = self.lowres_envmap_renderer.equirectangular_resolution\n        base_projection = projections.EquirectangularProjection(self.envmap_height, self.envmap_width)\n        self.envmap_projection = projections.RotatedProjection(base_projection, self.R_cam_to_world.T)\n        # 3D Rays map associated to each envmap\n        self.envmap_rays = projections.get_projection_rays(self.envmap_projection)\n\n    def compute_pointmap(self, distancemap, position):\n        # Point cloud associated to each ray\n        return self.envmap_rays * distancemap[:, :, None] + position\n\n    def render_viewpoint_data(self, position):\n        data = self.lowres_envmap_renderer.render_viewpoint(np.asarray(position))\n        colormap = data['observations']['color_equirectangular'][..., :3]  # Ignore the alpha channel\n        distancemap = data['observations']['depth_equirectangular']\n        pointmap = self.compute_pointmap(distancemap, position)\n        return ViewpointData(colormap=colormap, distancemap=distancemap, pointmap=pointmap, position=position)\n\n    def extract_cropped_camera(self, projection, color_image, distancemap, pointmap, voxelmap=None):\n        remapper = projections_conversions.RemapProjection(input_projection=self.envmap_projection, output_projection=projection,\n                                                           pixel_jittering_iterations=self.pixel_jittering_iterations, jittering_noise_level=self.jittering_noise_level)\n        cropped_color_image = remapper.convert(\n            color_image, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False)\n        cropped_distancemap = remapper.convert(\n            distancemap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True)\n        cropped_pointmap = remapper.convert(pointmap, interpolation=cv2.INTER_NEAREST,\n                                            borderMode=cv2.BORDER_WRAP, single_map=True)\n        cropped_voxelmap = (None if voxelmap is None else\n                            remapper.convert(voxelmap, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_WRAP, single_map=True))\n        # Convert the distance map into a depth map\n        cropped_depthmap = np.asarray(\n            cropped_distancemap / np.linalg.norm(remapper.output_rays, axis=-1), dtype=cropped_distancemap.dtype)\n\n        return cropped_color_image, cropped_depthmap, cropped_pointmap, cropped_voxelmap\n\ndef perspective_projection_to_dict(persp_projection, position):\n    \"\"\"\n    Serialization-like function.\"\"\"\n    camera_params = dict(camera_intrinsics=projections.colmap_to_opencv_intrinsics(persp_projection.base_projection.K).tolist(),\n                         size=(persp_projection.base_projection.width, persp_projection.base_projection.height),\n                         R_cam2world=persp_projection.R_to_base_projection.T.tolist(),\n                         t_cam2world=position)\n    return camera_params\n\n\ndef dict_to_perspective_projection(camera_params):\n    K = projections.opencv_to_colmap_intrinsics(np.asarray(camera_params[\"camera_intrinsics\"]))\n    size = camera_params[\"size\"]\n    R_cam2world = np.asarray(camera_params[\"R_cam2world\"])\n    projection = projections.PerspectiveProjection(K, height=size[1], width=size[0])\n    projection = projections.RotatedProjection(projection, R_to_base_projection=R_cam2world.T)\n    position = camera_params[\"t_cam2world\"]\n    return projection, position"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/projections.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Various 3D/2D projection utils, useful to sample virtual cameras.\n# --------------------------------------------------------\nimport numpy as np\n\nclass EquirectangularProjection:\n    \"\"\"\n    Convention for the central pixel of the equirectangular map similar to OpenCV perspective model:\n        +X from left to right\n        +Y from top to bottom\n        +Z going outside the camera\n    EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5))\n    \"\"\"\n\n    def __init__(self, height, width):\n        self.height = height\n        self.width = width\n        self.u_scaling = (2 * np.pi) / self.width\n        self.v_scaling = np.pi / self.height\n\n    def unproject(self, u, v):\n        \"\"\"\n        Args:\n            u, v: 2D coordinates\n        Returns:\n            unnormalized 3D rays.\n        \"\"\"\n        longitude = self.u_scaling * u - np.pi\n        minus_latitude = self.v_scaling * v - np.pi/2\n\n        cos_latitude = np.cos(minus_latitude)\n        x, z = np.sin(longitude) * cos_latitude, np.cos(longitude) * cos_latitude\n        y = np.sin(minus_latitude)\n\n        rays = np.stack([x, y, z], axis=-1)\n        return rays\n\n    def project(self, rays):\n        \"\"\"\n        Args:\n            rays: Bx3 array of 3D rays.\n        Returns:\n            u, v: tuple of 2D coordinates.\n        \"\"\"\n        rays = rays / np.linalg.norm(rays, axis=-1, keepdims=True)\n        x, y, z = [rays[..., i] for i in range(3)]\n\n        longitude = np.arctan2(x, z)\n        minus_latitude = np.arcsin(y)\n\n        u = (longitude + np.pi) * (1.0 / self.u_scaling)\n        v = (minus_latitude + np.pi/2) * (1.0 / self.v_scaling)\n        return u, v\n\n\nclass PerspectiveProjection:\n    \"\"\"\n    OpenCV convention:\n    World space:\n        +X from left to right\n        +Y from top to bottom\n        +Z going outside the camera\n    Pixel space:\n        +u from left to right\n        +v from top to bottom\n    EXCEPT that the top left corner of the image is assumed to have (0,0) coordinates (OpenCV assumes (-0.5,-0.5)).\n    \"\"\"\n\n    def __init__(self, K, height, width):\n        self.height = height\n        self.width = width\n        self.K = K\n        self.Kinv = np.linalg.inv(K)\n\n    def project(self, rays):\n        uv_homogeneous = np.einsum(\"ik, ...k -> ...i\", self.K, rays)\n        uv = uv_homogeneous[..., :2] / uv_homogeneous[..., 2, None]\n        return uv[..., 0], uv[..., 1]\n\n    def unproject(self, u, v):\n        uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1)\n        rays = np.einsum(\"ik, ...k -> ...i\", self.Kinv, uv_homogeneous)\n        return rays\n\n\nclass RotatedProjection:\n    def __init__(self, base_projection, R_to_base_projection):\n        self.base_projection = base_projection\n        self.R_to_base_projection = R_to_base_projection\n\n    @property\n    def width(self):\n        return self.base_projection.width\n\n    @property\n    def height(self):\n        return self.base_projection.height\n\n    def project(self, rays):\n        if self.R_to_base_projection is not None:\n            rays = np.einsum(\"ik, ...k -> ...i\", self.R_to_base_projection, rays)\n        return self.base_projection.project(rays)\n\n    def unproject(self, u, v):\n        rays = self.base_projection.unproject(u, v)\n        if self.R_to_base_projection is not None:\n            rays = np.einsum(\"ik, ...k -> ...i\", self.R_to_base_projection.T, rays)\n        return rays\n\ndef get_projection_rays(projection, noise_level=0):\n    \"\"\"\n    Return a 2D map of 3D rays corresponding to the projection.\n    If noise_level > 0, add some jittering noise to these rays.\n    \"\"\"\n    grid_u, grid_v = np.meshgrid(0.5 + np.arange(projection.width), 0.5 + np.arange(projection.height))\n    if noise_level > 0:\n        grid_u += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_u.shape), projection.width)\n        grid_v += np.clip(0, noise_level * np.random.uniform(-0.5, 0.5, size=grid_v.shape), projection.height)\n    return projection.unproject(grid_u, grid_v)\n\ndef compute_camera_intrinsics(height, width, hfov):\n    f = width/2 / np.tan(hfov/2 * np.pi/180)\n    cu, cv = width/2, height/2\n    return f, cu, cv\n\ndef colmap_to_opencv_intrinsics(K):\n    \"\"\"\n    Modify camera intrinsics to follow a different convention.\n    Coordinates of the center of the top-left pixels are by default:\n    - (0.5, 0.5) in Colmap\n    - (0,0) in OpenCV\n    \"\"\"\n    K = K.copy()\n    K[0, 2] -= 0.5\n    K[1, 2] -= 0.5\n    return K\n\ndef opencv_to_colmap_intrinsics(K):\n    \"\"\"\n    Modify camera intrinsics to follow a different convention.\n    Coordinates of the center of the top-left pixels are by default:\n    - (0.5, 0.5) in Colmap\n    - (0,0) in OpenCV\n    \"\"\"\n    K = K.copy()\n    K[0, 2] += 0.5\n    K[1, 2] += 0.5\n    return K"
  },
  {
    "path": "datasets_preprocess/habitat/habitat_renderer/projections_conversions.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Remap data from one projection to an other\n# --------------------------------------------------------\nimport numpy as np\nimport cv2\nfrom habitat_renderer import projections\n\nclass RemapProjection:\n    def __init__(self, input_projection, output_projection, pixel_jittering_iterations=0, jittering_noise_level=0):\n        \"\"\"\n        Some naive random jittering can be introduced in the remapping to mitigate aliasing artecfacts.\n        \"\"\"\n        assert jittering_noise_level >= 0\n        assert pixel_jittering_iterations >= 0\n\n        maps = []\n        # Initial map\n        self.output_rays = projections.get_projection_rays(output_projection)\n        map_u, map_v = input_projection.project(self.output_rays)\n        map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32)\n        maps.append((map_u, map_v))\n\n        for _ in range(pixel_jittering_iterations):\n            # Define multiple mappings using some coordinates jittering to mitigate aliasing effects\n            crop_rays = projections.get_projection_rays(output_projection, jittering_noise_level)\n            map_u, map_v = input_projection.project(crop_rays)\n            map_u, map_v = np.asarray(map_u, dtype=np.float32), np.asarray(map_v, dtype=np.float32)\n            maps.append((map_u, map_v))\n        self.maps = maps\n\n    def convert(self, img, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP, single_map=False):\n        remapped = []\n        for map_u, map_v in self.maps:\n            res = cv2.remap(img, map_u, map_v, interpolation=interpolation, borderMode=borderMode)\n            remapped.append(res)\n            if single_map:\n                break\n        if len(remapped) == 1:\n            res = remapped[0]\n        else:\n            res = np.asarray(np.mean(remapped, axis=0), dtype=img.dtype)\n        return res\n"
  },
  {
    "path": "datasets_preprocess/habitat/preprocess_habitat.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# main executable for preprocessing habitat\n# export METADATA_DIR=\"/path/to/habitat/5views_v1_512x512_metadata\"\n# export SCENES_DIR=\"/path/to/habitat/data/scene_datasets/\"\n# export OUTPUT_DIR=\"data/habitat_processed\"\n# export PYTHONPATH=$(pwd)\n# python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16\n# --------------------------------------------------------\nimport os\nimport glob\nimport json\nimport os\n\nimport PIL.Image\nimport json\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"  # noqa\nimport cv2\nfrom habitat_renderer import multiview_crop_generator\nfrom tqdm import tqdm\n\n\ndef preprocess_metadata(metadata_filename,\n                        scenes_dir,\n                        output_dir,\n                        crop_resolution=[512, 512],\n                        equirectangular_resolution=None,\n                        fix_existing_dataset=False):\n    # Load data\n    with open(metadata_filename, \"r\") as f:\n        metadata = json.load(f)\n\n    if metadata[\"scene_dataset_config_file\"] == \"\":\n        scene = os.path.join(scenes_dir, metadata[\"scene\"])\n        scene_dataset_config_file = \"\"\n    else:\n        scene = metadata[\"scene\"]\n        scene_dataset_config_file = os.path.join(scenes_dir, metadata[\"scene_dataset_config_file\"])\n    navmesh = None\n\n    # Use 4 times the crop size as resolution for rendering the environment map.\n    max_res = max(crop_resolution)\n\n    if equirectangular_resolution == None:\n        # Use 4 times the crop size as resolution for rendering the environment map.\n        max_res = max(crop_resolution)\n        equirectangular_resolution = (4*max_res, 8*max_res)\n\n    print(\"equirectangular_resolution:\", equirectangular_resolution)\n\n    if os.path.exists(output_dir) and not fix_existing_dataset:\n        raise FileExistsError(output_dir)\n\n    # Lazy initialization\n    highres_dataset = None\n\n    for batch_label, batch in tqdm(metadata[\"view_batches\"].items()):\n        for view_label, view_params in batch.items():\n\n            assert view_params[\"size\"] == crop_resolution\n            label = f\"{batch_label}_{view_label}\"\n\n            output_camera_params_filename = os.path.join(output_dir, f\"{label}_camera_params.json\")\n            if fix_existing_dataset and os.path.isfile(output_camera_params_filename):\n                # Skip generation if we are fixing a dataset and the corresponding output file already exists\n                continue\n\n            # Lazy initialization\n            if highres_dataset is None:\n                highres_dataset = multiview_crop_generator.HabitatMultiviewCrops(scene=scene,\n                                                                                 navmesh=navmesh,\n                                                                                 scene_dataset_config_file=scene_dataset_config_file,\n                                                                                 equirectangular_resolution=equirectangular_resolution,\n                                                                                 crop_resolution=crop_resolution,)\n                os.makedirs(output_dir, exist_ok=bool(fix_existing_dataset))\n\n            # Generate a higher resolution crop\n            original_projection, position = multiview_crop_generator.dict_to_perspective_projection(view_params)\n            # Render an envmap at the given position\n            viewpoint_data = highres_dataset.render_viewpoint_data(position)\n\n            projection = original_projection\n            colormap, depthmap, pointmap, _ = highres_dataset.extract_cropped_camera(\n                projection, viewpoint_data.colormap, viewpoint_data.distancemap, viewpoint_data.pointmap)\n\n            camera_params = multiview_crop_generator.perspective_projection_to_dict(projection, position)\n\n            # Color image\n            PIL.Image.fromarray(colormap).save(os.path.join(output_dir, f\"{label}.jpeg\"))\n            # Depth image\n            cv2.imwrite(os.path.join(output_dir, f\"{label}_depth.exr\"),\n                        depthmap, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])\n            # Camera parameters\n            with open(output_camera_params_filename, \"w\") as f:\n                json.dump(camera_params, f)\n\n\nif __name__ == \"__main__\":\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--metadata_dir\", required=True)\n    parser.add_argument(\"--scenes_dir\", required=True)\n    parser.add_argument(\"--output_dir\", required=True)\n    parser.add_argument(\"--metadata_filename\", default=\"\")\n\n    args = parser.parse_args()\n\n    if args.metadata_filename == \"\":\n        # Walk through the metadata dir to generate commandlines\n        for filename in glob.iglob(os.path.join(args.metadata_dir, \"**/metadata.json\"), recursive=True):\n            output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(filename), args.metadata_dir))\n            if not os.path.exists(output_dir):\n                commandline = f\"python {__file__} --metadata_filename={filename} --metadata_dir={args.metadata_dir} --scenes_dir={args.scenes_dir} --output_dir={output_dir}\"\n                print(commandline)\n    else:\n        preprocess_metadata(metadata_filename=args.metadata_filename,\n                            scenes_dir=args.scenes_dir,\n                            output_dir=args.output_dir)\n"
  },
  {
    "path": "datasets_preprocess/path_to_root.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# DUSt3R repo root import\n# --------------------------------------------------------\n\nimport sys\nimport os.path as path\nHERE_PATH = path.normpath(path.dirname(__file__))\nDUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))\n# workaround for sibling import\nsys.path.insert(0, DUST3R_REPO_PATH)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_arkitscenes.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Script to pre-process the arkitscenes dataset.\n# Usage:\n# python3 datasets_preprocess/preprocess_arkitscenes.py --arkitscenes_dir /path/to/arkitscenes --precomputed_pairs /path/to/arkitscenes_pairs\n# --------------------------------------------------------\nimport os\nimport json\nimport os.path as osp\nimport decimal\nimport argparse\nimport math\nfrom bisect import bisect_left\nfrom PIL import Image\nimport numpy as np\nimport quaternion\nfrom scipy import interpolate\nimport cv2\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--arkitscenes_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/arkitscenes_processed')\n    return parser\n\n\ndef value_to_decimal(value, decimal_places):\n    decimal.getcontext().rounding = decimal.ROUND_HALF_UP  # define rounding method\n    return decimal.Decimal(str(float(value))).quantize(decimal.Decimal('1e-{}'.format(decimal_places)))\n\n\ndef closest(value, sorted_list):\n    index = bisect_left(sorted_list, value)\n    if index == 0:\n        return sorted_list[0]\n    elif index == len(sorted_list):\n        return sorted_list[-1]\n    else:\n        value_before = sorted_list[index - 1]\n        value_after = sorted_list[index]\n        if value_after - value < value - value_before:\n            return value_after\n        else:\n            return value_before\n\n\ndef get_up_vectors(pose_device_to_world):\n    return np.matmul(pose_device_to_world, np.array([[0.0], [-1.0], [0.0], [0.0]]))\n\n\ndef get_right_vectors(pose_device_to_world):\n    return np.matmul(pose_device_to_world, np.array([[1.0], [0.0], [0.0], [0.0]]))\n\n\ndef read_traj(traj_path):\n    quaternions = []\n    poses = []\n    timestamps = []\n    poses_p_to_w = []\n    with open(traj_path) as f:\n        traj_lines = f.readlines()\n        for line in traj_lines:\n            tokens = line.split()\n            assert len(tokens) == 7\n            traj_timestamp = float(tokens[0])\n\n            timestamps_decimal_value = value_to_decimal(traj_timestamp, 3)\n            timestamps.append(float(timestamps_decimal_value))  # for spline interpolation\n\n            angle_axis = [float(tokens[1]), float(tokens[2]), float(tokens[3])]\n            r_w_to_p, _ = cv2.Rodrigues(np.asarray(angle_axis))\n            t_w_to_p = np.asarray([float(tokens[4]), float(tokens[5]), float(tokens[6])])\n\n            pose_w_to_p = np.eye(4)\n            pose_w_to_p[:3, :3] = r_w_to_p\n            pose_w_to_p[:3, 3] = t_w_to_p\n\n            pose_p_to_w = np.linalg.inv(pose_w_to_p)\n\n            r_p_to_w_as_quat = quaternion.from_rotation_matrix(pose_p_to_w[:3, :3])\n            t_p_to_w = pose_p_to_w[:3, 3]\n            poses_p_to_w.append(pose_p_to_w)\n            poses.append(t_p_to_w)\n            quaternions.append(r_p_to_w_as_quat)\n    return timestamps, poses, quaternions, poses_p_to_w\n\n\ndef main(rootdir, pairsdir, outdir):\n    os.makedirs(outdir, exist_ok=True)\n\n    subdirs = ['Test', 'Training']\n    for subdir in subdirs:\n        if not osp.isdir(osp.join(rootdir, subdir)):\n            continue\n        # STEP 1: list all scenes\n        outsubdir = osp.join(outdir, subdir)\n        os.makedirs(outsubdir, exist_ok=True)\n        listfile = osp.join(pairsdir, subdir, 'scene_list.json')\n        with open(listfile, 'r') as f:\n            scene_dirs = json.load(f)\n\n        valid_scenes = []\n        for scene_subdir in scene_dirs:\n            out_scene_subdir = osp.join(outsubdir, scene_subdir)\n            os.makedirs(out_scene_subdir, exist_ok=True)\n\n            scene_dir = osp.join(rootdir, subdir, scene_subdir)\n            depth_dir = osp.join(scene_dir, 'lowres_depth')\n            rgb_dir = osp.join(scene_dir, 'vga_wide')\n            intrinsics_dir = osp.join(scene_dir, 'vga_wide_intrinsics')\n            traj_path = osp.join(scene_dir, 'lowres_wide.traj')\n\n            # STEP 2: read selected_pairs.npz\n            selected_pairs_path = osp.join(pairsdir, subdir, scene_subdir, 'selected_pairs.npz')\n            selected_npz = np.load(selected_pairs_path)\n            selection, pairs = selected_npz['selection'], selected_npz['pairs']\n            selected_sky_direction_scene = str(selected_npz['sky_direction_scene'][0])\n            if len(selection) == 0 or len(pairs) == 0:\n                # not a valid scene\n                continue\n            valid_scenes.append(scene_subdir)\n\n            # STEP 3: parse the scene and export the list of valid (K, pose, rgb, depth) and convert images\n            scene_metadata_path = osp.join(out_scene_subdir, 'scene_metadata.npz')\n            if osp.isfile(scene_metadata_path):\n                continue\n            else:\n                print(f'parsing {scene_subdir}')\n                # loads traj\n                timestamps, poses, quaternions, poses_cam_to_world = read_traj(traj_path)\n\n                poses = np.array(poses)\n                quaternions = np.array(quaternions, dtype=np.quaternion)\n                quaternions = quaternion.unflip_rotors(quaternions)\n                timestamps = np.array(timestamps)\n\n                selected_images = [(basename, basename.split(\".png\")[0].split(\"_\")[1]) for basename in selection]\n                timestamps_selected = [float(frame_id) for _, frame_id in selected_images]\n\n                sky_direction_scene, trajectories, intrinsics, images = convert_scene_metadata(scene_subdir,\n                                                                                               intrinsics_dir,\n                                                                                               timestamps,\n                                                                                               quaternions,\n                                                                                               poses,\n                                                                                               poses_cam_to_world,\n                                                                                               selected_images,\n                                                                                               timestamps_selected)\n                assert selected_sky_direction_scene == sky_direction_scene\n\n                os.makedirs(os.path.join(out_scene_subdir, 'vga_wide'), exist_ok=True)\n                os.makedirs(os.path.join(out_scene_subdir, 'lowres_depth'), exist_ok=True)\n                assert isinstance(sky_direction_scene, str)\n                for basename in images:\n                    img_out = os.path.join(out_scene_subdir, 'vga_wide', basename.replace('.png', '.jpg'))\n                    depth_out = os.path.join(out_scene_subdir, 'lowres_depth', basename)\n                    if osp.isfile(img_out) and osp.isfile(depth_out):\n                        continue\n\n                    vga_wide_path = osp.join(rgb_dir, basename)\n                    depth_path = osp.join(depth_dir, basename)\n\n                    img = Image.open(vga_wide_path)\n                    depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)\n\n                    # rotate the image\n                    if sky_direction_scene == 'RIGHT':\n                        try:\n                            img = img.transpose(Image.Transpose.ROTATE_90)\n                        except Exception:\n                            img = img.transpose(Image.ROTATE_90)\n                        depth = cv2.rotate(depth, cv2.ROTATE_90_COUNTERCLOCKWISE)\n                    elif sky_direction_scene == 'LEFT':\n                        try:\n                            img = img.transpose(Image.Transpose.ROTATE_270)\n                        except Exception:\n                            img = img.transpose(Image.ROTATE_270)\n                        depth = cv2.rotate(depth, cv2.ROTATE_90_CLOCKWISE)\n                    elif sky_direction_scene == 'DOWN':\n                        try:\n                            img = img.transpose(Image.Transpose.ROTATE_180)\n                        except Exception:\n                            img = img.transpose(Image.ROTATE_180)\n                        depth = cv2.rotate(depth, cv2.ROTATE_180)\n\n                    W, H = img.size\n                    if not osp.isfile(img_out):\n                        img.save(img_out)\n\n                    depth = cv2.resize(depth, (W, H), interpolation=cv2.INTER_NEAREST_EXACT)\n                    if not osp.isfile(depth_out):  # avoid destroying the base dataset when you mess up the paths\n                        cv2.imwrite(depth_out, depth)\n\n                # save at the end\n                np.savez(scene_metadata_path,\n                         trajectories=trajectories,\n                         intrinsics=intrinsics,\n                         images=images,\n                         pairs=pairs)\n\n        outlistfile = osp.join(outsubdir, 'scene_list.json')\n        with open(outlistfile, 'w') as f:\n            json.dump(valid_scenes, f)\n\n        # STEP 5: concat all scene_metadata.npz into a single file\n        scene_data = {}\n        for scene_subdir in valid_scenes:\n            scene_metadata_path = osp.join(outsubdir, scene_subdir, 'scene_metadata.npz')\n            with np.load(scene_metadata_path) as data:\n                trajectories = data['trajectories']\n                intrinsics = data['intrinsics']\n                images = data['images']\n                pairs = data['pairs']\n            scene_data[scene_subdir] = {'trajectories': trajectories,\n                                        'intrinsics': intrinsics,\n                                        'images': images,\n                                        'pairs': pairs}\n        offset = 0\n        counts = []\n        scenes = []\n        sceneids = []\n        images = []\n        intrinsics = []\n        trajectories = []\n        pairs = []\n        for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()):\n            num_imgs = data['images'].shape[0]\n            img_pairs = data['pairs']\n\n            scenes.append(scene_subdir)\n            sceneids.extend([scene_idx] * num_imgs)\n\n            images.append(data['images'])\n\n            K = np.expand_dims(np.eye(3), 0).repeat(num_imgs, 0)\n            K[:, 0, 0] = [fx for _, _, fx, _, _, _ in data['intrinsics']]\n            K[:, 1, 1] = [fy for _, _, _, fy, _, _ in data['intrinsics']]\n            K[:, 0, 2] = [hw for _, _, _, _, hw, _ in data['intrinsics']]\n            K[:, 1, 2] = [hh for _, _, _, _, _, hh in data['intrinsics']]\n\n            intrinsics.append(K)\n            trajectories.append(data['trajectories'])\n\n            # offset pairs\n            img_pairs[:, 0:2] += offset\n            pairs.append(img_pairs)\n            counts.append(offset)\n\n            offset += num_imgs\n\n        images = np.concatenate(images, axis=0)\n        intrinsics = np.concatenate(intrinsics, axis=0)\n        trajectories = np.concatenate(trajectories, axis=0)\n        pairs = np.concatenate(pairs, axis=0)\n        np.savez(osp.join(outsubdir, 'all_metadata.npz'),\n                 counts=counts,\n                 scenes=scenes,\n                 sceneids=sceneids,\n                 images=images,\n                 intrinsics=intrinsics,\n                 trajectories=trajectories,\n                 pairs=pairs)\n\n\ndef convert_scene_metadata(scene_subdir, intrinsics_dir,\n                           timestamps, quaternions, poses, poses_cam_to_world,\n                           selected_images, timestamps_selected):\n    # find scene orientation\n    sky_direction_scene, rotated_to_cam = find_scene_orientation(poses_cam_to_world)\n\n    # find/compute pose for selected timestamps\n    # most images have a valid timestamp / exact pose associated\n    timestamps_selected = np.array(timestamps_selected)\n    spline = interpolate.interp1d(timestamps, poses, kind='linear', axis=0)\n    interpolated_rotations = quaternion.squad(quaternions, timestamps, timestamps_selected)\n    interpolated_positions = spline(timestamps_selected)\n\n    trajectories = []\n    intrinsics = []\n    images = []\n    for i, (basename, frame_id) in enumerate(selected_images):\n        intrinsic_fn = osp.join(intrinsics_dir, f\"{scene_subdir}_{frame_id}.pincam\")\n        if not osp.exists(intrinsic_fn):\n            intrinsic_fn = osp.join(intrinsics_dir, f\"{scene_subdir}_{float(frame_id) - 0.001:.3f}.pincam\")\n        if not osp.exists(intrinsic_fn):\n            intrinsic_fn = osp.join(intrinsics_dir, f\"{scene_subdir}_{float(frame_id) + 0.001:.3f}.pincam\")\n        assert osp.exists(intrinsic_fn)\n        w, h, fx, fy, hw, hh = np.loadtxt(intrinsic_fn)  # PINHOLE\n\n        pose = np.eye(4)\n        pose[:3, :3] = quaternion.as_rotation_matrix(interpolated_rotations[i])\n        pose[:3, 3] = interpolated_positions[i]\n\n        images.append(basename)\n        if sky_direction_scene == 'RIGHT' or sky_direction_scene == 'LEFT':\n            intrinsics.append([h, w, fy, fx, hh, hw])  # swapped intrinsics\n        else:\n            intrinsics.append([w, h, fx, fy, hw, hh])\n        trajectories.append(pose  @ rotated_to_cam)  # pose_cam_to_world @ rotated_to_cam = rotated(cam) to world\n\n    return sky_direction_scene, trajectories, intrinsics, images\n\n\ndef find_scene_orientation(poses_cam_to_world):\n    if len(poses_cam_to_world) > 0:\n        up_vector = sum(get_up_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world)\n        right_vector = sum(get_right_vectors(p) for p in poses_cam_to_world) / len(poses_cam_to_world)\n        up_world = np.array([[0.0], [0.0], [1.0], [0.0]])\n    else:\n        up_vector = np.array([[0.0], [-1.0], [0.0], [0.0]])\n        right_vector = np.array([[1.0], [0.0], [0.0], [0.0]])\n        up_world = np.array([[0.0], [0.0], [1.0], [0.0]])\n\n    # value between 0, 180\n    device_up_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world),\n                                                           up_vector), -1.0, 1.0)).item() * 180.0 / np.pi\n    device_right_to_world_up_angle = np.arccos(np.clip(np.dot(np.transpose(up_world),\n                                                              right_vector), -1.0, 1.0)).item() * 180.0 / np.pi\n\n    up_closest_to_90 = abs(device_up_to_world_up_angle - 90.0) < abs(device_right_to_world_up_angle - 90.0)\n    if up_closest_to_90:\n        assert abs(device_up_to_world_up_angle - 90.0) < 45.0\n        # LEFT\n        if device_right_to_world_up_angle > 90.0:\n            sky_direction_scene = 'LEFT'\n            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi / 2.0])\n        else:\n            # note that in metadata.csv RIGHT does not exist, but again it's not accurate...\n            # well, turns out there are scenes oriented like this\n            # for example Training/41124801\n            sky_direction_scene = 'RIGHT'\n            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, -math.pi / 2.0])\n    else:\n        # right is close to 90\n        assert abs(device_right_to_world_up_angle - 90.0) < 45.0\n        if device_up_to_world_up_angle > 90.0:\n            sky_direction_scene = 'DOWN'\n            cam_to_rotated_q = quaternion.from_rotation_vector([0.0, 0.0, math.pi])\n        else:\n            sky_direction_scene = 'UP'\n            cam_to_rotated_q = quaternion.quaternion(1, 0, 0, 0)\n    cam_to_rotated = np.eye(4)\n    cam_to_rotated[:3, :3] = quaternion.as_rotation_matrix(cam_to_rotated_q)\n    rotated_to_cam = np.linalg.inv(cam_to_rotated)\n    return sky_direction_scene, rotated_to_cam\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    main(args.arkitscenes_dir, args.precomputed_pairs, args.output_dir)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_blendedMVS.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Preprocessing code for the BlendedMVS dataset\n# dataset at https://github.com/YoYo000/BlendedMVS\n# 1) Download BlendedMVS.zip\n# 2) Download BlendedMVS+.zip\n# 3) Download BlendedMVS++.zip\n# 4) Unzip everything in the same /path/to/tmp/blendedMVS/ directory\n# 5) python datasets_preprocess/preprocess_blendedMVS.py --blendedmvs_dir /path/to/tmp/blendedMVS/\n# --------------------------------------------------------\nimport os\nimport os.path as osp\nimport re\nfrom tqdm import tqdm\nimport numpy as np\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2\n\nimport path_to_root  # noqa\nfrom dust3r.utils.parallel import parallel_threads\nfrom dust3r.datasets.utils import cropping  # noqa\n\n\ndef get_parser():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--blendedmvs_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/blendedmvs_processed')\n    return parser\n\n\ndef main(db_root, pairs_path, output_dir):\n    print('>> Listing all sequences')\n    sequences = [f for f in os.listdir(db_root) if len(f) == 24]\n    # should find 502 scenes\n    assert sequences, f'did not found any sequences at {db_root}'\n    print(f'   (found {len(sequences)} sequences)')\n\n    for i, seq in enumerate(tqdm(sequences)):\n        out_dir = osp.join(output_dir, seq)\n        os.makedirs(out_dir, exist_ok=True)\n\n        # generate the crops\n        root = osp.join(db_root, seq)\n        cam_dir = osp.join(root, 'cams')\n        func_args = [(root, f[:-8], out_dir) for f in os.listdir(cam_dir) if not f.startswith('pair')]\n        parallel_threads(load_crop_and_save, func_args, star_args=True, leave=False)\n\n    # verify that all pairs are there\n    pairs = np.load(pairs_path)\n    for seqh, seql, img1, img2, score in tqdm(pairs):\n        for view_index in [img1, img2]:\n            impath = osp.join(output_dir, f\"{seqh:08x}{seql:016x}\", f\"{view_index:08n}.jpg\")\n            assert osp.isfile(impath), f'missing image at {impath=}'\n\n    print(f'>> Done, saved everything in {output_dir}/')\n\n\ndef load_crop_and_save(root, img, out_dir):\n    if osp.isfile(osp.join(out_dir, img + '.npz')):\n        return  # already done\n\n    # load everything\n    intrinsics_in, R_camin2world, t_camin2world = _load_pose(osp.join(root, 'cams', img + '_cam.txt'))\n    color_image_in = cv2.cvtColor(cv2.imread(osp.join(root, 'blended_images', img +\n                                  '.jpg'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)\n    depthmap_in = load_pfm_file(osp.join(root, 'rendered_depth_maps', img + '.pfm'))\n\n    # do the crop\n    H, W = color_image_in.shape[:2]\n    assert H * 4 == W * 3\n    image, depthmap, intrinsics_out, R_in2out = _crop_image(intrinsics_in, color_image_in, depthmap_in, (512, 384))\n\n    # write everything\n    image.save(osp.join(out_dir, img + '.jpg'), quality=80)\n    cv2.imwrite(osp.join(out_dir, img + '.exr'), depthmap)\n\n    # New camera parameters\n    R_camout2world = R_camin2world @ R_in2out.T\n    t_camout2world = t_camin2world\n    np.savez(osp.join(out_dir, img + '.npz'), intrinsics=intrinsics_out,\n             R_cam2world=R_camout2world, t_cam2world=t_camout2world)\n\n\ndef _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(800, 800)):\n    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(\n        color_image_in, depthmap_in, intrinsics_in, resolution_out)\n    R_in2out = np.eye(3)\n    return image, depthmap, intrinsics_out, R_in2out\n\n\ndef _load_pose(path, ret_44=False):\n    f = open(path)\n    RT = np.loadtxt(f, skiprows=1, max_rows=4, dtype=np.float32)\n    assert RT.shape == (4, 4)\n    RT = np.linalg.inv(RT)  # world2cam to cam2world\n\n    K = np.loadtxt(f, skiprows=2, max_rows=3, dtype=np.float32)\n    assert K.shape == (3, 3)\n\n    if ret_44:\n        return K, RT\n    return K, RT[:3, :3], RT[:3, 3]  # , depth_uint8_to_f32\n\n\ndef load_pfm_file(file_path):\n    with open(file_path, 'rb') as file:\n        header = file.readline().decode('UTF-8').strip()\n\n        if header == 'PF':\n            is_color = True\n        elif header == 'Pf':\n            is_color = False\n        else:\n            raise ValueError('The provided file is not a valid PFM file.')\n\n        dimensions = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('UTF-8'))\n        if dimensions:\n            img_width, img_height = map(int, dimensions.groups())\n        else:\n            raise ValueError('Invalid PFM header format.')\n\n        endian_scale = float(file.readline().decode('UTF-8').strip())\n        if endian_scale < 0:\n            dtype = '<f'  # little-endian\n        else:\n            dtype = '>f'  # big-endian\n\n        data_buffer = file.read()\n        img_data = np.frombuffer(data_buffer, dtype=dtype)\n\n        if is_color:\n            img_data = np.reshape(img_data, (img_height, img_width, 3))\n        else:\n            img_data = np.reshape(img_data, (img_height, img_width))\n\n        img_data = cv2.flip(img_data, 0)\n\n    return img_data\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    main(args.blendedmvs_dir, args.precomputed_pairs, args.output_dir)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_co3d.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Script to pre-process the CO3D dataset.\n# Usage:\n# python3 datasets_preprocess/preprocess_co3d.py --co3d_dir /path/to/co3d\n# --------------------------------------------------------\n\nimport argparse\nimport random\nimport gzip\nimport json\nimport os\nimport os.path as osp\n\nimport torch\nimport PIL.Image\nimport numpy as np\nimport cv2\n\nfrom tqdm.auto import tqdm\nimport matplotlib.pyplot as plt\n\nimport path_to_root  # noqa\nimport dust3r.datasets.utils.cropping as cropping  # noqa\n\n\nCATEGORIES = [\n    \"apple\", \"backpack\", \"ball\", \"banana\", \"baseballbat\", \"baseballglove\",\n    \"bench\", \"bicycle\", \"book\", \"bottle\", \"bowl\", \"broccoli\", \"cake\", \"car\", \"carrot\",\n    \"cellphone\", \"chair\", \"couch\", \"cup\", \"donut\", \"frisbee\", \"hairdryer\", \"handbag\",\n    \"hotdog\", \"hydrant\", \"keyboard\", \"kite\", \"laptop\", \"microwave\",\n    \"motorcycle\",\n    \"mouse\", \"orange\", \"parkingmeter\", \"pizza\", \"plant\", \"remote\", \"sandwich\",\n    \"skateboard\", \"stopsign\",\n    \"suitcase\", \"teddybear\", \"toaster\", \"toilet\", \"toybus\",\n    \"toyplane\", \"toytrain\", \"toytruck\", \"tv\",\n    \"umbrella\", \"vase\", \"wineglass\",\n]\nCATEGORIES_IDX = {cat: i for i, cat in enumerate(CATEGORIES)}  # for seeding\n\nSINGLE_SEQUENCE_CATEGORIES = sorted(set(CATEGORIES) - set([\"microwave\", \"stopsign\", \"tv\"]))\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--category\", type=str, default=None)\n    parser.add_argument('--single_sequence_subset', default=False, action='store_true',\n                        help=\"prepare the single_sequence_subset instead.\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"data/co3d_processed\")\n    parser.add_argument(\"--co3d_dir\", type=str, required=True)\n    parser.add_argument(\"--num_sequences_per_object\", type=int, default=50)\n    parser.add_argument(\"--seed\", type=int, default=42)\n    parser.add_argument(\"--min_quality\", type=float, default=0.5, help=\"Minimum viewpoint quality score.\")\n\n    parser.add_argument(\"--img_size\", type=int, default=512,\n                        help=(\"lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size\"))\n    return parser\n\n\ndef convert_ndc_to_pinhole(focal_length, principal_point, image_size):\n    focal_length = np.array(focal_length)\n    principal_point = np.array(principal_point)\n    image_size_wh = np.array([image_size[1], image_size[0]])\n    half_image_size = image_size_wh / 2\n    rescale = half_image_size.min()\n    principal_point_px = half_image_size - principal_point * rescale\n    focal_length_px = focal_length * rescale\n    fx, fy = focal_length_px[0], focal_length_px[1]\n    cx, cy = principal_point_px[0], principal_point_px[1]\n    K = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=np.float32)\n    return K\n\n\ndef opencv_from_cameras_projection(R, T, focal, p0, image_size):\n    R = torch.from_numpy(R)[None, :, :]\n    T = torch.from_numpy(T)[None, :]\n    focal = torch.from_numpy(focal)[None, :]\n    p0 = torch.from_numpy(p0)[None, :]\n    image_size = torch.from_numpy(image_size)[None, :]\n\n    R_pytorch3d = R.clone()\n    T_pytorch3d = T.clone()\n    focal_pytorch3d = focal\n    p0_pytorch3d = p0\n    T_pytorch3d[:, :2] *= -1\n    R_pytorch3d[:, :, :2] *= -1\n    tvec = T_pytorch3d\n    R = R_pytorch3d.permute(0, 2, 1)\n\n    # Retype the image_size correctly and flip to width, height.\n    image_size_wh = image_size.to(R).flip(dims=(1,))\n\n    # NDC to screen conversion.\n    scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0\n    scale = scale.expand(-1, 2)\n    c0 = image_size_wh / 2.0\n\n    principal_point = -p0_pytorch3d * scale + c0\n    focal_length = focal_pytorch3d * scale\n\n    camera_matrix = torch.zeros_like(R)\n    camera_matrix[:, :2, 2] = principal_point\n    camera_matrix[:, 2, 2] = 1.0\n    camera_matrix[:, 0, 0] = focal_length[:, 0]\n    camera_matrix[:, 1, 1] = focal_length[:, 1]\n    return R[0], tvec[0], camera_matrix[0]\n\n\ndef get_set_list(category_dir, split, is_single_sequence_subset=False):\n    listfiles = os.listdir(osp.join(category_dir, \"set_lists\"))\n    if is_single_sequence_subset:\n        # not all objects have manyview_dev\n        subset_list_files = [f for f in listfiles if \"manyview_dev\" in f]\n    else:\n        subset_list_files = [f for f in listfiles if f\"fewview_train\" in f]\n\n    sequences_all = []\n    for subset_list_file in subset_list_files:\n        with open(osp.join(category_dir, \"set_lists\", subset_list_file)) as f:\n            subset_lists_data = json.load(f)\n            sequences_all.extend(subset_lists_data[split])\n\n    return sequences_all\n\n\ndef prepare_sequences(category, co3d_dir, output_dir, img_size, split, min_quality, max_num_sequences_per_object,\n                      seed, is_single_sequence_subset=False):\n    random.seed(seed)\n    category_dir = osp.join(co3d_dir, category)\n    category_output_dir = osp.join(output_dir, category)\n    sequences_all = get_set_list(category_dir, split, is_single_sequence_subset)\n    sequences_numbers = sorted(set(seq_name for seq_name, _, _ in sequences_all))\n\n    frame_file = osp.join(category_dir, \"frame_annotations.jgz\")\n    sequence_file = osp.join(category_dir, \"sequence_annotations.jgz\")\n\n    with gzip.open(frame_file, \"r\") as fin:\n        frame_data = json.loads(fin.read())\n    with gzip.open(sequence_file, \"r\") as fin:\n        sequence_data = json.loads(fin.read())\n\n    frame_data_processed = {}\n    for f_data in frame_data:\n        sequence_name = f_data[\"sequence_name\"]\n        frame_data_processed.setdefault(sequence_name, {})[f_data[\"frame_number\"]] = f_data\n\n    good_quality_sequences = set()\n    for seq_data in sequence_data:\n        if seq_data[\"viewpoint_quality_score\"] > min_quality:\n            good_quality_sequences.add(seq_data[\"sequence_name\"])\n\n    sequences_numbers = [seq_name for seq_name in sequences_numbers if seq_name in good_quality_sequences]\n    if len(sequences_numbers) < max_num_sequences_per_object:\n        selected_sequences_numbers = sequences_numbers\n    else:\n        selected_sequences_numbers = random.sample(sequences_numbers, max_num_sequences_per_object)\n\n    selected_sequences_numbers_dict = {seq_name: [] for seq_name in selected_sequences_numbers}\n    sequences_all = [(seq_name, frame_number, filepath)\n                     for seq_name, frame_number, filepath in sequences_all\n                     if seq_name in selected_sequences_numbers_dict]\n\n    for seq_name, frame_number, filepath in tqdm(sequences_all):\n        frame_idx = int(filepath.split('/')[-1][5:-4])\n        selected_sequences_numbers_dict[seq_name].append(frame_idx)\n        mask_path = filepath.replace(\"images\", \"masks\").replace(\".jpg\", \".png\")\n        frame_data = frame_data_processed[seq_name][frame_number]\n        focal_length = frame_data[\"viewpoint\"][\"focal_length\"]\n        principal_point = frame_data[\"viewpoint\"][\"principal_point\"]\n        image_size = frame_data[\"image\"][\"size\"]\n        K = convert_ndc_to_pinhole(focal_length, principal_point, image_size)\n        R, tvec, camera_intrinsics = opencv_from_cameras_projection(np.array(frame_data[\"viewpoint\"][\"R\"]),\n                                                                    np.array(frame_data[\"viewpoint\"][\"T\"]),\n                                                                    np.array(focal_length),\n                                                                    np.array(principal_point),\n                                                                    np.array(image_size))\n\n        frame_data = frame_data_processed[seq_name][frame_number]\n        depth_path = os.path.join(co3d_dir, frame_data[\"depth\"][\"path\"])\n        assert frame_data[\"depth\"][\"scale_adjustment\"] == 1.0\n        image_path = os.path.join(co3d_dir, filepath)\n        mask_path_full = os.path.join(co3d_dir, mask_path)\n\n        input_rgb_image = PIL.Image.open(image_path).convert('RGB')\n        input_mask = plt.imread(mask_path_full)\n\n        with PIL.Image.open(depth_path) as depth_pil:\n            # the image is stored with 16-bit depth but PIL reads it as I (32 bit).\n            # we cast it to uint16, then reinterpret as float16, then cast to float32\n            input_depthmap = (\n                np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)\n                .astype(np.float32)\n                .reshape((depth_pil.size[1], depth_pil.size[0])))\n        depth_mask = np.stack((input_depthmap, input_mask), axis=-1)\n        H, W = input_depthmap.shape\n\n        camera_intrinsics = camera_intrinsics.numpy()\n        cx, cy = camera_intrinsics[:2, 2].round().astype(int)\n        min_margin_x = min(cx, W - cx)\n        min_margin_y = min(cy, H - cy)\n\n        # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)\n        l, t = cx - min_margin_x, cy - min_margin_y\n        r, b = cx + min_margin_x, cy + min_margin_y\n        crop_bbox = (l, t, r, b)\n        input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(\n            input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)\n\n        # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384\n        scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8\n        output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)\n        if max(output_resolution) < img_size:\n            # let's put the max dimension to img_size\n            scale_final = (img_size / max(H, W)) + 1e-8\n            output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)\n\n        input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(\n            input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)\n        input_depthmap = depth_mask[:, :, 0]\n        input_mask = depth_mask[:, :, 1]\n\n        # generate and adjust camera pose\n        camera_pose = np.eye(4, dtype=np.float32)\n        camera_pose[:3, :3] = R\n        camera_pose[:3, 3] = tvec\n        camera_pose = np.linalg.inv(camera_pose)\n\n        # save crop images and depth, metadata\n        save_img_path = os.path.join(output_dir, filepath)\n        save_depth_path = os.path.join(output_dir, frame_data[\"depth\"][\"path\"])\n        save_mask_path = os.path.join(output_dir, mask_path)\n        os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)\n        os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)\n        os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)\n\n        input_rgb_image.save(save_img_path)\n        scaled_depth_map = (input_depthmap / np.max(input_depthmap) * 65535).astype(np.uint16)\n        cv2.imwrite(save_depth_path, scaled_depth_map)\n        cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))\n\n        save_meta_path = save_img_path.replace('jpg', 'npz')\n        np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,\n                 camera_pose=camera_pose, maximum_depth=np.max(input_depthmap))\n\n    return selected_sequences_numbers_dict\n\n\nif __name__ == \"__main__\":\n    parser = get_parser()\n    args = parser.parse_args()\n    assert args.co3d_dir != args.output_dir\n    if args.category is None:\n        if args.single_sequence_subset:\n            categories = SINGLE_SEQUENCE_CATEGORIES\n        else:\n            categories = CATEGORIES\n    else:\n        categories = [args.category]\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    for split in ['train', 'test']:\n        selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')\n        if os.path.isfile(selected_sequences_path):\n            continue\n\n        all_selected_sequences = {}\n        for category in categories:\n            category_output_dir = osp.join(args.output_dir, category)\n            os.makedirs(category_output_dir, exist_ok=True)\n            category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')\n            if os.path.isfile(category_selected_sequences_path):\n                with open(category_selected_sequences_path, 'r') as fid:\n                    category_selected_sequences = json.load(fid)\n            else:\n                print(f\"Processing {split} - category = {category}\")\n                category_selected_sequences = prepare_sequences(\n                    category=category,\n                    co3d_dir=args.co3d_dir,\n                    output_dir=args.output_dir,\n                    img_size=args.img_size,\n                    split=split,\n                    min_quality=args.min_quality,\n                    max_num_sequences_per_object=args.num_sequences_per_object,\n                    seed=args.seed + CATEGORIES_IDX[category],\n                    is_single_sequence_subset=args.single_sequence_subset\n                )\n                with open(category_selected_sequences_path, 'w') as file:\n                    json.dump(category_selected_sequences, file)\n\n            all_selected_sequences[category] = category_selected_sequences\n        with open(selected_sequences_path, 'w') as file:\n            json.dump(all_selected_sequences, file)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_megadepth.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Preprocessing code for the MegaDepth dataset\n# dataset at https://www.cs.cornell.edu/projects/megadepth/\n# --------------------------------------------------------\nimport os\nimport os.path as osp\nimport collections\nfrom tqdm import tqdm\nimport numpy as np\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2\nimport h5py\n\nimport path_to_root  # noqa\nfrom dust3r.utils.parallel import parallel_threads\nfrom dust3r.datasets.utils import cropping  # noqa\n\n\ndef get_parser():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--megadepth_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/megadepth_processed')\n    return parser\n\n\ndef main(db_root, pairs_path, output_dir):\n    os.makedirs(output_dir, exist_ok=True)\n\n    # load all pairs\n    data = np.load(pairs_path, allow_pickle=True)\n    scenes = data['scenes']\n    images = data['images']\n    pairs = data['pairs']\n\n    # enumerate all unique images\n    todo = collections.defaultdict(set)\n    for scene, im1, im2, score in pairs:\n        todo[scene].add(im1)\n        todo[scene].add(im2)\n\n    # for each scene, load intrinsics and then parallel crops\n    for scene, im_idxs in tqdm(todo.items(), desc='Overall'):\n        scene, subscene = scenes[scene].split()\n        out_dir = osp.join(output_dir, scene, subscene)\n        os.makedirs(out_dir, exist_ok=True)\n\n        # load all camera params\n        _, pose_w2cam, intrinsics = _load_kpts_and_poses(db_root, scene, subscene, intrinsics=True)\n\n        in_dir = osp.join(db_root, scene, 'dense' + subscene)\n        args = [(in_dir, img, intrinsics[img], pose_w2cam[img], out_dir)\n                for img in [images[im_id] for im_id in im_idxs]]\n        parallel_threads(resize_one_image, args, star_args=True, front_num=0, leave=False, desc=f'{scene}/{subscene}')\n\n    # save pairs\n    print('Done! prepared all pairs in', output_dir)\n\n\ndef resize_one_image(root, tag, K_pre_rectif, pose_w2cam, out_dir):\n    if osp.isfile(osp.join(out_dir, tag + '.npz')):\n        return\n\n    # load image\n    img = cv2.cvtColor(cv2.imread(osp.join(root, 'imgs', tag), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)\n    H, W = img.shape[:2]\n\n    # load depth\n    with h5py.File(osp.join(root, 'depths', osp.splitext(tag)[0] + '.h5'), 'r') as hd5:\n        depthmap = np.asarray(hd5['depth'])\n\n    # rectify = undistort the intrinsics\n    imsize_pre, K_pre, distortion = K_pre_rectif\n    imsize_post = img.shape[1::-1]\n    K_post = cv2.getOptimalNewCameraMatrix(K_pre, distortion, imsize_pre, alpha=0,\n                                           newImgSize=imsize_post, centerPrincipalPoint=True)[0]\n\n    # downscale\n    img_out, depthmap_out, intrinsics_out, R_in2out = _downscale_image(K_post, img, depthmap, resolution_out=(800, 600))\n\n    # write everything\n    img_out.save(osp.join(out_dir, tag + '.jpg'), quality=90)\n    cv2.imwrite(osp.join(out_dir, tag + '.exr'), depthmap_out)\n\n    camout2world = np.linalg.inv(pose_w2cam)\n    camout2world[:3, :3] = camout2world[:3, :3] @ R_in2out.T\n    np.savez(osp.join(out_dir, tag + '.npz'), intrinsics=intrinsics_out, cam2world=camout2world)\n\n\ndef _downscale_image(camera_intrinsics, image, depthmap, resolution_out=(512, 384)):\n    H, W = image.shape[:2]\n    resolution_out = sorted(resolution_out)[::+1 if W < H else -1]\n\n    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(\n        image, depthmap, camera_intrinsics, resolution_out, force=False)\n    R_in2out = np.eye(3)\n\n    return image, depthmap, intrinsics_out, R_in2out\n\n\ndef _load_kpts_and_poses(root, scene_id, subscene, z_only=False, intrinsics=False):\n    if intrinsics:\n        with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'cameras.txt'), 'r') as f:\n            raw = f.readlines()[3:]  # skip the header\n\n        camera_intrinsics = {}\n        for camera in raw:\n            camera = camera.split(' ')\n            width, height, focal, cx, cy, k0 = [float(elem) for elem in camera[2:]]\n            K = np.eye(3)\n            K[0, 0] = focal\n            K[1, 1] = focal\n            K[0, 2] = cx\n            K[1, 2] = cy\n            camera_intrinsics[int(camera[0])] = ((int(width), int(height)), K, (k0, 0, 0, 0))\n\n    with open(os.path.join(root, scene_id, 'sparse', 'manhattan', subscene, 'images.txt'), 'r') as f:\n        raw = f.read().splitlines()[4:]  # skip the header\n\n    extract_pose = colmap_raw_pose_to_principal_axis if z_only else colmap_raw_pose_to_RT\n\n    poses = {}\n    points3D_idxs = {}\n    camera = []\n\n    for image, points in zip(raw[:: 2], raw[1:: 2]):\n        image = image.split(' ')\n        points = points.split(' ')\n\n        image_id = image[-1]\n        camera.append(int(image[-2]))\n\n        # find the principal axis\n        raw_pose = [float(elem) for elem in image[1: -2]]\n        poses[image_id] = extract_pose(raw_pose)\n\n        current_points3D_idxs = {int(i) for i in points[2:: 3] if i != '-1'}\n        assert -1 not in current_points3D_idxs, bb()\n        points3D_idxs[image_id] = current_points3D_idxs\n\n    if intrinsics:\n        image_intrinsics = {im_id: camera_intrinsics[cam] for im_id, cam in zip(poses, camera)}\n        return points3D_idxs, poses, image_intrinsics\n    else:\n        return points3D_idxs, poses\n\n\ndef colmap_raw_pose_to_principal_axis(image_pose):\n    qvec = image_pose[: 4]\n    qvec = qvec / np.linalg.norm(qvec)\n    w, x, y, z = qvec\n    z_axis = np.float32([\n        2 * x * z - 2 * y * w,\n        2 * y * z + 2 * x * w,\n        1 - 2 * x * x - 2 * y * y\n    ])\n    return z_axis\n\n\ndef colmap_raw_pose_to_RT(image_pose):\n    qvec = image_pose[: 4]\n    qvec = qvec / np.linalg.norm(qvec)\n    w, x, y, z = qvec\n    R = np.array([\n        [\n            1 - 2 * y * y - 2 * z * z,\n            2 * x * y - 2 * z * w,\n            2 * x * z + 2 * y * w\n        ],\n        [\n            2 * x * y + 2 * z * w,\n            1 - 2 * x * x - 2 * z * z,\n            2 * y * z - 2 * x * w\n        ],\n        [\n            2 * x * z - 2 * y * w,\n            2 * y * z + 2 * x * w,\n            1 - 2 * x * x - 2 * y * y\n        ]\n    ])\n    # principal_axis.append(R[2, :])\n    t = image_pose[4: 7]\n    # World-to-Camera pose\n    current_pose = np.eye(4)\n    current_pose[: 3, : 3] = R\n    current_pose[: 3, 3] = t\n    return current_pose\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    main(args.megadepth_dir, args.precomputed_pairs, args.output_dir)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_scannetpp.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Script to pre-process the scannet++ dataset.\n# Usage:\n# python3 datasets_preprocess/preprocess_scannetpp.py --scannetpp_dir /path/to/scannetpp --precomputed_pairs /path/to/scannetpp_pairs --pyopengl-platform egl\n# --------------------------------------------------------\nimport os\nimport argparse\nimport os.path as osp\nimport re\nfrom tqdm import tqdm\nimport json\nfrom scipy.spatial.transform import Rotation\nimport pyrender\nimport trimesh\nimport trimesh.exchange.ply\nimport numpy as np\nimport cv2\nimport PIL.Image as Image\n\nfrom dust3r.datasets.utils.cropping import rescale_image_depthmap\nimport dust3r.utils.geometry as geometry\n\ninv = np.linalg.inv\nnorm = np.linalg.norm\nREGEXPR_DSLR = re.compile(r'^.*DSC(?P<frameid>\\d+).JPG$')\nREGEXPR_IPHONE = re.compile(r'.*frame_(?P<frameid>\\d+).jpg$')\n\nDEBUG_VIZ = None  # 'iou'\nif DEBUG_VIZ is not None:\n    import matplotlib.pyplot as plt  # noqa\n\n\nOPENGL_TO_OPENCV = np.float32([[1, 0, 0, 0],\n                               [0, -1, 0, 0],\n                               [0, 0, -1, 0],\n                               [0, 0, 0, 1]])\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--scannetpp_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/scannetpp_processed')\n    parser.add_argument('--target_resolution', default=920, type=int, help=\"images resolution\")\n    parser.add_argument('--pyopengl-platform', type=str, default='', help='PyOpenGL env variable')\n    return parser\n\n\ndef pose_from_qwxyz_txyz(elems):\n    qw, qx, qy, qz, tx, ty, tz = map(float, elems)\n    pose = np.eye(4)\n    pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix()\n    pose[:3, 3] = (tx, ty, tz)\n    return np.linalg.inv(pose)  # returns cam2world\n\n\ndef get_frame_number(name, cam_type='dslr'):\n    if cam_type == 'dslr':\n        regex_expr = REGEXPR_DSLR\n    elif cam_type == 'iphone':\n        regex_expr = REGEXPR_IPHONE\n    else:\n        raise NotImplementedError(f'wrong {cam_type=} for get_frame_number')\n    try:\n        matches = re.match(regex_expr, name)\n        return matches['frameid']\n    except Exception as e:\n        print(f'Error when parsing {name}')\n        raise ValueError(f'Invalid name {name}')\n\n\ndef load_sfm(sfm_dir, cam_type='dslr'):\n    # load cameras\n    with open(osp.join(sfm_dir, 'cameras.txt'), 'r') as f:\n        raw = f.read().splitlines()[3:]  # skip header\n\n    intrinsics = {}\n    for camera in tqdm(raw, position=1, leave=False):\n        camera = camera.split(' ')\n        intrinsics[int(camera[0])] = [camera[1]] + [float(cam) for cam in camera[2:]]\n\n    # load images\n    with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f:\n        raw = f.read().splitlines()\n        raw = [line for line in raw if not line.startswith('#')]  # skip header\n\n    img_idx = {}\n    img_infos = {}\n    for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2, position=1, leave=False):\n        image = image.split(' ')\n        points = points.split(' ')\n\n        idx = image[0]\n        img_name = image[-1]\n        prefixes = ['iphone/', 'video/']\n        for prefix in prefixes:\n            if img_name.startswith(prefix):\n                img_name = img_name[len(prefix):]\n        assert img_name not in img_idx, 'duplicate db image: ' + img_name\n        img_idx[img_name] = idx  # register image name\n\n        current_points2D = {int(i): (float(x), float(y))\n                            for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'}\n        img_infos[idx] = dict(intrinsics=intrinsics[int(image[-2])],\n                              path=img_name,\n                              frame_id=get_frame_number(img_name, cam_type),\n                              cam_to_world=pose_from_qwxyz_txyz(image[1: -2]),\n                              sparse_pts2d=current_points2D)\n\n    # load 3D points\n    with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f:\n        raw = f.read().splitlines()\n        raw = [line for line in raw if not line.startswith('#')]  # skip header\n\n    points3D = {}\n    observations = {idx: [] for idx in img_infos.keys()}\n    for point in tqdm(raw, position=1, leave=False):\n        point = point.split()\n        point_3d_idx = int(point[0])\n        points3D[point_3d_idx] = tuple(map(float, point[1:4]))\n        if len(point) > 8:\n            for idx, point_2d_idx in zip(point[8::2], point[9::2]):\n                if idx not in observations:\n                    continue\n                observations[idx].append((point_3d_idx, int(point_2d_idx)))\n\n    return img_idx, img_infos, points3D, observations\n\n\ndef subsample_img_infos(img_infos, num_images, allowed_name_subset=None):\n    img_infos_val = [(idx, val) for idx, val in img_infos.items()]\n    if allowed_name_subset is not None:\n        img_infos_val = [(idx, val) for idx, val in img_infos_val if val['path'] in allowed_name_subset]\n\n    if len(img_infos_val) > num_images:\n        img_infos_val = sorted(img_infos_val, key=lambda x: x[1]['frame_id'])\n        kept_idx = np.round(np.linspace(0, len(img_infos_val) - 1, num_images)).astype(int).tolist()\n        img_infos_val = [img_infos_val[idx] for idx in kept_idx]\n    return {idx: val for idx, val in img_infos_val}\n\n\ndef undistort_images(intrinsics, rgb, mask):\n    camera_type = intrinsics[0]\n\n    width = int(intrinsics[1])\n    height = int(intrinsics[2])\n    fx = intrinsics[3]\n    fy = intrinsics[4]\n    cx = intrinsics[5]\n    cy = intrinsics[6]\n    distortion = np.array(intrinsics[7:])\n\n    K = np.zeros([3, 3])\n    K[0, 0] = fx\n    K[0, 2] = cx\n    K[1, 1] = fy\n    K[1, 2] = cy\n    K[2, 2] = 1\n\n    K = geometry.colmap_to_opencv_intrinsics(K)\n    if camera_type == \"OPENCV_FISHEYE\":\n        assert len(distortion) == 4\n\n        new_K = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(\n            K,\n            distortion,\n            (width, height),\n            np.eye(3),\n            balance=0.0,\n        )\n        # Make the cx and cy to be the center of the image\n        new_K[0, 2] = width / 2.0\n        new_K[1, 2] = height / 2.0\n\n        map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1)\n    else:\n        new_K, _ = cv2.getOptimalNewCameraMatrix(K, distortion, (width, height), 1, (width, height), True)\n        map1, map2 = cv2.initUndistortRectifyMap(K, distortion, np.eye(3), new_K, (width, height), cv2.CV_32FC1)\n\n    undistorted_image = cv2.remap(rgb, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)\n    undistorted_mask = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_LINEAR,\n                                 borderMode=cv2.BORDER_CONSTANT, borderValue=255)\n    new_K = geometry.opencv_to_colmap_intrinsics(new_K)\n    return width, height, new_K, undistorted_image, undistorted_mask\n\n\ndef process_scenes(root, pairsdir, output_dir, target_resolution):\n    os.makedirs(output_dir, exist_ok=True)\n\n    # default values from\n    # https://github.com/scannetpp/scannetpp/blob/main/common/configs/render.yml\n    znear = 0.05\n    zfar = 20.0\n\n    listfile = osp.join(pairsdir, 'scene_list.json')\n    with open(listfile, 'r') as f:\n        scenes = json.load(f)\n\n    # for each of these, we will select some dslr images and some iphone images\n    # we will undistort them and render their depth\n    renderer = pyrender.OffscreenRenderer(0, 0)\n    for scene in tqdm(scenes, position=0, leave=True):\n        data_dir = os.path.join(root, 'data', scene)\n        dir_dslr = os.path.join(data_dir, 'dslr')\n        dir_iphone = os.path.join(data_dir, 'iphone')\n        dir_scans = os.path.join(data_dir, 'scans')\n\n        assert os.path.isdir(data_dir) and os.path.isdir(dir_dslr) \\\n            and os.path.isdir(dir_iphone) and os.path.isdir(dir_scans)\n\n        output_dir_scene = os.path.join(output_dir, scene)\n        scene_metadata_path = osp.join(output_dir_scene, 'scene_metadata.npz')\n        if osp.isfile(scene_metadata_path):\n            continue\n\n        pairs_dir_scene = os.path.join(pairsdir, scene)\n        pairs_dir_scene_selected_pairs = os.path.join(pairs_dir_scene, 'selected_pairs.npz')\n        assert osp.isfile(pairs_dir_scene_selected_pairs)\n        selected_npz = np.load(pairs_dir_scene_selected_pairs)\n        selection, pairs = selected_npz['selection'], selected_npz['pairs']\n\n        # set up the output paths\n        output_dir_scene_rgb = os.path.join(output_dir_scene, 'images')\n        output_dir_scene_depth = os.path.join(output_dir_scene, 'depth')\n        os.makedirs(output_dir_scene_rgb, exist_ok=True)\n        os.makedirs(output_dir_scene_depth, exist_ok=True)\n\n        ply_path = os.path.join(dir_scans, 'mesh_aligned_0.05.ply')\n\n        sfm_dir_dslr = os.path.join(dir_dslr, 'colmap')\n        rgb_dir_dslr = os.path.join(dir_dslr, 'resized_images')\n        mask_dir_dslr = os.path.join(dir_dslr, 'resized_anon_masks')\n\n        sfm_dir_iphone = os.path.join(dir_iphone, 'colmap')\n        rgb_dir_iphone = os.path.join(dir_iphone, 'rgb')\n        mask_dir_iphone = os.path.join(dir_iphone, 'rgb_masks')\n\n        # load the mesh\n        with open(ply_path, 'rb') as f:\n            mesh_kwargs = trimesh.exchange.ply.load_ply(f)\n        mesh_scene = trimesh.Trimesh(**mesh_kwargs)\n\n        # read colmap reconstruction, we will only use the intrinsics and pose here\n        img_idx_dslr, img_infos_dslr, points3D_dslr, observations_dslr = load_sfm(sfm_dir_dslr, cam_type='dslr')\n        dslr_paths = {\n            \"in_colmap\": sfm_dir_dslr,\n            \"in_rgb\": rgb_dir_dslr,\n            \"in_mask\": mask_dir_dslr,\n        }\n\n        img_idx_iphone, img_infos_iphone, points3D_iphone, observations_iphone = load_sfm(\n            sfm_dir_iphone, cam_type='iphone')\n        iphone_paths = {\n            \"in_colmap\": sfm_dir_iphone,\n            \"in_rgb\": rgb_dir_iphone,\n            \"in_mask\": mask_dir_iphone,\n        }\n\n        mesh = pyrender.Mesh.from_trimesh(mesh_scene, smooth=False)\n        pyrender_scene = pyrender.Scene()\n        pyrender_scene.add(mesh)\n\n        selection_iphone = [imgname + '.jpg' for imgname in selection if 'frame_' in imgname]\n        selection_dslr = [imgname + '.JPG' for imgname in selection if not 'frame_' in imgname]\n\n        # resize the image to a more manageable size and render depth\n        for selection_cam, img_idx, img_infos, paths_data in [(selection_dslr, img_idx_dslr, img_infos_dslr, dslr_paths),\n                                                              (selection_iphone, img_idx_iphone, img_infos_iphone, iphone_paths)]:\n            rgb_dir = paths_data['in_rgb']\n            mask_dir = paths_data['in_mask']\n            for imgname in tqdm(selection_cam, position=1, leave=False):\n                imgidx = img_idx[imgname]\n                img_infos_idx = img_infos[imgidx]\n                rgb = np.array(Image.open(os.path.join(rgb_dir, img_infos_idx['path'])))\n                mask = np.array(Image.open(os.path.join(mask_dir, img_infos_idx['path'][:-3] + 'png')))\n\n                _, _, K, rgb, mask = undistort_images(img_infos_idx['intrinsics'], rgb, mask)\n\n                # rescale_image_depthmap assumes opencv intrinsics\n                intrinsics = geometry.colmap_to_opencv_intrinsics(K)\n                image, mask, intrinsics = rescale_image_depthmap(\n                    rgb, mask, intrinsics, (target_resolution, target_resolution * 3.0 / 4))\n\n                W, H = image.size\n                intrinsics = geometry.opencv_to_colmap_intrinsics(intrinsics)\n\n                # update inpace img_infos_idx\n                img_infos_idx['intrinsics'] = intrinsics\n                rgb_outpath = os.path.join(output_dir_scene_rgb, img_infos_idx['path'][:-3] + 'jpg')\n                image.save(rgb_outpath)\n\n                depth_outpath = os.path.join(output_dir_scene_depth, img_infos_idx['path'][:-3] + 'png')\n                # render depth image\n                renderer.viewport_width, renderer.viewport_height = W, H\n                fx, fy, cx, cy = intrinsics[0, 0], intrinsics[1, 1], intrinsics[0, 2], intrinsics[1, 2]\n                camera = pyrender.camera.IntrinsicsCamera(fx, fy, cx, cy, znear=znear, zfar=zfar)\n                camera_node = pyrender_scene.add(camera, pose=img_infos_idx['cam_to_world'] @ OPENGL_TO_OPENCV)\n\n                _, depth = renderer.render(pyrender_scene, flags=pyrender.RenderFlags.SKIP_CULL_FACES)\n                pyrender_scene.remove_node(camera_node)  # dont forget to remove camera\n\n                depth = (depth * 1000).astype('uint16')\n                # invalidate depth from mask before saving\n                depth_mask = (mask < 255)\n                depth[depth_mask] = 0\n                Image.fromarray(depth).save(depth_outpath)\n\n        trajectories = []\n        intrinsics = []\n        for imgname in selection:\n            if 'frame_' in imgname:\n                imgidx = img_idx_iphone[imgname + '.jpg']\n                img_infos_idx = img_infos_iphone[imgidx]\n            elif 'DSC' in imgname:\n                imgidx = img_idx_dslr[imgname + '.JPG']\n                img_infos_idx = img_infos_dslr[imgidx]\n            else:\n                raise ValueError(f'invalid image name {imgname}')\n\n            intrinsics.append(img_infos_idx['intrinsics'])\n            trajectories.append(img_infos_idx['cam_to_world'])\n\n        intrinsics = np.stack(intrinsics, axis=0)\n        trajectories = np.stack(trajectories, axis=0)\n        # save metadata for this scene\n        np.savez(scene_metadata_path,\n                 trajectories=trajectories,\n                 intrinsics=intrinsics,\n                 images=selection,\n                 pairs=pairs)\n\n        del img_infos\n        del pyrender_scene\n\n    # concat all scene_metadata.npz into a single file\n    scene_data = {}\n    for scene_subdir in scenes:\n        scene_metadata_path = osp.join(output_dir, scene_subdir, 'scene_metadata.npz')\n        with np.load(scene_metadata_path) as data:\n            trajectories = data['trajectories']\n            intrinsics = data['intrinsics']\n            images = data['images']\n            pairs = data['pairs']\n        scene_data[scene_subdir] = {'trajectories': trajectories,\n                                    'intrinsics': intrinsics,\n                                    'images': images,\n                                    'pairs': pairs}\n\n    offset = 0\n    counts = []\n    scenes = []\n    sceneids = []\n    images = []\n    intrinsics = []\n    trajectories = []\n    pairs = []\n    for scene_idx, (scene_subdir, data) in enumerate(scene_data.items()):\n        num_imgs = data['images'].shape[0]\n        img_pairs = data['pairs']\n\n        scenes.append(scene_subdir)\n        sceneids.extend([scene_idx] * num_imgs)\n\n        images.append(data['images'])\n\n        intrinsics.append(data['intrinsics'])\n        trajectories.append(data['trajectories'])\n\n        # offset pairs\n        img_pairs[:, 0:2] += offset\n        pairs.append(img_pairs)\n        counts.append(offset)\n\n        offset += num_imgs\n\n    images = np.concatenate(images, axis=0)\n    intrinsics = np.concatenate(intrinsics, axis=0)\n    trajectories = np.concatenate(trajectories, axis=0)\n    pairs = np.concatenate(pairs, axis=0)\n    np.savez(osp.join(output_dir, 'all_metadata.npz'),\n             counts=counts,\n             scenes=scenes,\n             sceneids=sceneids,\n             images=images,\n             intrinsics=intrinsics,\n             trajectories=trajectories,\n             pairs=pairs)\n    print('all done')\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    if args.pyopengl_platform.strip():\n        os.environ['PYOPENGL_PLATFORM'] = args.pyopengl_platform\n    process_scenes(args.scannetpp_dir, args.precomputed_pairs, args.output_dir, args.target_resolution)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_staticthings3d.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Preprocessing code for the StaticThings3D dataset\n# dataset at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d\n# 1) Download StaticThings3D in /path/to/StaticThings3D/\n#    with the script at https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/scripts/download_staticthings3d.sh\n#    --> depths.tar.bz2 frames_finalpass.tar.bz2 poses.tar.bz2 frames_cleanpass.tar.bz2 intrinsics.tar.bz2\n# 2) unzip everything in the same /path/to/StaticThings3D/ directory\n# 5) python datasets_preprocess/preprocess_staticthings3d.py --StaticThings3D_dir /path/to/tmp/StaticThings3D/\n# --------------------------------------------------------\nimport os\nimport os.path as osp\nimport re\nfrom tqdm import tqdm\nimport numpy as np\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2\n\nimport path_to_root  # noqa\nfrom dust3r.utils.parallel import parallel_threads\nfrom dust3r.datasets.utils import cropping  # noqa\n\n\ndef get_parser():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--StaticThings3D_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/staticthings3d_processed')\n    return parser\n\n\ndef main(db_root, pairs_path, output_dir):\n    all_scenes = _list_all_scenes(db_root)\n\n    # crop images\n    args = [(db_root, osp.join(split, subsplit, seq), camera, f'{n:04d}', output_dir)\n            for split, subsplit, seq in all_scenes for camera in ['left', 'right'] for n in range(6, 16)]\n    parallel_threads(load_crop_and_save, args, star_args=True, front_num=1)\n\n    # verify that all images are there\n    CAM = {b'l': 'left', b'r': 'right'}\n    pairs = np.load(pairs_path)\n    for scene, seq, cam1, im1, cam2, im2 in tqdm(pairs):\n        seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}')\n        for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]:\n            for ext in ['clean', 'final']:\n                impath = osp.join(output_dir, seq_path, cam, f\"{idx:04n}_{ext}.jpg\")\n                assert osp.isfile(impath), f'missing an image at {impath=}'\n\n    print(f'>> Saved all data to {output_dir}!')\n\n\ndef load_crop_and_save(db_root, relpath_, camera, num, out_dir):\n    relpath = osp.join(relpath_, camera, num)\n    if osp.isfile(osp.join(out_dir, relpath + '.npz')):\n        return\n    os.makedirs(osp.join(out_dir, relpath_, camera), exist_ok=True)\n\n    # load everything\n    intrinsics_in = readFloat(osp.join(db_root, 'intrinsics', relpath_, num + '.float3'))\n    cam2world = np.linalg.inv(readFloat(osp.join(db_root, 'poses', relpath + '.float3')))\n    depthmap_in = readFloat(osp.join(db_root, 'depths', relpath + '.float3'))\n    img_clean = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_cleanpass',\n                             relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)\n    img_final = cv2.cvtColor(cv2.imread(osp.join(db_root, 'frames_finalpass',\n                             relpath + '.png'), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)\n\n    # do the crop\n    assert img_clean.shape[:2] == (540, 960)\n    assert img_final.shape[:2] == (540, 960)\n    (clean_out, final_out), depthmap, intrinsics_out, R_in2out = _crop_image(\n        intrinsics_in, (img_clean, img_final), depthmap_in, (512, 384))\n\n    # write everything\n    clean_out.save(osp.join(out_dir, relpath + '_clean.jpg'), quality=80)\n    final_out.save(osp.join(out_dir, relpath + '_final.jpg'), quality=80)\n    cv2.imwrite(osp.join(out_dir, relpath + '.exr'), depthmap)\n\n    # New camera parameters\n    cam2world[:3, :3] = cam2world[:3, :3] @ R_in2out.T\n    np.savez(osp.join(out_dir, relpath + '.npz'), intrinsics=intrinsics_out, cam2world=cam2world)\n\n\ndef _crop_image(intrinsics_in, color_image_in, depthmap_in, resolution_out=(512, 512)):\n    image, depthmap, intrinsics_out = cropping.rescale_image_depthmap(\n        color_image_in, depthmap_in, intrinsics_in, resolution_out)\n    R_in2out = np.eye(3)\n    return image, depthmap, intrinsics_out, R_in2out\n\n\ndef _list_all_scenes(path):\n    print('>> Listing all scenes')\n\n    res = []\n    for split in ['TRAIN']:\n        for subsplit in 'ABC':\n            for seq in os.listdir(osp.join(path, 'intrinsics', split, subsplit)):\n                res.append((split, subsplit, seq))\n    print(f'   (found ({len(res)}) scenes)')\n    assert res, f'Did not find anything at {path=}'\n    return res\n\n\ndef readFloat(name):\n    with open(name, 'rb') as f:\n        if (f.readline().decode(\"utf-8\")) != 'float\\n':\n            raise Exception('float file %s did not contain <float> keyword' % name)\n\n        dim = int(f.readline())\n\n        dims = []\n        count = 1\n        for i in range(0, dim):\n            d = int(f.readline())\n            dims.append(d)\n            count *= d\n\n        dims = list(reversed(dims))\n        data = np.fromfile(f, np.float32, count).reshape(dims)\n    return data  # Hxw or CxHxW NxCxHxW\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    main(args.StaticThings3D_dir, args.precomputed_pairs, args.output_dir)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_waymo.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Preprocessing code for the WayMo Open dataset\n# dataset at https://github.com/waymo-research/waymo-open-dataset\n# 1) Accept the license\n# 2) download all training/*.tfrecord files from Perception Dataset, version 1.4.2\n# 3) put all .tfrecord files in '/path/to/waymo_dir'\n# 4) install the waymo_open_dataset package with\n#    `python3 -m pip install gcsfs waymo-open-dataset-tf-2-12-0==1.6.4`\n# 5) execute this script as `python preprocess_waymo.py --waymo_dir /path/to/waymo_dir`\n# --------------------------------------------------------\nimport sys\nimport os\nimport os.path as osp\nimport shutil\nimport json\nfrom tqdm import tqdm\nimport PIL.Image\nimport numpy as np\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2\n\nimport tensorflow.compat.v1 as tf\ntf.enable_eager_execution()\n\nimport path_to_root  # noqa\nfrom dust3r.utils.geometry import geotrf, inv\nfrom dust3r.utils.image import imread_cv2\nfrom dust3r.utils.parallel import parallel_processes as parallel_map\nfrom dust3r.datasets.utils import cropping\nfrom dust3r.viz import show_raw_pointcloud\n\n\ndef get_parser():\n    import argparse\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--waymo_dir', required=True)\n    parser.add_argument('--precomputed_pairs', required=True)\n    parser.add_argument('--output_dir', default='data/waymo_processed')\n    parser.add_argument('--workers', type=int, default=1)\n    return parser\n\n\ndef main(waymo_root, pairs_path, output_dir, workers=1):\n    extract_frames(waymo_root, output_dir, workers=workers)\n    make_crops(output_dir, workers=args.workers)\n\n    # make sure all pairs are there\n    with np.load(pairs_path) as data:\n        scenes = data['scenes']\n        frames = data['frames']\n        pairs = data['pairs']  # (array of (scene_id, img1_id, img2_id)\n\n    for scene_id, im1_id, im2_id in pairs:\n        for im_id in (im1_id, im2_id):\n            path = osp.join(output_dir, scenes[scene_id], frames[im_id] + '.jpg')\n            assert osp.isfile(path), f'Missing a file at {path=}\\nDid you download all .tfrecord files?'\n\n    shutil.rmtree(osp.join(output_dir, 'tmp'))\n    print('Done! all data generated at', output_dir)\n\n\ndef _list_sequences(db_root):\n    print('>> Looking for sequences in', db_root)\n    res = sorted(f for f in os.listdir(db_root) if f.endswith('.tfrecord'))\n    print(f'    found {len(res)} sequences')\n    return res\n\n\ndef extract_frames(db_root, output_dir, workers=8):\n    sequences = _list_sequences(db_root)\n    output_dir = osp.join(output_dir, 'tmp')\n    print('>> outputing result to', output_dir)\n    args = [(db_root, output_dir, seq) for seq in sequences]\n    parallel_map(process_one_seq, args, star_args=True, workers=workers)\n\n\ndef process_one_seq(db_root, output_dir, seq):\n    out_dir = osp.join(output_dir, seq)\n    os.makedirs(out_dir, exist_ok=True)\n    calib_path = osp.join(out_dir, 'calib.json')\n    if osp.isfile(calib_path):\n        return\n\n    try:\n        with tf.device('/CPU:0'):\n            calib, frames = extract_frames_one_seq(osp.join(db_root, seq))\n    except RuntimeError:\n        print(f'/!\\\\ Error with sequence {seq} /!\\\\', file=sys.stderr)\n        return  # nothing is saved\n\n    for f, (frame_name, views) in enumerate(tqdm(frames, leave=False)):\n        for cam_idx, view in views.items():\n            img = PIL.Image.fromarray(view.pop('img'))\n            img.save(osp.join(out_dir, f'{f:05d}_{cam_idx}.jpg'))\n            np.savez(osp.join(out_dir, f'{f:05d}_{cam_idx}.npz'), **view)\n\n    with open(calib_path, 'w') as f:\n        json.dump(calib, f)\n\n\ndef extract_frames_one_seq(filename):\n    from waymo_open_dataset import dataset_pb2 as open_dataset\n    from waymo_open_dataset.utils import frame_utils\n\n    print('>> Opening', filename)\n    dataset = tf.data.TFRecordDataset(filename, compression_type='')\n\n    calib = None\n    frames = []\n\n    for data in tqdm(dataset, leave=False):\n        frame = open_dataset.Frame()\n        frame.ParseFromString(bytearray(data.numpy()))\n\n        content = frame_utils.parse_range_image_and_camera_projection(frame)\n        range_images, camera_projections, _, range_image_top_pose = content\n\n        views = {}\n        frames.append((frame.context.name, views))\n\n        # once in a sequence, read camera calibration info\n        if calib is None:\n            calib = []\n            for cam in frame.context.camera_calibrations:\n                calib.append((cam.name,\n                              dict(width=cam.width,\n                                   height=cam.height,\n                                   intrinsics=list(cam.intrinsic),\n                                   extrinsics=list(cam.extrinsic.transform))))\n\n        # convert LIDAR to pointcloud\n        points, cp_points = frame_utils.convert_range_image_to_point_cloud(\n            frame,\n            range_images,\n            camera_projections,\n            range_image_top_pose)\n\n        # 3d points in vehicle frame.\n        points_all = np.concatenate(points, axis=0)\n        cp_points_all = np.concatenate(cp_points, axis=0)\n\n        # The distance between lidar points and vehicle frame origin.\n        cp_points_all_tensor = tf.constant(cp_points_all, dtype=tf.int32)\n\n        for i, image in enumerate(frame.images):\n            # select relevant 3D points for this view\n            mask = tf.equal(cp_points_all_tensor[..., 0], image.name)\n            cp_points_msk_tensor = tf.cast(tf.gather_nd(cp_points_all_tensor, tf.where(mask)), dtype=tf.float32)\n\n            pose = np.asarray(image.pose.transform).reshape(4, 4)\n            timestamp = image.pose_timestamp\n\n            rgb = tf.image.decode_jpeg(image.image).numpy()\n\n            pix = cp_points_msk_tensor[..., 1:3].numpy().round().astype(np.int16)\n            pts3d = points_all[mask.numpy()]\n\n            views[image.name] = dict(img=rgb, pose=pose, pixels=pix, pts3d=pts3d, timestamp=timestamp)\n\n        if not 'show full point cloud':\n            show_raw_pointcloud([v['pts3d'] for v in views.values()], [v['img'] for v in views.values()])\n\n    return calib, frames\n\n\ndef make_crops(output_dir, workers=16, **kw):\n    tmp_dir = osp.join(output_dir, 'tmp')\n    sequences = _list_sequences(tmp_dir)\n    args = [(tmp_dir, output_dir, seq) for seq in sequences]\n    parallel_map(crop_one_seq, args, star_args=True, workers=workers, front_num=0)\n\n\ndef crop_one_seq(input_dir, output_dir, seq, resolution=512):\n    seq_dir = osp.join(input_dir, seq)\n    out_dir = osp.join(output_dir, seq)\n    if osp.isfile(osp.join(out_dir, '00100_1.jpg')):\n        return\n    os.makedirs(out_dir, exist_ok=True)\n\n    # load calibration file\n    try:\n        with open(osp.join(seq_dir, 'calib.json')) as f:\n            calib = json.load(f)\n    except IOError:\n        print(f'/!\\\\ Error: Missing calib.json in sequence {seq} /!\\\\', file=sys.stderr)\n        return\n\n    axes_transformation = np.array([\n        [0, -1, 0, 0],\n        [0, 0, -1, 0],\n        [1, 0, 0, 0],\n        [0, 0, 0, 1]])\n\n    cam_K = {}\n    cam_distortion = {}\n    cam_res = {}\n    cam_to_car = {}\n    for cam_idx, cam_info in calib:\n        cam_idx = str(cam_idx)\n        cam_res[cam_idx] = (W, H) = (cam_info['width'], cam_info['height'])\n        f1, f2, cx, cy, k1, k2, p1, p2, k3 = cam_info['intrinsics']\n        cam_K[cam_idx] = np.asarray([(f1, 0, cx), (0, f2, cy), (0, 0, 1)])\n        cam_distortion[cam_idx] = np.asarray([k1, k2, p1, p2, k3])\n        cam_to_car[cam_idx] = np.asarray(cam_info['extrinsics']).reshape(4, 4)  # cam-to-vehicle\n\n    frames = sorted(f[:-3] for f in os.listdir(seq_dir) if f.endswith('.jpg'))\n\n    # from dust3r.viz import SceneViz\n    # viz = SceneViz()\n\n    for frame in tqdm(frames, leave=False):\n        cam_idx = frame[-2]  # cam index\n        assert cam_idx in '12345', f'bad {cam_idx=} in {frame=}'\n        data = np.load(osp.join(seq_dir, frame + 'npz'))\n        car_to_world = data['pose']\n        W, H = cam_res[cam_idx]\n\n        # load depthmap\n        pos2d = data['pixels'].round().astype(np.uint16)\n        x, y = pos2d.T\n        pts3d = data['pts3d']  # already in the car frame\n        pts3d = geotrf(axes_transformation @ inv(cam_to_car[cam_idx]), pts3d)\n        # X=LEFT_RIGHT y=ALTITUDE z=DEPTH\n\n        # load image\n        image = imread_cv2(osp.join(seq_dir, frame + 'jpg'))\n\n        # downscale image\n        output_resolution = (resolution, 1) if W > H else (1, resolution)\n        image, _, intrinsics2 = cropping.rescale_image_depthmap(image, None, cam_K[cam_idx], output_resolution)\n        image.save(osp.join(out_dir, frame + 'jpg'), quality=80)\n\n        # save as an EXR file? yes it's smaller (and easier to load)\n        W, H = image.size\n        depthmap = np.zeros((H, W), dtype=np.float32)\n        pos2d = geotrf(intrinsics2 @ inv(cam_K[cam_idx]), pos2d).round().astype(np.int16)\n        x, y = pos2d.T\n        depthmap[y.clip(min=0, max=H - 1), x.clip(min=0, max=W - 1)] = pts3d[:, 2]\n        cv2.imwrite(osp.join(out_dir, frame + 'exr'), depthmap)\n\n        # save camera parametes\n        cam2world = car_to_world @ cam_to_car[cam_idx] @ inv(axes_transformation)\n        np.savez(osp.join(out_dir, frame + 'npz'), intrinsics=intrinsics2,\n                 cam2world=cam2world, distortion=cam_distortion[cam_idx])\n\n        # viz.add_rgbd(np.asarray(image), depthmap, intrinsics2, cam2world)\n    # viz.show()\n\n\nif __name__ == '__main__':\n    parser = get_parser()\n    args = parser.parse_args()\n    main(args.waymo_dir, args.precomputed_pairs, args.output_dir, workers=args.workers)\n"
  },
  {
    "path": "datasets_preprocess/preprocess_wildrgbd.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Script to pre-process the WildRGB-D dataset.\n# Usage:\n# python3 datasets_preprocess/preprocess_wildrgbd.py --wildrgbd_dir /path/to/wildrgbd\n# --------------------------------------------------------\n\nimport argparse\nimport random\nimport json\nimport os\nimport os.path as osp\n\nimport PIL.Image\nimport numpy as np\nimport cv2\n\nfrom tqdm.auto import tqdm\nimport matplotlib.pyplot as plt\n\nimport path_to_root  # noqa\nimport dust3r.datasets.utils.cropping as cropping  # noqa\nfrom dust3r.utils.image import imread_cv2\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--output_dir\", type=str, default=\"data/wildrgbd_processed\")\n    parser.add_argument(\"--wildrgbd_dir\", type=str, required=True)\n    parser.add_argument(\"--train_num_sequences_per_object\", type=int, default=50)\n    parser.add_argument(\"--test_num_sequences_per_object\", type=int, default=10)\n    parser.add_argument(\"--num_frames\", type=int, default=100)\n    parser.add_argument(\"--seed\", type=int, default=42)\n\n    parser.add_argument(\"--img_size\", type=int, default=512,\n                        help=(\"lower dimension will be >= img_size * 3/4, and max dimension will be >= img_size\"))\n    return parser\n\n\ndef get_set_list(category_dir, split):\n    listfiles = [\"camera_eval_list.json\", \"nvs_list.json\"]\n\n    sequences_all = {s: {k: set() for k in listfiles} for s in ['train', 'val']}\n    for listfile in listfiles:\n        with open(osp.join(category_dir, listfile)) as f:\n            subset_lists_data = json.load(f)\n            for s in ['train', 'val']:\n                sequences_all[s][listfile].update(subset_lists_data[s])\n    train_intersection = set.intersection(*list(sequences_all['train'].values()))\n    if split == \"train\":\n        return train_intersection\n    else:\n        all_seqs = set.union(*list(sequences_all['train'].values()), *list(sequences_all['val'].values()))\n        return all_seqs.difference(train_intersection)\n\n\ndef prepare_sequences(category, wildrgbd_dir, output_dir, img_size, split, max_num_sequences_per_object,\n                      output_num_frames, seed):\n    random.seed(seed)\n    category_dir = osp.join(wildrgbd_dir, category)\n    category_output_dir = osp.join(output_dir, category)\n    sequences_all = get_set_list(category_dir, split)\n    sequences_all = sorted(sequences_all)\n\n    sequences_all_tmp = []\n    for seq_name in sequences_all:\n        scene_dir = osp.join(wildrgbd_dir, category_dir, seq_name)\n        if not os.path.isdir(scene_dir):\n            print(f'{scene_dir} does not exist, skipped')\n            continue\n        sequences_all_tmp.append(seq_name)\n    sequences_all = sequences_all_tmp\n    if len(sequences_all) <= max_num_sequences_per_object:\n        selected_sequences = sequences_all\n    else:\n        selected_sequences = random.sample(sequences_all, max_num_sequences_per_object)\n\n    selected_sequences_numbers_dict = {}\n    for seq_name in tqdm(selected_sequences, leave=False):\n        scene_dir = osp.join(category_dir, seq_name)\n        scene_output_dir = osp.join(category_output_dir, seq_name)\n        with open(osp.join(scene_dir, 'metadata'), 'r') as f:\n            metadata = json.load(f)\n\n        K = np.array(metadata[\"K\"]).reshape(3, 3).T\n        fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n        w, h = metadata[\"w\"], metadata[\"h\"]\n\n        camera_intrinsics = np.array(\n            [[fx, 0, cx],\n             [0, fy, cy],\n             [0, 0, 1]]\n        )\n        camera_to_world_path = os.path.join(scene_dir, 'cam_poses.txt')\n        camera_to_world_content = np.genfromtxt(camera_to_world_path)\n        camera_to_world = camera_to_world_content[:, 1:].reshape(-1, 4, 4)\n\n        frame_idx = camera_to_world_content[:, 0]\n        num_frames = frame_idx.shape[0]\n        assert num_frames >= output_num_frames\n        assert np.all(frame_idx == np.arange(num_frames))\n\n        # selected_sequences_numbers_dict[seq_name] = num_frames\n\n        selected_frames = np.round(np.linspace(0, num_frames - 1, output_num_frames)).astype(int).tolist()\n        selected_sequences_numbers_dict[seq_name] = selected_frames\n\n        for frame_id in tqdm(selected_frames):\n            depth_path = os.path.join(scene_dir, 'depth', f'{frame_id:0>5d}.png')\n            masks_path = os.path.join(scene_dir, 'masks', f'{frame_id:0>5d}.png')\n            rgb_path = os.path.join(scene_dir, 'rgb', f'{frame_id:0>5d}.png')\n\n            input_rgb_image = PIL.Image.open(rgb_path).convert('RGB')\n            input_mask = plt.imread(masks_path)\n            input_depthmap = imread_cv2(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float64)\n            depth_mask = np.stack((input_depthmap, input_mask), axis=-1)\n            H, W = input_depthmap.shape\n\n            min_margin_x = min(cx, W - cx)\n            min_margin_y = min(cy, H - cy)\n\n            # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)\n            l, t = int(cx - min_margin_x), int(cy - min_margin_y)\n            r, b = int(cx + min_margin_x), int(cy + min_margin_y)\n            crop_bbox = (l, t, r, b)\n            input_rgb_image, depth_mask, input_camera_intrinsics = cropping.crop_image_depthmap(\n                input_rgb_image, depth_mask, camera_intrinsics, crop_bbox)\n\n            # try to set the lower dimension to img_size * 3/4 -> img_size=512 => 384\n            scale_final = ((img_size * 3 // 4) / min(H, W)) + 1e-8\n            output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)\n            if max(output_resolution) < img_size:\n                # let's put the max dimension to img_size\n                scale_final = (img_size / max(H, W)) + 1e-8\n                output_resolution = np.floor(np.array([W, H]) * scale_final).astype(int)\n\n            input_rgb_image, depth_mask, input_camera_intrinsics = cropping.rescale_image_depthmap(\n                input_rgb_image, depth_mask, input_camera_intrinsics, output_resolution)\n            input_depthmap = depth_mask[:, :, 0]\n            input_mask = depth_mask[:, :, 1]\n\n            camera_pose = camera_to_world[frame_id]\n\n            # save crop images and depth, metadata\n            save_img_path = os.path.join(scene_output_dir, 'rgb', f'{frame_id:0>5d}.jpg')\n            save_depth_path = os.path.join(scene_output_dir, 'depth', f'{frame_id:0>5d}.png')\n            save_mask_path = os.path.join(scene_output_dir, 'masks', f'{frame_id:0>5d}.png')\n            os.makedirs(os.path.split(save_img_path)[0], exist_ok=True)\n            os.makedirs(os.path.split(save_depth_path)[0], exist_ok=True)\n            os.makedirs(os.path.split(save_mask_path)[0], exist_ok=True)\n\n            input_rgb_image.save(save_img_path)\n            cv2.imwrite(save_depth_path, input_depthmap.astype(np.uint16))\n            cv2.imwrite(save_mask_path, (input_mask * 255).astype(np.uint8))\n\n            save_meta_path = os.path.join(scene_output_dir, 'metadata', f'{frame_id:0>5d}.npz')\n            os.makedirs(os.path.split(save_meta_path)[0], exist_ok=True)\n            np.savez(save_meta_path, camera_intrinsics=input_camera_intrinsics,\n                     camera_pose=camera_pose)\n\n    return selected_sequences_numbers_dict\n\n\nif __name__ == \"__main__\":\n    parser = get_parser()\n    args = parser.parse_args()\n    assert args.wildrgbd_dir != args.output_dir\n\n    categories = sorted([\n        dirname for dirname in os.listdir(args.wildrgbd_dir)\n        if os.path.isdir(os.path.join(args.wildrgbd_dir, dirname, 'scenes'))\n    ])\n\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    splits_num_sequences_per_object = [args.train_num_sequences_per_object, args.test_num_sequences_per_object]\n    for split, num_sequences_per_object in zip(['train', 'test'], splits_num_sequences_per_object):\n        selected_sequences_path = os.path.join(args.output_dir, f'selected_seqs_{split}.json')\n        if os.path.isfile(selected_sequences_path):\n            continue\n        all_selected_sequences = {}\n        for category in categories:\n            category_output_dir = osp.join(args.output_dir, category)\n            os.makedirs(category_output_dir, exist_ok=True)\n            category_selected_sequences_path = os.path.join(category_output_dir, f'selected_seqs_{split}.json')\n            if os.path.isfile(category_selected_sequences_path):\n                with open(category_selected_sequences_path, 'r') as fid:\n                    category_selected_sequences = json.load(fid)\n            else:\n                print(f\"Processing {split} - category = {category}\")\n                category_selected_sequences = prepare_sequences(\n                    category=category,\n                    wildrgbd_dir=args.wildrgbd_dir,\n                    output_dir=args.output_dir,\n                    img_size=args.img_size,\n                    split=split,\n                    max_num_sequences_per_object=num_sequences_per_object,\n                    output_num_frames=args.num_frames,\n                    seed=args.seed + int(\"category\".encode('ascii').hex(), 16),\n                )\n                with open(category_selected_sequences_path, 'w') as file:\n                    json.dump(category_selected_sequences, file)\n\n            all_selected_sequences[category] = category_selected_sequences\n        with open(selected_sequences_path, 'w') as file:\n            json.dump(all_selected_sequences, file)\n"
  },
  {
    "path": "demo.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# dust3r gradio demo executable\n# --------------------------------------------------------\nimport os\nimport torch\nimport tempfile\n\nfrom dust3r.model import AsymmetricCroCo3DStereo\nfrom dust3r.demo import get_args_parser, main_demo, set_print_with_timestamp\n\nimport matplotlib.pyplot as pl\npl.ion()\n\ntorch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12\n\nif __name__ == '__main__':\n    parser = get_args_parser()\n    args = parser.parse_args()\n    set_print_with_timestamp()\n\n    if args.tmp_dir is not None:\n        tmp_path = args.tmp_dir\n        os.makedirs(tmp_path, exist_ok=True)\n        tempfile.tempdir = tmp_path\n\n    if args.server_name is not None:\n        server_name = args.server_name\n    else:\n        server_name = '0.0.0.0' if args.local_network else '127.0.0.1'\n\n    if args.weights is not None:\n        weights_path = args.weights\n    else:\n        weights_path = \"naver/\" + args.model_name\n    model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device)\n\n    # dust3r will write the 3D model inside tmpdirname\n    with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:\n        if not args.silent:\n            print('Outputing stuff in', tmpdirname)\n        main_demo(tmpdirname, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent)\n"
  },
  {
    "path": "docker/docker-compose-cpu.yml",
    "content": "version: '3.8'\nservices:\n  dust3r-demo:\n    build:\n      context: ./files\n      dockerfile: cpu.Dockerfile \n    ports:\n      - \"7860:7860\"\n    volumes:\n      - ./files/checkpoints:/dust3r/checkpoints\n    environment:\n      - DEVICE=cpu\n      - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}\n    cap_add:\n      - IPC_LOCK\n      - SYS_RESOURCE\n"
  },
  {
    "path": "docker/docker-compose-cuda.yml",
    "content": "version: '3.8'\nservices:\n  dust3r-demo:\n    build:\n      context: ./files\n      dockerfile: cuda.Dockerfile \n    ports:\n      - \"7860:7860\"\n    environment:\n      - DEVICE=cuda\n      - MODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}\n    volumes:\n      - ./files/checkpoints:/dust3r/checkpoints\n    cap_add:\n      - IPC_LOCK\n      - SYS_RESOURCE\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: 1\n              capabilities: [gpu]\n"
  },
  {
    "path": "docker/files/cpu.Dockerfile",
    "content": "FROM python:3.11-slim\n\nLABEL description=\"Docker container for DUSt3R with dependencies installed. CPU VERSION\"\n\nENV DEVICE=\"cpu\"\nENV MODEL=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\"\nARG DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get update && apt-get install -y \\\n    git \\\n    libgl1-mesa-glx \\\n    libegl1-mesa \\\n    libxrandr2 \\\n    libxrandr2 \\\n    libxss1 \\\n    libxcursor1 \\\n    libxcomposite1 \\\n    libasound2 \\\n    libxi6 \\\n    libxtst6 \\\n    libglib2.0-0 \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN git clone --recursive https://github.com/naver/dust3r /dust3r\nWORKDIR /dust3r\n\nRUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu\nRUN pip install -r requirements.txt\nRUN pip install -r requirements_optional.txt\nRUN pip install opencv-python==4.8.0.74\n\nWORKDIR /dust3r\n\nCOPY entrypoint.sh /entrypoint.sh\nRUN chmod +x /entrypoint.sh\n\nENTRYPOINT [\"/entrypoint.sh\"]\n"
  },
  {
    "path": "docker/files/cuda.Dockerfile",
    "content": "FROM nvcr.io/nvidia/pytorch:24.01-py3\n\nLABEL description=\"Docker container for DUSt3R with dependencies installed. CUDA VERSION\"\nENV DEVICE=\"cuda\"\nENV MODEL=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\"\nARG DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get update && apt-get install -y \\\n    git=1:2.34.1-1ubuntu1.10 \\\n    libglib2.0-0=2.72.4-0ubuntu2.2 \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN git clone --recursive https://github.com/naver/dust3r /dust3r\nWORKDIR /dust3r\nRUN pip install -r requirements.txt\nRUN pip install -r requirements_optional.txt\nRUN pip install opencv-python==4.8.0.74\n\nWORKDIR /dust3r/croco/models/curope/\nRUN python setup.py build_ext --inplace\n\nWORKDIR /dust3r\nCOPY entrypoint.sh /entrypoint.sh\nRUN chmod +x /entrypoint.sh\n\nENTRYPOINT [\"/entrypoint.sh\"]\n"
  },
  {
    "path": "docker/files/entrypoint.sh",
    "content": "#!/bin/bash\n\nset -eux\n\nDEVICE=${DEVICE:-cuda}\nMODEL=${MODEL:-DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth}\n\nexec python3 demo.py --weights \"checkpoints/$MODEL\" --device \"$DEVICE\" --local_network \"$@\"\n"
  },
  {
    "path": "docker/run.sh",
    "content": "#!/bin/bash\n\nset -eux\n\n# Default model name\nmodel_name=\"DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth\"\n\ncheck_docker() {\n    if ! command -v docker &>/dev/null; then\n        echo \"Docker could not be found. Please install Docker and try again.\"\n        exit 1\n    fi\n}\n\ndownload_model_checkpoint() { \n    if [ -f \"./files/checkpoints/${model_name}\" ]; then\n        echo \"Model checkpoint ${model_name} already exists. Skipping download.\"\n        return\n    fi\n    echo \"Downloading model checkpoint ${model_name}...\"\n    wget \"https://download.europe.naverlabs.com/ComputerVision/DUSt3R/${model_name}\" -P ./files/checkpoints\n}\n\nset_dcomp() {\n    if command -v docker-compose &>/dev/null; then\n        dcomp=\"docker-compose\"\n    elif command -v docker &>/dev/null && docker compose version &>/dev/null; then\n        dcomp=\"docker compose\"\n    else\n        echo \"Docker Compose could not be found. Please install Docker Compose and try again.\"\n        exit 1\n    fi\n}\n\nrun_docker() {\n    export MODEL=${model_name}\n    if [ \"$with_cuda\" -eq 1 ]; then\n        $dcomp -f docker-compose-cuda.yml up --build\n    else\n        $dcomp -f docker-compose-cpu.yml up --build\n    fi\n}\n\nwith_cuda=0\nfor arg in \"$@\"; do\n    case $arg in\n        --with-cuda)\n            with_cuda=1\n            ;;\n        --model_name=*)\n            model_name=\"${arg#*=}.pth\"\n            ;;\n        *)\n            echo \"Unknown parameter passed: $arg\"\n            exit 1\n            ;;\n    esac\ndone\n\n\nmain() {\n    check_docker\n    download_model_checkpoint\n    set_dcomp\n    run_docker\n}\n\nmain\n"
  },
  {
    "path": "dust3r/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "dust3r/cloud_opt/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# global alignment optimization wrapper function\n# --------------------------------------------------------\nfrom enum import Enum\n\nfrom .optimizer import PointCloudOptimizer\nfrom .modular_optimizer import ModularPointCloudOptimizer\nfrom .pair_viewer import PairViewer\n\n\nclass GlobalAlignerMode(Enum):\n    PointCloudOptimizer = \"PointCloudOptimizer\"\n    ModularPointCloudOptimizer = \"ModularPointCloudOptimizer\"\n    PairViewer = \"PairViewer\"\n\n\ndef global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):\n    # extract all inputs\n    view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]\n    # build the optimizer\n    if mode == GlobalAlignerMode.PointCloudOptimizer:\n        net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)\n    elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:\n        net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)\n    elif mode == GlobalAlignerMode.PairViewer:\n        net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)\n    else:\n        raise NotImplementedError(f'Unknown mode {mode}')\n\n    return net\n"
  },
  {
    "path": "dust3r/cloud_opt/base_opt.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Base class for the global alignement procedure\n# --------------------------------------------------------\nfrom copy import deepcopy\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport roma\nfrom copy import deepcopy\nimport tqdm\n\nfrom dust3r.utils.geometry import inv, geotrf\nfrom dust3r.utils.device import to_numpy\nfrom dust3r.utils.image import rgb\nfrom dust3r.viz import SceneViz, segment_sky, auto_cam_size\nfrom dust3r.optim_factory import adjust_learning_rate_by_lr\n\nfrom dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,\n                                      cosine_schedule, linear_schedule, get_conf_trf)\nimport dust3r.cloud_opt.init_im_poses as init_fun\n\n\nclass BasePCOptimizer (nn.Module):\n    \"\"\" Optimize a global scene, given a list of pairwise observations.\n    Graph node: images\n    Graph edges: observations = (pred1, pred2)\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        if len(args) == 1 and len(kwargs) == 0:\n            other = deepcopy(args[0])\n            attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes \n                        min_conf_thr conf_thr conf_i conf_j im_conf\n                        base_scale norm_pw_scale POSE_DIM pw_poses \n                        pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()\n            self.__dict__.update({k: other[k] for k in attrs})\n        else:\n            self._init_from_views(*args, **kwargs)\n\n    def _init_from_views(self, view1, view2, pred1, pred2,\n                         dist='l1',\n                         conf='log',\n                         min_conf_thr=3,\n                         base_scale=0.5,\n                         allow_pw_adaptors=False,\n                         pw_break=20,\n                         rand_pose=torch.randn,\n                         iterationsCount=None,\n                         verbose=True):\n        super().__init__()\n        if not isinstance(view1['idx'], list):\n            view1['idx'] = view1['idx'].tolist()\n        if not isinstance(view2['idx'], list):\n            view2['idx'] = view2['idx'].tolist()\n        self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]\n        self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}\n        self.dist = ALL_DISTS[dist]\n        self.verbose = verbose\n\n        self.n_imgs = self._check_edges()\n\n        # input data\n        pred1_pts = pred1['pts3d']\n        pred2_pts = pred2['pts3d_in_other_view']\n        self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})\n        self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})\n        self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)\n\n        # work in log-scale with conf\n        pred1_conf = pred1['conf']\n        pred2_conf = pred2['conf']\n        self.min_conf_thr = min_conf_thr\n        self.conf_trf = get_conf_trf(conf)\n\n        self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})\n        self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})\n        self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)\n        for i in range(len(self.im_conf)):\n            self.im_conf[i].requires_grad = False\n\n        # pairwise pose parameters\n        self.base_scale = base_scale\n        self.norm_pw_scale = True\n        self.pw_break = pw_break\n        self.POSE_DIM = 7\n        self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM)))  # pairwise poses\n        self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2)))  # slight xy/z adaptation\n        self.pw_adaptors.requires_grad_(allow_pw_adaptors)\n        self.has_im_poses = False\n        self.rand_pose = rand_pose\n\n        # possibly store images for show_pointcloud\n        self.imgs = None\n        if 'img' in view1 and 'img' in view2:\n            imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]\n            for v in range(len(self.edges)):\n                idx = view1['idx'][v]\n                imgs[idx] = view1['img'][v]\n                idx = view2['idx'][v]\n                imgs[idx] = view2['img'][v]\n            self.imgs = rgb(imgs)\n\n    @property\n    def n_edges(self):\n        return len(self.edges)\n\n    @property\n    def str_edges(self):\n        return [edge_str(i, j) for i, j in self.edges]\n\n    @property\n    def imsizes(self):\n        return [(w, h) for h, w in self.imshapes]\n\n    @property\n    def device(self):\n        return next(iter(self.parameters())).device\n\n    def state_dict(self, trainable=True):\n        all_params = super().state_dict()\n        return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}\n\n    def load_state_dict(self, data):\n        return super().load_state_dict(self.state_dict(trainable=False) | data)\n\n    def _check_edges(self):\n        indices = sorted({i for edge in self.edges for i in edge})\n        assert indices == list(range(len(indices))), 'bad pair indices: missing values '\n        return len(indices)\n\n    @torch.no_grad()\n    def _compute_img_conf(self, pred1_conf, pred2_conf):\n        im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])\n        for e, (i, j) in enumerate(self.edges):\n            im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])\n            im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])\n        return im_conf\n\n    def get_adaptors(self):\n        adapt = self.pw_adaptors\n        adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1)  # (scale_xy, scale_xy, scale_z)\n        if self.norm_pw_scale:  # normalize so that the product == 1\n            adapt = adapt - adapt.mean(dim=1, keepdim=True)\n        return (adapt / self.pw_break).exp()\n\n    def _get_poses(self, poses):\n        # normalize rotation\n        Q = poses[:, :4]\n        T = signed_expm1(poses[:, 4:7])\n        RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()\n        return RT\n\n    def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):\n        # all poses == cam-to-world\n        pose = poses[idx]\n        if not (pose.requires_grad or force):\n            return pose\n\n        if R.shape == (4, 4):\n            assert T is None\n            T = R[:3, 3]\n            R = R[:3, :3]\n\n        if R is not None:\n            pose.data[0:4] = roma.rotmat_to_unitquat(R)\n        if T is not None:\n            pose.data[4:7] = signed_log1p(T / (scale or 1))  # translation is function of scale\n\n        if scale is not None:\n            assert poses.shape[-1] in (8, 13)\n            pose.data[-1] = np.log(float(scale))\n        return pose\n\n    def get_pw_norm_scale_factor(self):\n        if self.norm_pw_scale:\n            # normalize scales so that things cannot go south\n            # we want that exp(scale) ~= self.base_scale\n            return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()\n        else:\n            return 1  # don't norm scale for known poses\n\n    def get_pw_scale(self):\n        scale = self.pw_poses[:, -1].exp()  # (n_edges,)\n        scale = scale * self.get_pw_norm_scale_factor()\n        return scale\n\n    def get_pw_poses(self):  # cam to world\n        RT = self._get_poses(self.pw_poses)\n        scaled_RT = RT.clone()\n        scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1)  # scale the rotation AND translation\n        return scaled_RT\n\n    def get_masks(self):\n        return [(conf > self.min_conf_thr) for conf in self.im_conf]\n\n    def depth_to_pts3d(self):\n        raise NotImplementedError()\n\n    def get_pts3d(self, raw=False):\n        res = self.depth_to_pts3d()\n        if not raw:\n            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]\n        return res\n\n    def _set_focal(self, idx, focal, force=False):\n        raise NotImplementedError()\n\n    def get_focals(self):\n        raise NotImplementedError()\n\n    def get_known_focal_mask(self):\n        raise NotImplementedError()\n\n    def get_principal_points(self):\n        raise NotImplementedError()\n\n    def get_conf(self, mode=None):\n        trf = self.conf_trf if mode is None else get_conf_trf(mode)\n        return [trf(c) for c in self.im_conf]\n\n    def get_im_poses(self):\n        raise NotImplementedError()\n\n    def _set_depthmap(self, idx, depth, force=False):\n        raise NotImplementedError()\n\n    def get_depthmaps(self, raw=False):\n        raise NotImplementedError()\n\n    def clean_pointcloud(self, **kw):\n        cams = inv(self.get_im_poses())\n        K = self.get_intrinsics()\n        depthmaps = self.get_depthmaps()\n        all_pts3d = self.get_pts3d()\n\n        new_im_confs = clean_pointcloud(self.im_conf, K, cams, depthmaps, all_pts3d, **kw)\n\n        for i, new_conf in enumerate(new_im_confs):\n            self.im_conf[i].data[:] = new_conf\n        return self\n\n    def forward(self, ret_details=False):\n        pw_poses = self.get_pw_poses()  # cam-to-world\n        pw_adapt = self.get_adaptors()\n        proj_pts3d = self.get_pts3d()\n        # pre-compute pixel weights\n        weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}\n        weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}\n\n        loss = 0\n        if ret_details:\n            details = -torch.ones((self.n_imgs, self.n_imgs))\n\n        for e, (i, j) in enumerate(self.edges):\n            i_j = edge_str(i, j)\n            # distance in image i and j\n            aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])\n            aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])\n            li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()\n            lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()\n            loss = loss + li + lj\n\n            if ret_details:\n                details[i, j] = li + lj\n        loss /= self.n_edges  # average over all pairs\n\n        if ret_details:\n            return loss, details\n        return loss\n\n    @torch.cuda.amp.autocast(enabled=False)\n    def compute_global_alignment(self, init=None, niter_PnP=10, **kw):\n        if init is None:\n            pass\n        elif init == 'msp' or init == 'mst':\n            init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)\n        elif init == 'known_poses':\n            init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,\n                                           niter_PnP=niter_PnP)\n        else:\n            raise ValueError(f'bad value for {init=}')\n\n        return global_alignment_loop(self, **kw)\n\n    @torch.no_grad()\n    def mask_sky(self):\n        res = deepcopy(self)\n        for i in range(self.n_imgs):\n            sky = segment_sky(self.imgs[i])\n            res.im_conf[i][sky] = 0\n        return res\n\n    def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):\n        viz = SceneViz()\n        if self.imgs is None:\n            colors = np.random.randint(0, 256, size=(self.n_imgs, 3))\n            colors = list(map(tuple, colors.tolist()))\n            for n in range(self.n_imgs):\n                viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])\n        else:\n            viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())\n            colors = np.random.randint(256, size=(self.n_imgs, 3))\n\n        # camera poses\n        im_poses = to_numpy(self.get_im_poses())\n        if cam_size is None:\n            cam_size = auto_cam_size(im_poses)\n        viz.add_cameras(im_poses, self.get_focals(), colors=colors,\n                        images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)\n        if show_pw_cams:\n            pw_poses = self.get_pw_poses()\n            viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)\n\n            if show_pw_pts3d:\n                pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]\n                viz.add_pointcloud(pts, (128, 0, 128))\n\n        viz.show(**kw)\n        return viz\n\n\ndef global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):\n    params = [p for p in net.parameters() if p.requires_grad]\n    if not params:\n        return net\n\n    verbose = net.verbose\n    if verbose:\n        print('Global alignement - optimizing for:')\n        print([name for name, value in net.named_parameters() if value.requires_grad])\n\n    lr_base = lr\n    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))\n\n    loss = float('inf')\n    if verbose:\n        with tqdm.tqdm(total=niter) as bar:\n            while bar.n < bar.total:\n                loss, lr = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)\n                bar.set_postfix_str(f'{lr=:g} loss={loss:g}')\n                bar.update()\n    else:\n        for n in range(niter):\n            loss, _ = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)\n    return loss\n\n\ndef global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):\n    t = cur_iter / niter\n    if schedule == 'cosine':\n        lr = cosine_schedule(t, lr_base, lr_min)\n    elif schedule == 'linear':\n        lr = linear_schedule(t, lr_base, lr_min)\n    else:\n        raise ValueError(f'bad lr {schedule=}')\n    adjust_learning_rate_by_lr(optimizer, lr)\n    optimizer.zero_grad()\n    loss = net()\n    loss.backward()\n    optimizer.step()\n\n    return float(loss), lr\n\n\n@torch.no_grad()\ndef clean_pointcloud( im_confs, K, cams, depthmaps, all_pts3d, \n                      tol=0.001, bad_conf=0, dbg=()):\n    \"\"\" Method: \n    1) express all 3d points in each camera coordinate frame\n    2) if they're in front of a depthmap --> then lower their confidence\n    \"\"\"\n    assert len(im_confs) == len(cams) == len(K) == len(depthmaps) == len(all_pts3d)\n    assert 0 <= tol < 1\n    res = [c.clone() for c in im_confs]\n\n    # reshape appropriately\n    all_pts3d = [p.view(*c.shape,3) for p,c in zip(all_pts3d, im_confs)]\n    depthmaps = [d.view(*c.shape) for d,c in zip(depthmaps, im_confs)]\n    \n    for i, pts3d in enumerate(all_pts3d):\n        for j in range(len(all_pts3d)):\n            if i == j: continue\n\n            # project 3dpts in other view\n            proj = geotrf(cams[j], pts3d)\n            proj_depth = proj[:,:,2]\n            u,v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)\n\n            # check which points are actually in the visible cone\n            H, W = im_confs[j].shape\n            msk_i = (proj_depth > 0) & (0 <= u) & (u < W) & (0 <= v) & (v < H)\n            msk_j = v[msk_i], u[msk_i]\n\n            # find bad points = those in front but less confident\n            bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]) & (res[i][msk_i] < res[j][msk_j])\n\n            bad_msk_i = msk_i.clone()\n            bad_msk_i[msk_i] = bad_points\n            res[i][bad_msk_i] = res[i][bad_msk_i].clip_(max=bad_conf)\n\n    return res\n"
  },
  {
    "path": "dust3r/cloud_opt/commons.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utility functions for global alignment\n# --------------------------------------------------------\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\n\ndef edge_str(i, j):\n    return f'{i}_{j}'\n\n\ndef i_j_ij(ij):\n    return edge_str(*ij), ij\n\n\ndef edge_conf(conf_i, conf_j, edge):\n    return float(conf_i[edge].mean() * conf_j[edge].mean())\n\n\ndef compute_edge_scores(edges, conf_i, conf_j):\n    return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}\n\n\ndef NoGradParamDict(x):\n    assert isinstance(x, dict)\n    return nn.ParameterDict(x).requires_grad_(False)\n\n\ndef get_imshapes(edges, pred_i, pred_j):\n    n_imgs = max(max(e) for e in edges) + 1\n    imshapes = [None] * n_imgs\n    for e, (i, j) in enumerate(edges):\n        shape_i = tuple(pred_i[e].shape[0:2])\n        shape_j = tuple(pred_j[e].shape[0:2])\n        if imshapes[i]:\n            assert imshapes[i] == shape_i, f'incorrect shape for image {i}'\n        if imshapes[j]:\n            assert imshapes[j] == shape_j, f'incorrect shape for image {j}'\n        imshapes[i] = shape_i\n        imshapes[j] = shape_j\n    return imshapes\n\n\ndef get_conf_trf(mode):\n    if mode == 'log':\n        def conf_trf(x): return x.log()\n    elif mode == 'sqrt':\n        def conf_trf(x): return x.sqrt()\n    elif mode == 'm1':\n        def conf_trf(x): return x-1\n    elif mode in ('id', 'none'):\n        def conf_trf(x): return x\n    else:\n        raise ValueError(f'bad mode for {mode=}')\n    return conf_trf\n\n\ndef l2_dist(a, b, weight):\n    return ((a - b).square().sum(dim=-1) * weight)\n\n\ndef l1_dist(a, b, weight):\n    return ((a - b).norm(dim=-1) * weight)\n\n\nALL_DISTS = dict(l1=l1_dist, l2=l2_dist)\n\n\ndef signed_log1p(x):\n    sign = torch.sign(x)\n    return sign * torch.log1p(torch.abs(x))\n\n\ndef signed_expm1(x):\n    sign = torch.sign(x)\n    return sign * torch.expm1(torch.abs(x))\n\n\ndef cosine_schedule(t, lr_start, lr_end):\n    assert 0 <= t <= 1\n    return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2\n\n\ndef linear_schedule(t, lr_start, lr_end):\n    assert 0 <= t <= 1\n    return lr_start + (lr_end - lr_start) * t\n"
  },
  {
    "path": "dust3r/cloud_opt/init_im_poses.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Initialization functions for global alignment\n# --------------------------------------------------------\nfrom functools import cache\n\nimport numpy as np\nimport scipy.sparse as sp\nimport torch\nimport cv2\nimport roma\nfrom tqdm import tqdm\n\nfrom dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses\nfrom dust3r.post_process import estimate_focal_knowing_depth\nfrom dust3r.viz import to_numpy\n\nfrom dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores\n\n\n@torch.no_grad()\ndef init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):\n    device = self.device\n\n    # indices of known poses\n    nkp, known_poses_msk, known_poses = get_known_poses(self)\n    assert nkp == self.n_imgs, 'not all poses are known'\n\n    # get all focals\n    nkf, _, im_focals = get_known_focals(self)\n    assert nkf == self.n_imgs\n    im_pp = self.get_principal_points()\n\n    best_depthmaps = {}\n    # init all pairwise poses\n    for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):\n        i_j = edge_str(i, j)\n\n        # find relative pose for this pair\n        P1 = torch.eye(4, device=device)\n        msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)\n        _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),\n                         pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)\n\n        # align the two predicted camera with the two gt cameras\n        s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])\n        # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1\n        # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])\n        self._set_pose(self.pw_poses, e, R, T, scale=s)\n\n        # remember if this is a good depthmap\n        score = float(self.conf_i[i_j].mean())\n        if score > best_depthmaps.get(i, (0,))[0]:\n            best_depthmaps[i] = score, i_j, s\n\n    # init all image poses\n    for n in range(self.n_imgs):\n        assert known_poses_msk[n]\n        _, i_j, scale = best_depthmaps[n]\n        depth = self.pred_i[i_j][:, :, 2]\n        self._set_depthmap(n, depth * scale)\n\n\n@torch.no_grad()\ndef init_minimum_spanning_tree(self, **kw):\n    \"\"\" Init all camera poses (image-wise and pairwise poses) given\n        an initial set of pairwise estimations.\n    \"\"\"\n    device = self.device\n    pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,\n                                                          self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,\n                                                          device, has_im_poses=self.has_im_poses, verbose=self.verbose,\n                                                          **kw)\n\n    return init_from_pts3d(self, pts3d, im_focals, im_poses)\n\n\ndef init_from_pts3d(self, pts3d, im_focals, im_poses):\n    # init poses\n    nkp, known_poses_msk, known_poses = get_known_poses(self)\n    if nkp == 1:\n        raise NotImplementedError(\"Would be simpler to just align everything afterwards on the single known pose\")\n    elif nkp > 1:\n        # global rigid SE3 alignment\n        s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])\n        trf = sRT_to_4x4(s, R, T, device=known_poses.device)\n\n        # rotate everything\n        im_poses = trf @ im_poses\n        im_poses[:, :3, :3] /= s  # undo scaling on the rotation part\n        for img_pts3d in pts3d:\n            img_pts3d[:] = geotrf(trf, img_pts3d)\n\n    # set all pairwise poses\n    for e, (i, j) in enumerate(self.edges):\n        i_j = edge_str(i, j)\n        # compute transform that goes from cam to world\n        s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])\n        self._set_pose(self.pw_poses, e, R, T, scale=s)\n\n    # take into account the scale normalization\n    s_factor = self.get_pw_norm_scale_factor()\n    im_poses[:, :3, 3] *= s_factor  # apply downscaling factor\n    for img_pts3d in pts3d:\n        img_pts3d *= s_factor\n\n    # init all image poses\n    if self.has_im_poses:\n        for i in range(self.n_imgs):\n            cam2world = im_poses[i]\n            depth = geotrf(inv(cam2world), pts3d[i])[..., 2]\n            self._set_depthmap(i, depth)\n            self._set_pose(self.im_poses, i, cam2world)\n            if im_focals[i] is not None:\n                self._set_focal(i, im_focals[i])\n\n    if self.verbose:\n        print(' init loss =', float(self()))\n\n\ndef minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,\n                          device, has_im_poses=True, niter_PnP=10, verbose=True):\n    n_imgs = len(imshapes)\n    sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))\n    msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()\n\n    # temp variable to store 3d points\n    pts3d = [None] * len(imshapes)\n\n    todo = sorted(zip(-msp.data, msp.row, msp.col))  # sorted edges\n    im_poses = [None] * n_imgs\n    im_focals = [None] * n_imgs\n\n    # init with strongest edge\n    score, i, j = todo.pop()\n    if verbose:\n        print(f' init edge ({i}*,{j}*) {score=}')\n    i_j = edge_str(i, j)\n    pts3d[i] = pred_i[i_j].clone()\n    pts3d[j] = pred_j[i_j].clone()\n    done = {i, j}\n    if has_im_poses:\n        im_poses[i] = torch.eye(4, device=device)\n        im_focals[i] = estimate_focal(pred_i[i_j])\n\n    # set initial pointcloud based on pairwise graph\n    msp_edges = [(i, j)]\n    while todo:\n        # each time, predict the next one\n        score, i, j = todo.pop()\n\n        if im_focals[i] is None:\n            im_focals[i] = estimate_focal(pred_i[i_j])\n\n        if i in done:\n            if verbose:\n                print(f' init edge ({i},{j}*) {score=}')\n            assert j not in done\n            # align pred[i] with pts3d[i], and then set j accordingly\n            i_j = edge_str(i, j)\n            s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])\n            trf = sRT_to_4x4(s, R, T, device)\n            pts3d[j] = geotrf(trf, pred_j[i_j])\n            done.add(j)\n            msp_edges.append((i, j))\n\n            if has_im_poses and im_poses[i] is None:\n                im_poses[i] = sRT_to_4x4(1, R, T, device)\n\n        elif j in done:\n            if verbose:\n                print(f' init edge ({i}*,{j}) {score=}')\n            assert i not in done\n            i_j = edge_str(i, j)\n            s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])\n            trf = sRT_to_4x4(s, R, T, device)\n            pts3d[i] = geotrf(trf, pred_i[i_j])\n            done.add(i)\n            msp_edges.append((i, j))\n\n            if has_im_poses and im_poses[i] is None:\n                im_poses[i] = sRT_to_4x4(1, R, T, device)\n        else:\n            # let's try again later\n            todo.insert(0, (score, i, j))\n\n    if has_im_poses:\n        # complete all missing informations\n        pair_scores = list(sparse_graph.values())  # already negative scores: less is best\n        edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]\n        for i, j in edges_from_best_to_worse.tolist():\n            if im_focals[i] is None:\n                im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])\n\n        for i in range(n_imgs):\n            if im_poses[i] is None:\n                msk = im_conf[i] > min_conf_thr\n                res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)\n                if res:\n                    im_focals[i], im_poses[i] = res\n            if im_poses[i] is None:\n                im_poses[i] = torch.eye(4, device=device)\n        im_poses = torch.stack(im_poses)\n    else:\n        im_poses = im_focals = None\n\n    return pts3d, msp_edges, im_focals, im_poses\n\n\ndef dict_to_sparse_graph(dic):\n    n_imgs = max(max(e) for e in dic) + 1\n    res = sp.dok_array((n_imgs, n_imgs))\n    for edge, value in dic.items():\n        res[edge] = value\n    return res\n\n\ndef rigid_points_registration(pts1, pts2, conf):\n    R, T, s = roma.rigid_points_registration(\n        pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)\n    return s, R, T  # return un-scaled (R, T)\n\n\ndef sRT_to_4x4(scale, R, T, device):\n    trf = torch.eye(4, device=device)\n    trf[:3, :3] = R * scale\n    trf[:3, 3] = T.ravel()  # doesn't need scaling\n    return trf\n\n\ndef estimate_focal(pts3d_i, pp=None):\n    if pp is None:\n        H, W, THREE = pts3d_i.shape\n        assert THREE == 3\n        pp = torch.tensor((W/2, H/2), device=pts3d_i.device)\n    focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()\n    return float(focal)\n\n\n@cache\ndef pixel_grid(H, W):\n    return np.mgrid[:W, :H].T.astype(np.float32)\n\n\ndef fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):\n    # extract camera poses and focals with RANSAC-PnP\n    if msk.sum() < 4:\n        return None  # we need at least 4 points for PnP\n    pts3d, msk = map(to_numpy, (pts3d, msk))\n\n    H, W, THREE = pts3d.shape\n    assert THREE == 3\n    pixels = pixel_grid(H, W)\n\n    if focal is None:\n        S = max(W, H)\n        tentative_focals = np.geomspace(S/2, S*3, 21)\n    else:\n        tentative_focals = [focal]\n\n    if pp is None:\n        pp = (W/2, H/2)\n    else:\n        pp = to_numpy(pp)\n\n    best = 0,\n    for focal in tentative_focals:\n        K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])\n\n        success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,\n                                                    iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)\n        if not success:\n            continue\n\n        score = len(inliers)\n        if success and score > best[0]:\n            best = score, R, T, focal\n\n    if not best[0]:\n        return None\n\n    _, R, T, best_focal = best\n    R = cv2.Rodrigues(R)[0]  # world to cam\n    R, T = map(torch.from_numpy, (R, T))\n    return best_focal, inv(sRT_to_4x4(1, R, T, device))  # cam to world\n\n\ndef get_known_poses(self):\n    if self.has_im_poses:\n        known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])\n        known_poses = self.get_im_poses()\n        return known_poses_msk.sum(), known_poses_msk, known_poses\n    else:\n        return 0, None, None\n\n\ndef get_known_focals(self):\n    if self.has_im_poses:\n        known_focal_msk = self.get_known_focal_mask()\n        known_focals = self.get_focals()\n        return known_focal_msk.sum(), known_focal_msk, known_focals\n    else:\n        return 0, None, None\n\n\ndef align_multiple_poses(src_poses, target_poses):\n    N = len(src_poses)\n    assert src_poses.shape == target_poses.shape == (N, 4, 4)\n\n    def center_and_z(poses):\n        eps = get_med_dist_between_poses(poses) / 100\n        return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))\n    R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)\n    return s, R, T\n"
  },
  {
    "path": "dust3r/cloud_opt/modular_optimizer.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics\n# --------------------------------------------------------\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom dust3r.cloud_opt.base_opt import BasePCOptimizer\nfrom dust3r.utils.geometry import geotrf\nfrom dust3r.utils.device import to_cpu, to_numpy\nfrom dust3r.utils.geometry import depthmap_to_pts3d\n\n\nclass ModularPointCloudOptimizer (BasePCOptimizer):\n    \"\"\" Optimize a global scene, given a list of pairwise observations.\n    Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)\n    Graph node: images\n    Graph edges: observations = (pred1, pred2)\n    \"\"\"\n\n    def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.has_im_poses = True  # by definition of this class\n        self.focal_brake = focal_brake\n\n        # adding thing to optimize\n        self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes)  # log(depth)\n        self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs))  # camera poses\n        default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]\n        self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [\n                                          f]) for f in default_focals)  # camera intrinsics\n        self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs))  # camera intrinsics\n        self.im_pp.requires_grad_(optimize_pp)\n\n    def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world\n        if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:\n            known_poses = [known_poses]\n        for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):\n            if self.verbose:\n                print(f' (setting pose #{idx} = {pose[:3,3]})')\n            self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))\n\n        # normalize scale if there's less than 1 known pose\n        n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)\n        self.norm_pw_scale = (n_known_poses <= 1)\n\n    def preset_intrinsics(self, known_intrinsics, msk=None):\n        if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:\n            known_intrinsics = [known_intrinsics]\n        for K in known_intrinsics:\n            assert K.shape == (3, 3)\n        self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)\n        self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)\n\n    def preset_focal(self, known_focals, msk=None):\n        for idx, focal in zip(self._get_msk_indices(msk), known_focals):\n            if self.verbose:\n                print(f' (setting focal #{idx} = {focal})')\n            self._no_grad(self._set_focal(idx, focal, force=True))\n\n    def preset_principal_point(self, known_pp, msk=None):\n        for idx, pp in zip(self._get_msk_indices(msk), known_pp):\n            if self.verbose:\n                print(f' (setting principal point #{idx} = {pp})')\n            self._no_grad(self._set_principal_point(idx, pp, force=True))\n\n    def _no_grad(self, tensor):\n        return tensor.requires_grad_(False)\n\n    def _get_msk_indices(self, msk):\n        if msk is None:\n            return range(self.n_imgs)\n        elif isinstance(msk, int):\n            return [msk]\n        elif isinstance(msk, (tuple, list)):\n            return self._get_msk_indices(np.array(msk))\n        elif msk.dtype in (bool, torch.bool, np.bool_):\n            assert len(msk) == self.n_imgs\n            return np.where(msk)[0]\n        elif np.issubdtype(msk.dtype, np.integer):\n            return msk\n        else:\n            raise ValueError(f'bad {msk=}')\n\n    def _set_focal(self, idx, focal, force=False):\n        param = self.im_focals[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = self.focal_brake * np.log(focal)\n        return param\n\n    def get_focals(self):\n        log_focals = torch.stack(list(self.im_focals), dim=0)\n        return (log_focals / self.focal_brake).exp()\n\n    def _set_principal_point(self, idx, pp, force=False):\n        param = self.im_pp[idx]\n        H, W = self.imshapes[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10\n        return param\n\n    def get_principal_points(self):\n        return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])\n\n    def get_intrinsics(self):\n        K = torch.zeros((self.n_imgs, 3, 3), device=self.device)\n        focals = self.get_focals().view(self.n_imgs, -1)\n        K[:, 0, 0] = focals[:, 0]\n        K[:, 1, 1] = focals[:, -1]\n        K[:, :2, 2] = self.get_principal_points()\n        K[:, 2, 2] = 1\n        return K\n\n    def get_im_poses(self):  # cam to world\n        cam2world = self._get_poses(torch.stack(list(self.im_poses)))\n        return cam2world\n\n    def _set_depthmap(self, idx, depth, force=False):\n        param = self.im_depthmaps[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = depth.log().nan_to_num(neginf=0)\n        return param\n\n    def get_depthmaps(self):\n        return [d.exp() for d in self.im_depthmaps]\n\n    def depth_to_pts3d(self):\n        # Get depths and  projection params if not provided\n        focals = self.get_focals()\n        pp = self.get_principal_points()\n        im_poses = self.get_im_poses()\n        depth = self.get_depthmaps()\n\n        # convert focal to (1,2,H,W) constant field\n        def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])\n        # get pointmaps in camera frame\n        rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]\n        # project to world frame\n        return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]\n\n    def get_pts3d(self):\n        return self.depth_to_pts3d()\n"
  },
  {
    "path": "dust3r/cloud_opt/optimizer.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Main class for the implementation of the global alignment\n# --------------------------------------------------------\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom dust3r.cloud_opt.base_opt import BasePCOptimizer\nfrom dust3r.utils.geometry import xy_grid, geotrf\nfrom dust3r.utils.device import to_cpu, to_numpy\n\n\nclass PointCloudOptimizer(BasePCOptimizer):\n    \"\"\" Optimize a global scene, given a list of pairwise observations.\n    Graph node: images\n    Graph edges: observations = (pred1, pred2)\n    \"\"\"\n\n    def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.has_im_poses = True  # by definition of this class\n        self.focal_break = focal_break\n\n        # adding thing to optimize\n        self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes)  # log(depth)\n        self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs))  # camera poses\n        self.im_focals = nn.ParameterList(torch.FloatTensor(\n            [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes)  # camera intrinsics\n        self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs))  # camera intrinsics\n        self.im_pp.requires_grad_(optimize_pp)\n\n        self.imshape = self.imshapes[0]\n        im_areas = [h*w for h, w in self.imshapes]\n        self.max_area = max(im_areas)\n\n        # adding thing to optimize\n        self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)\n        self.im_poses = ParameterStack(self.im_poses, is_param=True)\n        self.im_focals = ParameterStack(self.im_focals, is_param=True)\n        self.im_pp = ParameterStack(self.im_pp, is_param=True)\n        self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))\n        self.register_buffer('_grid', ParameterStack(\n            [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))\n\n        # pre-compute pixel weights\n        self.register_buffer('_weight_i', ParameterStack(\n            [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))\n        self.register_buffer('_weight_j', ParameterStack(\n            [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))\n\n        # precompute aa\n        self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))\n        self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))\n        self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))\n        self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))\n        self.total_area_i = sum([im_areas[i] for i, j in self.edges])\n        self.total_area_j = sum([im_areas[j] for i, j in self.edges])\n\n    def _check_all_imgs_are_selected(self, msk):\n        assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'\n\n    def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world\n        self._check_all_imgs_are_selected(pose_msk)\n\n        if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:\n            known_poses = [known_poses]\n        for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):\n            if self.verbose:\n                print(f' (setting pose #{idx} = {pose[:3,3]})')\n            self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))\n\n        # normalize scale if there's less than 1 known pose\n        n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)\n        self.norm_pw_scale = (n_known_poses <= 1)\n\n        self.im_poses.requires_grad_(False)\n        self.norm_pw_scale = False\n\n    def preset_focal(self, known_focals, msk=None):\n        self._check_all_imgs_are_selected(msk)\n\n        for idx, focal in zip(self._get_msk_indices(msk), known_focals):\n            if self.verbose:\n                print(f' (setting focal #{idx} = {focal})')\n            self._no_grad(self._set_focal(idx, focal))\n\n        self.im_focals.requires_grad_(False)\n\n    def preset_principal_point(self, known_pp, msk=None):\n        self._check_all_imgs_are_selected(msk)\n\n        for idx, pp in zip(self._get_msk_indices(msk), known_pp):\n            if self.verbose:\n                print(f' (setting principal point #{idx} = {pp})')\n            self._no_grad(self._set_principal_point(idx, pp))\n\n        self.im_pp.requires_grad_(False)\n\n    def _get_msk_indices(self, msk):\n        if msk is None:\n            return range(self.n_imgs)\n        elif isinstance(msk, int):\n            return [msk]\n        elif isinstance(msk, (tuple, list)):\n            return self._get_msk_indices(np.array(msk))\n        elif msk.dtype in (bool, torch.bool, np.bool_):\n            assert len(msk) == self.n_imgs\n            return np.where(msk)[0]\n        elif np.issubdtype(msk.dtype, np.integer):\n            return msk\n        else:\n            raise ValueError(f'bad {msk=}')\n\n    def _no_grad(self, tensor):\n        assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'\n\n    def _set_focal(self, idx, focal, force=False):\n        param = self.im_focals[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = self.focal_break * np.log(focal)\n        return param\n\n    def get_focals(self):\n        log_focals = torch.stack(list(self.im_focals), dim=0)\n        return (log_focals / self.focal_break).exp()\n\n    def get_known_focal_mask(self):\n        return torch.tensor([not (p.requires_grad) for p in self.im_focals])\n\n    def _set_principal_point(self, idx, pp, force=False):\n        param = self.im_pp[idx]\n        H, W = self.imshapes[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10\n        return param\n\n    def get_principal_points(self):\n        return self._pp + 10 * self.im_pp\n\n    def get_intrinsics(self):\n        K = torch.zeros((self.n_imgs, 3, 3), device=self.device)\n        focals = self.get_focals().flatten()\n        K[:, 0, 0] = K[:, 1, 1] = focals\n        K[:, :2, 2] = self.get_principal_points()\n        K[:, 2, 2] = 1\n        return K\n\n    def get_im_poses(self):  # cam to world\n        cam2world = self._get_poses(self.im_poses)\n        return cam2world\n\n    def _set_depthmap(self, idx, depth, force=False):\n        depth = _ravel_hw(depth, self.max_area)\n\n        param = self.im_depthmaps[idx]\n        if param.requires_grad or force:  # can only init a parameter not already initialized\n            param.data[:] = depth.log().nan_to_num(neginf=0)\n        return param\n\n    def get_depthmaps(self, raw=False):\n        res = self.im_depthmaps.exp()\n        if not raw:\n            res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]\n        return res\n\n    def depth_to_pts3d(self):\n        # Get depths and  projection params if not provided\n        focals = self.get_focals()\n        pp = self.get_principal_points()\n        im_poses = self.get_im_poses()\n        depth = self.get_depthmaps(raw=True)\n\n        # get pointmaps in camera frame\n        rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)\n        # project to world frame\n        return geotrf(im_poses, rel_ptmaps)\n\n    def get_pts3d(self, raw=False):\n        res = self.depth_to_pts3d()\n        if not raw:\n            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]\n        return res\n\n    def forward(self):\n        pw_poses = self.get_pw_poses()  # cam-to-world\n        pw_adapt = self.get_adaptors().unsqueeze(1)\n        proj_pts3d = self.get_pts3d(raw=True)\n\n        # rotate pairwise prediction according to pw_poses\n        aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)\n        aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)\n\n        # compute the less\n        li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i\n        lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j\n\n        return li + lj\n\n\ndef _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):\n    pp = pp.unsqueeze(1)\n    focal = focal.unsqueeze(1)\n    assert focal.shape == (len(depth), 1, 1)\n    assert pp.shape == (len(depth), 1, 2)\n    assert pixel_grid.shape == depth.shape + (2,)\n    depth = depth.unsqueeze(-1)\n    return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)\n\n\ndef ParameterStack(params, keys=None, is_param=None, fill=0):\n    if keys is not None:\n        params = [params[k] for k in keys]\n\n    if fill > 0:\n        params = [_ravel_hw(p, fill) for p in params]\n\n    requires_grad = params[0].requires_grad\n    assert all(p.requires_grad == requires_grad for p in params)\n\n    params = torch.stack(list(params)).float().detach()\n    if is_param or requires_grad:\n        params = nn.Parameter(params)\n        params.requires_grad_(requires_grad)\n    return params\n\n\ndef _ravel_hw(tensor, fill=0):\n    # ravel H,W\n    tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])\n\n    if len(tensor) < fill:\n        tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))\n    return tensor\n\n\ndef acceptable_focal_range(H, W, minf=0.5, maxf=3.5):\n    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  # size / 1.1547005383792515\n    return minf*focal_base, maxf*focal_base\n\n\ndef apply_mask(img, msk):\n    img = img.copy()\n    img[msk] = 0\n    return img\n"
  },
  {
    "path": "dust3r/cloud_opt/pair_viewer.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dummy optimizer for visualizing pairs\n# --------------------------------------------------------\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport cv2\n\nfrom dust3r.cloud_opt.base_opt import BasePCOptimizer\nfrom dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates\nfrom dust3r.cloud_opt.commons import edge_str\nfrom dust3r.post_process import estimate_focal_knowing_depth\n\n\nclass PairViewer (BasePCOptimizer):\n    \"\"\"\n    This a Dummy Optimizer.\n    To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        assert self.is_symmetrized and self.n_edges == 2\n        self.has_im_poses = True\n\n        # compute all parameters directly from raw input\n        self.focals = []\n        self.pp = []\n        rel_poses = []\n        confs = []\n        for i in range(self.n_imgs):\n            conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())\n            if self.verbose:\n                print(f'  - {conf=:.3} for edge {i}-{1-i}')\n            confs.append(conf)\n\n            H, W = self.imshapes[i]\n            pts3d = self.pred_i[edge_str(i, 1-i)]\n            pp = torch.tensor((W/2, H/2))\n            focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))\n            self.focals.append(focal)\n            self.pp.append(pp)\n\n            # estimate the pose of pts1 in image 2\n            pixels = np.mgrid[:W, :H].T.astype(np.float32)\n            pts3d = self.pred_j[edge_str(1-i, i)].numpy()\n            assert pts3d.shape[:2] == (H, W)\n            msk = self.get_masks()[i].numpy()\n            K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])\n\n            try:\n                res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,\n                                         iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)\n                success, R, T, inliers = res\n                assert success\n\n                R = cv2.Rodrigues(R)[0]  # world to cam\n                pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]])  # cam to world\n            except:\n                pose = np.eye(4)\n            rel_poses.append(torch.from_numpy(pose.astype(np.float32)))\n\n        # let's use the pair with the most confidence\n        if confs[0] > confs[1]:\n            # ptcloud is expressed in camera1\n            self.im_poses = [torch.eye(4), rel_poses[1]]  # I, cam2-to-cam1\n            self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]\n        else:\n            # ptcloud is expressed in camera2\n            self.im_poses = [rel_poses[0], torch.eye(4)]  # I, cam1-to-cam2\n            self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]\n\n        self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)\n        self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)\n        self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)\n        self.depth = nn.ParameterList(self.depth)\n        for p in self.parameters():\n            p.requires_grad = False\n\n    def _set_depthmap(self, idx, depth, force=False):\n        if self.verbose:\n            print('_set_depthmap is ignored in PairViewer')\n        return\n\n    def get_depthmaps(self, raw=False):\n        depth = [d.to(self.device) for d in self.depth]\n        return depth\n\n    def _set_focal(self, idx, focal, force=False):\n        self.focals[idx] = focal\n\n    def get_focals(self):\n        return self.focals\n\n    def get_known_focal_mask(self):\n        return torch.tensor([not (p.requires_grad) for p in self.focals])\n\n    def get_principal_points(self):\n        return self.pp\n\n    def get_intrinsics(self):\n        focals = self.get_focals()\n        pps = self.get_principal_points()\n        K = torch.zeros((len(focals), 3, 3), device=self.device)\n        for i in range(len(focals)):\n            K[i, 0, 0] = K[i, 1, 1] = focals[i]\n            K[i, :2, 2] = pps[i]\n            K[i, 2, 2] = 1\n        return K\n\n    def get_im_poses(self):\n        return self.im_poses\n\n    def depth_to_pts3d(self):\n        pts3d = []\n        for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):\n            pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),\n                                                             intrinsics.cpu().numpy(),\n                                                             im_pose.cpu().numpy())\n            pts3d.append(torch.from_numpy(pts).to(device=self.device))\n        return pts3d\n\n    def forward(self):\n        return float('nan')\n"
  },
  {
    "path": "dust3r/datasets/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\nfrom .utils.transforms import *\nfrom .base.batched_sampler import BatchedRandomSampler  # noqa\nfrom .arkitscenes import ARKitScenes  # noqa\nfrom .blendedmvs import BlendedMVS  # noqa\nfrom .co3d import Co3d  # noqa\nfrom .habitat import Habitat  # noqa\nfrom .megadepth import MegaDepth  # noqa\nfrom .scannetpp import ScanNetpp  # noqa\nfrom .staticthings3d import StaticThings3D  # noqa\nfrom .waymo import Waymo  # noqa\nfrom .wildrgbd import WildRGBD  # noqa\n\n\ndef get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):\n    import torch\n    from croco.utils.misc import get_world_size, get_rank\n\n    # pytorch dataset\n    if isinstance(dataset, str):\n        dataset = eval(dataset)\n\n    world_size = get_world_size()\n    rank = get_rank()\n\n    try:\n        sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,\n                                       rank=rank, drop_last=drop_last)\n    except (AttributeError, NotImplementedError):\n        # not avail for this dataset\n        if torch.distributed.is_initialized():\n            sampler = torch.utils.data.DistributedSampler(\n                dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last\n            )\n        elif shuffle:\n            sampler = torch.utils.data.RandomSampler(dataset)\n        else:\n            sampler = torch.utils.data.SequentialSampler(dataset)\n\n    data_loader = torch.utils.data.DataLoader(\n        dataset,\n        sampler=sampler,\n        batch_size=batch_size,\n        num_workers=num_workers,\n        pin_memory=pin_mem,\n        drop_last=drop_last,\n    )\n\n    return data_loader\n"
  },
  {
    "path": "dust3r/datasets/arkitscenes.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed arkitscenes\n# dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license\n# See datasets_preprocess/preprocess_arkitscenes.py\n# --------------------------------------------------------\nimport os.path as osp\nimport cv2\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass ARKitScenes(BaseStereoViewDataset):\n    def __init__(self, *args, split, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        if split == \"train\":\n            self.split = \"Training\"\n        elif split == \"test\":\n            self.split = \"Test\"\n        else:\n            raise ValueError(\"\")\n\n        self.loaded_data = self._load_data(self.split)\n\n    def _load_data(self, split):\n        with np.load(osp.join(self.ROOT, split, 'all_metadata.npz')) as data:\n            self.scenes = data['scenes']\n            self.sceneids = data['sceneids']\n            self.images = data['images']\n            self.intrinsics = data['intrinsics'].astype(np.float32)\n            self.trajectories = data['trajectories'].astype(np.float32)\n            self.pairs = data['pairs'][:, :2].astype(int)\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def _get_views(self, idx, resolution, rng):\n\n        image_idx1, image_idx2 = self.pairs[idx]\n\n        views = []\n        for view_idx in [image_idx1, image_idx2]:\n            scene_id = self.sceneids[view_idx]\n            scene_dir = osp.join(self.ROOT, self.split, self.scenes[scene_id])\n\n            intrinsics = self.intrinsics[view_idx]\n            camera_pose = self.trajectories[view_idx]\n            basename = self.images[view_idx]\n\n            # Load RGB image\n            rgb_image = imread_cv2(osp.join(scene_dir, 'vga_wide', basename.replace('.png', '.jpg')))\n            # Load depthmap\n            depthmap = imread_cv2(osp.join(scene_dir, 'lowres_depth', basename), cv2.IMREAD_UNCHANGED)\n            depthmap = depthmap.astype(np.float32) / 1000\n            depthmap[~np.isfinite(depthmap)] = 0  # invalid\n\n            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)\n\n            views.append(dict(\n                img=rgb_image,\n                depthmap=depthmap.astype(np.float32),\n                camera_pose=camera_pose.astype(np.float32),\n                camera_intrinsics=intrinsics.astype(np.float32),\n                dataset='arkitscenes',\n                label=self.scenes[scene_id] + '_' + basename,\n                instance=f'{str(idx)}_{str(view_idx)}',\n            ))\n\n        return views\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = ARKitScenes(split='train', ROOT=\"data/arkitscenes_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/base/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "dust3r/datasets/base/base_stereo_view_dataset.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# base class for implementing datasets\n# --------------------------------------------------------\nimport PIL\nimport numpy as np\nimport torch\n\nfrom dust3r.datasets.base.easy_dataset import EasyDataset\nfrom dust3r.datasets.utils.transforms import ImgNorm\nfrom dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates\nimport dust3r.datasets.utils.cropping as cropping\n\n\nclass BaseStereoViewDataset (EasyDataset):\n    \"\"\" Define all basic options.\n\n    Usage:\n        class MyDataset (BaseStereoViewDataset):\n            def _get_views(self, idx, rng):\n                # overload here\n                views = []\n                views.append(dict(img=, ...))\n                return views\n    \"\"\"\n\n    def __init__(self, *,  # only keyword arguments\n                 split=None,\n                 resolution=None,  # square_size or (width, height) or list of [(width,height), ...]\n                 transform=ImgNorm,\n                 aug_crop=False,\n                 seed=None):\n        self.num_views = 2\n        self.split = split\n        self._set_resolutions(resolution)\n\n        if isinstance(transform, str):\n            transform = eval(transform)\n        self.transform = transform\n\n        self.aug_crop = aug_crop\n        self.seed = seed\n\n    def __len__(self):\n        return len(self.scenes)\n\n    def get_stats(self):\n        return f\"{len(self)} pairs\"\n\n    def __repr__(self):\n        resolutions_str = '[' + ';'.join(f'{w}x{h}' for w, h in self._resolutions) + ']'\n        return f\"\"\"{type(self).__name__}({self.get_stats()},\n            {self.split=},\n            {self.seed=},\n            resolutions={resolutions_str},\n            {self.transform=})\"\"\".replace('self.', '').replace('\\n', '').replace('   ', '')\n\n    def _get_views(self, idx, resolution, rng):\n        raise NotImplementedError()\n\n    def __getitem__(self, idx):\n        if isinstance(idx, tuple):\n            # the idx is specifying the aspect-ratio\n            idx, ar_idx = idx\n        else:\n            assert len(self._resolutions) == 1\n            ar_idx = 0\n\n        # set-up the rng\n        if self.seed:  # reseed for each __getitem__\n            self._rng = np.random.default_rng(seed=self.seed + idx)\n        elif not hasattr(self, '_rng'):\n            seed = torch.initial_seed()  # this is different for each dataloader process\n            self._rng = np.random.default_rng(seed=seed)\n\n        # over-loaded code\n        resolution = self._resolutions[ar_idx]  # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)\n        views = self._get_views(idx, resolution, self._rng)\n        assert len(views) == self.num_views\n\n        # check data-types\n        for v, view in enumerate(views):\n            assert 'pts3d' not in view, f\"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}\"\n            view['idx'] = (idx, ar_idx, v)\n\n            # encode the image\n            width, height = view['img'].size\n            view['true_shape'] = np.int32((height, width))\n            view['img'] = self.transform(view['img'])\n\n            assert 'camera_intrinsics' in view\n            if 'camera_pose' not in view:\n                view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)\n            else:\n                assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'\n            assert 'pts3d' not in view\n            assert 'valid_mask' not in view\n            assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'\n            pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)\n\n            view['pts3d'] = pts3d\n            view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)\n\n            # check all datatypes\n            for key, val in view.items():\n                res, err_msg = is_good_type(key, val)\n                assert res, f\"{err_msg} with {key}={val} for view {view_name(view)}\"\n            K = view['camera_intrinsics']\n\n        # last thing done!\n        for view in views:\n            # transpose to make sure all views are the same size\n            transpose_to_landscape(view)\n            # this allows to check whether the RNG is is the same state each time\n            view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')\n        return views\n\n    def _set_resolutions(self, resolutions):\n        assert resolutions is not None, 'undefined resolution'\n\n        if not isinstance(resolutions, list):\n            resolutions = [resolutions]\n\n        self._resolutions = []\n        for resolution in resolutions:\n            if isinstance(resolution, int):\n                width = height = resolution\n            else:\n                width, height = resolution\n            assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'\n            assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'\n            assert width >= height\n            self._resolutions.append((width, height))\n\n    def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):\n        \"\"\" This function:\n            - first downsizes the image with LANCZOS inteprolation,\n              which is better than bilinear interpolation in\n        \"\"\"\n        if not isinstance(image, PIL.Image.Image):\n            image = PIL.Image.fromarray(image)\n\n        # downscale with lanczos interpolation so that image.size == resolution\n        # cropping centered on the principal point\n        W, H = image.size\n        cx, cy = intrinsics[:2, 2].round().astype(int)\n        min_margin_x = min(cx, W - cx)\n        min_margin_y = min(cy, H - cy)\n        # assert min_margin_x > W/5, f'Bad principal point in view={info}'\n        # assert min_margin_y > H/5, f'Bad principal point in view={info}'\n        # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)\n        l, t = cx - min_margin_x, cy - min_margin_y\n        r, b = cx + min_margin_x, cy + min_margin_y\n        crop_bbox = (l, t, r, b)\n        image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)\n\n        # transpose the resolution if necessary\n        W, H = image.size  # new size\n        assert resolution[0] >= resolution[1]\n        if H > 1.1 * W:\n            # image is portrait mode\n            resolution = resolution[::-1]\n        elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:\n            # image is square, so we chose (portrait, landscape) randomly\n            if rng.integers(2):\n                resolution = resolution[::-1]\n\n        # high-quality Lanczos down-scaling\n        target_resolution = np.array(resolution)\n        if self.aug_crop > 1:\n            target_resolution += rng.integers(0, self.aug_crop)\n        image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)\n\n        # actual cropping (if necessary) with bilinear interpolation\n        intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)\n        crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)\n        image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)\n\n        return image, depthmap, intrinsics2\n\n\ndef is_good_type(key, v):\n    \"\"\" returns (is_good, err_msg) \n    \"\"\"\n    if isinstance(v, (str, int, tuple)):\n        return True, None\n    if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):\n        return False, f\"bad {v.dtype=}\"\n    return True, None\n\n\ndef view_name(view, batch_index=None):\n    def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x\n    db = sel(view['dataset'])\n    label = sel(view['label'])\n    instance = sel(view['instance'])\n    return f\"{db}/{label}/{instance}\"\n\n\ndef transpose_to_landscape(view):\n    height, width = view['true_shape']\n\n    if width < height:\n        # rectify portrait to landscape\n        assert view['img'].shape == (3, height, width)\n        view['img'] = view['img'].swapaxes(1, 2)\n\n        assert view['valid_mask'].shape == (height, width)\n        view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)\n\n        assert view['depthmap'].shape == (height, width)\n        view['depthmap'] = view['depthmap'].swapaxes(0, 1)\n\n        assert view['pts3d'].shape == (height, width, 3)\n        view['pts3d'] = view['pts3d'].swapaxes(0, 1)\n\n        # transpose x and y pixels\n        view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]\n"
  },
  {
    "path": "dust3r/datasets/base/batched_sampler.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Random sampling under a constraint\n# --------------------------------------------------------\nimport numpy as np\nimport torch\n\n\nclass BatchedRandomSampler:\n    \"\"\" Random sampling under a constraint: each sample in the batch has the same feature, \n    which is chosen randomly from a known pool of 'features' for each batch.\n\n    For instance, the 'feature' could be the image aspect-ratio.\n\n    The index returned is a tuple (sample_idx, feat_idx).\n    This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.\n    \"\"\"\n\n    def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):\n        self.batch_size = batch_size\n        self.pool_size = pool_size\n\n        self.len_dataset = N = len(dataset)\n        self.total_size = round_by(N, batch_size*world_size) if drop_last else N\n        assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'\n\n        # distributed sampler\n        self.world_size = world_size\n        self.rank = rank\n        self.epoch = None\n\n    def __len__(self):\n        return self.total_size // self.world_size\n\n    def set_epoch(self, epoch):\n        self.epoch = epoch\n\n    def __iter__(self):\n        # prepare RNG\n        if self.epoch is None:\n            assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'\n            seed = int(torch.empty((), dtype=torch.int64).random_().item())\n        else:\n            seed = self.epoch + 777\n        rng = np.random.default_rng(seed=seed)\n\n        # random indices (will restart from 0 if not drop_last)\n        sample_idxs = np.arange(self.total_size)\n        rng.shuffle(sample_idxs)\n\n        # random feat_idxs (same across each batch)\n        n_batches = (self.total_size+self.batch_size-1) // self.batch_size\n        feat_idxs = rng.integers(self.pool_size, size=n_batches)\n        feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))\n        feat_idxs = feat_idxs.ravel()[:self.total_size]\n\n        # put them together\n        idxs = np.c_[sample_idxs, feat_idxs]  # shape = (total_size, 2)\n\n        # Distributed sampler: we select a subset of batches\n        # make sure the slice for each node is aligned with batch_size\n        size_per_proc = self.batch_size * ((self.total_size + self.world_size *\n                                           self.batch_size-1) // (self.world_size * self.batch_size))\n        idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]\n\n        yield from (tuple(idx) for idx in idxs)\n\n\ndef round_by(total, multiple, up=False):\n    if up:\n        total = total + multiple-1\n    return (total//multiple) * multiple\n"
  },
  {
    "path": "dust3r/datasets/base/easy_dataset.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# A dataset base class that you can easily resize and combine.\n# --------------------------------------------------------\nimport numpy as np\nfrom dust3r.datasets.base.batched_sampler import BatchedRandomSampler\n\n\nclass EasyDataset:\n    \"\"\" a dataset that you can easily resize and combine.\n    Examples:\n    ---------\n        2 * dataset ==> duplicate each element 2x\n\n        10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)\n\n        dataset1 + dataset2 ==> concatenate datasets\n    \"\"\"\n\n    def __add__(self, other):\n        return CatDataset([self, other])\n\n    def __rmul__(self, factor):\n        return MulDataset(factor, self)\n\n    def __rmatmul__(self, factor):\n        return ResizedDataset(factor, self)\n\n    def set_epoch(self, epoch):\n        pass  # nothing to do by default\n\n    def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):\n        if not (shuffle):\n            raise NotImplementedError()  # cannot deal yet\n        num_of_aspect_ratios = len(self._resolutions)\n        return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)\n\n\nclass MulDataset (EasyDataset):\n    \"\"\" Artifically augmenting the size of a dataset.\n    \"\"\"\n    multiplicator: int\n\n    def __init__(self, multiplicator, dataset):\n        assert isinstance(multiplicator, int) and multiplicator > 0\n        self.multiplicator = multiplicator\n        self.dataset = dataset\n\n    def __len__(self):\n        return self.multiplicator * len(self.dataset)\n\n    def __repr__(self):\n        return f'{self.multiplicator}*{repr(self.dataset)}'\n\n    def __getitem__(self, idx):\n        if isinstance(idx, tuple):\n            idx, other = idx\n            return self.dataset[idx // self.multiplicator, other]\n        else:\n            return self.dataset[idx // self.multiplicator]\n\n    @property\n    def _resolutions(self):\n        return self.dataset._resolutions\n\n\nclass ResizedDataset (EasyDataset):\n    \"\"\" Artifically changing the size of a dataset.\n    \"\"\"\n    new_size: int\n\n    def __init__(self, new_size, dataset):\n        assert isinstance(new_size, int) and new_size > 0\n        self.new_size = new_size\n        self.dataset = dataset\n\n    def __len__(self):\n        return self.new_size\n\n    def __repr__(self):\n        size_str = str(self.new_size)\n        for i in range((len(size_str)-1) // 3):\n            sep = -4*i-3\n            size_str = size_str[:sep] + '_' + size_str[sep:]\n        return f'{size_str} @ {repr(self.dataset)}'\n\n    def set_epoch(self, epoch):\n        # this random shuffle only depends on the epoch\n        rng = np.random.default_rng(seed=epoch+777)\n\n        # shuffle all indices\n        perm = rng.permutation(len(self.dataset))\n\n        # rotary extension until target size is met\n        shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))\n        self._idxs_mapping = shuffled_idxs[:self.new_size]\n\n        assert len(self._idxs_mapping) == self.new_size\n\n    def __getitem__(self, idx):\n        assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'\n        if isinstance(idx, tuple):\n            idx, other = idx\n            return self.dataset[self._idxs_mapping[idx], other]\n        else:\n            return self.dataset[self._idxs_mapping[idx]]\n\n    @property\n    def _resolutions(self):\n        return self.dataset._resolutions\n\n\nclass CatDataset (EasyDataset):\n    \"\"\" Concatenation of several datasets \n    \"\"\"\n\n    def __init__(self, datasets):\n        for dataset in datasets:\n            assert isinstance(dataset, EasyDataset)\n        self.datasets = datasets\n        self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])\n\n    def __len__(self):\n        return self._cum_sizes[-1]\n\n    def __repr__(self):\n        # remove uselessly long transform\n        return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)\n\n    def set_epoch(self, epoch):\n        for dataset in self.datasets:\n            dataset.set_epoch(epoch)\n\n    def __getitem__(self, idx):\n        other = None\n        if isinstance(idx, tuple):\n            idx, other = idx\n\n        if not (0 <= idx < len(self)):\n            raise IndexError()\n\n        db_idx = np.searchsorted(self._cum_sizes, idx, 'right')\n        dataset = self.datasets[db_idx]\n        new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)\n\n        if other is not None:\n            new_idx = (new_idx, other)\n        return dataset[new_idx]\n\n    @property\n    def _resolutions(self):\n        resolutions = self.datasets[0]._resolutions\n        for dataset in self.datasets[1:]:\n            assert tuple(dataset._resolutions) == tuple(resolutions)\n        return resolutions\n"
  },
  {
    "path": "dust3r/datasets/blendedmvs.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed BlendedMVS\n# dataset at https://github.com/YoYo000/BlendedMVS\n# See datasets_preprocess/preprocess_blendedmvs.py\n# --------------------------------------------------------\nimport os.path as osp\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass BlendedMVS (BaseStereoViewDataset):\n    \"\"\" Dataset of outdoor street scenes, 5 images each time\n    \"\"\"\n\n    def __init__(self, *args, ROOT, split=None, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        self._load_data(split)\n\n    def _load_data(self, split):\n        pairs = np.load(osp.join(self.ROOT, 'blendedmvs_pairs.npy'))\n        if split is None:\n            selection = slice(None)\n        if split == 'train':\n            # select 90% of all scenes\n            selection = (pairs['seq_low'] % 10) > 0\n        if split == 'val':\n            # select 10% of all scenes\n            selection = (pairs['seq_low'] % 10) == 0\n        self.pairs = pairs[selection]\n\n        # list of all scenes\n        self.scenes = np.unique(self.pairs['seq_low'])  # low is unique enough\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def get_stats(self):\n        return f'{len(self)} pairs from {len(self.scenes)} scenes'\n\n    def _get_views(self, pair_idx, resolution, rng):\n        seqh, seql, img1, img2, score = self.pairs[pair_idx]\n\n        seq = f\"{seqh:08x}{seql:016x}\"\n        seq_path = osp.join(self.ROOT, seq)\n\n        views = []\n\n        for view_index in [img1, img2]:\n            impath = f\"{view_index:08n}\"\n            image = imread_cv2(osp.join(seq_path, impath + \".jpg\"))\n            depthmap = imread_cv2(osp.join(seq_path, impath + \".exr\"))\n            camera_params = np.load(osp.join(seq_path, impath + \".npz\"))\n\n            intrinsics = np.float32(camera_params['intrinsics'])\n            camera_pose = np.eye(4, dtype=np.float32)\n            camera_pose[:3, :3] = camera_params['R_cam2world']\n            camera_pose[:3, 3] = camera_params['t_cam2world']\n\n            image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath))\n\n            views.append(dict(\n                img=image,\n                depthmap=depthmap,\n                camera_pose=camera_pose,  # cam2world\n                camera_intrinsics=intrinsics,\n                dataset='BlendedMVS',\n                label=osp.relpath(seq_path, self.ROOT),\n                instance=impath))\n\n        return views\n\n\nif __name__ == '__main__':\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = BlendedMVS(split='train', ROOT=\"data/blendedmvs_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(idx, view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/co3d.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed Co3d_v2\n# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International\n# See datasets_preprocess/preprocess_co3d.py\n# --------------------------------------------------------\nimport os.path as osp\nimport json\nimport itertools\nfrom collections import deque\n\nimport cv2\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass Co3d(BaseStereoViewDataset):\n    def __init__(self, mask_bg=True, *args, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        assert mask_bg in (True, False, 'rand')\n        self.mask_bg = mask_bg\n        self.dataset_label = 'Co3d_v2'\n\n        # load all scenes\n        with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:\n            self.scenes = json.load(f)\n            self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}\n            self.scenes = {(k, k2): v2 for k, v in self.scenes.items()\n                           for k2, v2 in v.items()}\n        self.scene_list = list(self.scenes.keys())\n\n        # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)\n        # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees\n        self.combinations = [(i, j)\n                             for i, j in itertools.combinations(range(100), 2)\n                             if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0]\n\n        self.invalidate = {scene: {} for scene in self.scene_list}\n\n    def __len__(self):\n        return len(self.scene_list) * len(self.combinations)\n\n    def _get_metadatapath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz')\n\n    def _get_impath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')\n\n    def _get_depthpath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png')\n\n    def _get_maskpath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')\n\n    def _read_depthmap(self, depthpath, input_metadata):\n        depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)\n        depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])\n        return depthmap\n\n    def _get_views(self, idx, resolution, rng):\n        # choose a scene\n        obj, instance = self.scene_list[idx // len(self.combinations)]\n        image_pool = self.scenes[obj, instance]\n        im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]\n\n        # add a bit of randomness\n        last = len(image_pool) - 1\n\n        if resolution not in self.invalidate[obj, instance]:  # flag invalid images\n            self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]\n\n        # decide now if we mask the bg\n        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))\n\n        views = []\n        imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]\n        imgs_idxs = deque(imgs_idxs)\n        while len(imgs_idxs) > 0:  # some images (few) have zero depth\n            im_idx = imgs_idxs.pop()\n\n            if self.invalidate[obj, instance][resolution][im_idx]:\n                # search for a valid image\n                random_direction = 2 * rng.choice(2) - 1\n                for offset in range(1, len(image_pool)):\n                    tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)\n                    if not self.invalidate[obj, instance][resolution][tentative_im_idx]:\n                        im_idx = tentative_im_idx\n                        break\n\n            view_idx = image_pool[im_idx]\n\n            impath = self._get_impath(obj, instance, view_idx)\n            depthpath = self._get_depthpath(obj, instance, view_idx)\n\n            # load camera params\n            metadata_path = self._get_metadatapath(obj, instance, view_idx)\n            input_metadata = np.load(metadata_path)\n            camera_pose = input_metadata['camera_pose'].astype(np.float32)\n            intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)\n\n            # load image and depth\n            rgb_image = imread_cv2(impath)\n            depthmap = self._read_depthmap(depthpath, input_metadata)\n\n            if mask_bg:\n                # load object mask\n                maskpath = self._get_maskpath(obj, instance, view_idx)\n                maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)\n                maskmap = (maskmap / 255.0) > 0.1\n\n                # update the depthmap with mask\n                depthmap *= maskmap\n\n            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)\n\n            num_valid = (depthmap > 0.0).sum()\n            if num_valid == 0:\n                # problem, invalidate image and retry\n                self.invalidate[obj, instance][resolution][im_idx] = True\n                imgs_idxs.append(im_idx)\n                continue\n\n            views.append(dict(\n                img=rgb_image,\n                depthmap=depthmap,\n                camera_pose=camera_pose,\n                camera_intrinsics=intrinsics,\n                dataset=self.dataset_label,\n                label=osp.join(obj, instance),\n                instance=osp.split(impath)[1],\n            ))\n        return views\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = Co3d(split='train', ROOT=\"data/co3d_subset_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/habitat.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed habitat\n# dataset at https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md\n# See datasets_preprocess/habitat for more details\n# --------------------------------------------------------\nimport os.path as osp\nimport os\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"  # noqa\nimport cv2  # noqa\nimport numpy as np\nfrom PIL import Image\nimport json\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\n\n\nclass Habitat(BaseStereoViewDataset):\n    def __init__(self, size, *args, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        assert self.split is not None\n        # loading list of scenes\n        with open(osp.join(self.ROOT, f'Habitat_{size}_scenes_{self.split}.txt')) as f:\n            self.scenes = f.read().splitlines()\n        self.instances = list(range(1, 5))\n\n    def filter_scene(self, label, instance=None):\n        if instance:\n            subscene, instance = instance.split('_')\n            label += '/' + subscene\n            self.instances = [int(instance) - 1]\n        valid = np.bool_([scene.startswith(label) for scene in self.scenes])\n        assert sum(valid), 'no scene was selected for {label=} {instance=}'\n        self.scenes = [scene for i, scene in enumerate(self.scenes) if valid[i]]\n\n    def _get_views(self, idx, resolution, rng):\n        scene = self.scenes[idx]\n        data_path, key = osp.split(osp.join(self.ROOT, scene))\n        views = []\n        two_random_views = [0, rng.choice(self.instances)]  # view 0 is connected with all other views\n        for view_index in two_random_views:\n            # load the view (and use the next one if this one's broken)\n            for ii in range(view_index, view_index + 5):\n                image, depthmap, intrinsics, camera_pose = self._load_one_view(data_path, key, ii % 5, resolution, rng)\n                if np.isfinite(camera_pose).all():\n                    break\n            views.append(dict(\n                img=image,\n                depthmap=depthmap,\n                camera_pose=camera_pose,  # cam2world\n                camera_intrinsics=intrinsics,\n                dataset='Habitat',\n                label=osp.relpath(data_path, self.ROOT),\n                instance=f\"{key}_{view_index}\"))\n        return views\n\n    def _load_one_view(self, data_path, key, view_index, resolution, rng):\n        view_index += 1  # file indices starts at 1\n        impath = osp.join(data_path, f\"{key}_{view_index}.jpeg\")\n        image = Image.open(impath)\n\n        depthmap_filename = osp.join(data_path, f\"{key}_{view_index}_depth.exr\")\n        depthmap = cv2.imread(depthmap_filename, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)\n\n        camera_params_filename = osp.join(data_path, f\"{key}_{view_index}_camera_params.json\")\n        with open(camera_params_filename, 'r') as f:\n            camera_params = json.load(f)\n\n        intrinsics = np.float32(camera_params['camera_intrinsics'])\n        camera_pose = np.eye(4, dtype=np.float32)\n        camera_pose[:3, :3] = camera_params['R_cam2world']\n        camera_pose[:3, 3] = camera_params['t_cam2world']\n\n        image, depthmap, intrinsics = self._crop_resize_if_necessary(\n            image, depthmap, intrinsics, resolution, rng, info=impath)\n        return image, depthmap, intrinsics, camera_pose\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = Habitat(1_000_000, split='train', ROOT=\"data/habitat_processed\",\n                      resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/megadepth.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed MegaDepth\n# dataset at https://www.cs.cornell.edu/projects/megadepth/\n# See datasets_preprocess/preprocess_megadepth.py\n# --------------------------------------------------------\nimport os.path as osp\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass MegaDepth(BaseStereoViewDataset):\n    def __init__(self, *args, split, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        self.loaded_data = self._load_data(self.split)\n\n        if self.split is None:\n            pass\n        elif self.split == 'train':\n            self.select_scene(('0015', '0022'), opposite=True)\n        elif self.split == 'val':\n            self.select_scene(('0015', '0022'))\n        else:\n            raise ValueError(f'bad {self.split=}')\n\n    def _load_data(self, split):\n        with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data:\n            self.all_scenes = data['scenes']\n            self.all_images = data['images']\n            self.pairs = data['pairs']\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def get_stats(self):\n        return f'{len(self)} pairs from {len(self.all_scenes)} scenes'\n\n    def select_scene(self, scene, *instances, opposite=False):\n        scenes = (scene,) if isinstance(scene, str) else tuple(scene)\n        scene_id = [s.startswith(scenes) for s in self.all_scenes]\n        assert any(scene_id), 'no scene found'\n\n        valid = np.in1d(self.pairs['scene_id'], np.nonzero(scene_id)[0])\n        if instances:\n            image_id = [i.startswith(instances) for i in self.all_images]\n            image_id = np.nonzero(image_id)[0]\n            assert len(image_id), 'no instance found'\n            # both together?\n            if len(instances) == 2:\n                valid &= np.in1d(self.pairs['im1_id'], image_id) & np.in1d(self.pairs['im2_id'], image_id)\n            else:\n                valid &= np.in1d(self.pairs['im1_id'], image_id) | np.in1d(self.pairs['im2_id'], image_id)\n\n        if opposite:\n            valid = ~valid\n        assert valid.any()\n        self.pairs = self.pairs[valid]\n\n    def _get_views(self, pair_idx, resolution, rng):\n        scene_id, im1_id, im2_id, score = self.pairs[pair_idx]\n\n        scene, subscene = self.all_scenes[scene_id].split()\n        seq_path = osp.join(self.ROOT, scene, subscene)\n\n        views = []\n\n        for im_id in [im1_id, im2_id]:\n            img = self.all_images[im_id]\n            try:\n                image = imread_cv2(osp.join(seq_path, img + '.jpg'))\n                depthmap = imread_cv2(osp.join(seq_path, img + \".exr\"))\n                camera_params = np.load(osp.join(seq_path, img + \".npz\"))\n            except Exception as e:\n                raise OSError(f'cannot load {img}, got exception {e}')\n\n            intrinsics = np.float32(camera_params['intrinsics'])\n            camera_pose = np.float32(camera_params['cam2world'])\n\n            image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                image, depthmap, intrinsics, resolution, rng, info=(seq_path, img))\n\n            views.append(dict(\n                img=image,\n                depthmap=depthmap,\n                camera_pose=camera_pose,  # cam2world\n                camera_intrinsics=intrinsics,\n                dataset='MegaDepth',\n                label=osp.relpath(seq_path, self.ROOT),\n                instance=img))\n\n        return views\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = MegaDepth(split='train', ROOT=\"data/megadepth_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(idx, view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/scannetpp.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed scannet++\n# dataset at https://github.com/scannetpp/scannetpp - non-commercial research and educational purposes\n# https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf\n# See datasets_preprocess/preprocess_scannetpp.py\n# --------------------------------------------------------\nimport os.path as osp\nimport cv2\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass ScanNetpp(BaseStereoViewDataset):\n    def __init__(self, *args, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        assert self.split == 'train'\n        self.loaded_data = self._load_data()\n\n    def _load_data(self):\n        with np.load(osp.join(self.ROOT, 'all_metadata.npz')) as data:\n            self.scenes = data['scenes']\n            self.sceneids = data['sceneids']\n            self.images = data['images']\n            self.intrinsics = data['intrinsics'].astype(np.float32)\n            self.trajectories = data['trajectories'].astype(np.float32)\n            self.pairs = data['pairs'][:, :2].astype(int)\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def _get_views(self, idx, resolution, rng):\n\n        image_idx1, image_idx2 = self.pairs[idx]\n\n        views = []\n        for view_idx in [image_idx1, image_idx2]:\n            scene_id = self.sceneids[view_idx]\n            scene_dir = osp.join(self.ROOT, self.scenes[scene_id])\n\n            intrinsics = self.intrinsics[view_idx]\n            camera_pose = self.trajectories[view_idx]\n            basename = self.images[view_idx]\n\n            # Load RGB image\n            rgb_image = imread_cv2(osp.join(scene_dir, 'images', basename + '.jpg'))\n            # Load depthmap\n            depthmap = imread_cv2(osp.join(scene_dir, 'depth', basename + '.png'), cv2.IMREAD_UNCHANGED)\n            depthmap = depthmap.astype(np.float32) / 1000\n            depthmap[~np.isfinite(depthmap)] = 0  # invalid\n\n            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx)\n\n            views.append(dict(\n                img=rgb_image,\n                depthmap=depthmap.astype(np.float32),\n                camera_pose=camera_pose.astype(np.float32),\n                camera_intrinsics=intrinsics.astype(np.float32),\n                dataset='ScanNet++',\n                label=self.scenes[scene_id] + '_' + basename,\n                instance=f'{str(idx)}_{str(view_idx)}',\n            ))\n        return views\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = ScanNetpp(split='train', ROOT=\"data/scannetpp_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx*255, (1 - idx)*255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/staticthings3d.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed StaticThings3D\n# dataset at https://github.com/lmb-freiburg/robustmvd/\n# See datasets_preprocess/preprocess_staticthings3d.py\n# --------------------------------------------------------\nimport os.path as osp\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass StaticThings3D (BaseStereoViewDataset):\n    \"\"\" Dataset of indoor scenes, 5 images each time\n    \"\"\"\n    def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n\n        assert mask_bg in (True, False, 'rand')\n        self.mask_bg = mask_bg\n\n        # loading all pairs\n        assert self.split is None\n        self.pairs = np.load(osp.join(ROOT, 'staticthings_pairs.npy'))\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def get_stats(self):\n        return f'{len(self)} pairs'\n\n    def _get_views(self, pair_idx, resolution, rng):\n        scene, seq, cam1, im1, cam2, im2 = self.pairs[pair_idx]\n        seq_path = osp.join('TRAIN', scene.decode('ascii'), f'{seq:04d}')\n\n        views = []\n\n        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))\n\n        CAM = {b'l':'left', b'r':'right'}\n        for cam, idx in [(CAM[cam1], im1), (CAM[cam2], im2)]:\n            num = f\"{idx:04n}\"\n            img = num+\"_clean.jpg\" if rng.choice(2) else num+\"_final.jpg\"\n            image = imread_cv2(osp.join(self.ROOT, seq_path, cam, img))\n            depthmap = imread_cv2(osp.join(self.ROOT, seq_path, cam, num+\".exr\"))\n            camera_params = np.load(osp.join(self.ROOT, seq_path, cam, num+\".npz\"))\n\n            intrinsics = camera_params['intrinsics']\n            camera_pose = camera_params['cam2world']\n\n            if mask_bg:\n                depthmap[depthmap > 200] = 0\n\n            image, depthmap, intrinsics = self._crop_resize_if_necessary(image, depthmap, intrinsics, resolution, rng, info=(seq_path,cam,img))\n\n            views.append(dict(\n                img = image, \n                depthmap = depthmap,\n                camera_pose = camera_pose, # cam2world\n                camera_intrinsics = intrinsics,\n                dataset = 'StaticThings3D',\n                label = seq_path,\n                instance = cam+'_'+img))\n\n        return views\n\n\nif __name__ == '__main__':\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = StaticThings3D(ROOT=\"data/staticthings3d_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(idx, view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx*255, (1 - idx)*255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/utils/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "dust3r/datasets/utils/cropping.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# croppping utilities\n# --------------------------------------------------------\nimport PIL.Image\nimport os\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2  # noqa\nimport numpy as np  # noqa\nfrom dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics  # noqa\ntry:\n    lanczos = PIL.Image.Resampling.LANCZOS\n    bicubic = PIL.Image.Resampling.BICUBIC\nexcept AttributeError:\n    lanczos = PIL.Image.LANCZOS\n    bicubic = PIL.Image.BICUBIC\n\n\nclass ImageList:\n    \"\"\" Convenience class to aply the same operation to a whole set of images.\n    \"\"\"\n\n    def __init__(self, images):\n        if not isinstance(images, (tuple, list, set)):\n            images = [images]\n        self.images = []\n        for image in images:\n            if not isinstance(image, PIL.Image.Image):\n                image = PIL.Image.fromarray(image)\n            self.images.append(image)\n\n    def __len__(self):\n        return len(self.images)\n\n    def to_pil(self):\n        return tuple(self.images) if len(self.images) > 1 else self.images[0]\n\n    @property\n    def size(self):\n        sizes = [im.size for im in self.images]\n        assert all(sizes[0] == s for s in sizes)\n        return sizes[0]\n\n    def resize(self, *args, **kwargs):\n        return ImageList(self._dispatch('resize', *args, **kwargs))\n\n    def crop(self, *args, **kwargs):\n        return ImageList(self._dispatch('crop', *args, **kwargs))\n\n    def _dispatch(self, func, *args, **kwargs):\n        return [getattr(im, func)(*args, **kwargs) for im in self.images]\n\n\ndef rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True):\n    \"\"\" Jointly rescale a (image, depthmap) \n        so that (out_width, out_height) >= output_res\n    \"\"\"\n    image = ImageList(image)\n    input_resolution = np.array(image.size)  # (W,H)\n    output_resolution = np.array(output_resolution)\n    if depthmap is not None:\n        # can also use this with masks instead of depthmaps\n        assert tuple(depthmap.shape[:2]) == image.size[::-1]\n\n    # define output resolution\n    assert output_resolution.shape == (2,)\n    scale_final = max(output_resolution / image.size) + 1e-8\n    if scale_final >= 1 and not force:  # image is already smaller than what is asked\n        return (image.to_pil(), depthmap, camera_intrinsics)\n    output_resolution = np.floor(input_resolution * scale_final).astype(int)\n\n    # first rescale the image so that it contains the crop\n    image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)\n    if depthmap is not None:\n        depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,\n                              fy=scale_final, interpolation=cv2.INTER_NEAREST)\n\n    # no offset here; simple rescaling\n    camera_intrinsics = camera_matrix_of_crop(\n        camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)\n\n    return image.to_pil(), depthmap, camera_intrinsics\n\n\ndef camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):\n    # Margins to offset the origin\n    margins = np.asarray(input_resolution) * scaling - output_resolution\n    assert np.all(margins >= 0.0)\n    if offset is None:\n        offset = offset_factor * margins\n\n    # Generate new camera parameters\n    output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)\n    output_camera_matrix_colmap[:2, :] *= scaling\n    output_camera_matrix_colmap[:2, 2] -= offset\n    output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)\n\n    return output_camera_matrix\n\n\ndef crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):\n    \"\"\"\n    Return a crop of the input view.\n    \"\"\"\n    image = ImageList(image)\n    l, t, r, b = crop_bbox\n\n    image = image.crop((l, t, r, b))\n    depthmap = depthmap[t:b, l:r]\n\n    camera_intrinsics = camera_intrinsics.copy()\n    camera_intrinsics[0, 2] -= l\n    camera_intrinsics[1, 2] -= t\n\n    return image.to_pil(), depthmap, camera_intrinsics\n\n\ndef bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):\n    out_width, out_height = output_resolution\n    l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))\n    crop_bbox = (l, t, l + out_width, t + out_height)\n    return crop_bbox\n"
  },
  {
    "path": "dust3r/datasets/utils/transforms.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# DUST3R default transforms\n# --------------------------------------------------------\nimport torchvision.transforms as tvf\nfrom dust3r.utils.image import ImgNorm\n\n# define the standard image transforms\nColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])\n"
  },
  {
    "path": "dust3r/datasets/waymo.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed WayMo\n# dataset at https://github.com/waymo-research/waymo-open-dataset\n# See datasets_preprocess/preprocess_waymo.py\n# --------------------------------------------------------\nimport os.path as osp\nimport numpy as np\n\nfrom dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset\nfrom dust3r.utils.image import imread_cv2\n\n\nclass Waymo (BaseStereoViewDataset):\n    \"\"\" Dataset of outdoor street scenes, 5 images each time\n    \"\"\"\n\n    def __init__(self, *args, ROOT, **kwargs):\n        self.ROOT = ROOT\n        super().__init__(*args, **kwargs)\n        self._load_data()\n\n    def _load_data(self):\n        with np.load(osp.join(self.ROOT, 'waymo_pairs.npz')) as data:\n            self.scenes = data['scenes']\n            self.frames = data['frames']\n            self.inv_frames = {frame: i for i, frame in enumerate(data['frames'])}\n            self.pairs = data['pairs']  # (array of (scene_id, img1_id, img2_id)\n            assert self.pairs[:, 0].max() == len(self.scenes) - 1\n\n    def __len__(self):\n        return len(self.pairs)\n\n    def get_stats(self):\n        return f'{len(self)} pairs from {len(self.scenes)} scenes'\n\n    def _get_views(self, pair_idx, resolution, rng):\n        seq, img1, img2 = self.pairs[pair_idx]\n        seq_path = osp.join(self.ROOT, self.scenes[seq])\n\n        views = []\n\n        for view_index in [img1, img2]:\n            impath = self.frames[view_index]\n            image = imread_cv2(osp.join(seq_path, impath + \".jpg\"))\n            depthmap = imread_cv2(osp.join(seq_path, impath + \".exr\"))\n            camera_params = np.load(osp.join(seq_path, impath + \".npz\"))\n\n            intrinsics = np.float32(camera_params['intrinsics'])\n            camera_pose = np.float32(camera_params['cam2world'])\n\n            image, depthmap, intrinsics = self._crop_resize_if_necessary(\n                image, depthmap, intrinsics, resolution, rng, info=(seq_path, impath))\n\n            views.append(dict(\n                img=image,\n                depthmap=depthmap,\n                camera_pose=camera_pose,  # cam2world\n                camera_intrinsics=intrinsics,\n                dataset='Waymo',\n                label=osp.relpath(seq_path, self.ROOT),\n                instance=impath))\n\n        return views\n\n\nif __name__ == '__main__':\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = Waymo(split='train', ROOT=\"data/megadepth_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(idx, view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/datasets/wildrgbd.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Dataloader for preprocessed WildRGB-D\n# dataset at https://github.com/wildrgbd/wildrgbd/\n# See datasets_preprocess/preprocess_wildrgbd.py\n# --------------------------------------------------------\nimport os.path as osp\n\nimport cv2\nimport numpy as np\n\nfrom dust3r.datasets.co3d import Co3d\nfrom dust3r.utils.image import imread_cv2\n\n\nclass WildRGBD(Co3d):\n    def __init__(self, mask_bg=True, *args, ROOT, **kwargs):\n        super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)\n        self.dataset_label = 'WildRGBD'\n\n    def _get_metadatapath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'metadata', f'{view_idx:0>5d}.npz')\n\n    def _get_impath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'rgb', f'{view_idx:0>5d}.jpg')\n\n    def _get_depthpath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'depth', f'{view_idx:0>5d}.png')\n\n    def _get_maskpath(self, obj, instance, view_idx):\n        return osp.join(self.ROOT, obj, instance, 'masks', f'{view_idx:0>5d}.png')\n\n    def _read_depthmap(self, depthpath, input_metadata):\n        # We store depths in the depth scale of 1000.\n        # That is, when we load depth image and divide by 1000, we could get depth in meters.\n        depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)\n        depthmap = depthmap.astype(np.float32) / 1000.0\n        return depthmap\n\n\nif __name__ == \"__main__\":\n    from dust3r.datasets.base.base_stereo_view_dataset import view_name\n    from dust3r.viz import SceneViz, auto_cam_size\n    from dust3r.utils.image import rgb\n\n    dataset = WildRGBD(split='train', ROOT=\"data/wildrgbd_processed\", resolution=224, aug_crop=16)\n\n    for idx in np.random.permutation(len(dataset)):\n        views = dataset[idx]\n        assert len(views) == 2\n        print(view_name(views[0]), view_name(views[1]))\n        viz = SceneViz()\n        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]\n        cam_size = max(auto_cam_size(poses), 0.001)\n        for view_idx in [0, 1]:\n            pts3d = views[view_idx]['pts3d']\n            valid_mask = views[view_idx]['valid_mask']\n            colors = rgb(views[view_idx]['img'])\n            viz.add_pointcloud(pts3d, colors, valid_mask)\n            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],\n                           focal=views[view_idx]['camera_intrinsics'][0, 0],\n                           color=(idx * 255, (1 - idx) * 255, 0),\n                           image=colors,\n                           cam_size=cam_size)\n        viz.show()\n"
  },
  {
    "path": "dust3r/demo.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# gradio demo\n# --------------------------------------------------------\nimport argparse\nimport math\nimport builtins\nimport datetime\nimport gradio\nimport os\nimport torch\nimport numpy as np\nimport functools\nimport trimesh\nimport copy\nfrom scipy.spatial.transform import Rotation\n\nfrom dust3r.inference import inference\nfrom dust3r.image_pairs import make_pairs\nfrom dust3r.utils.image import load_images, rgb\nfrom dust3r.utils.device import to_numpy\nfrom dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes\nfrom dust3r.cloud_opt import global_aligner, GlobalAlignerMode\n\nimport matplotlib.pyplot as pl\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser()\n    parser_url = parser.add_mutually_exclusive_group()\n    parser_url.add_argument(\"--local_network\", action='store_true', default=False,\n                            help=\"make app accessible on local network: address will be set to 0.0.0.0\")\n    parser_url.add_argument(\"--server_name\", type=str, default=None, help=\"server url, default is 127.0.0.1\")\n    parser.add_argument(\"--image_size\", type=int, default=512, choices=[512, 224], help=\"image size\")\n    parser.add_argument(\"--server_port\", type=int, help=(\"will start gradio app on this port (if available). \"\n                                                         \"If None, will search for an available port starting at 7860.\"),\n                        default=None)\n    parser_weights = parser.add_mutually_exclusive_group(required=True)\n    parser_weights.add_argument(\"--weights\", type=str, help=\"path to the model weights\", default=None)\n    parser_weights.add_argument(\"--model_name\", type=str, help=\"name of the model weights\",\n                                choices=[\"DUSt3R_ViTLarge_BaseDecoder_512_dpt\",\n                                         \"DUSt3R_ViTLarge_BaseDecoder_512_linear\",\n                                         \"DUSt3R_ViTLarge_BaseDecoder_224_linear\"])\n    parser.add_argument(\"--device\", type=str, default='cuda', help=\"pytorch device\")\n    parser.add_argument(\"--tmp_dir\", type=str, default=None, help=\"value for tempfile.tempdir\")\n    parser.add_argument(\"--silent\", action='store_true', default=False,\n                        help=\"silence logs\")\n    return parser\n\n\ndef set_print_with_timestamp(time_format=\"%Y-%m-%d %H:%M:%S\"):\n    builtin_print = builtins.print\n\n    def print_with_timestamp(*args, **kwargs):\n        now = datetime.datetime.now()\n        formatted_date_time = now.strftime(time_format)\n\n        builtin_print(f'[{formatted_date_time}] ', end='')  # print with time stamp\n        builtin_print(*args, **kwargs)\n\n    builtins.print = print_with_timestamp\n\n\ndef _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,\n                                 cam_color=None, as_pointcloud=False,\n                                 transparent_cams=False, silent=False):\n    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)\n    pts3d = to_numpy(pts3d)\n    imgs = to_numpy(imgs)\n    focals = to_numpy(focals)\n    cams2world = to_numpy(cams2world)\n\n    scene = trimesh.Scene()\n\n    # full pointcloud\n    if as_pointcloud:\n        pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])\n        col = np.concatenate([p[m] for p, m in zip(imgs, mask)])\n        pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))\n        scene.add_geometry(pct)\n    else:\n        meshes = []\n        for i in range(len(imgs)):\n            meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))\n        mesh = trimesh.Trimesh(**cat_meshes(meshes))\n        scene.add_geometry(mesh)\n\n    # add each camera\n    for i, pose_c2w in enumerate(cams2world):\n        if isinstance(cam_color, list):\n            camera_edge_color = cam_color[i]\n        else:\n            camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]\n        add_scene_cam(scene, pose_c2w, camera_edge_color,\n                      None if transparent_cams else imgs[i], focals[i],\n                      imsize=imgs[i].shape[1::-1], screen_width=cam_size)\n\n    rot = np.eye(4)\n    rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()\n    scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))\n    outfile = os.path.join(outdir, 'scene.glb')\n    if not silent:\n        print('(exporting 3D scene to', outfile, ')')\n    scene.export(file_obj=outfile)\n    return outfile\n\n\ndef get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,\n                            clean_depth=False, transparent_cams=False, cam_size=0.05):\n    \"\"\"\n    extract 3D_model (glb file) from a reconstructed scene\n    \"\"\"\n    if scene is None:\n        return None\n    # post processes\n    if clean_depth:\n        scene = scene.clean_pointcloud()\n    if mask_sky:\n        scene = scene.mask_sky()\n\n    # get optimized values from scene\n    rgbimg = scene.imgs\n    focals = scene.get_focals().cpu()\n    cams2world = scene.get_im_poses().cpu()\n    # 3D pointcloud from depthmap, poses and intrinsics\n    pts3d = to_numpy(scene.get_pts3d())\n    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))\n    msk = to_numpy(scene.get_masks())\n    return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,\n                                        transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)\n\n\ndef get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,\n                            as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,\n                            scenegraph_type, winsize, refid):\n    \"\"\"\n    from a list of images, run dust3r inference, global aligner.\n    then run get_3D_model_from_scene\n    \"\"\"\n    try:\n        square_ok = model.square_ok\n    except Exception as e:\n        square_ok = False\n    imgs = load_images(filelist, size=image_size, verbose=not silent, patch_size=model.patch_size, square_ok=square_ok)\n    if len(imgs) == 1:\n        imgs = [imgs[0], copy.deepcopy(imgs[0])]\n        imgs[1]['idx'] = 1\n    if scenegraph_type == \"swin\":\n        scenegraph_type = scenegraph_type + \"-\" + str(winsize)\n    elif scenegraph_type == \"oneref\":\n        scenegraph_type = scenegraph_type + \"-\" + str(refid)\n\n    pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)\n    output = inference(pairs, model, device, batch_size=1, verbose=not silent)\n\n    mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer\n    scene = global_aligner(output, device=device, mode=mode, verbose=not silent)\n    lr = 0.01\n\n    if mode == GlobalAlignerMode.PointCloudOptimizer:\n        loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)\n\n    outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,\n                                      clean_depth, transparent_cams, cam_size)\n\n    # also return rgb, depth and confidence imgs\n    # depth is normalized with the max value for all images\n    # we apply the jet colormap on the confidence maps\n    rgbimg = scene.imgs\n    depths = to_numpy(scene.get_depthmaps())\n    confs = to_numpy([c for c in scene.im_conf])\n    cmap = pl.get_cmap('jet')\n    depths_max = max([d.max() for d in depths])\n    depths = [d / depths_max for d in depths]\n    confs_max = max([d.max() for d in confs])\n    confs = [cmap(d / confs_max) for d in confs]\n\n    imgs = []\n    for i in range(len(rgbimg)):\n        imgs.append(rgbimg[i])\n        imgs.append(rgb(depths[i]))\n        imgs.append(rgb(confs[i]))\n\n    return scene, outfile, imgs\n\n\ndef set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):\n    num_files = len(inputfiles) if inputfiles is not None else 1\n    max_winsize = max(1, math.ceil((num_files - 1) / 2))\n    if scenegraph_type == \"swin\":\n        winsize = gradio.Slider(label=\"Scene Graph: Window Size\", value=max_winsize,\n                                minimum=1, maximum=max_winsize, step=1, visible=True)\n        refid = gradio.Slider(label=\"Scene Graph: Id\", value=0, minimum=0,\n                              maximum=num_files - 1, step=1, visible=False)\n    elif scenegraph_type == \"oneref\":\n        winsize = gradio.Slider(label=\"Scene Graph: Window Size\", value=max_winsize,\n                                minimum=1, maximum=max_winsize, step=1, visible=False)\n        refid = gradio.Slider(label=\"Scene Graph: Id\", value=0, minimum=0,\n                              maximum=num_files - 1, step=1, visible=True)\n    else:\n        winsize = gradio.Slider(label=\"Scene Graph: Window Size\", value=max_winsize,\n                                minimum=1, maximum=max_winsize, step=1, visible=False)\n        refid = gradio.Slider(label=\"Scene Graph: Id\", value=0, minimum=0,\n                              maximum=num_files - 1, step=1, visible=False)\n    return winsize, refid\n\n\ndef main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False):\n    recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)\n    model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)\n    with gradio.Blocks(css=\"\"\".gradio-container {margin: 0 !important; min-width: 100%};\"\"\", title=\"DUSt3R Demo\") as demo:\n        # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference\n        scene = gradio.State(None)\n        gradio.HTML('<h2 style=\"text-align: center;\">DUSt3R Demo</h2>')\n        with gradio.Column():\n            inputfiles = gradio.File(file_count=\"multiple\")\n            with gradio.Row():\n                schedule = gradio.Dropdown([\"linear\", \"cosine\"],\n                                           value='linear', label=\"schedule\", info=\"For global alignment!\")\n                niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,\n                                      label=\"num_iterations\", info=\"For global alignment!\")\n                scenegraph_type = gradio.Dropdown([(\"complete: all possible image pairs\", \"complete\"),\n                                                   (\"swin: sliding window\", \"swin\"),\n                                                   (\"oneref: match one image with all\", \"oneref\")],\n                                                  value='complete', label=\"Scenegraph\",\n                                                  info=\"Define how to make pairs\",\n                                                  interactive=True)\n                winsize = gradio.Slider(label=\"Scene Graph: Window Size\", value=1,\n                                        minimum=1, maximum=1, step=1, visible=False)\n                refid = gradio.Slider(label=\"Scene Graph: Id\", value=0, minimum=0, maximum=0, step=1, visible=False)\n\n            run_btn = gradio.Button(\"Run\")\n\n            with gradio.Row():\n                # adjust the confidence threshold\n                min_conf_thr = gradio.Slider(label=\"min_conf_thr\", value=3.0, minimum=1.0, maximum=20, step=0.1)\n                # adjust the camera size in the output pointcloud\n                cam_size = gradio.Slider(label=\"cam_size\", value=0.05, minimum=0.001, maximum=0.1, step=0.001)\n            with gradio.Row():\n                as_pointcloud = gradio.Checkbox(value=False, label=\"As pointcloud\")\n                # two post process implemented\n                mask_sky = gradio.Checkbox(value=False, label=\"Mask sky\")\n                clean_depth = gradio.Checkbox(value=True, label=\"Clean-up depthmaps\")\n                transparent_cams = gradio.Checkbox(value=False, label=\"Transparent cameras\")\n\n            outmodel = gradio.Model3D()\n            outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height=\"100%\")\n\n            # events\n            scenegraph_type.change(set_scenegraph_options,\n                                   inputs=[inputfiles, winsize, refid, scenegraph_type],\n                                   outputs=[winsize, refid])\n            inputfiles.change(set_scenegraph_options,\n                              inputs=[inputfiles, winsize, refid, scenegraph_type],\n                              outputs=[winsize, refid])\n            run_btn.click(fn=recon_fun,\n                          inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,\n                                  mask_sky, clean_depth, transparent_cams, cam_size,\n                                  scenegraph_type, winsize, refid],\n                          outputs=[scene, outmodel, outgallery])\n            min_conf_thr.release(fn=model_from_scene_fun,\n                                 inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                         clean_depth, transparent_cams, cam_size],\n                                 outputs=outmodel)\n            cam_size.change(fn=model_from_scene_fun,\n                            inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                    clean_depth, transparent_cams, cam_size],\n                            outputs=outmodel)\n            as_pointcloud.change(fn=model_from_scene_fun,\n                                 inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                         clean_depth, transparent_cams, cam_size],\n                                 outputs=outmodel)\n            mask_sky.change(fn=model_from_scene_fun,\n                            inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                    clean_depth, transparent_cams, cam_size],\n                            outputs=outmodel)\n            clean_depth.change(fn=model_from_scene_fun,\n                               inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                       clean_depth, transparent_cams, cam_size],\n                               outputs=outmodel)\n            transparent_cams.change(model_from_scene_fun,\n                                    inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,\n                                            clean_depth, transparent_cams, cam_size],\n                                    outputs=outmodel)\n    demo.launch(share=False, server_name=server_name, server_port=server_port)\n"
  },
  {
    "path": "dust3r/heads/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# head factory\n# --------------------------------------------------------\nfrom .linear_head import LinearPts3d\nfrom .dpt_head import create_dpt_head\n\n\ndef head_factory(head_type, output_mode, net, has_conf=False):\n    \"\"\"\" build a prediction head for the decoder \n    \"\"\"\n    if head_type == 'linear' and output_mode == 'pts3d':\n        return LinearPts3d(net, has_conf)\n    elif head_type == 'dpt' and output_mode == 'pts3d':\n        return create_dpt_head(net, has_conf=has_conf)\n    else:\n        raise NotImplementedError(f\"unexpected {head_type=} and {output_mode=}\")\n"
  },
  {
    "path": "dust3r/heads/dpt_head.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# dpt head implementation for DUST3R\n# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;\n# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True\n# the forward function also takes as input a dictionnary img_info with key \"height\" and \"width\"\n# for PixelwiseTask, the output will be of dimension B x num_channels x H x W\n# --------------------------------------------------------\nfrom einops import rearrange\nfrom typing import List\nimport torch\nimport torch.nn as nn\nfrom dust3r.heads.postprocess import postprocess\nimport dust3r.utils.path_to_croco  # noqa: F401\nfrom models.dpt_block import DPTOutputAdapter  # noqa\n\n\nclass DPTOutputAdapter_fix(DPTOutputAdapter):\n    \"\"\"\n    Adapt croco's DPTOutputAdapter implementation for dust3r:\n    remove duplicated weigths, and fix forward for dust3r\n    \"\"\"\n\n    def init(self, dim_tokens_enc=768):\n        super().init(dim_tokens_enc)\n        # these are duplicated weights\n        del self.act_1_postprocess\n        del self.act_2_postprocess\n        del self.act_3_postprocess\n        del self.act_4_postprocess\n\n    def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):\n        assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'\n        # H, W = input_info['image_size']\n        image_size = self.image_size if image_size is None else image_size\n        H, W = image_size\n        # Number of patches in height and width\n        N_H = H // (self.stride_level * self.P_H)\n        N_W = W // (self.stride_level * self.P_W)\n\n        # Hook decoder onto 4 layers from specified ViT layers\n        layers = [encoder_tokens[hook] for hook in self.hooks]\n\n        # Extract only task-relevant tokens and ignore global tokens.\n        layers = [self.adapt_tokens(l) for l in layers]\n\n        # Reshape tokens to spatial representation\n        layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]\n\n        layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]\n        # Project layers to chosen feature dim\n        layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]\n\n        # Fuse layers using refinement stages\n        path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]\n        path_3 = self.scratch.refinenet3(path_4, layers[2])\n        path_2 = self.scratch.refinenet2(path_3, layers[1])\n        path_1 = self.scratch.refinenet1(path_2, layers[0])\n\n        # Output head\n        out = self.head(path_1)\n\n        return out\n\n\nclass PixelwiseTaskWithDPT(nn.Module):\n    \"\"\" DPT module for dust3r, can return 3D points + confidence for all pixels\"\"\"\n\n    def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,\n                 output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):\n        super(PixelwiseTaskWithDPT, self).__init__()\n        self.return_all_layers = True  # backbone needs to return all layers\n        self.postprocess = postprocess\n        self.depth_mode = depth_mode\n        self.conf_mode = conf_mode\n\n        assert n_cls_token == 0, \"Not implemented\"\n        dpt_args = dict(output_width_ratio=output_width_ratio,\n                        num_channels=num_channels,\n                        **kwargs)\n        if hooks_idx is not None:\n            dpt_args.update(hooks=hooks_idx)\n        self.dpt = DPTOutputAdapter_fix(**dpt_args)\n        dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}\n        self.dpt.init(**dpt_init_args)\n\n    def forward(self, x, img_info):\n        out = self.dpt(x, image_size=(img_info[0], img_info[1]))\n        if self.postprocess:\n            out = self.postprocess(out, self.depth_mode, self.conf_mode)\n        return out\n\n\ndef create_dpt_head(net, has_conf=False):\n    \"\"\"\n    return PixelwiseTaskWithDPT for given net params\n    \"\"\"\n    assert net.dec_depth > 9\n    l2 = net.dec_depth\n    feature_dim = 256\n    last_dim = feature_dim//2\n    out_nchan = 3\n    ed = net.enc_embed_dim\n    dd = net.dec_embed_dim\n    return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,\n                                feature_dim=feature_dim,\n                                last_dim=last_dim,\n                                hooks_idx=[0, l2*2//4, l2*3//4, l2],\n                                dim_tokens=[ed, dd, dd, dd],\n                                postprocess=postprocess,\n                                depth_mode=net.depth_mode,\n                                conf_mode=net.conf_mode,\n                                head_type='regression')\n"
  },
  {
    "path": "dust3r/heads/linear_head.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# linear head implementation for DUST3R\n# --------------------------------------------------------\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom dust3r.heads.postprocess import postprocess\n\n\nclass LinearPts3d (nn.Module):\n    \"\"\" \n    Linear head for dust3r\n    Each token outputs: - 16x16 3D points (+ confidence)\n    \"\"\"\n\n    def __init__(self, net, has_conf=False):\n        super().__init__()\n        self.patch_size = net.patch_embed.patch_size[0]\n        self.depth_mode = net.depth_mode\n        self.conf_mode = net.conf_mode\n        self.has_conf = has_conf\n\n        self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)\n\n    def setup(self, croconet):\n        pass\n\n    def forward(self, decout, img_shape):\n        H, W = img_shape\n        tokens = decout[-1]\n        B, S, D = tokens.shape\n\n        # extract 3D points\n        feat = self.proj(tokens)  # B,S,D\n        feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)\n        feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W\n\n        # permute + norm depth\n        return postprocess(feat, self.depth_mode, self.conf_mode)\n"
  },
  {
    "path": "dust3r/heads/postprocess.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# post process function for all heads: extract 3D points/confidence from output\n# --------------------------------------------------------\nimport torch\n\n\ndef postprocess(out, depth_mode, conf_mode):\n    \"\"\"\n    extract 3D points/confidence from prediction head output\n    \"\"\"\n    fmap = out.permute(0, 2, 3, 1)  # B,H,W,3\n    res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))\n\n    if conf_mode is not None:\n        res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)\n    return res\n\n\ndef reg_dense_depth(xyz, mode):\n    \"\"\"\n    extract 3D points from prediction head output\n    \"\"\"\n    mode, vmin, vmax = mode\n\n    no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))\n    assert no_bounds\n\n    if mode == 'linear':\n        if no_bounds:\n            return xyz  # [-inf, +inf]\n        return xyz.clip(min=vmin, max=vmax)\n\n    # distance to origin\n    d = xyz.norm(dim=-1, keepdim=True)\n    xyz = xyz / d.clip(min=1e-8)\n\n    if mode == 'square':\n        return xyz * d.square()\n\n    if mode == 'exp':\n        return xyz * torch.expm1(d)\n\n    raise ValueError(f'bad {mode=}')\n\n\ndef reg_dense_conf(x, mode):\n    \"\"\"\n    extract confidence from prediction head output\n    \"\"\"\n    mode, vmin, vmax = mode\n    if mode == 'exp':\n        return vmin + x.exp().clip(max=vmax-vmin)\n    if mode == 'sigmoid':\n        return (vmax - vmin) * torch.sigmoid(x) + vmin\n    raise ValueError(f'bad {mode=}')\n"
  },
  {
    "path": "dust3r/image_pairs.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilities needed to load image pairs\n# --------------------------------------------------------\nimport numpy as np\nimport torch\n\n\ndef make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):\n    pairs = []\n    if scene_graph == 'complete':  # complete graph\n        for i in range(len(imgs)):\n            for j in range(i):\n                pairs.append((imgs[i], imgs[j]))\n    elif scene_graph.startswith('swin'):\n        iscyclic = not scene_graph.endswith('noncyclic')\n        try:\n            winsize = int(scene_graph.split('-')[1])\n        except Exception as e:\n            winsize = 3\n        pairsid = set()\n        for i in range(len(imgs)):\n            for j in range(1, winsize + 1):\n                idx = (i + j)\n                if iscyclic:\n                    idx = idx % len(imgs)  # explicit loop closure\n                if idx >= len(imgs):\n                    continue\n                pairsid.add((i, idx) if i < idx else (idx, i))\n        for i, j in pairsid:\n            pairs.append((imgs[i], imgs[j]))\n    elif scene_graph.startswith('logwin'):\n        iscyclic = not scene_graph.endswith('noncyclic')\n        try:\n            winsize = int(scene_graph.split('-')[1])\n        except Exception as e:\n            winsize = 3\n        offsets = [2**i for i in range(winsize)]\n        pairsid = set()\n        for i in range(len(imgs)):\n            ixs_l = [i - off for off in offsets]\n            ixs_r = [i + off for off in offsets]\n            for j in ixs_l + ixs_r:\n                if iscyclic:\n                    j = j % len(imgs)  # Explicit loop closure\n                if j < 0 or j >= len(imgs) or j == i:\n                    continue\n                pairsid.add((i, j) if i < j else (j, i))\n        for i, j in pairsid:\n            pairs.append((imgs[i], imgs[j]))\n    elif scene_graph.startswith('oneref'):\n        refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0\n        for j in range(len(imgs)):\n            if j != refid:\n                pairs.append((imgs[refid], imgs[j]))\n    if symmetrize:\n        pairs += [(img2, img1) for img1, img2 in pairs]\n\n    # now, remove edges\n    if isinstance(prefilter, str) and prefilter.startswith('seq'):\n        pairs = filter_pairs_seq(pairs, int(prefilter[3:]))\n\n    if isinstance(prefilter, str) and prefilter.startswith('cyc'):\n        pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)\n\n    return pairs\n\n\ndef sel(x, kept):\n    if isinstance(x, dict):\n        return {k: sel(v, kept) for k, v in x.items()}\n    if isinstance(x, (torch.Tensor, np.ndarray)):\n        return x[kept]\n    if isinstance(x, (tuple, list)):\n        return type(x)([x[k] for k in kept])\n\n\ndef _filter_edges_seq(edges, seq_dis_thr, cyclic=False):\n    # number of images\n    n = max(max(e) for e in edges) + 1\n\n    kept = []\n    for e, (i, j) in enumerate(edges):\n        dis = abs(i - j)\n        if cyclic:\n            dis = min(dis, abs(i + n - j), abs(i - n - j))\n        if dis <= seq_dis_thr:\n            kept.append(e)\n    return kept\n\n\ndef filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):\n    edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]\n    kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)\n    return [pairs[i] for i in kept]\n\n\ndef filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):\n    edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]\n    kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)\n    print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')\n    return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)\n"
  },
  {
    "path": "dust3r/inference.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilities needed for the inference\n# --------------------------------------------------------\nimport tqdm\nimport torch\nfrom dust3r.utils.device import to_cpu, collate_with_cat\nfrom dust3r.utils.misc import invalid_to_nans\nfrom dust3r.utils.geometry import depthmap_to_pts3d, geotrf\n\n\ndef _interleave_imgs(img1, img2):\n    res = {}\n    for key, value1 in img1.items():\n        value2 = img2[key]\n        if isinstance(value1, torch.Tensor):\n            value = torch.stack((value1, value2), dim=1).flatten(0, 1)\n        else:\n            value = [x for pair in zip(value1, value2) for x in pair]\n        res[key] = value\n    return res\n\n\ndef make_batch_symmetric(batch):\n    view1, view2 = batch\n    view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))\n    return view1, view2\n\n\ndef loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):\n    view1, view2 = batch\n    ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'rng'])\n    for view in batch:\n        for name in view.keys():  # pseudo_focal\n            if name in ignore_keys:\n                continue\n            view[name] = view[name].to(device, non_blocking=True)\n\n    if symmetrize_batch:\n        view1, view2 = make_batch_symmetric(batch)\n\n    with torch.cuda.amp.autocast(enabled=bool(use_amp)):\n        pred1, pred2 = model(view1, view2)\n\n        # loss is supposed to be symmetric\n        with torch.cuda.amp.autocast(enabled=False):\n            loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None\n\n    result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)\n    return result[ret] if ret else result\n\n\n@torch.no_grad()\ndef inference(pairs, model, device, batch_size=8, verbose=True):\n    if verbose:\n        print(f'>> Inference with model on {len(pairs)} image pairs')\n    result = []\n\n    # first, check if all images have the same size\n    multiple_shapes = not (check_if_same_size(pairs))\n    if multiple_shapes:  # force bs=1\n        batch_size = 1\n\n    for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):\n        res = loss_of_one_batch(collate_with_cat(pairs[i:i + batch_size]), model, None, device)\n        result.append(to_cpu(res))\n\n    result = collate_with_cat(result, lists=multiple_shapes)\n\n    return result\n\n\ndef check_if_same_size(pairs):\n    shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]\n    shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]\n    return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)\n\n\ndef get_pred_pts3d(gt, pred, use_pose=False):\n    if 'depth' in pred and 'pseudo_focal' in pred:\n        try:\n            pp = gt['camera_intrinsics'][..., :2, 2]\n        except KeyError:\n            pp = None\n        pts3d = depthmap_to_pts3d(**pred, pp=pp)\n\n    elif 'pts3d' in pred:\n        # pts3d from my camera\n        pts3d = pred['pts3d']\n\n    elif 'pts3d_in_other_view' in pred:\n        # pts3d from the other camera, already transformed\n        assert use_pose is True\n        return pred['pts3d_in_other_view']  # return!\n\n    if use_pose:\n        camera_pose = pred.get('camera_pose')\n        assert camera_pose is not None\n        pts3d = geotrf(camera_pose, pts3d)\n\n    return pts3d\n\n\ndef find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):\n    assert gt_pts1.ndim == pr_pts1.ndim == 4\n    assert gt_pts1.shape == pr_pts1.shape\n    if gt_pts2 is not None:\n        assert gt_pts2.ndim == pr_pts2.ndim == 4\n        assert gt_pts2.shape == pr_pts2.shape\n\n    # concat the pointcloud\n    nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)\n    nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None\n\n    pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)\n    pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None\n\n    all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1\n    all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1\n\n    dot_gt_pr = (all_pr * all_gt).sum(dim=-1)\n    dot_gt_gt = all_gt.square().sum(dim=-1)\n\n    if fit_mode.startswith('avg'):\n        # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)\n        scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)\n    elif fit_mode.startswith('median'):\n        scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values\n    elif fit_mode.startswith('weiszfeld'):\n        # init scaling with l2 closed form\n        scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)\n        # iterative re-weighted least-squares\n        for iter in range(10):\n            # re-weighting by inverse of distance\n            dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)\n            # print(dis.nanmean(-1))\n            w = dis.clip_(min=1e-8).reciprocal()\n            # update the scaling with the new weights\n            scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)\n    else:\n        raise ValueError(f'bad {fit_mode=}')\n\n    if fit_mode.endswith('stop_grad'):\n        scaling = scaling.detach()\n\n    scaling = scaling.clip(min=1e-3)\n    # assert scaling.isfinite().all(), bb()\n    return scaling\n"
  },
  {
    "path": "dust3r/losses.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Implementation of DUSt3R training losses\n# --------------------------------------------------------\nfrom copy import copy, deepcopy\nimport torch\nimport torch.nn as nn\n\nfrom dust3r.inference import get_pred_pts3d, find_opt_scaling\nfrom dust3r.utils.geometry import inv, geotrf, normalize_pointcloud\nfrom dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale\n\n\ndef Sum(*losses_and_masks):\n    loss, mask = losses_and_masks[0]\n    if loss.ndim > 0:\n        # we are actually returning the loss for every pixels\n        return losses_and_masks\n    else:\n        # we are returning the global loss\n        for loss2, mask2 in losses_and_masks[1:]:\n            loss = loss + loss2\n        return loss\n\n\nclass BaseCriterion(nn.Module):\n    def __init__(self, reduction='mean'):\n        super().__init__()\n        self.reduction = reduction\n\n\nclass LLoss (BaseCriterion):\n    \"\"\" L-norm loss\n    \"\"\"\n\n    def forward(self, a, b):\n        assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'\n        dist = self.distance(a, b)\n        assert dist.ndim == a.ndim - 1  # one dimension less\n        if self.reduction == 'none':\n            return dist\n        if self.reduction == 'sum':\n            return dist.sum()\n        if self.reduction == 'mean':\n            return dist.mean() if dist.numel() > 0 else dist.new_zeros(())\n        raise ValueError(f'bad {self.reduction=} mode')\n\n    def distance(self, a, b):\n        raise NotImplementedError()\n\n\nclass L21Loss (LLoss):\n    \"\"\" Euclidean distance between 3d points  \"\"\"\n\n    def distance(self, a, b):\n        return torch.norm(a - b, dim=-1)  # normalized L2 distance\n\n\nL21 = L21Loss()\n\n\nclass Criterion (nn.Module):\n    def __init__(self, criterion=None):\n        super().__init__()\n        assert isinstance(criterion, BaseCriterion), f'{criterion} is not a proper criterion!'\n        self.criterion = copy(criterion)\n\n    def get_name(self):\n        return f'{type(self).__name__}({self.criterion})'\n\n    def with_reduction(self, mode='none'):\n        res = loss = deepcopy(self)\n        while loss is not None:\n            assert isinstance(loss, Criterion)\n            loss.criterion.reduction = mode  # make it return the loss for each sample\n            loss = loss._loss2  # we assume loss is a Multiloss\n        return res\n\n\nclass MultiLoss (nn.Module):\n    \"\"\" Easily combinable losses (also keep track of individual loss values):\n        loss = MyLoss1() + 0.1*MyLoss2()\n    Usage:\n        Inherit from this class and override get_name() and compute_loss()\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self._alpha = 1\n        self._loss2 = None\n\n    def compute_loss(self, *args, **kwargs):\n        raise NotImplementedError()\n\n    def get_name(self):\n        raise NotImplementedError()\n\n    def __mul__(self, alpha):\n        assert isinstance(alpha, (int, float))\n        res = copy(self)\n        res._alpha = alpha\n        return res\n    __rmul__ = __mul__  # same\n\n    def __add__(self, loss2):\n        assert isinstance(loss2, MultiLoss)\n        res = cur = copy(self)\n        # find the end of the chain\n        while cur._loss2 is not None:\n            cur = cur._loss2\n        cur._loss2 = loss2\n        return res\n\n    def __repr__(self):\n        name = self.get_name()\n        if self._alpha != 1:\n            name = f'{self._alpha:g}*{name}'\n        if self._loss2:\n            name = f'{name} + {self._loss2}'\n        return name\n\n    def forward(self, *args, **kwargs):\n        loss = self.compute_loss(*args, **kwargs)\n        if isinstance(loss, tuple):\n            loss, details = loss\n        elif loss.ndim == 0:\n            details = {self.get_name(): float(loss)}\n        else:\n            details = {}\n        loss = loss * self._alpha\n\n        if self._loss2:\n            loss2, details2 = self._loss2(*args, **kwargs)\n            loss = loss + loss2\n            details |= details2\n\n        return loss, details\n\n\nclass Regr3D (Criterion, MultiLoss):\n    \"\"\" Ensure that all 3D points are correct.\n        Asymmetric loss: view1 is supposed to be the anchor.\n\n        P1 = RT1 @ D1\n        P2 = RT2 @ D2\n        loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)\n        loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)\n              = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)\n    \"\"\"\n\n    def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):\n        super().__init__(criterion)\n        self.norm_mode = norm_mode\n        self.gt_scale = gt_scale\n\n    def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):\n        # everything is normalized w.r.t. camera of view1\n        in_camera1 = inv(gt1['camera_pose'])\n        gt_pts1 = geotrf(in_camera1, gt1['pts3d'])  # B,H,W,3\n        gt_pts2 = geotrf(in_camera1, gt2['pts3d'])  # B,H,W,3\n\n        valid1 = gt1['valid_mask'].clone()\n        valid2 = gt2['valid_mask'].clone()\n\n        if dist_clip is not None:\n            # points that are too far-away == invalid\n            dis1 = gt_pts1.norm(dim=-1)  # (B, H, W)\n            dis2 = gt_pts2.norm(dim=-1)  # (B, H, W)\n            valid1 = valid1 & (dis1 <= dist_clip)\n            valid2 = valid2 & (dis2 <= dist_clip)\n\n        pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False)\n        pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True)\n\n        # normalize 3d points\n        if self.norm_mode:\n            pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)\n        if self.norm_mode and not self.gt_scale:\n            gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)\n\n        return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {}\n\n    def compute_loss(self, gt1, gt2, pred1, pred2, **kw):\n        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \\\n            self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)\n        # loss on img1 side\n        l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1])\n        # loss on gt2 side\n        l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2])\n        self_name = type(self).__name__\n        details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())}\n        return Sum((l1, mask1), (l2, mask2)), (details | monitoring)\n\n\nclass ConfLoss (MultiLoss):\n    \"\"\" Weighted regression by learned confidence.\n        Assuming the input pixel_loss is a pixel-level regression loss.\n\n    Principle:\n        high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)\n        low  confidence means low  conf = 10  ==> conf_loss = x * 10 - alpha*log(10) \n\n        alpha: hyperparameter\n    \"\"\"\n\n    def __init__(self, pixel_loss, alpha=1):\n        super().__init__()\n        assert alpha > 0\n        self.alpha = alpha\n        self.pixel_loss = pixel_loss.with_reduction('none')\n\n    def get_name(self):\n        return f'ConfLoss({self.pixel_loss})'\n\n    def get_conf_log(self, x):\n        return x, torch.log(x)\n\n    def compute_loss(self, gt1, gt2, pred1, pred2, **kw):\n        # compute per-pixel loss\n        ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)\n        if loss1.numel() == 0:\n            print('NO VALID POINTS in img1', force=True)\n        if loss2.numel() == 0:\n            print('NO VALID POINTS in img2', force=True)\n\n        # weight by confidence\n        conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])\n        conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])\n        conf_loss1 = loss1 * conf1 - self.alpha * log_conf1\n        conf_loss2 = loss2 * conf2 - self.alpha * log_conf2\n\n        # average + nan protection (in case of no valid pixels at all)\n        conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0\n        conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0\n\n        return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)\n\n\nclass Regr3D_ShiftInv (Regr3D):\n    \"\"\" Same than Regr3D but invariant to depth shift.\n    \"\"\"\n\n    def get_all_pts3d(self, gt1, gt2, pred1, pred2):\n        # compute unnormalized points\n        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \\\n            super().get_all_pts3d(gt1, gt2, pred1, pred2)\n\n        # compute median depth\n        gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]\n        pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]\n        gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]\n        pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]\n\n        # subtract the median depth\n        gt_z1 -= gt_shift_z\n        gt_z2 -= gt_shift_z\n        pred_z1 -= pred_shift_z\n        pred_z2 -= pred_shift_z\n\n        # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())\n        return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring\n\n\nclass Regr3D_ScaleInv (Regr3D):\n    \"\"\" Same than Regr3D but invariant to depth shift.\n        if gt_scale == True: enforce the prediction to take the same scale than GT\n    \"\"\"\n\n    def get_all_pts3d(self, gt1, gt2, pred1, pred2):\n        # compute depth-normalized points\n        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2)\n\n        # measure scene scale\n        _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)\n        _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)\n\n        # prevent predictions to be in a ridiculous range\n        pred_scale = pred_scale.clip(min=1e-3, max=1e3)\n\n        # subtract the median depth\n        if self.gt_scale:\n            pred_pts1 *= gt_scale / pred_scale\n            pred_pts2 *= gt_scale / pred_scale\n            # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())\n        else:\n            gt_pts1 /= gt_scale\n            gt_pts2 /= gt_scale\n            pred_pts1 /= pred_scale\n            pred_pts2 /= pred_scale\n            # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())\n\n        return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring\n\n\nclass Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):\n    # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv\n    pass\n"
  },
  {
    "path": "dust3r/model.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# DUSt3R model class\n# --------------------------------------------------------\nfrom copy import deepcopy\nimport torch\nimport os\nfrom packaging import version\nimport huggingface_hub\n\nfrom .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape\nfrom .heads import head_factory\nfrom dust3r.patch_embed import get_patch_embed\n\nimport dust3r.utils.path_to_croco  # noqa: F401\nfrom models.croco import CroCoNet  # noqa\n\ninf = float('inf')\n\nhf_version_number = huggingface_hub.__version__\nassert version.parse(hf_version_number) >= version.parse(\"0.22.0\"), (\"Outdated huggingface_hub version, \"\n                                                                     \"please reinstall requirements.txt\")\n\n\ndef load_model(model_path, device, verbose=True):\n    if verbose:\n        print('... loading model from', model_path)\n    ckpt = torch.load(model_path, map_location='cpu')\n    args = ckpt['args'].model.replace(\"ManyAR_PatchEmbed\", \"PatchEmbedDust3R\")\n    if 'landscape_only' not in args:\n        args = args[:-1] + ', landscape_only=False)'\n    else:\n        args = args.replace(\" \", \"\").replace('landscape_only=True', 'landscape_only=False')\n    assert \"landscape_only=False\" in args\n    if verbose:\n        print(f\"instantiating : {args}\")\n    net = eval(args)\n    s = net.load_state_dict(ckpt['model'], strict=False)\n    if verbose:\n        print(s)\n    return net.to(device)\n\n\nclass AsymmetricCroCo3DStereo (\n    CroCoNet,\n    huggingface_hub.PyTorchModelHubMixin,\n    library_name=\"dust3r\",\n    repo_url=\"https://github.com/naver/dust3r\",\n    tags=[\"image-to-3d\"],\n):\n    \"\"\" Two siamese encoders, followed by two decoders.\n    The goal is to output 3d points directly, both images in view1's frame\n    (hence the asymmetry).   \n    \"\"\"\n\n    def __init__(self,\n                 output_mode='pts3d',\n                 head_type='linear',\n                 depth_mode=('exp', -inf, inf),\n                 conf_mode=('exp', 1, inf),\n                 freeze='none',\n                 landscape_only=True,\n                 patch_embed_cls='PatchEmbedDust3R',  # PatchEmbedDust3R or ManyAR_PatchEmbed\n                 **croco_kwargs):\n        self.patch_embed_cls = patch_embed_cls\n        self.croco_args = fill_default_args(croco_kwargs, super().__init__)\n        super().__init__(**croco_kwargs)\n\n        # dust3r specific initialization\n        self.dec_blocks2 = deepcopy(self.dec_blocks)\n        self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)\n        self.set_freeze(freeze)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, **kw):\n        if os.path.isfile(pretrained_model_name_or_path):\n            return load_model(pretrained_model_name_or_path, device='cpu')\n        else:\n            try:\n                model = super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw)\n            except TypeError as e:\n                raise Exception(f'tried to load {pretrained_model_name_or_path} from huggingface, but failed')\n            return model\n\n    def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):\n        self.patch_size = patch_size\n        self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)\n\n    def load_state_dict(self, ckpt, **kw):\n        # duplicate all weights for the second decoder if not present\n        new_ckpt = dict(ckpt)\n        if not any(k.startswith('dec_blocks2') for k in ckpt):\n            for key, value in ckpt.items():\n                if key.startswith('dec_blocks'):\n                    new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value\n        return super().load_state_dict(new_ckpt, **kw)\n\n    def set_freeze(self, freeze):  # this is for use by downstream models\n        self.freeze = freeze\n        to_be_frozen = {\n            'none': [],\n            'mask': [self.mask_token],\n            'encoder': [self.mask_token, self.patch_embed, self.enc_blocks],\n        }\n        freeze_all_params(to_be_frozen[freeze])\n\n    def _set_prediction_head(self, *args, **kwargs):\n        \"\"\" No prediction head \"\"\"\n        return\n\n    def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size,\n                            **kw):\n        assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \\\n            f'{img_size=} must be multiple of {patch_size=}'\n        self.output_mode = output_mode\n        self.head_type = head_type\n        self.depth_mode = depth_mode\n        self.conf_mode = conf_mode\n        # allocate heads\n        self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))\n        self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))\n        # magic wrapper\n        self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)\n        self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)\n\n    def _encode_image(self, image, true_shape):\n        # embed the image into patches  (x has size B x Npatches x C)\n        x, pos = self.patch_embed(image, true_shape=true_shape)\n\n        # add positional embedding without cls token\n        assert self.enc_pos_embed is None\n\n        # now apply the transformer encoder and normalization\n        for blk in self.enc_blocks:\n            x = blk(x, pos)\n\n        x = self.enc_norm(x)\n        return x, pos, None\n\n    def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):\n        if img1.shape[-2:] == img2.shape[-2:]:\n            out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0),\n                                             torch.cat((true_shape1, true_shape2), dim=0))\n            out, out2 = out.chunk(2, dim=0)\n            pos, pos2 = pos.chunk(2, dim=0)\n        else:\n            out, pos, _ = self._encode_image(img1, true_shape1)\n            out2, pos2, _ = self._encode_image(img2, true_shape2)\n        return out, out2, pos, pos2\n\n    def _encode_symmetrized(self, view1, view2):\n        img1 = view1['img']\n        img2 = view2['img']\n        B = img1.shape[0]\n        # Recover true_shape when available, otherwise assume that the img shape is the true one\n        shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))\n        shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))\n        # warning! maybe the images have different portrait/landscape orientations\n\n        if is_symmetrized(view1, view2):\n            # computing half of forward pass!'\n            feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2])\n            feat1, feat2 = interleave(feat1, feat2)\n            pos1, pos2 = interleave(pos1, pos2)\n        else:\n            feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2)\n\n        return (shape1, shape2), (feat1, feat2), (pos1, pos2)\n\n    def _decoder(self, f1, pos1, f2, pos2):\n        final_output = [(f1, f2)]  # before projection\n\n        # project to decoder dim\n        f1 = self.decoder_embed(f1)\n        f2 = self.decoder_embed(f2)\n\n        final_output.append((f1, f2))\n        for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):\n            # img1 side\n            f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)\n            # img2 side\n            f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)\n            # store the result\n            final_output.append((f1, f2))\n\n        # normalize last output\n        del final_output[1]  # duplicate with final_output[0]\n        final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))\n        return zip(*final_output)\n\n    def _downstream_head(self, head_num, decout, img_shape):\n        B, S, D = decout[-1].shape\n        # img_shape = tuple(map(int, img_shape))\n        head = getattr(self, f'head{head_num}')\n        return head(decout, img_shape)\n\n    def forward(self, view1, view2):\n        # encode the two images --> B,S,D\n        (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2)\n\n        # combine all ref images into object-centric representation\n        dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)\n\n        with torch.cuda.amp.autocast(enabled=False):\n            res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)\n            res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)\n\n        res2['pts3d_in_other_view'] = res2.pop('pts3d')  # predict view2's pts3d in view1's frame\n        return res1, res2\n"
  },
  {
    "path": "dust3r/optim_factory.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# optimization functions\n# --------------------------------------------------------\n\n\ndef adjust_learning_rate_by_lr(optimizer, lr):\n    for param_group in optimizer.param_groups:\n        if \"lr_scale\" in param_group:\n            param_group[\"lr\"] = lr * param_group[\"lr_scale\"]\n        else:\n            param_group[\"lr\"] = lr\n"
  },
  {
    "path": "dust3r/patch_embed.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# PatchEmbed implementation for DUST3R,\n# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio\n# --------------------------------------------------------\nimport torch\nimport dust3r.utils.path_to_croco  # noqa: F401\nfrom models.blocks import PatchEmbed  # noqa\n\n\ndef get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):\n    assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']\n    patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)\n    return patch_embed\n\n\nclass PatchEmbedDust3R(PatchEmbed):\n    def forward(self, x, **kw):\n        B, C, H, W = x.shape\n        assert H % self.patch_size[0] == 0, f\"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]}).\"\n        assert W % self.patch_size[1] == 0, f\"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]}).\"\n        x = self.proj(x)\n        pos = self.position_getter(B, x.size(2), x.size(3), x.device)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        return x, pos\n\n\nclass ManyAR_PatchEmbed (PatchEmbed):\n    \"\"\" Handle images with non-square aspect ratio.\n        All images in the same batch have the same aspect ratio.\n        true_shape = [(height, width) ...] indicates the actual shape of each image.\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):\n        self.embed_dim = embed_dim\n        super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)\n\n    def forward(self, img, true_shape):\n        B, C, H, W = img.shape\n        assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'\n        assert H % self.patch_size[0] == 0, f\"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]}).\"\n        assert W % self.patch_size[1] == 0, f\"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]}).\"\n        assert true_shape.shape == (B, 2), f\"true_shape has the wrong shape={true_shape.shape}\"\n\n        # size expressed in tokens\n        W //= self.patch_size[0]\n        H //= self.patch_size[1]\n        n_tokens = H * W\n\n        height, width = true_shape.T\n        is_landscape = (width >= height)\n        is_portrait = ~is_landscape\n\n        # allocate result\n        x = img.new_zeros((B, n_tokens, self.embed_dim))\n        pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)\n\n        # linear projection, transposed if necessary\n        x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()\n        x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()\n\n        pos[is_landscape] = self.position_getter(1, H, W, pos.device)\n        pos[is_portrait] = self.position_getter(1, W, H, pos.device)\n\n        x = self.norm(x)\n        return x, pos\n"
  },
  {
    "path": "dust3r/post_process.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilities for interpreting the DUST3R output\n# --------------------------------------------------------\nimport numpy as np\nimport torch\nfrom dust3r.utils.geometry import xy_grid\n\n\ndef estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf):\n    \"\"\" Reprojection method, for when the absolute depth is known:\n        1) estimate the camera focal using a robust estimator\n        2) reproject points onto true rays, minimizing a certain error\n    \"\"\"\n    B, H, W, THREE = pts3d.shape\n    assert THREE == 3\n\n    # centered pixel grid\n    pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2)  # B,HW,2\n    pts3d = pts3d.flatten(1, 2)  # (B, HW, 3)\n\n    if focal_mode == 'median':\n        with torch.no_grad():\n            # direct estimation of focal\n            u, v = pixels.unbind(dim=-1)\n            x, y, z = pts3d.unbind(dim=-1)\n            fx_votes = (u * z) / x\n            fy_votes = (v * z) / y\n\n            # assume square pixels, hence same focal for X and Y\n            f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)\n            focal = torch.nanmedian(f_votes, dim=-1).values\n\n    elif focal_mode == 'weiszfeld':\n        # init focal with l2 closed form\n        # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|\n        xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0)  # homogeneous (x,y,1)\n\n        dot_xy_px = (xy_over_z * pixels).sum(dim=-1)\n        dot_xy_xy = xy_over_z.square().sum(dim=-1)\n\n        focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)\n\n        # iterative re-weighted least-squares\n        for iter in range(10):\n            # re-weighting by inverse of distance\n            dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)\n            # print(dis.nanmean(-1))\n            w = dis.clip(min=1e-8).reciprocal()\n            # update the scaling with the new weights\n            focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)\n    else:\n        raise ValueError(f'bad {focal_mode=}')\n\n    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  # size / 1.1547005383792515\n    focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)\n    # print(focal)\n    return focal\n"
  },
  {
    "path": "dust3r/training.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# training code for DUSt3R\n# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n# DeiT: https://github.com/facebookresearch/deit\n# BEiT: https://github.com/microsoft/unilm/tree/master/beit\n# --------------------------------------------------------\nimport argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport sys\nimport time\nimport math\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Sized\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.tensorboard import SummaryWriter\ntorch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12\n\nfrom dust3r.model import AsymmetricCroCo3DStereo, inf  # noqa: F401, needed when loading the model\nfrom dust3r.datasets import get_data_loader  # noqa\nfrom dust3r.losses import *  # noqa: F401, needed when loading the model\nfrom dust3r.inference import loss_of_one_batch  # noqa\n\nimport dust3r.utils.path_to_croco  # noqa: F401\nimport croco.utils.misc as misc  # noqa\nfrom croco.utils.misc import NativeScalerWithGradNormCount as NativeScaler  # noqa\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser('DUST3R training', add_help=False)\n    # model and criterion\n    parser.add_argument('--model', default=\"AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')\",\n                        type=str, help=\"string containing the model to build\")\n    parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint')\n    parser.add_argument('--train_criterion', default=\"ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)\",\n                        type=str, help=\"train criterion\")\n    parser.add_argument('--test_criterion', default=None, type=str, help=\"test criterion\")\n\n    # dataset\n    parser.add_argument('--train_dataset', required=True, type=str, help=\"training set\")\n    parser.add_argument('--test_dataset', default='[None]', type=str, help=\"testing set\")\n\n    # training\n    parser.add_argument('--seed', default=0, type=int, help=\"Random seed\")\n    parser.add_argument('--batch_size', default=64, type=int,\n                        help=\"Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus\")\n    parser.add_argument('--accum_iter', default=1, type=int,\n                        help=\"Accumulate gradient iterations (for increasing the effective batch size under memory constraints)\")\n    parser.add_argument('--epochs', default=800, type=int, help=\"Maximum number of epochs for the scheduler\")\n\n    parser.add_argument('--weight_decay', type=float, default=0.05, help=\"weight decay (default: 0.05)\")\n    parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')\n    parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',\n                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')\n    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0')\n    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')\n\n    parser.add_argument('--amp', type=int, default=0,\n                        choices=[0, 1], help=\"Use Automatic Mixed Precision for pretraining\")\n    parser.add_argument(\"--disable_cudnn_benchmark\", action='store_true', default=False,\n                        help=\"set cudnn.benchmark = False\")\n    # others\n    parser.add_argument('--num_workers', default=8, type=int)\n    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')\n\n    parser.add_argument('--eval_freq', type=int, default=1, help='Test loss evaluation frequency')\n    parser.add_argument('--save_freq', default=1, type=int,\n                        help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')\n    parser.add_argument('--keep_freq', default=20, type=int,\n                        help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')\n    parser.add_argument('--print_freq', default=20, type=int,\n                        help='frequence (number of iterations) to print infos while training')\n\n    # output dir\n    parser.add_argument('--output_dir', default='./output/', type=str, help=\"path where to save the output\")\n    return parser\n\n\ndef train(args):\n    misc.init_distributed_mode(args)\n    global_rank = misc.get_rank()\n    world_size = misc.get_world_size()\n\n    print(\"output_dir: \" + args.output_dir)\n    if args.output_dir:\n        Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n\n    # auto resume\n    last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')\n    args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None\n\n    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))\n    print(\"{}\".format(args).replace(', ', ',\\n'))\n\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    device = torch.device(device)\n\n    # fix the seed\n    seed = args.seed + misc.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    cudnn.benchmark = not args.disable_cudnn_benchmark\n\n    # training dataset and loader\n    print('Building train dataset {:s}'.format(args.train_dataset))\n    #  dataset and loader\n    data_loader_train = build_dataset(args.train_dataset, args.batch_size, args.num_workers, test=False)\n    print('Building test dataset {:s}'.format(args.train_dataset))\n    data_loader_test = {dataset.split('(')[0]: build_dataset(dataset, args.batch_size, args.num_workers, test=True)\n                        for dataset in args.test_dataset.split('+')}\n\n    # model\n    print('Loading model: {:s}'.format(args.model))\n    model = eval(args.model)\n    print(f'>> Creating train criterion = {args.train_criterion}')\n    train_criterion = eval(args.train_criterion).to(device)\n    print(f'>> Creating test criterion = {args.test_criterion or args.train_criterion}')\n    test_criterion = eval(args.test_criterion or args.criterion).to(device)\n\n    model.to(device)\n    model_without_ddp = model\n    print(\"Model = %s\" % str(model_without_ddp))\n\n    if args.pretrained and not args.resume:\n        print('Loading pretrained: ', args.pretrained)\n        ckpt = torch.load(args.pretrained, map_location=device)\n        print(model.load_state_dict(ckpt['model'], strict=False))\n        del ckpt  # in case it occupies memory\n\n    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()\n    if args.lr is None:  # only base_lr is specified\n        args.lr = args.blr * eff_batch_size / 256\n    print(\"base lr: %.2e\" % (args.lr * 256 / eff_batch_size))\n    print(\"actual lr: %.2e\" % args.lr)\n    print(\"accumulate grad iterations: %d\" % args.accum_iter)\n    print(\"effective batch size: %d\" % eff_batch_size)\n\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(\n            model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True)\n        model_without_ddp = model.module\n\n    # following timm: set wd as 0 for bias and norm layers\n    param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)\n    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))\n    print(optimizer)\n    loss_scaler = NativeScaler()\n\n    def write_log_stats(epoch, train_stats, test_stats):\n        if misc.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n\n            log_stats = dict(epoch=epoch, **{f'train_{k}': v for k, v in train_stats.items()})\n            for test_name in data_loader_test:\n                if test_name not in test_stats:\n                    continue\n                log_stats.update({test_name + '_' + k: v for k, v in test_stats[test_name].items()})\n\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    def save_model(epoch, fname, best_so_far):\n        misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                        loss_scaler=loss_scaler, epoch=epoch, fname=fname, best_so_far=best_so_far)\n\n    best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp,\n                                  optimizer=optimizer, loss_scaler=loss_scaler)\n    if best_so_far is None:\n        best_so_far = float('inf')\n    if global_rank == 0 and args.output_dir is not None:\n        log_writer = SummaryWriter(log_dir=args.output_dir)\n    else:\n        log_writer = None\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    train_stats = test_stats = {}\n    for epoch in range(args.start_epoch, args.epochs + 1):\n\n        # Save immediately the last checkpoint\n        if epoch > args.start_epoch:\n            if args.save_freq and epoch % args.save_freq == 0 or epoch == args.epochs:\n                save_model(epoch - 1, 'last', best_so_far)\n\n        # Test on multiple datasets\n        new_best = False\n        if (epoch > 0 and args.eval_freq > 0 and epoch % args.eval_freq == 0):\n            test_stats = {}\n            for test_name, testset in data_loader_test.items():\n                stats = test_one_epoch(model, test_criterion, testset,\n                                       device, epoch, log_writer=log_writer, args=args, prefix=test_name)\n                test_stats[test_name] = stats\n\n                # Save best of all\n                if stats['loss_med'] < best_so_far:\n                    best_so_far = stats['loss_med']\n                    new_best = True\n\n        # Save more stuff\n        write_log_stats(epoch, train_stats, test_stats)\n\n        if epoch > args.start_epoch:\n            if args.keep_freq and epoch % args.keep_freq == 0:\n                save_model(epoch - 1, str(epoch), best_so_far)\n            if new_best:\n                save_model(epoch - 1, 'best', best_so_far)\n        if epoch >= args.epochs:\n            break  # exit after writing last test to disk\n\n        # Train\n        train_stats = train_one_epoch(\n            model, train_criterion, data_loader_train,\n            optimizer, device, epoch, loss_scaler,\n            log_writer=log_writer,\n            args=args)\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n    save_final_model(args, args.epochs, model_without_ddp, best_so_far=best_so_far)\n\n\ndef save_final_model(args, epoch, model_without_ddp, best_so_far=None):\n    output_dir = Path(args.output_dir)\n    checkpoint_path = output_dir / 'checkpoint-final.pth'\n    to_save = {\n        'args': args,\n        'model': model_without_ddp if isinstance(model_without_ddp, dict) else model_without_ddp.cpu().state_dict(),\n        'epoch': epoch\n    }\n    if best_so_far is not None:\n        to_save['best_so_far'] = best_so_far\n    print(f'>> Saving model to {checkpoint_path} ...')\n    misc.save_on_master(to_save, checkpoint_path)\n\n\ndef build_dataset(dataset, batch_size, num_workers, test=False):\n    split = ['Train', 'Test'][test]\n    print(f'Building {split} Data loader for dataset: ', dataset)\n    loader = get_data_loader(dataset,\n                             batch_size=batch_size,\n                             num_workers=num_workers,\n                             pin_mem=True,\n                             shuffle=not (test),\n                             drop_last=not (test))\n\n    print(f\"{split} dataset length: \", len(loader))\n    return loader\n\n\ndef train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,\n                    data_loader: Sized, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler,\n                    args,\n                    log_writer=None):\n    assert torch.backends.cuda.matmul.allow_tf32 == True\n\n    model.train(True)\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    accum_iter = args.accum_iter\n\n    if log_writer is not None:\n        print('log_dir: {}'.format(log_writer.log_dir))\n\n    if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):\n        data_loader.dataset.set_epoch(epoch)\n    if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):\n        data_loader.sampler.set_epoch(epoch)\n\n    optimizer.zero_grad()\n\n    for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):\n        epoch_f = epoch + data_iter_step / len(data_loader)\n\n        # we use a per iteration (instead of per epoch) lr scheduler\n        if data_iter_step % accum_iter == 0:\n            misc.adjust_learning_rate(optimizer, epoch_f, args)\n\n        loss_tuple = loss_of_one_batch(batch, model, criterion, device,\n                                       symmetrize_batch=True,\n                                       use_amp=bool(args.amp), ret='loss')\n        loss, loss_details = loss_tuple  # criterion returns two values\n        loss_value = float(loss)\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value), force=True)\n            sys.exit(1)\n\n        loss /= accum_iter\n        loss_scaler(loss, optimizer, parameters=model.parameters(),\n                    update_grad=(data_iter_step + 1) % accum_iter == 0)\n        if (data_iter_step + 1) % accum_iter == 0:\n            optimizer.zero_grad()\n\n        del loss\n        del batch\n\n        lr = optimizer.param_groups[0][\"lr\"]\n        metric_logger.update(epoch=epoch_f)\n        metric_logger.update(lr=lr)\n        metric_logger.update(loss=loss_value, **loss_details)\n\n        if (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0:\n            loss_value_reduce = misc.all_reduce_mean(loss_value)  # MUST BE EXECUTED BY ALL NODES\n            if log_writer is None:\n                continue\n            \"\"\" We use epoch_1000x as the x-axis in tensorboard.\n            This calibrates different curves when batch size changes.\n            \"\"\"\n            epoch_1000x = int(epoch_f * 1000)\n            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)\n            log_writer.add_scalar('train_lr', lr, epoch_1000x)\n            log_writer.add_scalar('train_iter', epoch_1000x, epoch_1000x)\n            for name, val in loss_details.items():\n                log_writer.add_scalar('train_' + name, val, epoch_1000x)\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n\n\n@torch.no_grad()\ndef test_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,\n                   data_loader: Sized, device: torch.device, epoch: int,\n                   args, log_writer=None, prefix='test'):\n\n    model.eval()\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    metric_logger.meters = defaultdict(lambda: misc.SmoothedValue(window_size=9**9))\n    header = 'Test Epoch: [{}]'.format(epoch)\n\n    if log_writer is not None:\n        print('log_dir: {}'.format(log_writer.log_dir))\n\n    if hasattr(data_loader, 'dataset') and hasattr(data_loader.dataset, 'set_epoch'):\n        data_loader.dataset.set_epoch(epoch)\n    if hasattr(data_loader, 'sampler') and hasattr(data_loader.sampler, 'set_epoch'):\n        data_loader.sampler.set_epoch(epoch)\n\n    for _, batch in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):\n        loss_tuple = loss_of_one_batch(batch, model, criterion, device,\n                                       symmetrize_batch=True,\n                                       use_amp=bool(args.amp), ret='loss')\n        loss_value, loss_details = loss_tuple  # criterion returns two values\n        metric_logger.update(loss=float(loss_value), **loss_details)\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n\n    aggs = [('avg', 'global_avg'), ('med', 'median')]\n    results = {f'{k}_{tag}': getattr(meter, attr) for k, meter in metric_logger.meters.items() for tag, attr in aggs}\n\n    if log_writer is not None:\n        for name, val in results.items():\n            log_writer.add_scalar(prefix + '_' + name, val, 1000 * epoch)\n\n    return results\n"
  },
  {
    "path": "dust3r/utils/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "dust3r/utils/device.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilitary functions for DUSt3R\n# --------------------------------------------------------\nimport numpy as np\nimport torch\n\n\ndef todevice(batch, device, callback=None, non_blocking=False):\n    ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).\n\n    batch: list, tuple, dict of tensors or other things\n    device: pytorch device or 'numpy'\n    callback: function that would be called on every sub-elements.\n    '''\n    if callback:\n        batch = callback(batch)\n\n    if isinstance(batch, dict):\n        return {k: todevice(v, device) for k, v in batch.items()}\n\n    if isinstance(batch, (tuple, list)):\n        return type(batch)(todevice(x, device) for x in batch)\n\n    x = batch\n    if device == 'numpy':\n        if isinstance(x, torch.Tensor):\n            x = x.detach().cpu().numpy()\n    elif x is not None:\n        if isinstance(x, np.ndarray):\n            x = torch.from_numpy(x)\n        if torch.is_tensor(x):\n            x = x.to(device, non_blocking=non_blocking)\n    return x\n\n\nto_device = todevice  # alias\n\n\ndef to_numpy(x): return todevice(x, 'numpy')\ndef to_cpu(x): return todevice(x, 'cpu')\ndef to_cuda(x): return todevice(x, 'cuda')\n\n\ndef collate_with_cat(whatever, lists=False):\n    if isinstance(whatever, dict):\n        return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}\n\n    elif isinstance(whatever, (tuple, list)):\n        if len(whatever) == 0:\n            return whatever\n        elem = whatever[0]\n        T = type(whatever)\n\n        if elem is None:\n            return None\n        if isinstance(elem, (bool, float, int, str)):\n            return whatever\n        if isinstance(elem, tuple):\n            return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))\n        if isinstance(elem, dict):\n            return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}\n\n        if isinstance(elem, torch.Tensor):\n            return listify(whatever) if lists else torch.cat(whatever)\n        if isinstance(elem, np.ndarray):\n            return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])\n\n        # otherwise, we just chain lists\n        return sum(whatever, T())\n\n\ndef listify(elems):\n    return [x for e in elems for x in e]\n"
  },
  {
    "path": "dust3r/utils/geometry.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# geometry utilitary functions\n# --------------------------------------------------------\nimport torch\nimport numpy as np\nfrom scipy.spatial import cKDTree as KDTree\n\nfrom dust3r.utils.misc import invalid_to_zeros, invalid_to_nans\nfrom dust3r.utils.device import to_numpy\n\n\ndef xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):\n    \"\"\" Output a (H,W,2) array of int32 \n        with output[j,i,0] = i + origin[0]\n             output[j,i,1] = j + origin[1]\n    \"\"\"\n    if device is None:\n        # numpy\n        arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones\n    else:\n        # torch\n        arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)\n        meshgrid, stack = torch.meshgrid, torch.stack\n        ones = lambda *a: torch.ones(*a, device=device)\n\n    tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]\n    grid = meshgrid(tw, th, indexing='xy')\n    if homogeneous:\n        grid = grid + (ones((H, W)),)\n    if unsqueeze is not None:\n        grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))\n    if cat_dim is not None:\n        grid = stack(grid, cat_dim)\n    return grid\n\n\ndef geotrf(Trf, pts, ncol=None, norm=False):\n    \"\"\" Apply a geometric transformation to a list of 3-D points.\n\n    H: 3x3 or 4x4 projection matrix (typically a Homography)\n    p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)\n\n    ncol: int. number of columns of the result (2 or 3)\n    norm: float. if != 0, the resut is projected on the z=norm plane.\n\n    Returns an array of projected 2d points.\n    \"\"\"\n    assert Trf.ndim >= 2\n    if isinstance(Trf, np.ndarray):\n        pts = np.asarray(pts)\n    elif isinstance(Trf, torch.Tensor):\n        pts = torch.as_tensor(pts, dtype=Trf.dtype)\n\n    # adapt shape if necessary\n    output_reshape = pts.shape[:-1]\n    ncol = ncol or pts.shape[-1]\n\n    # optimized code\n    if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and\n            Trf.ndim == 3 and pts.ndim == 4):\n        d = pts.shape[3]\n        if Trf.shape[-1] == d:\n            pts = torch.einsum(\"bij, bhwj -> bhwi\", Trf, pts)\n        elif Trf.shape[-1] == d + 1:\n            pts = torch.einsum(\"bij, bhwj -> bhwi\", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]\n        else:\n            raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')\n    else:\n        if Trf.ndim >= 3:\n            n = Trf.ndim - 2\n            assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'\n            Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])\n\n            if pts.ndim > Trf.ndim:\n                # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)\n                pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])\n            elif pts.ndim == 2:\n                # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)\n                pts = pts[:, None, :]\n\n        if pts.shape[-1] + 1 == Trf.shape[-1]:\n            Trf = Trf.swapaxes(-1, -2)  # transpose Trf\n            pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]\n        elif pts.shape[-1] == Trf.shape[-1]:\n            Trf = Trf.swapaxes(-1, -2)  # transpose Trf\n            pts = pts @ Trf\n        else:\n            pts = Trf @ pts.T\n            if pts.ndim >= 2:\n                pts = pts.swapaxes(-1, -2)\n\n    if norm:\n        pts = pts / pts[..., -1:]  # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG\n        if norm != 1:\n            pts *= norm\n\n    res = pts[..., :ncol].reshape(*output_reshape, ncol)\n    return res\n\n\ndef inv(mat):\n    \"\"\" Invert a torch or numpy matrix\n    \"\"\"\n    if isinstance(mat, torch.Tensor):\n        return torch.linalg.inv(mat)\n    if isinstance(mat, np.ndarray):\n        return np.linalg.inv(mat)\n    raise ValueError(f'bad matrix type = {type(mat)}')\n\n\ndef depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):\n    \"\"\"\n    Args:\n        - depthmap (BxHxW array):\n        - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]\n    Returns:\n        pointmap of absolute coordinates (BxHxWx3 array)\n    \"\"\"\n\n    if len(depth.shape) == 4:\n        B, H, W, n = depth.shape\n    else:\n        B, H, W = depth.shape\n        n = None\n\n    if len(pseudo_focal.shape) == 3:  # [B,H,W]\n        pseudo_focalx = pseudo_focaly = pseudo_focal\n    elif len(pseudo_focal.shape) == 4:  # [B,2,H,W] or [B,1,H,W]\n        pseudo_focalx = pseudo_focal[:, 0]\n        if pseudo_focal.shape[1] == 2:\n            pseudo_focaly = pseudo_focal[:, 1]\n        else:\n            pseudo_focaly = pseudo_focalx\n    else:\n        raise NotImplementedError(\"Error, unknown input focal shape format.\")\n\n    assert pseudo_focalx.shape == depth.shape[:3]\n    assert pseudo_focaly.shape == depth.shape[:3]\n    grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]\n\n    # set principal point\n    if pp is None:\n        grid_x = grid_x - (W - 1) / 2\n        grid_y = grid_y - (H - 1) / 2\n    else:\n        grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]\n        grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]\n\n    if n is None:\n        pts3d = torch.empty((B, H, W, 3), device=depth.device)\n        pts3d[..., 0] = depth * grid_x / pseudo_focalx\n        pts3d[..., 1] = depth * grid_y / pseudo_focaly\n        pts3d[..., 2] = depth\n    else:\n        pts3d = torch.empty((B, H, W, 3, n), device=depth.device)\n        pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]\n        pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]\n        pts3d[..., 2, :] = depth\n    return pts3d\n\n\ndef depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):\n    \"\"\"\n    Args:\n        - depthmap (HxW array):\n        - camera_intrinsics: a 3x3 matrix\n    Returns:\n        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.\n    \"\"\"\n    camera_intrinsics = np.float32(camera_intrinsics)\n    H, W = depthmap.shape\n\n    # Compute 3D ray associated with each pixel\n    # Strong assumption: there are no skew terms\n    assert camera_intrinsics[0, 1] == 0.0\n    assert camera_intrinsics[1, 0] == 0.0\n    if pseudo_focal is None:\n        fu = camera_intrinsics[0, 0]\n        fv = camera_intrinsics[1, 1]\n    else:\n        assert pseudo_focal.shape == (H, W)\n        fu = fv = pseudo_focal\n    cu = camera_intrinsics[0, 2]\n    cv = camera_intrinsics[1, 2]\n\n    u, v = np.meshgrid(np.arange(W), np.arange(H))\n    z_cam = depthmap\n    x_cam = (u - cu) * z_cam / fu\n    y_cam = (v - cv) * z_cam / fv\n    X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)\n\n    # Mask for valid coordinates\n    valid_mask = (depthmap > 0.0)\n    return X_cam, valid_mask\n\n\ndef depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):\n    \"\"\"\n    Args:\n        - depthmap (HxW array):\n        - camera_intrinsics: a 3x3 matrix\n        - camera_pose: a 4x3 or 4x4 cam2world matrix\n    Returns:\n        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.\"\"\"\n    X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)\n\n    X_world = X_cam # default\n    if camera_pose is not None:\n        # R_cam2world = np.float32(camera_params[\"R_cam2world\"])\n        # t_cam2world = np.float32(camera_params[\"t_cam2world\"]).squeeze()\n        R_cam2world = camera_pose[:3, :3]\n        t_cam2world = camera_pose[:3, 3]\n\n        # Express in absolute coordinates (invalid depth values)\n        X_world = np.einsum(\"ik, vuk -> vui\", R_cam2world, X_cam) + t_cam2world[None, None, :]\n\n    return X_world, valid_mask\n\n\ndef colmap_to_opencv_intrinsics(K):\n    \"\"\"\n    Modify camera intrinsics to follow a different convention.\n    Coordinates of the center of the top-left pixels are by default:\n    - (0.5, 0.5) in Colmap\n    - (0,0) in OpenCV\n    \"\"\"\n    K = K.copy()\n    K[0, 2] -= 0.5\n    K[1, 2] -= 0.5\n    return K\n\n\ndef opencv_to_colmap_intrinsics(K):\n    \"\"\"\n    Modify camera intrinsics to follow a different convention.\n    Coordinates of the center of the top-left pixels are by default:\n    - (0.5, 0.5) in Colmap\n    - (0,0) in OpenCV\n    \"\"\"\n    K = K.copy()\n    K[0, 2] += 0.5\n    K[1, 2] += 0.5\n    return K\n\n\ndef normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None, ret_factor=False):\n    \"\"\" renorm pointmaps pts1, pts2 with norm_mode\n    \"\"\"\n    assert pts1.ndim >= 3 and pts1.shape[-1] == 3\n    assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)\n    norm_mode, dis_mode = norm_mode.split('_')\n\n    if norm_mode == 'avg':\n        # gather all points together (joint normalization)\n        nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)\n        nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)\n        all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1\n\n        # compute distance to origin\n        all_dis = all_pts.norm(dim=-1)\n        if dis_mode == 'dis':\n            pass  # do nothing\n        elif dis_mode == 'log1p':\n            all_dis = torch.log1p(all_dis)\n        elif dis_mode == 'warp-log1p':\n            # actually warp input points before normalizing them\n            log_dis = torch.log1p(all_dis)\n            warp_factor = log_dis / all_dis.clip(min=1e-8)\n            H1, W1 = pts1.shape[1:-1]\n            pts1 = pts1 * warp_factor[:, :W1 * H1].view(-1, H1, W1, 1)\n            if pts2 is not None:\n                H2, W2 = pts2.shape[1:-1]\n                pts2 = pts2 * warp_factor[:, W1 * H1:].view(-1, H2, W2, 1)\n            all_dis = log_dis  # this is their true distance afterwards\n        else:\n            raise ValueError(f'bad {dis_mode=}')\n\n        norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)\n    else:\n        # gather all points together (joint normalization)\n        nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)\n        nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None\n        all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1\n\n        # compute distance to origin\n        all_dis = all_pts.norm(dim=-1)\n\n        if norm_mode == 'avg':\n            norm_factor = all_dis.nanmean(dim=1)\n        elif norm_mode == 'median':\n            norm_factor = all_dis.nanmedian(dim=1).values.detach()\n        elif norm_mode == 'sqrt':\n            norm_factor = all_dis.sqrt().nanmean(dim=1)**2\n        else:\n            raise ValueError(f'bad {norm_mode=}')\n\n    norm_factor = norm_factor.clip(min=1e-8)\n    while norm_factor.ndim < pts1.ndim:\n        norm_factor.unsqueeze_(-1)\n\n    res = pts1 / norm_factor\n    if pts2 is not None:\n        res = (res, pts2 / norm_factor)\n    if ret_factor:\n        res = res + (norm_factor,)\n    return res\n\n\n@torch.no_grad()\ndef get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):\n    # set invalid points to NaN\n    _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)\n    _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None\n    _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1\n\n    # compute median depth overall (ignoring nans)\n    if quantile == 0.5:\n        shift_z = torch.nanmedian(_z, dim=-1).values\n    else:\n        shift_z = torch.nanquantile(_z, quantile, dim=-1)\n    return shift_z  # (B,)\n\n\n@torch.no_grad()\ndef get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):\n    # set invalid points to NaN\n    _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)\n    _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None\n    _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1\n\n    # compute median center\n    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)\n    if z_only:\n        _center[..., :2] = 0  # do not center X and Y\n\n    # compute median norm\n    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)\n    scale = torch.nanmedian(_norm, dim=1).values\n    return _center[:, None, :, :], scale[:, None, None, None]\n\n\ndef find_reciprocal_matches(P1, P2):\n    \"\"\"\n    returns 3 values:\n    1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a \"True\" value indicates a match\n    2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1\n    3 - reciprocal_in_P2.sum(): the number of matches\n    \"\"\"\n    tree1 = KDTree(P1)\n    tree2 = KDTree(P2)\n\n    _, nn1_in_P2 = tree2.query(P1, workers=8)\n    _, nn2_in_P1 = tree1.query(P2, workers=8)\n\n    reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))\n    reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))\n    assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()\n    return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()\n\n\ndef get_med_dist_between_poses(poses):\n    from scipy.spatial.distance import pdist\n    return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))\n"
  },
  {
    "path": "dust3r/utils/image.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilitary functions about images (loading/converting...)\n# --------------------------------------------------------\nimport os\nimport torch\nimport numpy as np\nimport PIL.Image\nfrom PIL.ImageOps import exif_transpose\nimport torchvision.transforms as tvf\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2  # noqa\n\ntry:\n    from pillow_heif import register_heif_opener  # noqa\n    register_heif_opener()\n    heif_support_enabled = True\nexcept ImportError:\n    heif_support_enabled = False\n\nImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n\n\ndef img_to_arr(img):\n    if isinstance(img, str):\n        img = imread_cv2(img)\n    return img\n\n\ndef imread_cv2(path, options=cv2.IMREAD_COLOR):\n    \"\"\" Open an image or a depthmap with opencv-python.\n    \"\"\"\n    if path.endswith(('.exr', 'EXR')):\n        options = cv2.IMREAD_ANYDEPTH\n    img = cv2.imread(path, options)\n    if img is None:\n        raise IOError(f'Could not load image={path} with {options=}')\n    if img.ndim == 3:\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef rgb(ftensor, true_shape=None):\n    if isinstance(ftensor, list):\n        return [rgb(x, true_shape=true_shape) for x in ftensor]\n    if isinstance(ftensor, torch.Tensor):\n        ftensor = ftensor.detach().cpu().numpy()  # H,W,3\n    if ftensor.ndim == 3 and ftensor.shape[0] == 3:\n        ftensor = ftensor.transpose(1, 2, 0)\n    elif ftensor.ndim == 4 and ftensor.shape[1] == 3:\n        ftensor = ftensor.transpose(0, 2, 3, 1)\n    if true_shape is not None:\n        H, W = true_shape\n        ftensor = ftensor[:H, :W]\n    if ftensor.dtype == np.uint8:\n        img = np.float32(ftensor) / 255\n    else:\n        img = (ftensor * 0.5) + 0.5\n    return img.clip(min=0, max=1)\n\n\ndef _resize_pil_image(img, long_edge_size):\n    S = max(img.size)\n    if S > long_edge_size:\n        interp = PIL.Image.LANCZOS\n    elif S <= long_edge_size:\n        interp = PIL.Image.BICUBIC\n    new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)\n    return img.resize(new_size, interp)\n\n\ndef load_images(folder_or_list, size, square_ok=False, verbose=True, patch_size=16):\n    \"\"\" open and convert all images in a list or folder to proper input format for DUSt3R\n    \"\"\"\n    if isinstance(folder_or_list, str):\n        if verbose:\n            print(f'>> Loading images from {folder_or_list}')\n        root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))\n\n    elif isinstance(folder_or_list, list):\n        if verbose:\n            print(f'>> Loading a list of {len(folder_or_list)} images')\n        root, folder_content = '', folder_or_list\n\n    else:\n        raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')\n\n    supported_images_extensions = ['.jpg', '.jpeg', '.png']\n    if heif_support_enabled:\n        supported_images_extensions += ['.heic', '.heif']\n    supported_images_extensions = tuple(supported_images_extensions)\n\n    imgs = []\n    for path in folder_content:\n        if not path.lower().endswith(supported_images_extensions):\n            continue\n        img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')\n        W1, H1 = img.size\n        if size == 224:\n            # resize short side to 224 (then crop)\n            img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))\n        else:\n            # resize long side to 512\n            img = _resize_pil_image(img, size)\n        W, H = img.size\n        cx, cy = W//2, H//2\n        if size == 224:\n            half = min(cx, cy)\n            img = img.crop((cx-half, cy-half, cx+half, cy+half))\n        else:\n            halfw = ((2 * cx) // patch_size) * patch_size / 2\n            halfh = ((2 * cy) // patch_size) * patch_size / 2\n            if not (square_ok) and W == H:\n                halfh = 3*halfw/4\n            img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))\n\n        W2, H2 = img.size\n        if verbose:\n            print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')\n        imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(\n            [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))\n\n    assert imgs, 'no images foud at '+root\n    if verbose:\n        print(f' (Found {len(imgs)} images)')\n    return imgs\n"
  },
  {
    "path": "dust3r/utils/misc.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilitary functions for DUSt3R\n# --------------------------------------------------------\nimport torch\n\n\ndef fill_default_args(kwargs, func):\n    import inspect  # a bit hacky but it works reliably\n    signature = inspect.signature(func)\n\n    for k, v in signature.parameters.items():\n        if v.default is inspect.Parameter.empty:\n            continue\n        kwargs.setdefault(k, v.default)\n\n    return kwargs\n\n\ndef freeze_all_params(modules):\n    for module in modules:\n        try:\n            for n, param in module.named_parameters():\n                param.requires_grad = False\n        except AttributeError:\n            # module is directly a parameter\n            module.requires_grad = False\n\n\ndef is_symmetrized(gt1, gt2):\n    x = gt1['instance']\n    y = gt2['instance']\n    if len(x) == len(y) and len(x) == 1:\n        return False  # special case of batchsize 1\n    ok = True\n    for i in range(0, len(x), 2):\n        ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])\n    return ok\n\n\ndef flip(tensor):\n    \"\"\" flip so that tensor[0::2] <=> tensor[1::2] \"\"\"\n    return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)\n\n\ndef interleave(tensor1, tensor2):\n    res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)\n    res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)\n    return res1, res2\n\n\ndef transpose_to_landscape(head, activate=True):\n    \"\"\" Predict in the correct aspect-ratio,\n        then transpose the result in landscape \n        and stack everything back together.\n    \"\"\"\n    def wrapper_no(decout, true_shape):\n        B = len(true_shape)\n        assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'\n        H, W = true_shape[0].cpu().tolist()\n        res = head(decout, (H, W))\n        return res\n\n    def wrapper_yes(decout, true_shape):\n        B = len(true_shape)\n        # by definition, the batch is in landscape mode so W >= H\n        H, W = int(true_shape.min()), int(true_shape.max())\n\n        height, width = true_shape.T\n        is_landscape = (width >= height)\n        is_portrait = ~is_landscape\n\n        # true_shape = true_shape.cpu()\n        if is_landscape.all():\n            return head(decout, (H, W))\n        if is_portrait.all():\n            return transposed(head(decout, (W, H)))\n\n        # batch is a mix of both portraint & landscape\n        def selout(ar): return [d[ar] for d in decout]\n        l_result = head(selout(is_landscape), (H, W))\n        p_result = transposed(head(selout(is_portrait), (W, H)))\n\n        # allocate full result\n        result = {}\n        for k in l_result | p_result:\n            x = l_result[k].new(B, *l_result[k].shape[1:])\n            x[is_landscape] = l_result[k]\n            x[is_portrait] = p_result[k]\n            result[k] = x\n\n        return result\n\n    return wrapper_yes if activate else wrapper_no\n\n\ndef transposed(dic):\n    return {k: v.swapaxes(1, 2) for k, v in dic.items()}\n\n\ndef invalid_to_nans(arr, valid_mask, ndim=999):\n    if valid_mask is not None:\n        arr = arr.clone()\n        arr[~valid_mask] = float('nan')\n    if arr.ndim > ndim:\n        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)\n    return arr\n\n\ndef invalid_to_zeros(arr, valid_mask, ndim=999):\n    if valid_mask is not None:\n        arr = arr.clone()\n        arr[~valid_mask] = 0\n        nnz = valid_mask.view(len(valid_mask), -1).sum(1)\n    else:\n        nnz = arr.numel() // len(arr) if len(arr) else 0  # number of point per image\n    if arr.ndim > ndim:\n        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)\n    return arr, nnz\n"
  },
  {
    "path": "dust3r/utils/parallel.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# utilitary functions for multiprocessing\n# --------------------------------------------------------\nfrom tqdm import tqdm\nfrom multiprocessing.dummy import Pool as ThreadPool\nfrom multiprocessing import cpu_count\n\n\ndef parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw):\n    \"\"\" tqdm but with parallel execution.\n\n    Will essentially return \n      res = [ function(arg) # default\n              function(*arg) # if star_args is True\n              function(**arg) # if kw_args is True\n              for arg in args]\n\n    Note:\n        the <front_num> first elements of args will not be parallelized. \n        This can be useful for debugging.\n    \"\"\"\n    while workers <= 0:\n        workers += cpu_count()\n    if workers == 1:\n        front_num = float('inf')\n\n    # convert into an iterable\n    try:\n        n_args_parallel = len(args) - front_num\n    except TypeError:\n        n_args_parallel = None\n    args = iter(args)\n\n    # sequential execution first\n    front = []\n    while len(front) < front_num:\n        try:\n            a = next(args)\n        except StopIteration:\n            return front  # end of the iterable\n        front.append(function(*a) if star_args else function(**a) if kw_args else function(a))\n\n    # then parallel execution\n    out = []\n    with Pool(workers) as pool:\n        # Pass the elements of args into function\n        if star_args:\n            futures = pool.imap(starcall, [(function, a) for a in args])\n        elif kw_args:\n            futures = pool.imap(starstarcall, [(function, a) for a in args])\n        else:\n            futures = pool.imap(function, args)\n        # Print out the progress as tasks complete\n        for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):\n            out.append(f)\n    return front + out\n\n\ndef parallel_processes(*args, **kwargs):\n    \"\"\" Same as parallel_threads, with processes\n    \"\"\"\n    import multiprocessing as mp\n    kwargs['Pool'] = mp.Pool\n    return parallel_threads(*args, **kwargs)\n\n\ndef starcall(args):\n    \"\"\" convenient wrapper for Process.Pool \"\"\"\n    function, args = args\n    return function(*args)\n\n\ndef starstarcall(args):\n    \"\"\" convenient wrapper for Process.Pool \"\"\"\n    function, args = args\n    return function(**args)\n"
  },
  {
    "path": "dust3r/utils/path_to_croco.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# CroCo submodule import\n# --------------------------------------------------------\n\nimport sys\nimport os.path as path\nHERE_PATH = path.normpath(path.dirname(__file__))\nCROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco'))\nCROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models')\n# check the presence of models directory in repo to be sure its cloned\nif path.isdir(CROCO_MODELS_PATH):\n    # workaround for sibling import\n    sys.path.insert(0, CROCO_REPO_PATH)\nelse:\n    raise ImportError(f\"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\\n \"\n                      \"Did you forget to run 'git submodule update --init --recursive' ?\")\n"
  },
  {
    "path": "dust3r/viz.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Visualization utilities using trimesh\n# --------------------------------------------------------\nimport PIL.Image\nimport numpy as np\nfrom scipy.spatial.transform import Rotation\nimport torch\n\nfrom dust3r.utils.geometry import geotrf, get_med_dist_between_poses, depthmap_to_absolute_camera_coordinates\nfrom dust3r.utils.device import to_numpy\nfrom dust3r.utils.image import rgb, img_to_arr\n\ntry:\n    import trimesh\nexcept ImportError:\n    print('/!\\\\ module trimesh is not installed, cannot visualize results /!\\\\')\n\n\n\ndef cat_3d(vecs):\n    if isinstance(vecs, (np.ndarray, torch.Tensor)):\n        vecs = [vecs]\n    return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])\n\n\ndef show_raw_pointcloud(pts3d, colors, point_size=2):\n    scene = trimesh.Scene()\n\n    pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))\n    scene.add_geometry(pct)\n\n    scene.show(line_settings={'point_size': point_size})\n\n\ndef pts3d_to_trimesh(img, pts3d, valid=None):\n    H, W, THREE = img.shape\n    assert THREE == 3\n    assert img.shape == pts3d.shape\n\n    vertices = pts3d.reshape(-1, 3)\n\n    # make squares: each pixel == 2 triangles\n    idx = np.arange(len(vertices)).reshape(H, W)\n    idx1 = idx[:-1, :-1].ravel()  # top-left corner\n    idx2 = idx[:-1, +1:].ravel()  # right-left corner\n    idx3 = idx[+1:, :-1].ravel()  # bottom-left corner\n    idx4 = idx[+1:, +1:].ravel()  # bottom-right corner\n    faces = np.concatenate((\n        np.c_[idx1, idx2, idx3],\n        np.c_[idx3, idx2, idx1],  # same triangle, but backward (cheap solution to cancel face culling)\n        np.c_[idx2, idx3, idx4],\n        np.c_[idx4, idx3, idx2],  # same triangle, but backward (cheap solution to cancel face culling)\n    ), axis=0)\n\n    # prepare triangle colors\n    face_colors = np.concatenate((\n        img[:-1, :-1].reshape(-1, 3),\n        img[:-1, :-1].reshape(-1, 3),\n        img[+1:, +1:].reshape(-1, 3),\n        img[+1:, +1:].reshape(-1, 3)\n    ), axis=0)\n\n    # remove invalid faces\n    if valid is not None:\n        assert valid.shape == (H, W)\n        valid_idxs = valid.ravel()\n        valid_faces = valid_idxs[faces].all(axis=-1)\n        faces = faces[valid_faces]\n        face_colors = face_colors[valid_faces]\n\n    assert len(faces) == len(face_colors)\n    return dict(vertices=vertices, face_colors=face_colors, faces=faces)\n\n\ndef cat_meshes(meshes):\n    vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])\n    n_vertices = np.cumsum([0]+[len(v) for v in vertices])\n    for i in range(len(faces)):\n        faces[i][:] += n_vertices[i]\n\n    vertices = np.concatenate(vertices)\n    colors = np.concatenate(colors)\n    faces = np.concatenate(faces)\n    return dict(vertices=vertices, face_colors=colors, faces=faces)\n\n\ndef show_duster_pairs(view1, view2, pred1, pred2):\n    import matplotlib.pyplot as pl\n    pl.ion()\n\n    for e in range(len(view1['instance'])):\n        i = view1['idx'][e]\n        j = view2['idx'][e]\n        img1 = rgb(view1['img'][e])\n        img2 = rgb(view2['img'][e])\n        conf1 = pred1['conf'][e].squeeze()\n        conf2 = pred2['conf'][e].squeeze()\n        score = conf1.mean()*conf2.mean()\n        print(f\">> Showing pair #{e} {i}-{j} {score=:g}\")\n        pl.clf()\n        pl.subplot(221).imshow(img1)\n        pl.subplot(223).imshow(img2)\n        pl.subplot(222).imshow(conf1, vmin=1, vmax=30)\n        pl.subplot(224).imshow(conf2, vmin=1, vmax=30)\n        pts1 = pred1['pts3d'][e]\n        pts2 = pred2['pts3d_in_other_view'][e]\n        pl.subplots_adjust(0, 0, 1, 1, 0, 0)\n        if input('show pointcloud? (y/n) ') == 'y':\n            show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)\n\n\ndef auto_cam_size(im_poses):\n    return 0.1 * get_med_dist_between_poses(im_poses)\n\n\nclass SceneViz:\n    def __init__(self):\n        self.scene = trimesh.Scene()\n\n    def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None):\n        image = img_to_arr(image)\n\n        # make up some intrinsics\n        if intrinsics is None:\n            H, W, THREE = image.shape\n            focal = max(H, W)\n            intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]])\n\n        # compute 3d points\n        pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world)\n\n        return self.add_pointcloud(pts3d, image, mask=(depth<zfar) if mask is None else mask)\n\n    def add_pointcloud(self, pts3d, color=(0,0,0), mask=None, denoise=False):\n        pts3d = to_numpy(pts3d)\n        mask = to_numpy(mask)\n        if not isinstance(pts3d, list):\n            pts3d = [pts3d.reshape(-1,3)]\n            if mask is not None: \n                mask = [mask.ravel()]\n        if not isinstance(color, (tuple,list)):\n            color = [color.reshape(-1,3)]\n        if mask is None:\n            mask = [slice(None)] * len(pts3d)\n\n        pts = np.concatenate([p[m] for p,m in zip(pts3d,mask)])\n        pct = trimesh.PointCloud(pts)\n\n        if isinstance(color, (list, np.ndarray, torch.Tensor)):\n            color = to_numpy(color)\n            col = np.concatenate([p[m] for p,m in zip(color,mask)])\n            assert col.shape == pts.shape, bb()\n            pct.visual.vertex_colors = uint8(col.reshape(-1,3))\n        else:\n            assert len(color) == 3\n            pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)\n\n        if denoise:\n            # remove points which are noisy\n            centroid = np.median(pct.vertices, axis=0)\n            dist_to_centroid = np.linalg.norm( pct.vertices - centroid, axis=-1)\n            dist_thr = np.quantile(dist_to_centroid, 0.99)\n            valid = (dist_to_centroid < dist_thr)\n            # new cleaned pointcloud\n            pct = trimesh.PointCloud(pct.vertices[valid], color=pct.visual.vertex_colors[valid])\n\n        self.scene.add_geometry(pct)\n        return self\n\n    def add_rgbd(self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None):\n        # make up some intrinsics\n        if intrinsics is None:\n            H, W, THREE = image.shape\n            focal = max(H, W)\n            intrinsics = np.float32([[focal, 0, W/2], [0, focal, H/2], [0, 0, 1]])\n\n        # compute 3d points\n        pts3d, mask2 = depthmap_to_absolute_camera_coordinates(depth, intrinsics, cam2world)\n        mask2 &= (depth<zfar) \n\n        # combine with provided mask if any\n        if mask is not None:\n            mask2 &= mask\n\n        return self.add_pointcloud(pts3d, image, mask=mask2)\n\n    def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):\n        pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))\n        image = img_to_arr(image)\n        if isinstance(focal, np.ndarray) and focal.shape == (3,3):\n            intrinsics = focal\n            focal = (intrinsics[0,0] * intrinsics[1,1]) ** 0.5\n            if imsize is None:\n                imsize = (2*intrinsics[0,2], 2*intrinsics[1,2])\n        \n        add_scene_cam(self.scene, pose_c2w, color, image, focal, imsize=imsize, screen_width=cam_size, marker=None)\n        return self\n\n    def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):\n        get = lambda arr,idx: None if arr is None else arr[idx]\n        for i, pose_c2w in enumerate(poses):\n            self.add_camera(pose_c2w, get(focals,i), image=get(images,i), color=get(colors,i), imsize=get(imsizes,i), **kw)\n        return self\n\n    def show(self, point_size=2):\n        self.scene.show(line_settings= {'point_size': point_size})\n\n\ndef show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,\n                                  point_size=2, cam_size=0.05, cam_color=None):\n    \"\"\" Visualization of a pointcloud with cameras\n        imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]\n        pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]\n        focals = (N,) or N-size list of [focal, ...]\n        cams2world = (N,4,4) or N-size list of [(4,4), ...]\n    \"\"\"\n    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)\n    pts3d = to_numpy(pts3d)\n    imgs = to_numpy(imgs)\n    focals = to_numpy(focals)\n    cams2world = to_numpy(cams2world)\n\n    scene = trimesh.Scene()\n\n    # full pointcloud\n    pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])\n    col = np.concatenate([p[m] for p, m in zip(imgs, mask)])\n    pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))\n    scene.add_geometry(pct)\n\n    # add each camera\n    for i, pose_c2w in enumerate(cams2world):\n        if isinstance(cam_color, list):\n            camera_edge_color = cam_color[i]\n        else:\n            camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]\n        add_scene_cam(scene, pose_c2w, camera_edge_color,\n                      imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)\n\n    scene.show(line_settings={'point_size': point_size})\n\n\ndef add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, \n                  screen_width=0.03, marker=None):\n    if image is not None:\n        image = np.asarray(image)\n        H, W, THREE = image.shape\n        assert THREE == 3\n        if image.dtype != np.uint8:\n            image = np.uint8(255*image)\n    elif imsize is not None:\n        W, H = imsize\n    elif focal is not None:\n        H = W = focal / 1.1\n    else:\n        H = W = 1\n\n    if isinstance(focal, np.ndarray):\n        focal = focal[0]\n    if not focal:\n        focal = min(H,W) * 1.1 # default value\n\n    # create fake camera\n    height = max( screen_width/10, focal * screen_width / H )\n    width = screen_width * 0.5**0.5\n    rot45 = np.eye(4)\n    rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()\n    rot45[2, 3] = -height  # set the tip of the cone = optical center\n    aspect_ratio = np.eye(4)\n    aspect_ratio[0, 0] = W/H\n    transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45\n    cam = trimesh.creation.cone(width, height, sections=4)  # , transform=transform)\n\n    # this is the image\n    if image is not None:\n        vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])\n        faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])\n        img = trimesh.Trimesh(vertices=vertices, faces=faces)\n        uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])\n        img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))\n        scene.add_geometry(img)\n\n    # this is the camera mesh\n    rot2 = np.eye(4)\n    rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()\n    vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]\n    vertices = geotrf(transform, vertices)\n    faces = []\n    for face in cam.faces:\n        if 0 in face:\n            continue\n        a, b, c = face\n        a2, b2, c2 = face + len(cam.vertices)\n        a3, b3, c3 = face + 2*len(cam.vertices)\n\n        # add 3 pseudo-edges\n        faces.append((a, b, b2))\n        faces.append((a, a2, c))\n        faces.append((c2, b, c))\n\n        faces.append((a, b, b3))\n        faces.append((a, a3, c))\n        faces.append((c3, b, c))\n\n    # no culling\n    faces += [(c, b, a) for a, b, c in faces]\n\n    cam = trimesh.Trimesh(vertices=vertices, faces=faces)\n    cam.visual.face_colors[:, :3] = edge_color\n    scene.add_geometry(cam)\n\n    if marker == 'o':\n        marker = trimesh.creation.icosphere(3, radius=screen_width/4)\n        marker.vertices += pose_c2w[:3,3]\n        marker.visual.face_colors[:,:3] = edge_color\n        scene.add_geometry(marker)\n\n\ndef cat(a, b):\n    return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))\n\n\nOPENGL = np.array([[1, 0, 0, 0],\n                   [0, -1, 0, 0],\n                   [0, 0, -1, 0],\n                   [0, 0, 0, 1]])\n\n\nCAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),\n              (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]\n\n\ndef uint8(colors):\n    if not isinstance(colors, np.ndarray):\n        colors = np.array(colors)\n    if np.issubdtype(colors.dtype, np.floating):\n        colors *= 255\n    assert 0 <= colors.min() and colors.max() < 256\n    return np.uint8(colors)\n\n\ndef segment_sky(image):\n    import cv2\n    from scipy import ndimage\n\n    # Convert to HSV\n    image = to_numpy(image)\n    if np.issubdtype(image.dtype, np.floating):\n        image = np.uint8(255*image.clip(min=0, max=1))\n    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)\n\n    # Define range for blue color and create mask\n    lower_blue = np.array([0, 0, 100])\n    upper_blue = np.array([30, 255, 255])\n    mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)\n\n    # add luminous gray\n    mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)\n    mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)\n    mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)\n\n    # Morphological operations\n    kernel = np.ones((5, 5), np.uint8)\n    mask2 = ndimage.binary_opening(mask, structure=kernel)\n\n    # keep only largest CC\n    _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)\n    cc_sizes = stats[1:, cv2.CC_STAT_AREA]\n    order = cc_sizes.argsort()[::-1]  # bigger first\n    i = 0\n    selection = []\n    while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:\n        selection.append(1 + order[i])\n        i += 1\n    mask3 = np.in1d(labels, selection).reshape(labels.shape)\n\n    # Apply mask\n    return torch.from_numpy(mask3)\n"
  },
  {
    "path": "dust3r_visloc/README.md",
    "content": "# Visual Localization with DUSt3R\n\n## Dataset preparation\n\n### CambridgeLandmarks\n\nEach subscene should look like this:\n\n```\nCambridge_Landmarks\n├─ mapping\n│   ├─ GreatCourt\n│   │  └─ colmap/reconstruction\n│   │     ├─ cameras.txt\n│   │     ├─ images.txt\n│   │     └─ points3D.txt\n├─ kapture\n│   ├─ GreatCourt\n│   │  └─ query  # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#cambridge-landmarks\n│   ... \n├─ GreatCourt \n│   ├─ pairsfile/query\n│   │     └─ AP-GeM-LM18_top50.txt  # https://github.com/naver/deep-image-retrieval/blob/master/dirtorch/extract_kapture.py followed by https://github.com/naver/kapture-localization/blob/main/tools/kapture_compute_image_pairs.py\n│   ├─ seq1\n│   ...\n...\n```\n\n### 7Scenes\nEach subscene should look like this:\n\n```\n7-scenes\n├─ chess\n│   ├─ mapping/  # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#1-7-scenes\n│   ├─ query/  # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#1-7-scenes\n│   └─ pairsfile/query/\n│         └─ APGeM-LM18_top20.txt  # https://github.com/naver/deep-image-retrieval/blob/master/dirtorch/extract_kapture.py followed by https://github.com/naver/kapture-localization/blob/main/tools/kapture_compute_image_pairs.py\n...\n```\n\n### Aachen-Day-Night\n\n```\nAachen-Day-Night-v1.1\n├─ mapping\n│   ├─ colmap/reconstruction\n│   │  ├─ cameras.txt\n│   │  ├─ images.txt\n│   │  └─ points3D.txt\n├─ kapture\n│   └─ query  # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#2-aachen-day-night-v11\n├─ images\n│   ├─ db\n│   ├─ query\n│   └─ sequences\n└─ pairsfile/query\n    └─ fire_top50.txt  # https://github.com/naver/fire/blob/main/kapture_compute_pairs.py\n```\n\n### InLoc\n\n```\nInLoc\n├─ mapping  # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#6-inloc\n├─ query    # https://github.com/naver/kapture/blob/main/doc/datasets.adoc#6-inloc\n└─ pairsfile/query\n    └─ pairs-query-netvlad40-temporal.txt  # https://github.com/cvg/Hierarchical-Localization/blob/master/pairs/inloc/pairs-query-netvlad40-temporal.txt\n```\n\n## Example Commands\n\nWith `visloc.py` you can run our visual localization experiments on Aachen-Day-Night, InLoc, Cambridge Landmarks and 7 Scenes.\n\n```bash\n# Aachen-Day-Night-v1.1:\n# scene in 'day' 'night'\n# scene can also be 'all'\npython3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset \"VislocAachenDayNight('/path/to/prepared/Aachen-Day-Night-v1.1/', subscene='${scene}', pairsfile='fire_top50', topk=20)\" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/Aachen-Day-Night-v1.1/${scene}/loc\n\n# InLoc\npython3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset \"VislocInLoc('/path/to/prepared/InLoc/', pairsfile='pairs-query-netvlad40-temporal', topk=20)\" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/InLoc/loc\n\n\n# 7-scenes:\n# scene in 'chess' 'fire' 'heads' 'office' 'pumpkin' 'redkitchen' 'stairs'\npython3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset \"VislocSevenScenes('/path/to/prepared/7-scenes/', subscene='${scene}', pairsfile='APGeM-LM18_top20', topk=1)\" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/7-scenes/${scene}/loc\n\n# Cambridge Landmarks:\n# scene in 'ShopFacade' 'GreatCourt' 'KingsCollege' 'OldHospital' 'StMarysChurch'\npython3 visloc.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt --dataset \"VislocCambridgeLandmarks('/path/to/prepared/Cambridge_Landmarks/', subscene='${scene}', pairsfile='APGeM-LM18_top50', topk=20)\" --pnp_mode poselib --reprojection_error_diag_ratio 0.008 --output_dir /path/to/output/Cambridge_Landmarks/${scene}/loc\n\n```\n"
  },
  {
    "path": "dust3r_visloc/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n"
  },
  {
    "path": "dust3r_visloc/datasets/__init__.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\nfrom .sevenscenes import VislocSevenScenes\nfrom .cambridge_landmarks import VislocCambridgeLandmarks\nfrom .aachen_day_night import VislocAachenDayNight\nfrom .inloc import VislocInLoc\n"
  },
  {
    "path": "dust3r_visloc/datasets/aachen_day_night.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# AachenDayNight dataloader\n# --------------------------------------------------------\nimport os\nfrom dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset\n\n\nclass VislocAachenDayNight(BaseVislocColmapDataset):\n    def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False):\n        assert subscene in [None, '', 'day', 'night', 'all']\n        self.subscene = subscene\n        image_path = os.path.join(root, 'images')\n        map_path = os.path.join(root, 'mapping/colmap/reconstruction')\n        query_path = os.path.join(root, 'kapture', 'query')\n        pairsfile_path = os.path.join(root, 'pairsfile/query', pairsfile + '.txt')\n        super().__init__(image_path=image_path, map_path=map_path,\n                         query_path=query_path, pairsfile_path=pairsfile_path,\n                         topk=topk, cache_sfm=cache_sfm)\n        self.scenes = [filename for filename in self.scenes if filename in self.pairs]\n        if self.subscene == 'day' or self.subscene == 'night':\n            self.scenes = [filename for filename in self.scenes if self.subscene in filename]\n"
  },
  {
    "path": "dust3r_visloc/datasets/base_colmap.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Base class for colmap / kapture\n# --------------------------------------------------------\nimport os\nimport numpy as np\nfrom tqdm import tqdm\nimport collections\nimport pickle\nimport PIL.Image\nimport torch\nfrom scipy.spatial.transform import Rotation\nimport torchvision.transforms as tvf\n\nfrom kapture.core import CameraType\nfrom kapture.io.csv import kapture_from_dir\nfrom kapture_localization.utils.pairsfile import get_ordered_pairs_from_file\n\nfrom dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d\nfrom dust3r_visloc.datasets.base_dataset import BaseVislocDataset\nfrom dust3r.datasets.utils.transforms import ImgNorm\nfrom dust3r.utils.geometry import colmap_to_opencv_intrinsics\n\nKaptureSensor = collections.namedtuple('Sensor', 'sensor_params camera_params')\n\n\ndef kapture_to_opencv_intrinsics(sensor):\n    \"\"\"\n    Convert from Kapture to OpenCV parameters.\n    Warning: we assume that the camera and pixel coordinates follow Colmap conventions here.\n    Args:\n        sensor: Kapture sensor\n    \"\"\"\n    sensor_type = sensor.sensor_params[0]\n    if sensor_type == \"SIMPLE_PINHOLE\":\n        # Simple pinhole model.\n        # We still call OpenCV undistorsion however for code simplicity.\n        w, h, f, cx, cy = sensor.camera_params\n        k1 = 0\n        k2 = 0\n        p1 = 0\n        p2 = 0\n        fx = fy = f\n    elif sensor_type == \"PINHOLE\":\n        w, h, fx, fy, cx, cy = sensor.camera_params\n        k1 = 0\n        k2 = 0\n        p1 = 0\n        p2 = 0\n    elif sensor_type == \"SIMPLE_RADIAL\":\n        w, h, f, cx, cy, k1 = sensor.camera_params\n        k2 = 0\n        p1 = 0\n        p2 = 0\n        fx = fy = f\n    elif sensor_type == \"RADIAL\":\n        w, h, f, cx, cy, k1, k2 = sensor.camera_params\n        p1 = 0\n        p2 = 0\n        fx = fy = f\n    elif sensor_type == \"OPENCV\":\n        w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params\n    else:\n        raise NotImplementedError(f\"Sensor type {sensor_type} is not supported yet.\")\n\n    cameraMatrix = np.asarray([[fx, 0, cx],\n                               [0, fy, cy],\n                               [0, 0, 1]], dtype=np.float32)\n\n    # We assume that Kapture data comes from Colmap: the origin is different.\n    cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix)\n\n    distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32)\n    return cameraMatrix, distCoeffs, (w, h)\n\n\ndef K_from_colmap(elems):\n    sensor = KaptureSensor(elems, tuple(map(float, elems[1:])))\n    cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor)\n    res = dict(resolution=(w, h),\n               intrinsics=cameraMatrix,\n               distortion=distCoeffs)\n    return res\n\n\ndef pose_from_qwxyz_txyz(elems):\n    qw, qx, qy, qz, tx, ty, tz = map(float, elems)\n    pose = np.eye(4)\n    pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix()\n    pose[:3, 3] = (tx, ty, tz)\n    return np.linalg.inv(pose)  # returns cam2world\n\n\nclass BaseVislocColmapDataset(BaseVislocDataset):\n    def __init__(self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False):\n        super().__init__()\n        self.topk = topk\n        self.num_views = self.topk + 1\n        self.image_path = image_path\n        self.cache_sfm = cache_sfm\n\n        self._load_sfm(map_path)\n\n        kdata_query = kapture_from_dir(query_path)\n        assert kdata_query.records_camera is not None and kdata_query.trajectories is not None\n\n        kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)\n                                   for timestamp, sensor_id in kdata_query.records_camera.key_pairs()}\n        self.query_data = {'kdata': kdata_query, 'searchindex': kdata_query_searchindex}\n\n        self.pairs = get_ordered_pairs_from_file(pairsfile_path)\n        self.scenes = kdata_query.records_camera.data_list()\n\n    def _load_sfm(self, sfm_dir):\n        sfm_cache_path = os.path.join(sfm_dir, 'dust3r_cache.pkl')\n        if os.path.isfile(sfm_cache_path) and self.cache_sfm:\n            with open(sfm_cache_path, \"rb\") as f:\n                data = pickle.load(f)\n                self.img_infos = data['img_infos']\n                self.points3D = data['points3D']\n            return\n\n        # load cameras\n        with open(os.path.join(sfm_dir, 'cameras.txt'), 'r') as f:\n            raw = f.read().splitlines()[3:]  # skip header\n\n        intrinsics = {}\n        for camera in tqdm(raw):\n            camera = camera.split(' ')\n            intrinsics[int(camera[0])] = K_from_colmap(camera[1:])\n\n        # load images\n        with open(os.path.join(sfm_dir, 'images.txt'), 'r') as f:\n            raw = f.read().splitlines()\n            raw = [line for line in raw if not line.startswith('#')]  # skip header\n\n        self.img_infos = {}\n        for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2):\n            image = image.split(' ')\n            points = points.split(' ')\n\n            img_name = image[-1]\n            current_points2D = {int(i): (float(x), float(y))\n                                for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != '-1'}\n            self.img_infos[img_name] = dict(intrinsics[int(image[-2])],\n                                            path=img_name,\n                                            camera_pose=pose_from_qwxyz_txyz(image[1: -2]),\n                                            sparse_pts2d=current_points2D)\n\n        # load 3D points\n        with open(os.path.join(sfm_dir, 'points3D.txt'), 'r') as f:\n            raw = f.read().splitlines()\n            raw = [line for line in raw if not line.startswith('#')]  # skip header\n\n        self.points3D = {}\n        for point in tqdm(raw):\n            point = point.split()\n            self.points3D[int(point[0])] = tuple(map(float, point[1:4]))\n\n        if self.cache_sfm:\n            to_save = \\\n                {\n                    'img_infos': self.img_infos,\n                    'points3D': self.points3D\n                }\n            with open(sfm_cache_path, \"wb\") as f:\n                pickle.dump(to_save, f)\n\n    def __len__(self):\n        return len(self.scenes)\n\n    def _get_view_query(self, imgname):\n        kdata, searchindex = map(self.query_data.get, ['kdata', 'searchindex'])\n\n        timestamp, camera_id = searchindex[imgname]\n\n        camera_params = kdata.sensors[camera_id].camera_params\n        if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE:\n            W, H, f, cx, cy = camera_params\n            k1 = 0\n            fx = fy = f\n        elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL:\n            W, H, f, cx, cy, k1 = camera_params\n            fx = fy = f\n        else:\n            raise NotImplementedError('not implemented')\n\n        W, H = int(W), int(H)\n        intrinsics = np.float32([(fx, 0, cx),\n                                 (0, fy, cy),\n                                 (0, 0, 1)])\n        intrinsics = colmap_to_opencv_intrinsics(intrinsics)\n        distortion = [k1, 0, 0, 0]\n\n        if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories:\n            cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id)\n        else:\n            cam_to_world = np.eye(4, dtype=np.float32)\n\n        # Load RGB image\n        rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert('RGB')\n        rgb_image.load()\n        resize_func, _, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W)\n        rgb_tensor = resize_func(ImgNorm(rgb_image))\n\n        view = {\n            'intrinsics': intrinsics,\n            'distortion': distortion,\n            'cam_to_world': cam_to_world,\n            'rgb': rgb_image,\n            'rgb_rescaled': rgb_tensor,\n            'to_orig': to_orig,\n            'idx': 0,\n            'image_name': imgname\n        }\n        return view\n\n    def _get_view_map(self, imgname, idx):\n        infos = self.img_infos[imgname]\n\n        rgb_image = PIL.Image.open(os.path.join(self.image_path, infos['path'])).convert('RGB')\n        rgb_image.load()\n        W, H = rgb_image.size\n        intrinsics = infos['intrinsics']\n        intrinsics = colmap_to_opencv_intrinsics(intrinsics)\n        distortion_coefs = infos['distortion']\n\n        pts2d = infos['sparse_pts2d']\n        sparse_pos2d = np.float32(list(pts2d.values())).reshape((-1, 2))  # pts2d from colmap\n        sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3))\n\n        # store full resolution 2D->3D\n        sparse_pos2d_cv2 = sparse_pos2d.copy()\n        sparse_pos2d_cv2[:, 0] -= 0.5\n        sparse_pos2d_cv2[:, 1] -= 0.5\n        sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64)\n        valid = (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & (\n            sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H)\n        sparse_pos2d_int = sparse_pos2d_int[valid]\n        # nan => invalid\n        pts3d = np.full((H, W, 3), np.nan, dtype=np.float32)\n        pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid]\n        pts3d = torch.from_numpy(pts3d)\n\n        cam_to_world = infos['camera_pose']  # cam2world\n\n        # also store resized resolution 2D->3D\n        resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W)\n        rgb_tensor = resize_func(ImgNorm(rgb_image))\n\n        HR, WR = rgb_tensor.shape[1:]\n        _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR)\n        pts3d_rescaled = torch.from_numpy(pts3d_rescaled)\n        valid_rescaled = torch.from_numpy(valid_rescaled)\n\n        view = {\n            'intrinsics': intrinsics,\n            'distortion': distortion_coefs,\n            'cam_to_world': cam_to_world,\n            'rgb': rgb_image,\n            \"pts3d\": pts3d,\n            \"valid\": pts3d.sum(dim=-1).isfinite(),\n            'rgb_rescaled': rgb_tensor,\n            \"pts3d_rescaled\": pts3d_rescaled,\n            \"valid_rescaled\": valid_rescaled,\n            'to_orig': to_orig,\n            'idx': idx,\n            'image_name': imgname\n        }\n        return view\n\n    def __getitem__(self, idx):\n        assert self.maxdim is not None and self.patch_size is not None\n        query_image = self.scenes[idx]\n        map_images = [p[0] for p in self.pairs[query_image][:self.topk]]\n        views = []\n        views.append(self._get_view_query(query_image))\n        for idx, map_image in enumerate(map_images):\n            views.append(self._get_view_map(map_image, idx + 1))\n        return views\n"
  },
  {
    "path": "dust3r_visloc/datasets/base_dataset.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Base class\n# --------------------------------------------------------\nclass BaseVislocDataset:\n    def __init__(self):\n        pass\n\n    def set_resolution(self, model):\n        self.maxdim = max(model.patch_embed.img_size)\n        self.patch_size = model.patch_embed.patch_size\n\n    def __len__(self):\n        raise NotImplementedError()\n    \n    def __getitem__(self, idx):\n        raise NotImplementedError()"
  },
  {
    "path": "dust3r_visloc/datasets/cambridge_landmarks.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Cambridge Landmarks dataloader\n# --------------------------------------------------------\nimport os\nfrom dust3r_visloc.datasets.base_colmap import BaseVislocColmapDataset\n\n\nclass VislocCambridgeLandmarks (BaseVislocColmapDataset):\n    def __init__(self, root, subscene, pairsfile, topk=1, cache_sfm=False):\n        image_path = os.path.join(root, subscene)\n        map_path = os.path.join(root, 'mapping', subscene, 'colmap/reconstruction')\n        query_path = os.path.join(root, 'kapture', subscene, 'query')\n        pairsfile_path = os.path.join(root, subscene, 'pairsfile/query', pairsfile + '.txt')\n        super().__init__(image_path=image_path, map_path=map_path,\n                         query_path=query_path, pairsfile_path=pairsfile_path,\n                          topk=topk, cache_sfm=cache_sfm)"
  },
  {
    "path": "dust3r_visloc/datasets/inloc.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# InLoc dataloader\n# --------------------------------------------------------\nimport os\nimport numpy as np\nimport torch\nimport PIL.Image\nimport scipy.io\n\nimport kapture\nfrom kapture.io.csv import kapture_from_dir\nfrom kapture_localization.utils.pairsfile import get_ordered_pairs_from_file\n\nfrom dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d\nfrom dust3r_visloc.datasets.base_dataset import BaseVislocDataset\nfrom dust3r.datasets.utils.transforms import ImgNorm\nfrom dust3r.utils.geometry import xy_grid, geotrf\n\n\ndef read_alignments(path_to_alignment):\n    aligns = {}\n    with open(path_to_alignment, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            if len(line) == 4:\n                trans_nr = line[:-1]\n                while line != 'After general icp:\\n':\n                    line = fid.readline()\n                line = fid.readline()\n                p = []\n                for i in range(4):\n                    elems = line.split(' ')\n                    line = fid.readline()\n                    for e in elems:\n                        if len(e) != 0:\n                            p.append(float(e))\n                P = np.array(p).reshape(4, 4)\n                aligns[trans_nr] = P\n    return aligns\n\n\nclass VislocInLoc(BaseVislocDataset):\n    def __init__(self, root, pairsfile, topk=1):\n        super().__init__()\n        self.root = root\n        self.topk = topk\n        self.num_views = self.topk + 1\n        self.maxdim = None\n        self.patch_size = None\n\n        query_path = os.path.join(self.root, 'query')\n        kdata_query = kapture_from_dir(query_path)\n        assert kdata_query.records_camera is not None\n        kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)\n                                   for timestamp, sensor_id in kdata_query.records_camera.key_pairs()}\n        self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex}\n\n        map_path = os.path.join(self.root, 'mapping')\n        kdata_map = kapture_from_dir(map_path)\n        assert kdata_map.records_camera is not None and kdata_map.trajectories is not None\n        kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)\n                                 for timestamp, sensor_id in kdata_map.records_camera.key_pairs()}\n        self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex}\n\n        try:\n            self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'))\n        except Exception as e:\n            # if using pairs from hloc\n            self.pairs = {}\n            with open(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'), 'r') as fid:\n                lines = fid.readlines()\n                for line in lines:\n                    splits = line.rstrip(\"\\n\\r\").split(\" \")\n                    self.pairs.setdefault(splits[0].replace('query/', ''), []).append(\n                        (splits[1].replace('database/cutouts/', ''), 1.0)\n                    )\n\n        self.scenes = kdata_query.records_camera.data_list()\n\n        self.aligns_DUC1 = read_alignments(os.path.join(self.root, 'mapping/DUC1_alignment/all_transformations.txt'))\n        self.aligns_DUC2 = read_alignments(os.path.join(self.root, 'mapping/DUC2_alignment/all_transformations.txt'))\n\n    def __len__(self):\n        return len(self.scenes)\n\n    def __getitem__(self, idx):\n        assert self.maxdim is not None and self.patch_size is not None\n        query_image = self.scenes[idx]\n        map_images = [p[0] for p in self.pairs[query_image][:self.topk]]\n        views = []\n        dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True)\n                                                               for map_image in map_images]\n        for idx, (imgname, data, should_load_depth) in enumerate(dataarray):\n            imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex'])\n\n            timestamp, camera_id = searchindex[imgname]\n\n            # for InLoc, SIMPLE_PINHOLE\n            camera_params = kdata.sensors[camera_id].camera_params\n            W, H, f, cx, cy = camera_params\n            distortion = [0, 0, 0, 0]\n            intrinsics = np.float32([(f, 0, cx),\n                                     (0, f, cy),\n                                     (0, 0, 1)])\n\n            if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories:\n                cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id)\n            else:\n                cam_to_world = np.eye(4, dtype=np.float32)\n\n            # Load RGB image\n            rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB')\n            rgb_image.load()\n\n            W, H = rgb_image.size\n            resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W)\n\n            rgb_tensor = resize_func(ImgNorm(rgb_image))\n\n            view = {\n                'intrinsics': intrinsics,\n                'distortion': distortion,\n                'cam_to_world': cam_to_world,\n                'rgb': rgb_image,\n                'rgb_rescaled': rgb_tensor,\n                'to_orig': to_orig,\n                'idx': idx,\n                'image_name': imgname\n            }\n\n            # Load depthmap\n            if should_load_depth:\n                depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname + '.mat')\n                depthmap = scipy.io.loadmat(depthmap_filename)\n\n                pt3d_cut = depthmap['XYZcut']\n                scene_id = imgname.replace('\\\\', '/').split('/')[1]\n                if imgname.startswith('DUC1'):\n                    pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut)\n                else:\n                    pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut)\n\n                pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1))\n\n                pts3d = pts3d_full[pts3d_valid]\n                pts2d_int = xy_grid(W, H)[pts3d_valid]\n                pts2d = pts2d_int.astype(np.float64)\n\n                # nan => invalid\n                pts3d_full[~pts3d_valid] = np.nan\n                pts3d_full = torch.from_numpy(pts3d_full)\n                view['pts3d'] = pts3d_full\n                view[\"valid\"] = pts3d_full.sum(dim=-1).isfinite()\n\n                HR, WR = rgb_tensor.shape[1:]\n                _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR)\n                pts3d_rescaled = torch.from_numpy(pts3d_rescaled)\n                valid_rescaled = torch.from_numpy(valid_rescaled)\n                view['pts3d_rescaled'] = pts3d_rescaled\n                view[\"valid_rescaled\"] = valid_rescaled\n            views.append(view)\n        return views\n"
  },
  {
    "path": "dust3r_visloc/datasets/sevenscenes.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# 7 Scenes dataloader\n# --------------------------------------------------------\nimport os\nimport numpy as np\nimport torch\nimport PIL.Image\n\nimport kapture\nfrom kapture.io.csv import kapture_from_dir\nfrom kapture_localization.utils.pairsfile import get_ordered_pairs_from_file\nfrom kapture.io.records import depth_map_from_file\n\nfrom dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d\nfrom dust3r_visloc.datasets.base_dataset import BaseVislocDataset\nfrom dust3r.datasets.utils.transforms import ImgNorm\nfrom dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, xy_grid, geotrf\n\n\nclass VislocSevenScenes(BaseVislocDataset):\n    def __init__(self, root, subscene, pairsfile, topk=1):\n        super().__init__()\n        self.root = root\n        self.subscene = subscene\n        self.topk = topk\n        self.num_views = self.topk + 1\n        self.maxdim = None\n        self.patch_size = None\n\n        query_path = os.path.join(self.root, subscene, 'query')\n        kdata_query = kapture_from_dir(query_path)\n        assert kdata_query.records_camera is not None and kdata_query.trajectories is not None and kdata_query.rigs is not None\n        kapture.rigs_remove_inplace(kdata_query.trajectories, kdata_query.rigs)\n        kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)\n                                   for timestamp, sensor_id in kdata_query.records_camera.key_pairs()}\n        self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex}\n\n        map_path = os.path.join(self.root, subscene, 'mapping')\n        kdata_map = kapture_from_dir(map_path)\n        assert kdata_map.records_camera is not None and kdata_map.trajectories is not None and kdata_map.rigs is not None\n        kapture.rigs_remove_inplace(kdata_map.trajectories, kdata_map.rigs)\n        kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)\n                                 for timestamp, sensor_id in kdata_map.records_camera.key_pairs()}\n        self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex}\n\n        self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, subscene,\n                                                              'pairfiles/query',\n                                                              pairsfile + '.txt'))\n        self.scenes = kdata_query.records_camera.data_list()\n\n    def __len__(self):\n        return len(self.scenes)\n\n    def __getitem__(self, idx):\n        assert self.maxdim is not None and self.patch_size is not None\n        query_image = self.scenes[idx]\n        map_images = [p[0] for p in self.pairs[query_image][:self.topk]]\n        views = []\n        dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True)\n                                                               for map_image in map_images]\n        for idx, (imgname, data, should_load_depth) in enumerate(dataarray):\n            imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex'])\n\n            timestamp, camera_id = searchindex[imgname]\n\n            # for 7scenes, SIMPLE_PINHOLE\n            camera_params = kdata.sensors[camera_id].camera_params\n            W, H, f, cx, cy = camera_params\n            distortion = [0, 0, 0, 0]\n            intrinsics = np.float32([(f, 0, cx),\n                                     (0, f, cy),\n                                     (0, 0, 1)])\n\n            cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id)\n\n            # Load RGB image\n            rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB')\n            rgb_image.load()\n\n            W, H = rgb_image.size\n            resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W)\n\n            rgb_tensor = resize_func(ImgNorm(rgb_image))\n\n            view = {\n                'intrinsics': intrinsics,\n                'distortion': distortion,\n                'cam_to_world': cam_to_world,\n                'rgb': rgb_image,\n                'rgb_rescaled': rgb_tensor,\n                'to_orig': to_orig,\n                'idx': idx,\n                'image_name': imgname\n            }\n\n            # Load depthmap\n            if should_load_depth:\n                depthmap_filename = os.path.join(imgpath, 'sensors/records_data',\n                                                 imgname.replace('color.png', 'depth.reg'))\n                depthmap = depth_map_from_file(depthmap_filename, (int(W), int(H))).astype(np.float32)\n                pts3d_full, pts3d_valid = depthmap_to_absolute_camera_coordinates(depthmap, intrinsics, cam_to_world)\n\n                pts3d = pts3d_full[pts3d_valid]\n                pts2d_int = xy_grid(W, H)[pts3d_valid]\n                pts2d = pts2d_int.astype(np.float64)\n\n                # nan => invalid\n                pts3d_full[~pts3d_valid] = np.nan\n                pts3d_full = torch.from_numpy(pts3d_full)\n                view['pts3d'] = pts3d_full\n                view[\"valid\"] = pts3d_full.sum(dim=-1).isfinite()\n\n                HR, WR = rgb_tensor.shape[1:]\n                _, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR)\n                pts3d_rescaled = torch.from_numpy(pts3d_rescaled)\n                valid_rescaled = torch.from_numpy(valid_rescaled)\n                view['pts3d_rescaled'] = pts3d_rescaled\n                view[\"valid_rescaled\"] = valid_rescaled\n            views.append(view)\n        return views\n"
  },
  {
    "path": "dust3r_visloc/datasets/utils.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# dataset utilities\n# --------------------------------------------------------\nimport numpy as np\nimport quaternion\nimport torchvision.transforms as tvf\nfrom dust3r.utils.geometry import geotrf\n\n\ndef cam_to_world_from_kapture(kdata, timestamp, camera_id):\n    camera_to_world = kdata.trajectories[timestamp, camera_id].inverse()\n    camera_pose = np.eye(4, dtype=np.float32)\n    camera_pose[:3, :3] = quaternion.as_rotation_matrix(camera_to_world.r)\n    camera_pose[:3, 3] = camera_to_world.t_raw\n    return camera_pose\n\n\nratios_resolutions = {\n    224: {1.0: [224, 224]},\n    512: {4 / 3: [512, 384], 32 / 21: [512, 336], 16 / 9: [512, 288], 2 / 1: [512, 256], 16 / 5: [512, 160]}\n}\n\n\ndef get_HW_resolution(H, W, maxdim, patchsize=16):\n    assert maxdim in ratios_resolutions, \"Error, maxdim can only be 224 or 512 for now. Other maxdims not implemented yet.\"\n    ratios_resolutions_maxdim = ratios_resolutions[maxdim]\n    mindims = set([min(res) for res in ratios_resolutions_maxdim.values()])\n    ratio = W / H\n    ref_ratios = np.array([*(ratios_resolutions_maxdim.keys())])\n    islandscape = (W >= H)\n    if islandscape:\n        diff = np.abs(ratio - ref_ratios)\n    else:\n        diff = np.abs(ratio - (1 / ref_ratios))\n    selkey = ref_ratios[np.argmin(diff)]\n    res = ratios_resolutions_maxdim[selkey]\n    # check patchsize and make sure output resolution is a multiple of patchsize\n    if isinstance(patchsize, tuple):\n        assert len(patchsize) == 2 and isinstance(patchsize[0], int) and isinstance(\n            patchsize[1], int), \"What is your patchsize format? Expected a single int or a tuple of two ints.\"\n        assert patchsize[0] == patchsize[1], \"Error, non square patches not managed\"\n        patchsize = patchsize[0]\n    assert max(res) == maxdim\n    assert min(res) in mindims\n    return res[::-1] if islandscape else res  # return HW\n\n\ndef get_resize_function(maxdim, patch_size, H, W, is_mask=False):\n    if [max(H, W), min(H, W)] in ratios_resolutions[maxdim].values():\n        return lambda x: x, np.eye(3), np.eye(3)\n    else:\n        target_HW = get_HW_resolution(H, W, maxdim=maxdim, patchsize=patch_size)\n\n        ratio = W / H\n        target_ratio = target_HW[1] / target_HW[0]\n        to_orig_crop = np.eye(3)\n        to_rescaled_crop = np.eye(3)\n        if abs(ratio - target_ratio) < np.finfo(np.float32).eps:\n            crop_W = W\n            crop_H = H\n        elif ratio - target_ratio < 0:\n            crop_W = W\n            crop_H = int(W / target_ratio)\n            to_orig_crop[1, 2] = (H - crop_H) / 2.0\n            to_rescaled_crop[1, 2] = -(H - crop_H) / 2.0\n        else:\n            crop_W = int(H * target_ratio)\n            crop_H = H\n            to_orig_crop[0, 2] = (W - crop_W) / 2.0\n            to_rescaled_crop[0, 2] = - (W - crop_W) / 2.0\n\n        crop_op = tvf.CenterCrop([crop_H, crop_W])\n\n        if is_mask:\n            resize_op = tvf.Resize(size=target_HW, interpolation=tvf.InterpolationMode.NEAREST_EXACT)\n        else:\n            resize_op = tvf.Resize(size=target_HW)\n        to_orig_resize = np.array([[crop_W / target_HW[1], 0, 0],\n                                   [0, crop_H / target_HW[0], 0],\n                                   [0, 0, 1]])\n        to_rescaled_resize = np.array([[target_HW[1] / crop_W, 0, 0],\n                                       [0, target_HW[0] / crop_H, 0],\n                                       [0, 0, 1]])\n\n        op = tvf.Compose([crop_op, resize_op])\n\n        return op, to_rescaled_resize @ to_rescaled_crop, to_orig_crop @ to_orig_resize\n\n\ndef rescale_points3d(pts2d, pts3d, to_resize, HR, WR):\n    # rescale pts2d as floats\n    # to colmap, so that the image is in [0, D] -> [0, NewD]\n    pts2d = pts2d.copy()\n    pts2d[:, 0] += 0.5\n    pts2d[:, 1] += 0.5\n\n    pts2d_rescaled = geotrf(to_resize, pts2d, norm=True)\n\n    pts2d_rescaled_int = pts2d_rescaled.copy()\n    # convert back to cv2 before round [-0.5, 0.5] -> pixel 0\n    pts2d_rescaled_int[:, 0] -= 0.5\n    pts2d_rescaled_int[:, 1] -= 0.5\n    pts2d_rescaled_int = pts2d_rescaled_int.round().astype(np.int64)\n\n    # update valid (remove cropped regions)\n    valid_rescaled = (pts2d_rescaled_int[:, 0] >= 0) & (pts2d_rescaled_int[:, 0] < WR) & (\n        pts2d_rescaled_int[:, 1] >= 0) & (pts2d_rescaled_int[:, 1] < HR)\n\n    pts2d_rescaled_int = pts2d_rescaled_int[valid_rescaled]\n\n    # rebuild pts3d from rescaled ps2d poses\n    pts3d_rescaled = np.full((HR, WR, 3), np.nan, dtype=np.float32)  # pts3d in 512 x something\n    pts3d_rescaled[pts2d_rescaled_int[:, 1], pts2d_rescaled_int[:, 0]] = pts3d[valid_rescaled]\n\n    return pts2d_rescaled, pts2d_rescaled_int, pts3d_rescaled, np.isfinite(pts3d_rescaled.sum(axis=-1))\n"
  },
  {
    "path": "dust3r_visloc/evaluation.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# evaluation utilities\n# --------------------------------------------------------\nimport numpy as np\nimport quaternion\nimport torch\nimport roma\nimport collections\nimport os\n\n\ndef aggregate_stats(info_str, pose_errors, angular_errors):\n    stats = collections.Counter()\n    median_pos_error = np.median(pose_errors)\n    median_angular_error = np.median(angular_errors)\n    out_str = f'{info_str}: {len(pose_errors)} images - {median_pos_error=}, {median_angular_error=}'\n\n    for trl_thr, ang_thr in [(0.1, 1), (0.25, 2), (0.5, 5), (5, 10)]:\n        for pose_error, angular_error in zip(pose_errors, angular_errors):\n            correct_for_this_threshold = (pose_error < trl_thr) and (angular_error < ang_thr)\n            stats[trl_thr, ang_thr] += correct_for_this_threshold\n    stats = {f'acc@{key[0]:g}m,{key[1]}deg': 100 * val / len(pose_errors) for key, val in stats.items()}\n    for metric, perf in stats.items():\n        out_str += f'  - {metric:12s}={float(perf):.3f}'\n    return out_str\n\n\ndef get_pose_error(pr_camtoworld, gt_cam_to_world):\n    abs_transl_error = torch.linalg.norm(torch.tensor(pr_camtoworld[:3, 3]) - torch.tensor(gt_cam_to_world[:3, 3]))\n    abs_angular_error = roma.rotmat_geodesic_distance(torch.tensor(pr_camtoworld[:3, :3]),\n                                                      torch.tensor(gt_cam_to_world[:3, :3])) * 180 / np.pi\n    return abs_transl_error, abs_angular_error\n\n\ndef export_results(output_dir, xp_label, query_names, poses_pred):\n    if output_dir is not None:\n        os.makedirs(output_dir, exist_ok=True)\n\n        lines = \"\"\n        lines_ltvl = \"\"\n        for query_name, pr_querycam_to_world in zip(query_names, poses_pred):\n            if pr_querycam_to_world is None:\n                pr_world_to_querycam = np.eye(4)\n            else:\n                pr_world_to_querycam = np.linalg.inv(pr_querycam_to_world)\n            query_shortname = os.path.basename(query_name)\n            pr_world_to_querycam_q = quaternion.from_rotation_matrix(pr_world_to_querycam[:3, :3])\n            pr_world_to_querycam_t = pr_world_to_querycam[:3, 3]\n\n            line_pose = quaternion.as_float_array(pr_world_to_querycam_q).tolist() + \\\n                pr_world_to_querycam_t.flatten().tolist()\n\n            line_content = [query_name] + line_pose\n            lines += ' '.join(str(v) for v in line_content) + '\\n'\n\n            line_content_ltvl = [query_shortname] + line_pose\n            lines_ltvl += ' '.join(str(v) for v in line_content_ltvl) + '\\n'\n\n        with open(os.path.join(output_dir, xp_label + '_results.txt'), 'wt') as f:\n            f.write(lines)\n        with open(os.path.join(output_dir, xp_label + '_ltvl.txt'), 'wt') as f:\n            f.write(lines_ltvl)\n"
  },
  {
    "path": "dust3r_visloc/localization.py",
    "content": "# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# main pnp code\n# --------------------------------------------------------\nimport numpy as np\nimport quaternion\nimport cv2\nfrom packaging import version\n\nfrom dust3r.utils.geometry import opencv_to_colmap_intrinsics\n\ntry:\n    import poselib  # noqa\n    HAS_POSELIB = True\nexcept Exception as e:\n    HAS_POSELIB = False\n\ntry:\n    import pycolmap  # noqa\n    version_number = pycolmap.__version__\n    if version.parse(version_number) < version.parse(\"0.5.0\"):\n        HAS_PYCOLMAP = False\n    else:\n        HAS_PYCOLMAP = True\nexcept Exception as e:\n    HAS_PYCOLMAP = False\n    \ndef run_pnp(pts2D, pts3D, K, distortion = None, mode='cv2', reprojectionError=5, img_size = None):\n    \"\"\"\n    use OPENCV model for distortion (4 values)\n    \"\"\"\n    assert mode in ['cv2', 'poselib', 'pycolmap']\n    try:\n        if len(pts2D) > 4 and mode == \"cv2\":\n            confidence = 0.9999\n            iterationsCount = 10_000\n            if distortion is not None:\n                cv2_pts2ds = np.copy(pts2D)\n                cv2_pts2ds = cv2.undistortPoints(cv2_pts2ds, K, np.array(distortion), R=None, P=K)\n                pts2D = cv2_pts2ds.reshape((-1, 2))\n\n            success, r_pose, t_pose, _ = cv2.solvePnPRansac(pts3D, pts2D, K, None, flags=cv2.SOLVEPNP_SQPNP,\n                                                            iterationsCount=iterationsCount,\n                                                            reprojectionError=reprojectionError,\n                                                            confidence=confidence)\n            if not success:\n                return False, None\n            r_pose = cv2.Rodrigues(r_pose)[0]  # world2cam == world2cam2\n            RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]] # world2cam2\n            return True, np.linalg.inv(RT)  # cam2toworld\n        elif len(pts2D) > 4 and mode == \"poselib\":\n            assert HAS_POSELIB\n            confidence = 0.9999\n            iterationsCount = 10_000\n            # NOTE: `Camera` struct currently contains `width`/`height` fields,\n            # however these are not used anywhere in the code-base and are provided simply to be consistent with COLMAP.\n            # so we put garbage in there\n            colmap_intrinsics = opencv_to_colmap_intrinsics(K)\n            fx = colmap_intrinsics[0, 0]\n            fy = colmap_intrinsics[1, 1]\n            cx = colmap_intrinsics[0, 2]\n            cy = colmap_intrinsics[1, 2]\n            width = img_size[0] if img_size is not None else int(cx*2)\n            height = img_size[1] if img_size is not None else int(cy*2)\n\n            if distortion is None:\n                camera = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]}\n            else:\n                camera = {'model': 'OPENCV', 'width': width, 'height': height,\n                          'params': [fx, fy, cx, cy] + distortion}\n            \n            pts2D = np.copy(pts2D)\n            pts2D[:, 0] += 0.5\n            pts2D[:, 1] += 0.5\n            pose, _ = poselib.estimate_absolute_pose(pts2D, pts3D, camera,\n                                                        {'max_reproj_error': reprojectionError,\n                                                        'max_iterations': iterationsCount,\n                                                        'success_prob': confidence}, {})\n            if pose is None:\n                return False, None\n            RT = pose.Rt  # (3x4)\n            RT = np.r_[RT, [(0,0,0,1)]]  # world2cam\n            return True, np.linalg.inv(RT)  # cam2toworld\n        elif len(pts2D) > 4 and mode == \"pycolmap\":\n            assert HAS_PYCOLMAP\n            assert img_size is not None\n            \n            pts2D = np.copy(pts2D)\n            pts2D[:, 0] += 0.5\n            pts2D[:, 1] += 0.5\n            colmap_intrinsics = opencv_to_colmap_intrinsics(K)\n            fx = colmap_intrinsics[0, 0]\n            fy = colmap_intrinsics[1, 1]\n            cx = colmap_intrinsics[0, 2]\n            cy = colmap_intrinsics[1, 2]\n            width = img_size[0]\n            height = img_size[1]\n            if distortion is None:\n                camera_dict = {'model': 'PINHOLE', 'width': width, 'height': height, 'params': [fx, fy, cx, cy]}\n            else:\n                camera_dict = {'model': 'OPENCV', 'width': width, 'height': height,\n                               'params': [fx, fy, cx, cy] + distortion}\n\n            pycolmap_camera = pycolmap.Camera(\n            model=camera_dict['model'], width=camera_dict['width'], height=camera_dict['height'],\n            params=camera_dict['params'])\n\n            pycolmap_estimation_options = dict(ransac=dict(max_error=reprojectionError, min_inlier_ratio=0.01,\n                                               min_num_trials=1000, max_num_trials=100000,\n                                            confidence=0.9999))\n            pycolmap_refinement_options=dict(refine_focal_length=False, refine_extra_params=False)\n            ret = pycolmap.absolute_pose_estimation(pts2D, pts3D, pycolmap_camera,\n                                                    estimation_options=pycolmap_estimation_options,\n                                                    refinement_options=pycolmap_refinement_options)\n            if ret is None:\n                ret = {'success': False}\n            else:\n                ret['success'] = True\n                if callable(ret['cam_from_world'].matrix):\n                    retmat = ret['cam_from_world'].matrix()\n                else:\n                    retmat = ret['cam_from_world'].matrix\n                ret['qvec'] = quaternion.from_rotation_matrix(retmat[:3, :3])\n                ret['tvec'] = retmat[:3, 3]\n                \n            if not (ret['success'] and ret['num_inliers'] > 0):\n                success = False\n                pose = None\n            else:\n                success = True\n                pr_world_to_querycam = np.r_[ret['cam_from_world'].matrix(), [(0,0,0,1)]]\n                pose = np.linalg.inv(pr_world_to_querycam)\n            return success, pose\n        else:\n            return False, None\n    except Exception as e:\n        print(f'error during pnp: {e}')\n        return False, None"
  },
  {
    "path": "requirements.txt",
    "content": "torch\ntorchvision\nroma\ngradio\nmatplotlib\ntqdm\nopencv-python\nscipy\neinops\ntrimesh\ntensorboard\npyglet<2\nhuggingface-hub[torch]>=0.22"
  },
  {
    "path": "requirements_optional.txt",
    "content": "pillow-heif  # add heif/heic image support\npyrender  # for rendering depths in scannetpp\nkapture  # for visloc data loading\nkapture-localization\nnumpy-quaternion\npycolmap  # for pnp\nposelib  # for pnp\n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# training executable for DUSt3R\n# --------------------------------------------------------\nfrom dust3r.training import get_args_parser, train\n\nif __name__ == '__main__':\n    args = get_args_parser()\n    args = args.parse_args()\n    train(args)\n"
  },
  {
    "path": "visloc.py",
    "content": "#!/usr/bin/env python3\n# Copyright (C) 2024-present Naver Corporation. All rights reserved.\n# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).\n#\n# --------------------------------------------------------\n# Simple visloc script\n# --------------------------------------------------------\nimport numpy as np\nimport random\nimport argparse\nfrom tqdm import tqdm\nimport math\n\nfrom dust3r.inference import inference\nfrom dust3r.model import AsymmetricCroCo3DStereo\nfrom dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf\n\nfrom dust3r_visloc.datasets import *\nfrom dust3r_visloc.localization import run_pnp\nfrom dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset\", type=str, required=True, help=\"visloc dataset to eval\")\n    parser_weights = parser.add_mutually_exclusive_group(required=True)\n    parser_weights.add_argument(\"--weights\", type=str, help=\"path to the model weights\", default=None)\n    parser_weights.add_argument(\"--model_name\", type=str, help=\"name of the model weights\",\n                                choices=[\"DUSt3R_ViTLarge_BaseDecoder_512_dpt\",\n                                         \"DUSt3R_ViTLarge_BaseDecoder_512_linear\",\n                                         \"DUSt3R_ViTLarge_BaseDecoder_224_linear\"])\n    parser.add_argument(\"--confidence_threshold\", type=float, default=3.0,\n                        help=\"confidence values higher than threshold are invalid\")\n    parser.add_argument(\"--device\", type=str, default='cuda', help=\"pytorch device\")\n    parser.add_argument(\"--pnp_mode\", type=str, default=\"cv2\", choices=['cv2', 'poselib', 'pycolmap'],\n                        help=\"pnp lib to use\")\n    parser_reproj = parser.add_mutually_exclusive_group()\n    parser_reproj.add_argument(\"--reprojection_error\", type=float, default=5.0, help=\"pnp reprojection error\")\n    parser_reproj.add_argument(\"--reprojection_error_diag_ratio\", type=float, default=None,\n                               help=\"pnp reprojection error as a ratio of the diagonal of the image\")\n\n    parser.add_argument(\"--pnp_max_points\", type=int, default=100_000, help=\"pnp maximum number of points kept\")\n    parser.add_argument(\"--viz_matches\", type=int, default=0, help=\"debug matches\")\n\n    parser.add_argument(\"--output_dir\", type=str, default=None, help=\"output path\")\n    parser.add_argument(\"--output_label\", type=str, default='', help=\"prefix for results files\")\n    return parser\n\n\nif __name__ == '__main__':\n    parser = get_args_parser()\n    args = parser.parse_args()\n    conf_thr = args.confidence_threshold\n    device = args.device\n    pnp_mode = args.pnp_mode\n    reprojection_error = args.reprojection_error\n    reprojection_error_diag_ratio = args.reprojection_error_diag_ratio\n    pnp_max_points = args.pnp_max_points\n    viz_matches = args.viz_matches\n\n    if args.weights is not None:\n        weights_path = args.weights\n    else:\n        weights_path = \"naver/\" + args.model_name\n    model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device)\n\n    dataset = eval(args.dataset)\n    dataset.set_resolution(model)\n\n    query_names = []\n    poses_pred = []\n    pose_errors = []\n    angular_errors = []\n    for idx in tqdm(range(len(dataset))):\n        views = dataset[(idx)]  # 0 is the query\n        query_view = views[0]\n        map_views = views[1:]\n        query_names.append(query_view['image_name'])\n\n        query_pts2d = []\n        query_pts3d = []\n        for map_view in map_views:\n            # prepare batch\n            imgs = []\n            for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]):\n                imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]),\n                                 idx=idx, instance=str(idx)))\n            output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False)\n            pred1, pred2 = output['pred1'], output['pred2']\n            confidence_masks = [pred1['conf'].squeeze(0) >= conf_thr,\n                                (pred2['conf'].squeeze(0) >= conf_thr) & map_view['valid_rescaled']]\n            pts3d = [pred1['pts3d'].squeeze(0), pred2['pts3d_in_other_view'].squeeze(0)]\n\n            # find 2D-2D matches between the two images\n            pts2d_list, pts3d_list = [], []\n            for i in range(2):\n                conf_i = confidence_masks[i].cpu().numpy()\n                true_shape_i = imgs[i]['true_shape'][0]\n                pts2d_list.append(xy_grid(true_shape_i[1], true_shape_i[0])[conf_i])\n                pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])\n\n            PQ, PM = pts3d_list[0], pts3d_list[1]\n            if len(PQ) == 0 or len(PM) == 0:\n                continue\n            reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(PQ, PM)\n            if viz_matches > 0:\n                print(f'found {num_matches} matches')\n            matches_im1 = pts2d_list[1][reciprocal_in_PM]\n            matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM]\n            valid_pts3d = map_view['pts3d_rescaled'][matches_im1[:, 1], matches_im1[:, 0]]\n\n            # from cv2 to colmap\n            matches_im0 = matches_im0.astype(np.float64)\n            matches_im1 = matches_im1.astype(np.float64)\n            matches_im0[:, 0] += 0.5\n            matches_im0[:, 1] += 0.5\n            matches_im1[:, 0] += 0.5\n            matches_im1[:, 1] += 0.5\n            # rescale coordinates\n            matches_im0 = geotrf(query_view['to_orig'], matches_im0, norm=True)\n            matches_im1 = geotrf(query_view['to_orig'], matches_im1, norm=True)\n            # from colmap back to cv2\n            matches_im0[:, 0] -= 0.5\n            matches_im0[:, 1] -= 0.5\n            matches_im1[:, 0] -= 0.5\n            matches_im1[:, 1] -= 0.5\n\n            # visualize a few matches\n            if viz_matches > 0:\n                viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])]\n                from matplotlib import pyplot as pl\n                n_viz = viz_matches\n                match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)\n                viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]\n\n                H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]\n                img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n                img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)\n                img = np.concatenate((img0, img1), axis=1)\n                pl.figure()\n                pl.imshow(img)\n                cmap = pl.get_cmap('jet')\n                for i in range(n_viz):\n                    (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T\n                    pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)\n                pl.show(block=True)\n\n            if len(valid_pts3d) == 0:\n                pass\n            else:\n                query_pts3d.append(valid_pts3d.cpu().numpy())\n                query_pts2d.append(matches_im0)\n\n        if len(query_pts2d) == 0:\n            success = False\n            pr_querycam_to_world = None\n        else:\n            query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32)\n            query_pts3d = np.concatenate(query_pts3d, axis=0)\n            if len(query_pts2d) > pnp_max_points:\n                idxs = random.sample(range(len(query_pts2d)), pnp_max_points)\n                query_pts3d = query_pts3d[idxs]\n                query_pts2d = query_pts2d[idxs]\n\n            W, H = query_view['rgb'].size\n            if reprojection_error_diag_ratio is not None:\n                reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2)\n            else:\n                reprojection_error_img = reprojection_error\n            success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d,\n                                                    query_view['intrinsics'], query_view['distortion'],\n                                                    pnp_mode, reprojection_error_img, img_size=[W, H])\n\n        if not success:\n            abs_transl_error = float('inf')\n            abs_angular_error = float('inf')\n        else:\n            abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world'])\n\n        pose_errors.append(abs_transl_error)\n        angular_errors.append(abs_angular_error)\n        poses_pred.append(pr_querycam_to_world)\n\n    xp_label = f'tol_conf_{conf_thr}'\n    if args.output_label:\n        xp_label = args.output_label + '_' + xp_label\n    if reprojection_error_diag_ratio is not None:\n        xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}'\n    else:\n        xp_label = xp_label + f'_reproj_err_{reprojection_error}'\n    export_results(args.output_dir, xp_label, query_names, poses_pred)\n    out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors)\n    print(out_string)\n"
  }
]