Repository: facebookresearch/efm3d Branch: main Commit: 07950d73d147 Files: 92 Total size: 848.9 KB Directory structure: gitextract_3a354cf0/ ├── .github/ │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ └── workflows/ │ └── conda-env.yaml ├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── benchmark.md ├── efm3d/ │ ├── __init__.py │ ├── aria/ │ │ ├── __init__.py │ │ ├── aria_constants.py │ │ ├── camera.py │ │ ├── obb.py │ │ ├── pose.py │ │ ├── projection_utils.py │ │ └── tensor_wrapper.py │ ├── config/ │ │ ├── efm_preprocessing_conf.yaml │ │ ├── evl_inf.yaml │ │ ├── evl_inf_desktop.yaml │ │ ├── evl_train.yaml │ │ └── taxonomy/ │ │ ├── aeo_to_efm.csv │ │ ├── ase_sem_name_to_id.csv │ │ └── atek_to_efm.csv │ ├── dataset/ │ │ ├── atek_vrs_dataset.py │ │ ├── atek_wds_dataset.py │ │ ├── augmentation.py │ │ ├── efm_model_adaptor.py │ │ ├── vrs_dataset.py │ │ └── wds_dataset.py │ ├── inference/ │ │ ├── __init__.py │ │ ├── eval.py │ │ ├── fuse.py │ │ ├── model.py │ │ ├── pipeline.py │ │ ├── track.py │ │ └── viz.py │ ├── model/ │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── dinov2_utils.py │ │ ├── dpt.py │ │ ├── evl.py │ │ ├── evl_train.py │ │ ├── image_tokenizer.py │ │ ├── lifter.py │ │ └── video_backbone.py │ ├── thirdparty/ │ │ ├── __init__.py │ │ └── mmdetection3d/ │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── cuda/ │ │ │ ├── cuda_utils.h │ │ │ ├── iou3d.cpp │ │ │ ├── iou3d.h │ │ │ ├── iou3d_kernel.cu │ │ │ ├── setup.py │ │ │ ├── sort_vert.cpp │ │ │ ├── sort_vert.h │ │ │ ├── sort_vert_kernel.cu │ │ │ └── utils.h │ │ └── iou3d.py │ └── utils/ │ ├── __init__.py │ ├── common.py │ ├── depth.py │ ├── detection_utils.py │ ├── evl_loss.py │ ├── file_utils.py │ ├── gravity.py │ ├── image.py │ ├── image_sampling.py │ ├── marching_cubes.py │ ├── mesh_utils.py │ ├── obb_csv_writer.py │ ├── obb_io.py │ ├── obb_matchers.py │ ├── obb_metrics.py │ ├── obb_trackers.py │ ├── obb_utils.py │ ├── pointcloud.py │ ├── ray.py │ ├── reconstruction.py │ ├── render.py │ ├── rescale.py │ ├── viz.py │ ├── voxel.py │ └── voxel_sampling.py ├── environment-mac.yml ├── environment.yml ├── eval.py ├── infer.py ├── prepare_inference.sh ├── requirements-extra.txt ├── requirements.txt ├── sbatch_run.sh └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: .github/CONTRIBUTING.md ================================================ # Contributing to "efm3d" We want to make contributing to this project as easy and transparent as possible. ## Pull Requests We welcome pull requests. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation in the code. 4. Ensure the test suite passes. 5. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License By contributing to "efm3d", you agree that your contributions will be licensed under the [LICENSE](../LICENSE) file in the root directory of this source tree. ================================================ FILE: .github/workflows/conda-env.yaml ================================================ name: Conda Environment CI on: push: branches: - main pull_request: branches: - main jobs: test: name: Test conda env runs-on: "ubuntu-latest" defaults: run: shell: bash -el {0} steps: - uses: actions/checkout@v4 - uses: conda-incubator/setup-miniconda@v3 with: activate-environment: efm3d environment-file: environment-mac.yml python-version: 3.9 auto-activate-base: false - run: | conda info conda list conda activate efm3d pip install -r requirements.txt ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging dist/ build/ eggs/ .eggs/ *.egg-info/ lib/ lib64/ # PyTorch specific *.pt *.pth *.ckpt *.tfevents .ipynb_checkpoints/ # Environment .env venv/ ENV/ # IDEs .vscode/ .idea/ # Miscellaneous .DS_Store Thumbs.db # artifacts *.mp4 # data *.ply data/ tb_logs/ # model weights ckpt/ # output dir *.out output/ # tensoboard output runs/ ================================================ FILE: INSTALL.md ================================================ # Installation We provide two ways to install the dependencies of EFM3D. We recommend using miniconda to manage the dependencies, which also provide a easy setup to for all the additional dependencies listed in `requirements.txt` and `requirements-extra.txt`. ## Install using conda (recommended) First install [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install), then run the following commands under the `` root directory ``` conda env create --file=environment.yml conda activate efm3d cd efm3d/thirdparty/mmdetection3d/cuda/ python setup.py install ``` The commands will first create a conda environment named `efm3d`, and then build the third-party CUDA kernel required for training. ## Install via pip Make sure you have Python>=3.9, then install the dependencies using `pip`. The packages in `requirements.txt` are needed for the basic functionalities of EFM3D, such as running the example model inference to see 3D object detection and surface reconstruction on a [vrs](https://facebookresearch.github.io/vrs/) sequence. ``` pip install -r requirements.txt ``` Additional dependencies in `requirements-extra.txt` are needed for training and eval. ``` pip install -r requirements-extra.txt ``` **Important**: For training, we also need to built a CUDA kernel from [mmdetection3d](https://github.com/open-mmlab/mmdetection3d). Compile the CUDA kernel of the IoU3d loss by running the following commands, which requires the installation of [CUDA dev toolkit](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/). ``` cd efm3d/thirdparty/mmdetection3d/cuda/ python setup.py install ``` ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models [[paper](https://arxiv.org/abs/2406.10224)] [[website](https://www.projectaria.com/research/efm3D/)] ## Intro This is the official release for the paper EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models (https://arxiv.org/abs/2406.10224). To measure progress on what we term Egocentric Foundation Models (EFMs) we establish EFM3D, a benchmark with two core 3D egocentric perception tasks. EFM3D is the first benchmark for 3D object detection and surface regression on high quality annotated egocentric data of [Project Aria](https://www.projectaria.com/). We also propose Egocentric Voxel Lifting (EVL), a baseline for 3D EFMs. We provide the following code and assets - The pretrained EVL model weights for surface reconstruction and 3D object detection on Aria sequences - The datasets included in the EFM3D benchmark, including the training and evaluation data for Aria Synthetic Datasets (ASE), Aria Everyday Objects (AEO) for 3D object detection, and the eval mesh models for surface reconstruction evaluation. - Distributed training code to train EVL. - Native integration with [Aria Training and Evaluation Kit (ATEK)](https://github.com/facebookresearch/atek). The following serves as a minimal example to run the model inference, including installation guide, data downloading instructions and how to run the inference code. ## Installation **Option 1**: First navigate to the root folder. The core library is written in PyTorch, with additional dependencies listed in `requirements.txt`. This needs Python>=3.9 ``` pip install -r requirements.txt ``` **Option 2**: You can choose to use conda to manage the dependencies. We recommend using [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install) for its fast dependency solver. The runtime dependencies can be installed by running (replace `environment.yaml` with `environment-mac.yml` if run on macOS) ``` conda env create --file=environment.yml conda activate efm3d ``` This should be sufficient to initiate the use of the EVL model inference with the pretrained model weights. please refer to [INSTALL.md](INSTALL.md) for a full installation, which is required for training and eval. ## Inference ### Pretrained models Download the pretrained model weights and a sample data on the [EFM3D](https://www.projectaria.com/research/efm3D/#download-dataset) page (email required). We provide two model checkpoints, one for server-side GPU (>20GB GPU memory) and one for desktop GPU. There is a sample sequence attached to the model weights to facilitate using the model. Check out the [README.md](ckpt/README.md) for detailed instructions on how to download the model weights. ### Run on the sample data After downloading the model weights `evl_model_ckpt.zip`, put it under `${EFM3D_DIR}/ckpt/`, then run the command under `${EFM3D_DIR}` ``` sh prepare_inference.sh ``` This will unzip the file, make sure the model weights and sample data are put under the right paths. To run inference on the sample sequence ``` python infer.py --input ./data/seq136_sample/video.vrs ``` **Note**: the pretrained model requires ~20GB GPU memory. Use the following command to run the model on a desktop GPU with ~10GB memory (tested on RTX-3080). The performance is downgraded a bit. ``` python infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08 ``` ### Run on macOS The inference demo works on macOS too. Use the following command (tested on Apple M1 MAX 64GB memory) ``` PYTORCH_ENABLE_MPS_FALLBACK=1 python infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08 ``` This wraps up the basic usage of EVL model. To train the model from scratch and use the EFM3D benchmark, have a full installation following [INSTALL.md](INSTALL.md) then read below ### Inference with ATEK The inference also supports taking [ATEK-format](https://github.com/facebookresearch/atek) WDS sequences. First download a test ASE sequence following the `ASE eval data` section in [README.md](data/README.md), then run ``` python infer.py --input ./data/ase_eval/81022 ``` ## Datasets See [README.md](data/README.md) for instructions to work with all datasets included in the EFM3D benchmark. There are three datasets in the EFM3D benchmark - [Aria Synthetic Environments (ASE)](https://www.projectaria.com/datasets/ase/): for training and eval on 3D object detection and surface reconstruction - [Aria Digital Twin (ADT)](https://www.projectaria.com/datasets/adt/): for eval on surface reconstruction - [Aria Everyday Objects (AEO)](https://www.projectaria.com/datasets/aeo/): for eval on 3D object detection. ## Train EVL First make sure you have a full installation (see [INSTALL.md](INSTALL.md)). Train the EVL model from scratch requires downloading the full ASE training data You can download a small subset of ASE sequences (>10 sequences) to test the training script. Check out the `ASE training data` section in [data/README.md](data/README.md). After following the instructions to prepare the data, run the following command. - train the EVL model from scratch on a single GPU ``` python train.py ``` - train with 8 GPUs ``` torchrun --standalone --nproc_per_node=8 train.py ``` We also provide a script to train on multi-node multi-gpu environment via [slurm](https://slurm.schedmd.com/documentation.html). The pretrained model is trained on 2 nodes with 8xH100. - train with multi-node multi-gpu using slurm ``` sbatch sbatch_run.sh ``` By default the tensorboard log is saved to `${EFM3D_DIR}/tb_logs`. ## EFM3D benchmark Please see [benchmark.md](benchmark.md) for details. ## Citing EFM3D If you find EFM3D useful, please consider citing ``` @article{straub2024efm3d, title={EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models}, author={Straub, Julian and DeTone, Daniel and Shen, Tianwei and Yang, Nan and Sweeney, Chris and Newcombe, Richard}, journal={arXiv preprint arXiv:2406.10224}, year={2024} } ``` If you use Aria Digital Twin (ADT) dataset in the EFM3D benchmark, please consider citing ``` @inproceedings{pan2023aria, title={Aria digital twin: A new benchmark dataset for egocentric 3d machine perception}, author={Pan, Xiaqing and Charron, Nicholas and Yang, Yongqian and Peters, Scott and Whelan, Thomas and Kong, Chen and Parkhi, Omkar and Newcombe, Richard and Ren, Yuheng Carl}, booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, pages={20133--20143}, year={2023} } ``` If you use the Aria Synthetic Environments (ASE) dataset in the EFM3D benchmark, please consider citing ``` @article{avetisyan2024scenescript, title={SceneScript: Reconstructing Scenes With An Autoregressive Structured Language Model}, author={Avetisyan, Armen and Xie, Christopher and Howard-Jenkins, Henry and Yang, Tsun-Yi and Aroudj, Samir and Patra, Suvam and Zhang, Fuyang and Frost, Duncan and Holland, Luke and Orme, Campbell and others}, journal={arXiv preprint arXiv:2403.13064}, year={2024} } ``` ## How to Contribute We welcome contributions! Go to [CONTRIBUTING](./.github/CONTRIBUTING.md) and our [CODE OF CONDUCT](./.github/CODE_OF_CONDUCT.md) for how to get started. ## License EFM3D is released by Meta under the [Apache 2.0 license](LICENSE). ================================================ FILE: benchmark.md ================================================ ## EFM3D Benchmark We provide three evaluation datasets for the EFM3D benchmarks. For more details on the benchmark see the [EFM3D](https://arxiv.org/abs/2406.10224) paper. ### ASE - 3D object detection and mesh reconstruction Aria Synthetic Environments (ASE) is a synthetic dataset, created from procedurally-generated interior layouts filled with 3D objects, simulated with the sensor characteristics of Aria glasses. We use ASE for both surface reconstruction and 3D object detection evaluation. First follow instructions in the dataset [README.md](data/README.md) to download the ASE eval set and eval meshes, then run the following ``` python eval.py --ase ``` Running the full evaluation on 100 eval sequences takes a long time (>10 hrs on a single GPU). To see the eval results on a reduced set, use `--num_seqs` and `--num_snips` to specify number of sequences and number of snippets per sequence to speed up evaluation, for example ``` # run eval on the first 10 sequences of ASE, each running for 100 snippets (10s) python eval.py --ase --num_seqs 10 --num_snips 100 ``` ### ADT - mesh reconstruction ADT is the benchmark data for surface reconstruction, containing 6 sequences. Download the ADT data and mesh files following the data instruction. Then run ``` python eval.py --adt ``` The provided script provides an end-to-end solution to run EVL model with the default checkpoint, finding the right GT mesh path for ASE and ADT dataset, then run the evaluation metrics for mesh-to-mesh distance. If you have your own model that generates a ply file, check [eval_mesh_to_mesh](efm3d/utils/mesh_utils.py) for how to evaluate against surface GT directly. ### AEO - 3D object detection AEO is the benchmark data for 3D object detection, with 25 sequences. Download the AEO dataset following the data instruction. Then run ``` python eval.py --aeo ``` This will run the EVL model inference using the default model checkpoint path. If you have your own model for inference, check [eval.py](efm3d/inference/eval.py) for how to evaluate against 3D object GT directly. ================================================ FILE: efm3d/__init__.py ================================================ ================================================ FILE: efm3d/aria/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .camera import CameraTW, DEFAULT_CAM_DATA_SIZE from .obb import ObbTW, transform_obbs from .pose import PoseTW from .tensor_wrapper import smart_cat, smart_stack, TensorWrapper ================================================ FILE: efm3d/aria/aria_constants.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # High level organization of the constants: # - */time_ns is timestamp with respect to the aria clock in nanoseconds stored as torch.long() # - */snippet_time_s is the timestamp with respect to the start of the snippet in seconds stored as torch.float32() # - */t_A_B is a pose transformation from coordinate system B to A # - path-like key strings designate hierarchical relationships of data. I.e. # rgb/img/... is all data relating to the rgb image information. rgb/calib/... # is all about the calibration data. And all rgb/... is data relating to the # rgb video stream. # --------------------------------------------------------------------- # sequence level information # --------------------------------------------------------------------- ARIA_SEQ_ID = "sequence/id" # start of the sequence in ns relative to global Aria timestamp ARIA_SEQ_TIME_NS = "sequence/time_ns" # --------------------------------------------------------------------- # snippet level information # --------------------------------------------------------------------- ARIA_SNIPPET_ID = "snippet/id_in_sequence" ARIA_SNIPPET_LENGTH_S = "snippet/length_s" # start of sequence in ns relative to global Aria timestamp (sometimes unix 0) ARIA_SNIPPET_TIME_NS = "snippet/time_ns" # offset of snippet coordinate system to sequence coordinate system ARIA_SNIPPET_T_WORLD_SNIPPET = "snippet/t_world_snippet" # Ratio of where in the snippet is the origin of cosy relative to the # snippet length. E.g. 0.5 for a 10 sec snippet would mean that 5 sec is origin, # was previously known as "frame_selection" in LocalCosyPreprocessor. ARIA_SNIPPET_ORIGIN_RATIO = "snippet/origin_ratio" # --------------------------------------------------------------------- # streamer playback time information # --------------------------------------------------------------------- ARIA_PLAY_TIME_NS = "play/time_ns" ARIA_PLAY_SEQUENCE_TIME_S = "play/sequence_time_s" ARIA_PLAY_SNIPPET_TIME_S = "play/snippet_time_s" ARIA_PLAY_FREQUENCY_HZ = "play/hz" # --------------------------------------------------------------------- # aria video stream information # --------------------------------------------------------------------- # frame id in the sequence ARIA_FRAME_ID = [ "rgb/frame_id_in_sequence", "slaml/frame_id_in_sequence", "slamr/frame_id_in_sequence", ] # timestamp within snippet ARIA_IMG_SNIPPET_TIME_S = [ "rgb/img/snippet_time_s", "slaml/img/snippet_time_s", "slamr/img/snippet_time_s", ] # timestamp within sequence ARIA_IMG_TIME_NS = [ "rgb/img/time_ns", "slaml/img/time_ns", "slamr/img/time_ns", ] # poses of the rig at the time of the respective frame capture # T x 12 ARIA_IMG_T_SNIPPET_RIG = [ "rgb/t_snippet_rig", "slaml/t_snippet_rig", "slamr/t_snippet_rig", ] # image tensors ARIA_IMG = ["rgb/img", "slaml/img", "slamr/img"] ARIA_IMG_FREQUENCY_HZ = [ "rgb/img/hz", "slaml/img/hz", "slamr/img/hz", ] # --------------------------------------------------------------------- # calibration information # --------------------------------------------------------------------- ARIA_CALIB = [ "rgb/calib", "slaml/calib", "slamr/calib", ] # timestamp within the snippet ARIA_CALIB_SNIPPET_TIME_S = [ "rgb/calib/snippet_time_s", "slaml/calib/snippet_time_s", "slamr/calib/snippet_time_s", ] # timestamp within the sequence ARIA_CALIB_TIME_NS = [ "rgb/calib/time_ns", "slaml/calib/time_ns", "slamr/calib/time_ns", ] # --------------------------------------------------------------------- # pose information # --------------------------------------------------------------------- # pose timestamp within snippet ARIA_POSE_SNIPPET_TIME_S = "pose/snippet_time_s" # pose timestamp within sequence ARIA_POSE_TIME_NS = "pose/time_ns" # transformation from rig to snippet coordinate system ARIA_POSE_T_SNIPPET_RIG = "pose/t_snippet_rig" # transformation from rig to world coordinate system ARIA_POSE_T_WORLD_RIG = "pose/t_world_rig" # frequency of poses ARIA_POSE_FREQUENCY_HZ = "pose/hz" # --------------------------------------------------------------------- # semidense points information # --------------------------------------------------------------------- ARIA_POINTS_WORLD = "points/p3s_world" ARIA_POINTS_TIME_NS = "points/time_ns" ARIA_POINTS_SNIPPET_TIME_S = "points/snippet_time_s" ARIA_POINTS_FREQUENCY_HZ = "points/hz" ARIA_POINTS_INV_DIST_STD = "points/inv_dist_std" ARIA_POINTS_DIST_STD = "points/dist_std" # --------------------------------------------------------------------- # imu information # --------------------------------------------------------------------- ARIA_IMU = ["imur", "imul"] ARIA_IMU_CHANNELS = [ ["imur/lin_acc_ms2", "imur/rot_vel_rads"], ["imul/lin_acc_ms2", "imul/rot_vel_rads"], ] ARIA_IMU_SNIPPET_TIME_S = ["imur/snippet_time_s", "imul/snippet_time_s"] ARIA_IMU_TIME_NS = ["imur/time_ns", "imul/time_ns"] ARIA_IMU_FACTORY_CALIB = ["imur/factory_calib", "imul/factory_calib"] ARIA_IMU_FREQUENCY_HZ = ["imur/hz", "imul/hz"] # --------------------------------------------------------------------- # audio data # --------------------------------------------------------------------- ARIA_AUDIO = "audio" # snippet time within snippet of audio sample ARIA_AUDIO_SNIPPET_TIME_S = "audio/snippet_time_s" # timestamp of audio sample in sequence ARIA_AUDIO_TIME_NS = "audio/time_ns" # frequency of audio sample ARIA_AUDIO_FREQUENCY_HZ = "audio/hz" # --------------------------------------------------------------------- # OBB # --------------------------------------------------------------------- # padded ObbTW tensor for oriented object bounding boxes given in *snippet coordinate system* ARIA_OBB_PADDED = "obbs/padded_snippet" # mapping of semantic id of the obb to a string name ARIA_OBB_SEM_ID_TO_NAME = "obbs/sem_id_to_name" # snippet time within the sequence ARIA_OBB_SNIPPET_TIME_S = "obbs/snippet_time_s" # timestamp within the sequence ARIA_OBB_TIME_NS = "obbs/time_ns" # frequency of object detection information ARIA_OBB_FREQUENCY_HZ = "obbs/hz" # predicted ObbTW tensor for oriented object bounding boxes ARIA_OBB_PRED = "obbs/pred" # raw predictions from the networks. ARIA_OBB_PRED_VIZ = "obbs/pred_viz" # predictions for visualization (e.g. raw predictions filtered by some criteria.) ARIA_OBB_PRED_SEM_ID_TO_NAME = "obbs/pred/sem_id_to_name" ARIA_OBB_PRED_PROBS_FULL = "obbs/pred/probs_full" ARIA_OBB_PRED_PROBS_FULL_VIZ = "obbs/pred/probs_ful_viz" # tracked ObbTW tensor for oriented object bounding boxes ARIA_OBB_TRACKED = "obbs/tracked" ARIA_OBB_TRACKED_PROBS_FULL = "obbs/tracked/probs_full" # tracked but not instantiated ObbTW tensor for oriented object bounding boxes ARIA_OBB_UNINST = "obbs/uninst" ARIA_OBB_BB2 = ["bb2s_rgb", "bb2s_slaml", "bb2s_slamr"] ARIA_OBB_BB3 = "bb3s_object" # --------------------------------------------------------------------- # depth information # --------------------------------------------------------------------- # for depth images (z-depth) in meters ARIA_DEPTH_M = ["rgb/depth_m", "slaml/depth_m", "slamr/depth_m"] # for distance images (distance along ray) in meters ARIA_DISTANCE_M = ["rgb/distance_m", "slaml/distance_m", "slamr/distance_m"] ARIA_DEPTH_TIME_NS = [ "rgb/depth/time_ns", "slaml/depth/time_ns", "slamr/depth/time_ns", ] ARIA_DEPTH_SNIPPET_TIME_S = [ "rgb/depth/snippet_time_s", "slaml/depth/snippet_time_s", "slamr/depth/snippet_time_s", ] ARIA_DEPTH_M_PRED = ["rgb/pred/depth_m", "slaml/pred/depth_m", "slamr/pred/depth_m"] # for distance images (distance along ray) in meters ARIA_DISTANCE_M_PRED = [ "rgb/pred/distance_m", "slaml/pred/distance_m", "slamr/pred/distance_m", ] # --------------------------------------------------------------------- # SDF information # --------------------------------------------------------------------- ARIA_SDF = "snippet/sdf/sdf" ARIA_SDF_EXT = "snippet/sdf/extent" ARIA_SDF_COSY_TIME_NS = "snippet/sdf/cosy_time_ns" ARIA_SDF_MASK = "snippet/sdf/mask" ARIA_SDF_T_WORLD_VOXEL = "snippet/sdf/T_world_voxel" # --------------------------------------------------------------------- # GT Mesh information # --------------------------------------------------------------------- ARIA_MESH_VERTS_W = "snippet/mesh/verts_w" ARIA_MESH_FACES = "snippet/mesh/faces" ARIA_MESH_VERT_NORMS_W = "snippet/mesh/v_norms_w" ARIA_SCENE_MESH_VERTS_W = "scene/mesh/verts_w" ARIA_SCENE_MESH_FACES = "scene/mesh/faces" ARIA_SCENE_MESH_VERT_NORMS_W = "scene/mesh/v_norms_w" # --------------------------------------------------------------------- # Scene volume information (can be acquired from mesh or semidense points) # -------------------------------------------------------------------- ARIA_MESH_VOL_MIN = "scene/mesh/vol_min" ARIA_MESH_VOL_MAX = "scene/mesh/vol_max" ARIA_POINTS_VOL_MIN = "scene/points/vol_min" ARIA_POINTS_VOL_MAX = "scene/points/vol_max" # --------------------------------------------------------------------- # additional image constants # --------------------------------------------------------------------- # Fixed mapping of resolutions, tuple has three numbers: (RGB_HW, SLAM_W, SLAM_H) RESOLUTION_MAP = { 0: (1408, 640, 480), 1: (704, 640, 480), 2: (352, 320, 240), # 3: there is none 4: (176, 160, 112), # there is some cropping in SLAM image height 5: (480, 640, 480), 6: (336, 448, 336), # match typical internet image FOV (assume 70 deg) 7: (240, 320, 240), # match typical internet pixels e.g. ImageNet 8: (192, 256, 192), 9: (144, 192, 144), # divisible by 14 for ViTs that use patch size 14 10: ( 1400, 560, 420, ), # similar to 0 560x420 instead of 616x462 so that we can also get half the resolution for equivalent to 7 11: (700, 560, 420), # similar to 1 12: (420, 560, 420), # similar to 5 13: (210, 280, 210), # similar to 7 } # Fixed mapping of corresponding wh_multiple_of, for each resolution WH_MULTIPLE_OF_MAP = { 0: 16, 1: 16, 2: 16, # 3: there is none 4: 16, 5: 16, 6: 16, 7: 16, 8: 16, 9: 16, 10: 14, 11: 14, 12: 14, 13: 14, } # Helper constants for managing valid radius of the fisheye images, valid radius # defines a circle from the center of projection where project/unproject is valid RGB_RADIUS_FACTOR = 760.0 / 1408.0 SLAM_RADIUS_FACTOR = 320.0 / 640.0 ARIA_RGB_WIDTH_TO_RADIUS = { RESOLUTION_MAP[key][0]: RESOLUTION_MAP[key][0] * RGB_RADIUS_FACTOR for key in RESOLUTION_MAP } ARIA_SLAM_WIDTH_TO_RADIUS = { RESOLUTION_MAP[key][1]: RESOLUTION_MAP[key][1] * SLAM_RADIUS_FACTOR for key in RESOLUTION_MAP } ARIA_RGB_SCALE_TO_WH = { key: [RESOLUTION_MAP[key][0], RESOLUTION_MAP[key][0]] for key in RESOLUTION_MAP } ARIA_SLAM_SCALE_TO_WH = { key: [RESOLUTION_MAP[key][1], RESOLUTION_MAP[key][2]] for key in RESOLUTION_MAP } ARIA_IMG_MIN_LUX = 30.0 ARIA_IMG_MAX_LUX = 150000.0 ARIA_IMG_MAX_PERC_OVEREXPOSED = 0.02 ARIA_IMG_MAX_PERC_UNDEREXPOSED = 0.0001 # --------------------------------------------------------------------- # EFM Constants # --------------------------------------------------------------------- ARIA_EFM_OUTPUT = "efm/output" ARIA_CAM_INFO = { "name": ["rgb", "slaml", "slamr"], "stream_id": [0, 1, 2], "name_to_stream_id": { "rgb": 0, "slaml": 1, "slamr": 2, }, "width_height": { "rgb": (1408, 1408), "slaml": (640, 480), "slamr": (640, 480), }, # vrs id "id": ["214-1", "1201-1", "1201-2"], "id_to_name": { "214-1": "rgb", "1201-1": "slaml", "1201-2": "slamr", }, # display names "display": [ "RGB", "SLAM Left", "SLAM Right", ], # Physical position on glasses from left to right. "spatial_order": [1, 0, 2], } ================================================ FILE: efm3d/aria/camera.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import Tuple, Union import numpy as np import torch from .pose import get_T_rot_z, IdentityPose, PoseTW from .projection_utils import ( fisheye624_project, fisheye624_unproject, pinhole_project, pinhole_unproject, ) from .tensor_wrapper import autocast, autoinit, smart_cat, TensorWrapper logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class DefaultCameraTWData(TensorWrapper): """Allows multiple input sizes.""" def __init__(self): self._data = -1 * torch.ones(33) @property def shape(self): return (torch.Size([34]), torch.Size([26]), torch.Size([22])) class DefaultCameraTWParam(TensorWrapper): """Allows multiple input sizes.""" def __init__(self): self._data = -1 * torch.ones(15) @property def shape(self): return (torch.Size([16]), torch.Size([15]), torch.Size([8]), torch.Size([4])) class DefaultCameraTWDistParam(TensorWrapper): """Allows multiple input sizes.""" def __init__(self): self._data = -1 * torch.ones(12) @property def shape(self): return (torch.Size([12]), torch.Size([4]), torch.Size([0])) DEFAULT_CAM_DATA = DefaultCameraTWData() DEFAULT_CAM_PARAM = DefaultCameraTWParam() DEFAULT_CAM_DIST_PARAM = DefaultCameraTWDistParam() DEFAULT_CAM_DATA_SIZE = 34 RGB_PARAMS = np.float32( [2 * 600.0, 2 * 352.0, 2 * 352.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ) SLAM_PARAMS = np.float32([500.0, 320.0, 240.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) FISHEYE624_TYPE_STR = ( "FisheyeRadTanThinPrism:f,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4" ) FISHEYE624_DF_TYPE_STR = ( "FisheyeRadTanThinPrism:fu,fv,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4" ) PINHOLE_TYPE_STR = "Pinhole" def is_fisheye624(inp): names = [ "Fisheye624", "f624", FISHEYE624_TYPE_STR, FISHEYE624_DF_TYPE_STR, "FisheyeRadTanThinPrism", "CameraModelType.FISHEYE624", ] names += [name.lower() for name in names] return inp in names def is_kb3(inp): names = ["KB:fu,fv,u0,v0,k0,k1,k2,k3", "KannalaBrandtK3", "KB3"] names += [name.lower() for name in names] return inp in names def is_pinhole(inp): names = ["Pinhole", "Linear", "CameraModelType.LINEAR"] names += [name.lower() for name in names] return inp in names def get_aria_camera(params=SLAM_PARAMS, width=640, height=480, valid_radius=None, B=1): type_str = FISHEYE624_TYPE_STR if params.shape[-1] == 15 else FISHEYE624_DF_TYPE_STR if valid_radius is None: cam = CameraTW.from_surreal(width, height, type_str, params) else: cam = CameraTW.from_surreal( width, height, type_str, params, valid_radius=valid_radius, ) if B > 1: cam = cam.unsqueeze(0).repeat(B, 1) return cam def get_pinhole_camera(params, width=640, height=480, valid_radius=None, B=1): type_str = PINHOLE_TYPE_STR if valid_radius is None: cam = CameraTW.from_surreal(width, height, type_str, params) else: cam = CameraTW.from_surreal( width, height, type_str, params, valid_radius=valid_radius, ) if B > 1: cam = cam.unsqueeze(0).repeat(B, 1) return cam def get_base_aria_rgb_camera_full_res(): params = RGB_PARAMS * 2 params[1:3] += 32 return get_aria_camera(params, 2880, 2880) def get_base_aria_rgb_camera(): return get_aria_camera(RGB_PARAMS, 1408, 1408) def get_base_aria_slam_camera(): return get_aria_camera(SLAM_PARAMS, 640, 480) class CameraTW(TensorWrapper): """ Class to represent a batch of camera calibrations of the same camera type. """ SIZE_IND = slice(0, 2) F_IND = slice(2, 4) C_IND = slice(4, 6) GAIN_IND = 6 EXPOSURE_S_IND = 7 VALID_RADIUS_IND = slice(8, 10) T_CAM_RIG_IND = slice(10, 22) DIST_IND = slice(22, None) @autocast @autoinit def __init__( self, data: Union[torch.Tensor, DefaultCameraTWData] = DEFAULT_CAM_DATA ): assert isinstance(data, torch.Tensor) assert data.shape[-1] in {22, 26, 34} super().__init__(data) @classmethod @autoinit def from_parameters( cls, width: torch.Tensor = -1 * torch.ones(1), height: torch.Tensor = -1 * torch.ones(1), fx: torch.Tensor = -1 * torch.ones(1), fy: torch.Tensor = -1 * torch.ones(1), cx: torch.Tensor = -1 * torch.ones(1), cy: torch.Tensor = -1 * torch.ones(1), gain: torch.Tensor = -1 * torch.ones(1), exposure_s: torch.Tensor = 1e-3 * torch.ones(1), valid_radiusx: torch.Tensor = 99999.0 * torch.ones(1), valid_radiusy: torch.Tensor = 99999.0 * torch.ones(1), T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose, # 1x12. dist_params: Union[ torch.Tensor, DefaultCameraTWDistParam ] = DEFAULT_CAM_DIST_PARAM, ): # Concatenate into one big data tensor, handles TensorWrapper objects. data = smart_cat( [ width, height, fx, fy, cx, cy, gain, exposure_s, valid_radiusx, valid_radiusy, T_camera_rig, dist_params, ], dim=-1, ) return cls(data) @classmethod @autoinit def from_surreal( cls, width: torch.Tensor = -1 * torch.ones(1), height: torch.Tensor = -1 * torch.ones(1), type_str: str = "Fisheye624", params: Union[torch.Tensor, DefaultCameraTWParam] = DEFAULT_CAM_PARAM, gain: torch.Tensor = 1 * torch.ones(1), exposure_s: torch.Tensor = 1e-3 * torch.ones(1), valid_radius: torch.Tensor = 99999.0 * torch.ones(1), T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose, # 1x12. ): # Try to auto-determine the camera model. if ( is_fisheye624(type_str) and params.shape[-1] == 16 ): # Fisheye624 double focals fx = params[..., 0].unsqueeze(-1) fy = params[..., 1].unsqueeze(-1) cx = params[..., 2].unsqueeze(-1) cy = params[..., 3].unsqueeze(-1) dist_params = params[..., 4:] elif ( is_fisheye624(type_str) and params.shape[-1] == 15 ): # Fisheye624 single focal f = params[..., 0].unsqueeze(-1) cx = params[..., 1].unsqueeze(-1) cy = params[..., 2].unsqueeze(-1) dist_params = params[..., 3:] fx = fy = f elif is_kb3(type_str) and params.shape[-1] == 8: # KB3. fx = params[..., 0].unsqueeze(-1) fy = params[..., 1].unsqueeze(-1) cx = params[..., 2].unsqueeze(-1) cy = params[..., 3].unsqueeze(-1) dist_params = params[..., 4:] elif is_pinhole(type_str) and params.shape[-1] == 4: # Pinhole. fx = params[..., 0].unsqueeze(-1) fy = params[..., 1].unsqueeze(-1) cx = params[..., 2].unsqueeze(-1) cy = params[..., 3].unsqueeze(-1) dist_params = params[..., 4:] else: raise NotImplementedError( "Unknown number of params entered for camera model" ) if torch.any(torch.logical_or(valid_radius > height, valid_radius > width)): if not is_pinhole(type_str): # Try to auto-determine the valid radius for fisheye cameras. default_radius = 99999.0 hw_ratio = height / width eyevideo_camera_hw_ratio = torch.tensor(240.0 / 640.0).to(hw_ratio) slam_camera_hw_ratio = torch.tensor(480.0 / 640.0).to(hw_ratio) rgb_camera_hw_ratio = torch.tensor(2880.0 / 2880.0).to(hw_ratio) guess_rgb = hw_ratio == rgb_camera_hw_ratio guess_slam = hw_ratio == slam_camera_hw_ratio guess_eyevideo = hw_ratio == eyevideo_camera_hw_ratio valid_radius = default_radius * torch.ones_like(hw_ratio) valid_radius = torch.where( guess_rgb, 1415 * (height / 2880), valid_radius ) valid_radius = torch.where( guess_slam, 330 * (height / 480), valid_radius ) # This is for Eye Video Camera valid_radius = torch.where( guess_eyevideo, 330 * (height / 480), valid_radius ) if torch.any(valid_radius == default_radius): raise ValueError( f"Failed to auto-determine valid radius based on aspect ratios (valid_radius {valid_radius}, width {width}, height {height})" ) else: # Note that the valid_radius for pinhole camera is not well-defined. # We heuristically set the valid radius to be the half of the image diagonal. # Add one pixel to be sure that all pixels in the image are valid. valid_radius = ( torch.sqrt((width / 2.0) ** 2 + (height / 2.0) ** 2) + 1.0 ) return cls.from_parameters( width=width, height=height, fx=fx, fy=fy, cx=cx, cy=cy, gain=gain, exposure_s=exposure_s, valid_radiusx=valid_radius, valid_radiusy=valid_radius, T_camera_rig=T_camera_rig, dist_params=dist_params, ) @property def size(self) -> torch.Tensor: """Size (width height) of the images, with shape (..., 2).""" return self._data[..., self.SIZE_IND] @property def f(self) -> torch.Tensor: """Focal lengths (fx, fy) with shape (..., 2).""" return self._data[..., self.F_IND] @property def c(self) -> torch.Tensor: """Principal points (cx, cy) with shape (..., 2).""" return self._data[..., self.C_IND] @property def K(self) -> torch.Tensor: """Intrinsic matrix with shape (..., 3, 3)""" K = torch.eye(3, device=self.device, dtype=self.dtype) # Make proper size of K to take care of B and T dims. K_view = [1] * (self.f.ndim - 1) + [3, 3] K_repeat = list(self.f.shape[:-1]) + [1, 1] K = K.view(K_view) K = K.repeat(K_repeat) K[..., 0, 0] = self.f[..., 0] K[..., 1, 1] = self.f[..., 1] K[..., 0, 2] = self.c[..., 0] K[..., 1, 2] = self.c[..., 1] return K @property def K44(self) -> torch.Tensor: """Intrinsic matrix with shape (..., 4, 4)""" K = torch.eye(4, device=self.device, dtype=self.dtype) # Make proper size of K to take care of B and T dims. K_view = [1] * (self.f.ndim - 1) + [4, 4] K_repeat = list(self.f.shape[:-1]) + [1, 1] K = K.view(K_view) K = K.repeat(K_repeat) K[..., 0, 0] = self.f[..., 0] K[..., 1, 1] = self.f[..., 1] K[..., 0, 2] = self.c[..., 0] K[..., 1, 2] = self.c[..., 1] return K @property def gain(self) -> torch.Tensor: """Gain of the camera, with shape (..., 1).""" return self._data[..., self.GAIN_IND].unsqueeze(-1) @property def exposure_s(self) -> torch.Tensor: """Exposure of the camera in seconds, with shape (..., 1).""" return self._data[..., self.EXPOSURE_S_IND].unsqueeze(-1) @property def valid_radius(self) -> torch.Tensor: """Radius from camera center for valid projections, with shape (..., 1).""" return self._data[..., self.VALID_RADIUS_IND] @property def T_camera_rig(self) -> torch.Tensor: """Pose of camera, shape (..., 12).""" return PoseTW(self._data[..., self.T_CAM_RIG_IND]) @property def dist(self) -> torch.Tensor: """Distortion parameters, with shape (..., {0, D}), where D is number of distortion params.""" return self._data[..., self.DIST_IND] @property def params(self) -> torch.Tensor: """Get the camera "params", which are defined as fx,fy,cx,cy,dist""" return torch.cat([self.f, self.c, self.dist], dim=-1) @property def is_fisheye624(self): return self.dist.shape[-1] == 12 @property def is_kb3(self): return self.dist.shape[-1] == 4 @property def is_linear(self): return self.dist.shape[-1] == 0 def set_valid_radius(self, valid_radius: torch.Tensor): self._data[..., self.VALID_RADIUS_IND] = valid_radius def set_T_camera_rig(self, T_camera_rig: PoseTW): self._data[..., self.T_CAM_RIG_IND] = T_camera_rig._data.clone() def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]): """Update the camera parameters after resizing an image.""" if isinstance(scales, (int, float)): scales = (scales, scales) s = self._data.new_tensor(scales) data = torch.cat( [ self.size * s, self.f * s, (self.c + 0.5) * s - 0.5, self.gain, self.exposure_s, self.valid_radius * s, self.T_camera_rig._data, self.dist, ], dim=-1, ) return self.__class__(data) def scale_to_size(self, size_wh: Union[int, Tuple[int]]): """Scale the camera parameters to a given image size""" if torch.unique(self.size).numel() > 2: raise ValueError(f"cannot handle multiple sizes {self.size}") if isinstance(size_wh, int): size_wh = (size_wh, size_wh) i0w = tuple([0] * self.ndim) i0h = tuple([0] * (self.ndim - 1) + [1]) scale = ( float(size_wh[0]) / float(self.size[i0w]), float(size_wh[1]) / float(self.size[i0h]), ) return self.scale(scale) def scale_to(self, im: torch.Tensor): """ Scale the camera parameters to match the size of the given image assumes ...xHxW image tensor convention of pytorch """ H, W = im.shape[-2:] return self.scale_to_size((W, H)) def crop(self, left_top: Tuple[float], size: Tuple[int]): """Update the camera parameters after cropping an image.""" left_top = self._data.new_tensor(left_top) size = self._data.new_tensor(size) # Expand the dimension if self._data is a tensor of CameraTW if len(self._data.shape) > 1: expand_dim = list(self._data.shape[:-1]) + [1] size = size.repeat(expand_dim) left_top = left_top.repeat(expand_dim) data = torch.cat( [ size, self.f, self.c - left_top, self.gain, self.exposure_s, self.valid_radius, self.T_camera_rig._data, self.dist, ], dim=-1, ) return self.__class__(data) @autocast def in_image(self, p2d: torch.Tensor): """Check if 2D points are within the image boundaries.""" assert p2d.shape[-1] == 2, f"p2d shape needs to be 2d {p2d.shape}" # assert p2d.shape[:-2] == self.shape # allow broadcasting size = self.size.unsqueeze(-2) valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), dim=-1) return valid @autocast def in_radius(self, p2d: torch.Tensor): """Check if 2D points are within the valid fisheye radius region.""" assert p2d.shape[-1] == 2, f"p2d shape needs to be 2d {p2d.shape}" dists = torch.linalg.norm( (p2d - self.c.unsqueeze(-2)) / self.valid_radius.unsqueeze(-2), dim=-1, ord=2, ) valid = dists < 1.0 return valid @autocast def in_radius_mask(self): """ Return a mask that is True where 2D points are within the valid fisheye radius region. Returned mask is of shape ... x 1 x H x W, where ... is the shape of the camera (BxT or B for example). """ s = self.shape[:-1] C = self.shape[-1] px = pixel_grid(self.view(-1, C)[0]) H, W, _ = px.shape valids = self.in_radius(px.view(-1, 2)) s = s + (1, H, W) valids = valids.view(s) return valids @autocast def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: """Transform 3D points into 2D pixel coordinates.""" # Explicitly promote the data types. promoted_type = torch.promote_types(self._data.dtype, p3d.dtype) self._data = self._data.to(promoted_type) p3d = p3d.to(promoted_type) # Try to auto-determine the camera model. if self.is_fisheye624: # Fisheye624. params = torch.cat([self.f, self.c, self.dist], dim=-1) if params.ndim == 1: B = p3d.shape[0] params = params.unsqueeze(0).repeat(B, 1) p2d = fisheye624_project(p3d, params) elif self.is_linear: # Pinhole. params = self.params if params.ndim == 1: B = p3d.shape[0] params = params.unsqueeze(0).repeat(B, 1) p2d = pinhole_project(p3d, params) else: raise ValueError( "only fisheye624 and pinhole implemented, kb3 not yet implemented" ) in_image = self.in_image(p2d) in_radius = self.in_radius(p2d) in_front = p3d[..., -1] > 0 valid = in_image & in_radius & in_front return p2d, valid @autocast def unproject(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]: """Transform 2D points into 3D rays.""" # Explicitly promote the data types. promoted_type = torch.promote_types(self._data.dtype, p2d.dtype) self._data = self._data.to(promoted_type) p2d = p2d.to(promoted_type) # Try to auto-determine the camera model. if self.is_fisheye624: # Fisheye624. params = torch.cat([self.f, self.c, self.dist], dim=-1) if params.ndim == 1: B = p2d.shape[0] params = params.unsqueeze(0).repeat(B, 1) rays = fisheye624_unproject(p2d, params) elif self.is_linear: # Pinhole. params = self.params if params.ndim == 1: B = p2d.shape[0] params = params.unsqueeze(0).repeat(B, 1) rays = pinhole_unproject(p2d, params) else: raise ValueError( "only fisheye624 and pinhole implemented, kb3 not yet implemented" ) in_image = self.in_image(p2d) in_radius = self.in_radius(p2d) valid = in_image & in_radius return rays, valid def rotate_90_cw(self): return self.rotate_90(clock_wise=True) def rotate_90_ccw(self): return self.rotate_90(clock_wise=False) def rotate_90(self, clock_wise: bool): dist_params = self.dist.clone() if self.is_fisheye624: # swap thin prism and tangential distortion parameters # {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3} to # {k_0 ... k_5} {p_1 p_0} {s_2 s_3 s_0 s_1} dist_p = self.dist[..., 6:8] dist_s = self.dist[..., 8:12] dist_params[..., 6] = dist_p[..., 1] dist_params[..., 7] = dist_p[..., 0] dist_params[..., 8:10] = dist_s[..., 2:] dist_params[..., 10:12] = dist_s[..., :2] elif self.is_linear: # no need to rotate distortion parameters since there are none pass elif self.is_kb3: raise NotImplementedError("kb3 model rotation not implemented yet") else: raise NotImplementedError(f"camera model not recognized {self}") # clock-wise or counter clock-wise DIR = 1 if clock_wise else -1 # rotate camera extrinsics by 90 degree CW T_rot_z = PoseTW.from_matrix3x4(get_T_rot_z(DIR * np.pi * 0.5)).to(self.device) if clock_wise: # rotate x, y of principal point # x_rotated = height - 1 - y_before # y_rotated = x_before rot_cx = self.size[..., 1] - self.c[..., 1] - 1 rot_cy = self.c[..., 0].clone() else: rot_cx = self.c[..., 1].clone() rot_cy = self.size[..., 0] - self.c[..., 0] - 1 return CameraTW.from_parameters( # swap width and height self.size[..., 1].clone().unsqueeze(-1), self.size[..., 0].clone().unsqueeze(-1), # swap x, y of focal lengths self.f[..., 1].clone().unsqueeze(-1), self.f[..., 0].clone().unsqueeze(-1), rot_cx.unsqueeze(-1), rot_cy.unsqueeze(-1), self.gain.clone(), self.exposure_s.clone(), # swap valid radius x, y self.valid_radius[..., 1].clone().unsqueeze(-1), self.valid_radius[..., 0].clone().unsqueeze(-1), # rotate camera extrinsics T_rot_z @ self.T_camera_rig, dist_params, ) def __repr__(self): return f"CameraTW {self.shape} {self.dtype} {self.device}" def grid_2d( width: int, height: int, output_range=(-1.0, 1.0, -1.0, 1.0), device="cpu", dtype=torch.float32, ): x = torch.linspace( output_range[0], output_range[1], width + 1, device=device, dtype=dtype )[:-1] y = torch.linspace( output_range[2], output_range[3], height + 1, device=device, dtype=dtype )[:-1] xx, yy = torch.meshgrid(x, y, indexing="xy") grid = torch.stack([xx, yy], dim=-1) return grid def pixel_grid(cam: CameraTW): assert cam.ndim == 1, f"Camera must be 1 dimensional {cam.shape}" W, H = int(cam.size[0]), int(cam.size[1]) return grid_2d(W, H, output_range=[0, W, 0, H], device=cam.device, dtype=cam.dtype) def scale_image_to_cam(cams: CameraTW, ims: torch.Tensor) -> torch.Tensor: """Scale an image to a camera.""" from torchvision.transforms import InterpolationMode, Resize T = None if ims.ndim == 5: B, T, C, H, W = ims.shape ims = ims.view(-1, C, H, W) Wo, Ho = cams[0, 0].size.int().tolist() elif ims.ndim == 4: B, C, H, W = ims.shape Wo, Ho = cams[0].size.int().tolist() else: raise ValueError(f"unusable image shape {ims.shape}, {cams.shape}") ims = Resize((Ho, Wo), interpolation=InterpolationMode.BILINEAR, antialias=True)( ims ) if T is not None: return ims.view(B, T, C, Ho, Wo) return ims.view(B, C, Ho, Wo) ================================================ FILE: efm3d/aria/obb.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import List, Tuple, Union import torch import torch.nn.functional as F from .camera import CameraTW from .pose import IdentityPose, PAD_VAL, PoseTW, rotation_from_euler from .tensor_wrapper import autocast, autoinit, smart_cat, smart_stack, TensorWrapper logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) # OBB corner numbering diagram for this implementation (the same as pytorch3d # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/ops/iou_box3d.py#L111) # # (4) +---------+. (5) # | ` . | ` . # | (0) +---+-----+ (1) # | | | | # (7) +-----+---+. (6)| # ` . | ` . | # (3) ` +---------+ (2) # # NOTE: Throughout this implementation, we assume that boxes # are defined by their 8 corners exactly in the order specified in the # diagram above for the function to give correct results. In addition # the vertices on each plane must be coplanar. # As an alternative to the diagram, this is a unit bounding # box which has the correct vertex ordering: # box_corner_vertices = [ # [0, 0, 0], # (0) # [1, 0, 0], # (1) # [1, 1, 0], # (2) # [0, 1, 0], # (3) # [0, 0, 1], # (4) # [1, 0, 1], # (5) # [1, 1, 1], # (6) # [0, 1, 1], # (7) # ] # triangle indices to draw an OBB mesh from bb3corners_* OBB_MESH_TRI_INDS = [ [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2], [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3], [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6], ] # line indices to draw an OBB line strip frame from bb3corners_* OBB_LINE_INDS = [0, 1, 2, 3, 0, 3, 7, 4, 0, 1, 5, 6, 5, 4, 7, 6, 2, 1, 5] # corner indices to construct all edge lines BB3D_LINE_ORDERS = [ [0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], [0, 4], [1, 5], [2, 6], [3, 7], ] _box_planes = [ [0, 1, 2, 3], [3, 2, 6, 7], [0, 1, 5, 4], [0, 3, 7, 4], [1, 2, 6, 5], [4, 5, 6, 7], ] DOT_EPS = 1e-3 AREA_EPS = 1e-4 class ObbTW(TensorWrapper): """ Oriented 3D Bounding Box observation in world coordinates (via T_world_object) for Aria headsets. """ @autocast @autoinit def __init__(self, data: torch.Tensor = PAD_VAL * torch.ones((1, 34))): assert isinstance(data, torch.Tensor) assert data.shape[-1] == 34 super().__init__(data) @classmethod @autoinit def from_lmc( cls, bb3_object: torch.Tensor = PAD_VAL * torch.ones(6), bb2_rgb: torch.Tensor = PAD_VAL * torch.ones(4), bb2_slaml: torch.Tensor = PAD_VAL * torch.ones(4), bb2_slamr: torch.Tensor = PAD_VAL * torch.ones(4), T_world_object: Union[torch.Tensor, PoseTW] = IdentityPose, # 1x12. sem_id: torch.Tensor = PAD_VAL * torch.ones(1), inst_id: torch.Tensor = PAD_VAL * torch.ones(1), prob: torch.Tensor = 1 * torch.ones(1), moveable: torch.Tensor = 0 * torch.ones(1), ): # Concatenate into one big data tensor, handles TensorWrapper objects. # make sure that its on the same device (fails if IdentityPose is used) device = bb3_object.device data = smart_cat( [ bb3_object, bb2_rgb.to(device), bb2_slaml.to(device), bb2_slamr.to(device), T_world_object.to(device), sem_id.to(device), inst_id.to(device), prob.to(device), moveable.to(device), ], dim=-1, ) return cls(data) @property def bb3_object(self) -> torch.Tensor: """3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6).""" return self._data[..., :6] @property def bb3_min_object(self) -> torch.Tensor: """3D bounding box minimum corner [xmin,ymin,zmin] in object coord frame, with shape (..., 3).""" return self._data[..., 0:6:2] @property def bb3_max_object(self) -> torch.Tensor: """3D bounding box maximum corner [xmax,ymax,zmax] in object coord frame, with shape (..., 3).""" return self._data[..., 1:6:2] @property def bb3_center_object(self) -> torch.Tensor: """3D bounding box center in object coord frame, with shape (..., 3).""" return 0.5 * (self.bb3_min_object + self.bb3_max_object) @property def bb3_center_world(self) -> torch.Tensor: """3D bounding box center in world coord frame, with shape (..., 3).""" s = self.bb3_center_object.shape _bb3_center_world = self.T_world_object.view(-1, 12).batch_transform( self.bb3_center_object.view(-1, 3) ) return _bb3_center_world.view(s) @property def bb3_diagonal(self) -> torch.Tensor: """3D bounding box diagonal, with shape (..., 3).""" return self.bb3_max_object - self.bb3_min_object @property def bb3_volumes(self) -> torch.Tensor: """3D bounding box volumes, with shape (..., 1).""" diags = self.bb3_diagonal return diags.prod(dim=-1, keepdim=True) @property def bb2_rgb(self) -> torch.Tensor: """2D bounding box [xmin,xmax,ymin,ymax] as visible in RGB image, -1's if not visible, with shape (..., 4).""" return self._data[..., 6:10] def visible_bb3_ind(self, cam_id) -> torch.Tensor: """Indices of visible 3D bounding boxes in camera cam_id""" bb2_cam = self.bb2(cam_id) vis_ind = torch.all(bb2_cam > 0, dim=-1) return vis_ind @property def bb2_slaml(self) -> torch.Tensor: """2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Left image, -1's if not visible, with shape (..., 4).""" return self._data[..., 10:14] @property def bb2_slamr(self) -> torch.Tensor: """2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Right image, -1's if not visible, with shape (..., 4).""" return self._data[..., 14:18] def bb2(self, cam_id) -> torch.Tensor: """ 2D bounding box [xmin,xmax,ymin,ymax] as visible in camera with given cam_id, -1's if not visible, with shape (..., 4). cam_id == 0 for rgb cam_id == 1 for slam left cam_id == 2 for slam right """ return self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4] def set_bb2(self, cam_id, bb2d, use_mask=True): """ Set 2D bounding box [xmin,xmax,ymin,ymax] in camera with given cam_id == 0 for rgb cam_id == 1 for slam left cam_id == 2 for slam right """ padding_mask = self.get_padding_mask() self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4] = bb2d if use_mask: self._data[padding_mask] = PAD_VAL def set_bb3_object(self, bb3_object, use_mask=True) -> torch.Tensor: """set 3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6).""" padding_mask = self.get_padding_mask() self._data[..., :6] = bb3_object if use_mask: self._data[padding_mask] = PAD_VAL def set_prob(self, prob, use_mask=True): """Set probability score""" padding_mask = self.get_padding_mask() self._data[..., 32] = prob if use_mask: self._data[padding_mask] = PAD_VAL @property def T_world_object(self) -> torch.Tensor: """3D SE3 transform from object to world coords, with shape (..., 12).""" return PoseTW(self._data[..., 18:30]) def get_padding_mask(self) -> torch.Tensor: """get boolean mask indicating which Obbs are valid/non-padded.""" return (self._data == PAD_VAL).all(dim=-1, keepdim=False) def set_T_world_object(self, T_world_object: PoseTW): """set 3D SE3 transform from object to world coords.""" invalid_mask = self.get_padding_mask() self._data[..., 18:30] = T_world_object._data self._data[invalid_mask] = PAD_VAL @property def sem_id(self) -> torch.Tensor: """semantic id, with shape (..., 1).""" return self._data[..., 30].unsqueeze(-1).int() def set_sem_id(self, sem_id: torch.Tensor): """set semantic id to sem_id""" self._data[..., 30] = sem_id.squeeze() @property def inst_id(self) -> torch.Tensor: """instance id, with shape (..., 1).""" return self._data[..., 31].unsqueeze(-1).int() def set_inst_id(self, inst_id: torch.Tensor): """set instance id to inst_id""" self._data[..., 31] = inst_id.squeeze() @property def prob(self) -> torch.Tensor: """probability of detection, with shape (..., 1).""" return self._data[..., 32].unsqueeze(-1) @property def moveable(self) -> torch.Tensor: """boolean if moveable, with shape (..., 1).""" return self._data[..., 33].unsqueeze(-1) @property def bb3corners_world(self) -> torch.Tensor: return self.T_world_object * self.bb3corners_object @property def bb3corners_object(self) -> torch.Tensor: """return the 8 corners of the 3D BB in object coord frame (..., 8, 3).""" ids = [0, 2, 4, 1, 2, 4, 1, 3, 4, 0, 3, 4, 0, 2, 5, 1, 2, 5, 1, 3, 5, 0, 3, 5] b3o = self.bb3_object c3o = b3o[..., ids] c3o = c3o.reshape(*c3o.shape[:-1], 8, 3) return c3o def bb3edge_pts_object(self, num_samples_per_edge: int = 10) -> torch.Tensor: """ return the num_samples_per_edge points per 3D BB edge in object coord frame (..., num_samples_per_edge * 12, 3). num_samples_per_edge == 1 will result in a list of corners (with some duplicates) num_samples_per_edge == 2 will result in a list of corners (with some more duplicates) num_samples_per_edge == 3 will result in a list of corners and edge midpoints ... """ bb3corners = self.bb3corners_object shape = bb3corners.shape alphas = torch.linspace(0, 1, num_samples_per_edge, device=bb3corners.device) alphas = alphas.view([1] * len(shape[:-2]) + [num_samples_per_edge, 1]) alphas = alphas.repeat(list(shape[:-2]) + [1, 3]) betas = torch.ones_like(alphas) - alphas bb3edge_pts = [] for edge_ids in BB3D_LINE_ORDERS: bb3edge_pts.append( bb3corners[..., edge_ids[0], :].unsqueeze(-2) * betas + bb3corners[..., edge_ids[1], :].unsqueeze(-2) * alphas ) return torch.cat(bb3edge_pts, dim=-2) def center(self): """ Returns a ObbTW object where the 3D OBBs are centered in their local coordinate system. I.e. bb3_min_object == - bb3_max_object. """ T_wo = self.T_world_object center_o = self.bb3_center_object # compute centered bb3_object and obb pose T_world_object centered_T_wo = PoseTW.from_Rt(T_wo.R, T_wo.batch_transform(center_o)) centered_bb3_min_o = self.bb3_min_object - center_o centered_bb3_max_o = self.bb3_max_object - center_o centered_bb3_o = torch.stack( [ centered_bb3_min_o[..., 0], centered_bb3_max_o[..., 0], centered_bb3_min_o[..., 1], centered_bb3_max_o[..., 1], centered_bb3_min_o[..., 2], centered_bb3_max_o[..., 2], ], dim=-1, ) return ObbTW.from_lmc( bb3_object=centered_bb3_o, bb2_rgb=self.bb2_rgb, bb2_slaml=self.bb2_slaml, bb2_slamr=self.bb2_slamr, T_world_object=centered_T_wo, sem_id=self.sem_id, inst_id=self.inst_id, prob=self.prob, moveable=self.moveable, ) def add_padding(self, max_elts: int = 1000) -> "ObbTW": """ Adds padding to Obbs, useful for returning batches with a varying number of Obbs. E.g. if in one batch we have 4 Obbs and another one we have 2, setting max_elts=4 will add 2 pads (consisting of all -1s) to the second element in the batch. """ assert self._data.ndim <= 2, "higher than order 2 add_padding not supported yet" elts = self._data num_to_pad = max_elts - len(elts) # All -1's denotes a pad element. pad_elt = PAD_VAL * self._data.new_ones(self._data.shape[-1]) if num_to_pad > 0: rep_elts = torch.stack([pad_elt for _ in range(num_to_pad)], dim=0) elts = torch.cat([elts, rep_elts], dim=0) elif num_to_pad < 0: elts = elts[:max_elts] logger.warning( f"Warning: some obbs have been clipped (actual/max {len(elts)}/{max_elts}) in ObbTW.add_padding()" ) return self.__class__(elts) def remove_padding(self) -> List["ObbTW"]: """ Removes any padding by finding Obbs with all -1s. Returns a list. """ assert self.ndim <= 4, "higher than order 4 remove_padding not supported yet" if self.ndim == 1: return self # Nothing to be done in this case. # All -1's denotes a pad element. pad_elt = (PAD_VAL * self._data.new_ones(self._data.shape[-1])).unsqueeze(-2) is_not_pad = ~torch.all(self._data == pad_elt, dim=-1) if self.ndim == 2: new_data = self.__class__(self._data[is_not_pad]) elif self.ndim == 3: B = self._data.shape[0] new_data = [] for b in range(B): new_data.append(self.__class__(self._data[b][is_not_pad[b]])) else: # self.ndim == 4: B, T = self._data.shape[:2] new_data = [] for b in range(B): new_data.append([]) for t in range(T): new_data[-1].append( self.__class__(self._data[b, t][is_not_pad[b, t]]) ) return new_data def _mark_invalid(self, invalid_mask: torch.Tensor) -> "ObbTW": """ in place mark obbs in this ObbTW as invalid via mask """ assert invalid_mask.ndim == self.ndim - 1, "invalid_mask must match ObbTW" assert invalid_mask.shape[:-1] == self.shape[:-1], ( "invalid_mask must match ObbTW" ) self._data[invalid_mask] = PAD_VAL def _mark_invalid_ids(self, invalid_ids: torch.Tensor) -> "ObbTW": """ in place mark obbs in this ObbTW as invalid via mask """ assert self.ndim == 2, "invalid_ids only supported for 2d ObbTW" assert invalid_ids.ndim == 1, "invalid_ids must be 1d" assert invalid_ids.dtype == torch.int64, "invalid_ids must be int64" self._data[invalid_ids] = PAD_VAL def num_valid(self) -> int: """ Returns the number of valid Obbs in this collection. """ if self.ndim == 1: is_pad = torch.all(self._data == PAD_VAL, dim=-1) return 0 if is_pad.item() else 1 elif self.ndim == 2: is_pad = torch.all(self._data == PAD_VAL, dim=-1) return self.shape[0] - is_pad.sum() elif self.ndim == 3: is_pad = torch.all(self._data == PAD_VAL, dim=-1) return self.shape[0] * self.shape[1] - is_pad.sum() elif self.ndim == 4: is_pad = torch.all(self._data == PAD_VAL, dim=-1) return self.shape[0] * self.shape[1] * self.shape[2] - is_pad.sum() else: raise NotImplementedError(f"{self.shape}") def scale_bb2(self, scale_rgb: float, scale_slam: float): """Update the 2d bb parameters after resizing the underlying images. All 2d bbs are scaled by the same scale specified for the frame of the 2d bb (RGB vs SLAM).""" # Check for padded values and leave those unchanged. pad_rgb = ( torch.all(self.bb2_rgb == PAD_VAL, dim=-1) .unsqueeze(-1) .expand(*self.bb2_rgb.shape) ) pad_slamr = ( torch.all(self.bb2_slamr == PAD_VAL, dim=-1) .unsqueeze(-1) .expand(*self.bb2_slamr.shape) ) pad_slaml = ( torch.all(self.bb2_slaml == PAD_VAL, dim=-1) .unsqueeze(-1) .expand(*self.bb2_slaml.shape) ) sc_rgb = scale_rgb * torch.ones_like(self.bb2_rgb) sc_slamr = scale_slam * torch.ones_like(self.bb2_slamr) sc_slaml = scale_slam * torch.ones_like(self.bb2_slaml) # If False, multiply by scale, if True multiply by 1. sc_rgb = torch.where(pad_rgb, torch.ones_like(sc_rgb), sc_rgb) sc_slamr = torch.where(pad_slamr, torch.ones_like(sc_slamr), sc_slamr) sc_slaml = torch.where(pad_slaml, torch.ones_like(sc_slaml), sc_slaml) data = smart_cat( [ self.bb3_object, self.bb2_rgb * sc_rgb, self.bb2_slaml * sc_slaml, self.bb2_slamr * sc_slamr, self.T_world_object, self.sem_id, self.inst_id, self.prob, self.moveable, ], dim=-1, ) return self.__class__(data) def crop_bb2(self, left_top_rgb: Tuple[float], left_top_slam: Tuple[float]): """Update the 2d bb parameters after cropping the underlying images. All 2d bbs are cropped by the same crop specified for the frame of the 2d bb (RGB vs SLAM). left_top_* is assumed to be a 2D tuple of the left top corner of te crop. """ # accumulate 2d bb formatting of (xmin, xmax, ymin, ymax) left_top_rgb = self._data.new_tensor( (left_top_rgb[0], left_top_rgb[0], left_top_rgb[1], left_top_rgb[1]) ) left_top_slam = self._data.new_tensor( (left_top_slam[0], left_top_slam[0], left_top_slam[1], left_top_slam[1]) ) # Expand the dimension if self._data is a tensor of CameraTW if len(self._data.shape) > 1: expand_dim = list(self._data.shape[:-1]) + [1] left_top_rgb = left_top_rgb.repeat(expand_dim) left_top_slam = left_top_slam.repeat(expand_dim) data = smart_cat( [ self.bb3_object, self.bb2_rgb - left_top_rgb, self.bb2_slaml - left_top_slam, self.bb2_slamr - left_top_slam, self.T_world_object, self.sem_id, self.inst_id, self.prob, self.moveable, ], dim=-1, ) return self.__class__(data) def rotate_bb2_cw(self, image_sizes: List[Tuple[int]]): """Update the 2d bb parameters after rotating the underlying images. Args: image_sizes: List of original image sizes before the rotation. The order of the images sizes should be [(w_rgb, h_rgb), (w_slaml, h_slaml), (w_slamr, h_slamr)]. """ ## Early check the input input sizes assert len(image_sizes) == 3, ( f"the image sizes of 3 video stream should be given, but only got {len(image_sizes)}" ) for s in image_sizes: assert len(s) == 2 # rotate the obbs stream by stream bb2_rgb_cw = rot_obb2_cw(self.bb2_rgb.clone(), image_sizes[0]) bb2_slaml_cw = rot_obb2_cw(self.bb2_slaml.clone(), image_sizes[1]) bb2_slamr_cw = rot_obb2_cw(self.bb2_slamr.clone(), image_sizes[2]) data = smart_cat( [ self.bb3_object, bb2_rgb_cw, bb2_slaml_cw, bb2_slamr_cw, self.T_world_object, self.sem_id, self.inst_id, self.prob, self.moveable, ], dim=-1, ) return self.__class__(data) def rectify_obb2(self, fisheye_cams: List[CameraTW], pinhole_cams: List[CameraTW]): rect_bb2s = [] for idx, (fisheye_cam, pinhole_cam) in enumerate( zip(fisheye_cams, pinhole_cams) ): if idx == 0: # rgb bb2 = self.bb2_rgb elif idx == 1: # slaml bb2 = self.bb2_slaml else: # slamr bb2 = self.bb2_slamr tl_points = bb2[..., [0, 2]].clone() # top-left bl_points = bb2[..., [0, 3]].clone() # bottom-left br_points = bb2[..., [1, 3]].clone() # bottom-right tr_points = bb2[..., [1, 2]].clone() # top-right visible_points = self.visible_bb3_ind(idx) tl_rays, _ = fisheye_cam.unproject(tl_points) br_rays, _ = fisheye_cam.unproject(br_points) bl_rays, _ = fisheye_cam.unproject(bl_points) tr_rays, _ = fisheye_cam.unproject(tr_points) rect_tl_pts, valid = pinhole_cam.project(tl_rays) rect_br_pts, valid = pinhole_cam.project(br_rays) rect_tl_pts, valid = pinhole_cam.project(bl_rays) rect_tr_pts, valid = pinhole_cam.project(tr_rays) rect_concat = torch.cat( [rect_tl_pts, rect_br_pts, rect_tl_pts, rect_tr_pts], dim=-1 ) xmin, _ = torch.min(rect_concat[..., 0::2], dim=-1, keepdim=True) xmax, _ = torch.max(rect_concat[..., 0::2], dim=-1, keepdim=True) ymin, _ = torch.min(rect_concat[..., 1::2], dim=-1, keepdim=True) ymax, _ = torch.max(rect_concat[..., 1::2], dim=-1, keepdim=True) # trim width = pinhole_cam.size.reshape(-1, 2)[0][0] height = pinhole_cam.size.reshape(-1, 2)[0][1] xmin = torch.clamp(xmin, min=0, max=width - 1) xmax = torch.clamp(xmax, min=0, max=width - 1) ymin = torch.clamp(ymin, min=0, max=height - 1) ymax = torch.clamp(ymax, min=0, max=height - 1) rect_bb2 = torch.cat([xmin, xmax, ymin, ymax], dim=-1) # remove the ones without any area areas = (rect_bb2[..., 1] - rect_bb2[..., 0]) * ( rect_bb2[..., 3] - rect_bb2[..., 2] ) areas = areas.unsqueeze(-1) areas = areas.repeat(*([1] * (areas.ndim - 1)), 4) rect_bb2[areas <= 0] = PAD_VAL rect_bb2[~visible_points] = PAD_VAL rect_bb2s.append(rect_bb2) data = smart_cat( [ self.bb3_object, rect_bb2s[0], rect_bb2s[1], rect_bb2s[2], self.T_world_object, self.sem_id, self.inst_id, self.prob, self.moveable, ], dim=-1, ) return self.__class__(data) def get_pseudo_bb2( self, cam: CameraTW, T_world_rig: PoseTW, num_samples_per_edge: int = 1, return_frac_valids: bool = False, ): """ get the 2d bbs of the projection of the 3d bbs into all given camera view points. This is done by sampling points on the 3d bb edges (see bb3edge_pts_object), projecting them and then computing the 2d bbs from the valid projected points. The caller has to make sure the ObbTW has valid 3d bbs data num_samples_per_edge == 1 and num_samples_per_edge == 2 are equivalent (in both cases we project the obb corners into the frames to compute 2d bbs) """ assert self._data.shape[-2] > 0, "No valid 3d bbs data found!" return bb2d_from_project_bb3d( self, cam, T_world_rig, num_samples_per_edge, return_frac_valids ) def get_bb2_heights(self, cam_id): bb2s = self.bb2(cam_id) valid_bb2s = self.visible_bb3_ind(cam_id) heights = bb2s[..., 3] - bb2s[..., 2] heights[~valid_bb2s] = -1 return heights def get_bb2_widths(self, cam_id): bb2s = self.bb2(cam_id) valid_bb2s = self.visible_bb3_ind(cam_id) widths = bb2s[..., 1] - bb2s[..., 0] widths[~valid_bb2s] = -1 return widths def get_bb2_areas(self, cam_id): bb2s = self.bb2(cam_id) valid_bb2s = self.visible_bb3_ind(cam_id) areas = (bb2s[..., 1] - bb2s[..., 0]) * (bb2s[..., 3] - bb2s[..., 2]) areas[~valid_bb2s] = -1 return areas def get_bb2_centers(self, cam_id): bb2s = self.bb2(cam_id) valid_bb2s = self.visible_bb3_ind(cam_id) center_x = (bb2s[..., 0:1] + bb2s[..., 1:2]) / 2.0 center_y = (bb2s[..., 2:3] + bb2s[..., 3:4]) / 2.0 center_2d = torch.cat([center_x, center_y], -1) center_2d[~valid_bb2s] = -1 return center_2d def batch_points_inside_bb3(self, pts_world: torch.Tensor) -> torch.Tensor: """ checks if a set of points is inside the 3d bounding box expected input shape is N x 3 where N is the number of points and the number of obbs in self. """ assert pts_world.shape == self.T_world_object.t.shape pts_object = self.T_world_object.inverse().batch_transform(pts_world) inside_min = (pts_object > self.bb3_min_object).all(-1) inside_max = (pts_object < self.bb3_max_object).all(-1) return torch.logical_and(inside_min, inside_max) def points_inside_bb3( self, pts_world: torch.Tensor, scale_obb: float = 1.0 ) -> torch.Tensor: """ checks if a set of points is inside the 3d bounding box """ assert self.ndim == 1 and pts_world.ndim == 2 pts_object = self.T_world_object.inverse().transform(pts_world) inside_min = (pts_object > self.bb3_min_object * scale_obb).all(-1) inside_max = (pts_object < self.bb3_max_object * scale_obb).all(-1) return torch.logical_and(inside_min, inside_max) def _transform(self, T_new_world): """ in place transform T_world_object as T_new_object = T_new_world @ T_world_object """ T_world_object = self.T_world_object T_new_object = T_new_world @ T_world_object self.set_T_world_object(T_new_object) def transform(self, T_new_world): """ transform T_world_object as T_new_object = T_new_world @ T_world_object """ obb_new = self.clone() obb_new._transform(T_new_world) return obb_new def _transform_object(self, T_object_new): """ in place transform T_world_object as T_world_new = T_world_object @ T_object_new """ T_world_object = self.T_world_object T_world_new = T_world_object @ T_object_new self.set_T_world_object(T_world_new) def filter_by_sem_id(self, keep_sem_ids): valid = self._data.new_zeros(self.shape[:-1]).bool() for si in keep_sem_ids: valid = valid | (self.sem_id == si)[..., 0] self._data[~valid] = PAD_VAL return self def filter_by_prob(self, prob_thr: float): # since PAD_VAL is -1 this will work fine with padded entries invalid = self.prob.squeeze(-1) < prob_thr self._data[invalid] = PAD_VAL return self def filter_bb2_center_by_radius(self, calib, cam_id): """ Inputs calib: CameraTW : shaped ... x 34, matching leading dims with self cam_id : int : integer corresponding to which bb2ds to use (0: rgb, 1: slaml, 2: slamr) """ # Remove detections centers outside of valid_radius. centers = self.get_bb2_centers(cam_id) inside = calib.in_radius(centers) self._data[~inside, :] = PAD_VAL return self def voxel_grid(self, vD: int, vH: int, vW: int): """ Input: Works on obbs shaped (B) x 34 Output: world points sampled uniformly in a voxel grid (B) x vW*vH*vD x 3 """ x_min, x_max, y_min, y_max, z_min, z_max = self.bb3_object.unbind(-1) dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD # take the center position of each voxel rng_x = tensor_linspace( x_min + dW / 2, x_max - dW / 2, steps=vW, device=self.device ) rng_y = tensor_linspace( y_min + dH / 2, y_max - dH / 2, steps=vH, device=self.device ) rng_z = tensor_linspace( z_min + dD / 2, z_max - dD / 2, steps=vD, device=self.device ) if self.ndim > 1: if self.ndim > 2: raise NotImplementedError B = self.shape[0] xs, ys, zs = [], [], [] for b in range(B): xx, yy, zz = torch.meshgrid(rng_x[b], rng_y[b], rng_z[b], indexing="ij") xs.append(xx) ys.append(yy) zs.append(zz) xx = torch.stack(xs) yy = torch.stack(ys) zz = torch.stack(zs) else: xx, yy, zz = torch.meshgrid(rng_x, rng_y, rng_z, indexing="ij") vox_v = torch.stack([xx, yy, zz], axis=-1) vox_v = vox_v.reshape(B, -1, 3) # vox_v = vox_v.unsqueeze(0).repeat(B, 1, 1) T_wv = self.T_world_object vox_w = T_wv * vox_v return vox_w def __repr__(self): return f"ObbTW {self.shape} {self.dtype} {self.device}" def _single_transform_obbs(obbs_padded, Ts_other_world): assert obbs_padded.ndim == 3 # T x N x C assert Ts_other_world.ndim == 2 and Ts_other_world.shape[0] == 1 # 1 x C T, N, C = obbs_padded.shape if T == 0: # Directly return the input since T=0 and there are no obbs to transform. return obbs_padded obbs_transformed = [] for t in range(T): # clone so that we get a new transformed obbs object. obbs = obbs_padded[t, ...].remove_padding().clone() obbs._transform(Ts_other_world) obbs_transformed.append(obbs.add_padding(N)) obbs_transformed = ObbTW(smart_stack(obbs_transformed)) return obbs_transformed def _batched_transform_obbs(obbs_padded, Ts_other_world): assert obbs_padded.ndim == 4 # B x T x N x C assert Ts_other_world.ndim == 3 # T x 1 x C B, T, N, C = obbs_padded.shape obbs_transformed = [] for b in range(B): obbs_transformed.append( _single_transform_obbs(obbs_padded[b], Ts_other_world[b]) ) obbs_transformed = ObbTW(smart_stack(obbs_transformed)) return obbs_transformed def transform_obbs(obbs_padded, Ts_other_world): """ transform padded obbs from the world coordinate system to a "other" coordinate system. """ if obbs_padded.ndim == 4: return _batched_transform_obbs(obbs_padded, Ts_other_world) return _single_transform_obbs(obbs_padded, Ts_other_world) def rot_obb2_cw(bb2: torch.Tensor, size: Tuple[int]): bb2_ori = bb2.clone() # exchange (xmin, xmax, ymin, ymax) -> (ymax, ymin, xmin, xmax) bb2 = bb2[..., [3, 2, 0, 1]] # x_new = height - x_new bb2[..., 0:2] = size[1] - bb2[..., 0:2] - 1 # bring back the invalid entries. bb2[bb2_ori < 0] = bb2_ori[bb2_ori < 0] return bb2 def project_bb3d_onto_image( obbs: ObbTW, cam: CameraTW, T_world_rig: PoseTW, num_samples_per_edge: int = 1 ): """ project 3d bb edge points into snippet images defined by T_world_rig and camera cam. The assumption is that obbs are in the "world" coordinate system of T_world_rig. Supports batched operation. Args: obbs (ObbTW): obbs to project; shape is (Bx)(Tx)Nx34 cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension; T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12 num_samples_per_edge (int): how many points to sample per edge to compute 2d bb (1, and 2 means only corners) Returns: bb3_corners_im (Tensor): bb3 corners in the image coordinate system; shape is (Bx)TxNx8x2 bb3_valids (Tensor): valid indices of bb3_corners_im (indicates which corners lie within the images); shape is (Bx)TxNx8 """ obb_dim = obbs.dim() # support 3 sets of input shapes if obb_dim == 2: # Nx34 # cam: TxC, T_world_rig: Tx12 assert ( cam.dim() == 2 and T_world_rig.dim() == 2 and cam.shape[0] == T_world_rig.shape[0] # T dim should be the same for cam and T_world_rig ), ( f"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}." ) # To the consistent shapes obbs = obbs.unsqueeze(0).unsqueeze(0) # expand to B(1)xT(1)xNx34 cam = cam[None, ...] # expand to B(1)xTxC T_world_rig = T_world_rig[None, ...] # expand to B(1)xTx12 B, T = cam.shape[0:2] N = obbs.shape[-2] obbs = obbs.expand(B, T, *obbs.shape[-2:]) # repeat to real T: B(1)xTxNx34 elif obb_dim == 3: # BxNx34 # cam: BxTxC, T_world_rig: BxTx12 assert cam.dim() == 3 and T_world_rig.dim() == 3 # B dim should be the same assert obbs.shape[0] == cam.shape[0] and obbs.shape[0] == T_world_rig.shape[0] # T dim of cam and pose should be the same assert cam.shape[1] == T_world_rig.shape[1] # To the consistent shapes obbs = obbs.unsqueeze(1) # expand to BxT(1)xNx34 B, T = cam.shape[0:2] obbs = obbs.expand(B, T, *obbs.shape[-2:]) elif obb_dim == 4: # BxTxNx34 pass else: raise ValueError( f"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}." ) # check if all tensors are of correct shapes. assert obbs.dim() == 4 and cam.dim() == 3 and T_world_rig.dim() == 3, ( f"The shapes of obbs, cam and T_world_rig should be BxTxNx34, BxTxC, and BxTx12, respectively. However, we got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}" ) assert ( obbs.shape[0:2] == cam.shape[0:2] and obbs.shape[0:2] == T_world_rig.shape[0:2] ), ( f"The BxT dims should be the same for all tensors, but got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}" ) B, T = cam.shape[0:2] N = obbs.shape[-2] assert N > 0, "obbs have to exist for this frame" # Get pose of camera. T_world_cam = T_world_rig @ cam.T_camera_rig.inverse() # Project the 3D BB corners into the image. # BxTxNx8x3 -> BxTxN*8x3 if num_samples_per_edge <= 2: bb3pts_world = obbs.bb3corners_world.view(B, T, -1, 3) else: bb3pts_object = obbs.bb3edge_pts_object(num_samples_per_edge) bb3pts_world = obbs.T_world_object * bb3pts_object bb3pts_world = bb3pts_world.view(B, T, -1, 3) Npt = bb3pts_world.shape[2] T_world_cam = T_world_cam.unsqueeze(2).repeat(1, 1, Npt, 1) bb3pts_cam = ( T_world_cam.inverse() .view(-1, 12) .batch_transform(bb3pts_world.view(-1, 3)) .view(B, T, -1, 3) ) bb3pts_im, bb3pts_valids = cam.project(bb3pts_cam) bb3pts_im = bb3pts_im.view(B, T, N, -1, 2) bb3pts_valids = bb3pts_valids.detach().view(B, T, N, -1) if obb_dim == 2: # remove B dim if it didn't exist before. bb3pts_im = bb3pts_im.squeeze(0) bb3pts_valids = bb3pts_valids.squeeze(0) return bb3pts_im, bb3pts_valids def bb2d_from_project_bb3d( obbs: ObbTW, cam: CameraTW, T_world_rig: PoseTW, num_samples_per_edge: int = 1, return_frac_valids: bool = False, ): """ get 2d bbs around the 3d bb corners of obbs projected into the image coordinate system defined by T_world_rig and camera cam. The assumption is that obbs are in the "world" coordinate system of T_world_rig. This is done by sampling points on the 3d bb edges (see bb3edge_pts_object), projecting them and then computing the 2d bbs from the valid projected points. Supports batched operation. Args: obbs (ObbTW): obbs to project; shape is (Bx)Nx34 cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension; T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12 Returns: bb2s (Tensor): 2d bounding boxes in the image coordinate system; shape is (Bx)TxNx4 bb2s_valid (Tensor): valid indices of bb2s; shape is (Bx)TxN """ from torchvision.ops.boxes import box_iou bb3corners_im, bb3corners_valids = project_bb3d_onto_image( obbs, cam, T_world_rig, num_samples_per_edge ) # get image points that will min and max reduce correctly given the valid masks bb3corners_im_min = torch.where( bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im), bb3corners_im, 999999 * torch.ones_like(bb3corners_im), ) bb3corners_im_max = torch.where( bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im), bb3corners_im, -999999 * torch.ones_like(bb3corners_im), ) # compute 2d bounding boxes bb2s_min = torch.min(bb3corners_im_min, dim=-2)[0] bb2s_max = torch.max(bb3corners_im_max, dim=-2)[0] bb2s = torch.stack( [bb2s_min[..., 0], bb2s_max[..., 0], bb2s_min[..., 1], bb2s_max[..., 1]], dim=-1 ) # min < max so that it's a valid box. non_empty_boxes = (bb2s[..., 0] < bb2s[..., 1]) & (bb2s[..., 2] < bb2s[..., 3]) if cam.is_linear: bb2s_full = bb2s.clone() # Clamp based on the camera size for linear cameras. # Note that this could generate very big/loose bounding boxes if the object is badly truncated due to out of view. bb2s[..., 0:2] = torch.clamp( bb2s[..., 0:2], min=0, max=cam.size.view(-1, 2)[0, 0] - 1 ) bb2s[..., 2:4] = torch.clamp( bb2s[..., 2:4], min=0, max=cam.size.view(-1, 2)[0, 1] - 1 ) # filter out empty boxes. bb2s_valid = torch.logical_and(non_empty_boxes, bb3corners_valids.any(-1)) if return_frac_valids: frac_valid = torch.zeros_like(bb2s_valid).float() frac_valid[non_empty_boxes] = box_iou( bb2_xxyy_to_xyxy(bb2s[non_empty_boxes]), bb2_xxyy_to_xyxy(bb2s_full[non_empty_boxes]), ).diagonal() else: # count number of valid points num_points = bb3corners_valids.count_nonzero(-1) # valid 2d bbs are non-empty and have at least 1/6 of the edge sample # points in the valid image region bb2s_valid = torch.logical_and( non_empty_boxes, num_points > num_samples_per_edge * 2 ) if return_frac_valids: frac_valid = num_points / bb3corners_valids.shape[-1] frac_valid[~non_empty_boxes] = 0.0 if return_frac_valids: return bb2s, bb2s_valid, frac_valid return bb2s, bb2s_valid def bb2_xxyy_to_xyxy(bb2s): # check if the input is xxyy is_xxyy = torch.logical_and( bb2s[..., 0] <= bb2s[..., 1], bb2s[..., 2] <= bb2s[..., 3] ) is_xxyy = is_xxyy.all() if not is_xxyy: logger.warning("Input 2d bbx doesn't follow xxyy convention.") return bb2s[..., [0, 2, 1, 3]] def bb2_xyxy_to_xxyy(bb2s): # check if the input is xxyy is_xyxy = torch.logical_and( bb2s[..., 0] <= bb2s[..., 2], bb2s[..., 1] <= bb2s[..., 3] ) is_xyxy = is_xyxy.all() if not is_xyxy: logger.warning("Input 2d bbx doesn't follow xyxy convention.") return bb2s[..., [0, 2, 1, 3]] def bb3_xyzxyz_to_xxyyzz(bb3s): """ take bb3 in xyzxyz format and return xxyyzz format. """ return bb3s[..., [0, 3, 1, 4, 2, 5]] def bb3_xyz_xyz_to_xxyyzz(bb3s_min, bb3s_max): """ take min and max points of the bb3 and return xxyyzz format """ return torch.cat([bb3s_min, bb3s_max], -1)[..., [0, 3, 1, 4, 2, 5]] def rnd_obbs(N: int = 1, num_semcls: int = 10, bb3_min_diag=0.1, bb2_min_diag=10): pts3_min = torch.randn(N, 3) pts3_max = pts3_min + bb3_min_diag + torch.randn(N, 3).abs() pts2_min = torch.randn(N, 2) pts2_max = pts2_min + bb2_min_diag + torch.randn(N, 2).abs() obb = ObbTW.from_lmc( bb3_object=bb3_xyzxyz_to_xxyyzz(torch.cat([pts3_min, pts3_max], -1)), prob=torch.ones(N), bb2_rgb=bb2_xyxy_to_xxyy(torch.cat([pts2_min, pts2_max], -1)), sem_id=torch.randint(low=0, high=num_semcls - 1, size=[N]), T_world_object=PoseTW.from_aa(torch.randn(N, 3), 10.0 * torch.randn(N, 3)), ) return obb def obb_time_union(obbs, pad_size=128): """ Take frame level ground truth shaped BxTxNxC and take the union over the time dimensions using the instance id to extend to snippet level obbs shaped BxNxC. """ # T already merged somewhere else. if obbs.ndim == 3: return obbs assert obbs.ndim == 4, "Only B x T x N x C supported" new_obbs = [] for obb in obbs: new_obb = [] flat_time_obb = obb.clone().reshape(-1, 34) unique = flat_time_obb.inst_id.unique() for uni in unique: if uni == PAD_VAL: continue found = int(torch.argwhere(flat_time_obb.inst_id == uni)[0, 0]) found_obb = flat_time_obb[found].clone() new_obb.append(found_obb) if len(new_obb) == 0: print(f"Adding empty OBB in time_union {obbs.shape}") new_obb.append(ObbTW().reshape(-1).to(obbs._data)) new_obbs.append(torch.stack(new_obb).add_padding(pad_size)) new_obbs = torch.stack(new_obbs) # Remove all bb2 observations since we no longer know which frame in time it came from. # Note: we set the visibility for the merged obbs in order to do the evaluation on those theses obbs. pad_mask = new_obbs.get_padding_mask() new_obbs.set_bb2(cam_id=0, bb2d=1) new_obbs.set_bb2(cam_id=1, bb2d=1) new_obbs.set_bb2(cam_id=2, bb2d=1) new_obbs._data[pad_mask] = PAD_VAL return new_obbs def obb_filter_outside_volume(obbs, T_ws, T_wv, voxel_extent, border=0.1): """ Remove obbs outside a volume of size voxel_extent, e.g. from a lifter volume. Obbs are filtered based on their center point being inside the volume, and are additionally filtered near the border. """ assert obbs.ndim == 3, "Only B x N x C supported" T_vs = T_wv.inverse() @ T_ws obbs_v = obbs.transform(T_vs.unsqueeze(1)) centers_v = obbs_v.bb3_center_world cx = centers_v[:, :, 0] cy = centers_v[:, :, 1] cz = centers_v[:, :, 2] x_min = voxel_extent[0] x_max = voxel_extent[1] y_min = voxel_extent[2] y_max = voxel_extent[3] z_min = voxel_extent[4] z_max = voxel_extent[5] valid = (obbs_v.inst_id != PAD_VAL).squeeze(-1) inside = ( (cx > (x_min + border)) & (cy > (y_min + border)) & (cz > (z_min + border)) & (cx < (x_max - border)) & (cy < (y_max - border)) & (cz < (z_max - border)) ) remove = valid & ~inside obbs._data[remove, :] = PAD_VAL return obbs def tensor_linspace(start, end, steps, device): """ Vectorized version of torch.linspace. Inputs: - start: Tensor of any shape - end: Tensor of the same shape as start - steps: Integer Returns: - out: Tensor of shape start.size() + (steps,), such that out.select(-1, 0) == start, out.select(-1, -1) == end, and the other elements of out linearly interpolate between start and end. """ assert start.size() == end.size() view_size = start.size() + (1,) w_size = (1,) * start.dim() + (steps,) out_size = start.size() + (steps,) start_w = torch.linspace(1, 0, steps=steps, device=device).to(start) start_w = start_w.view(w_size).expand(out_size) end_w = torch.linspace(0, 1, steps=steps, device=device).to(start) end_w = end_w.view(w_size).expand(out_size) start = start.contiguous().view(view_size).expand(out_size) end = end.contiguous().view(view_size).expand(out_size) out = start_w * start + end_w * end return out def make_obb(sz, position, prob=1.0, roll=0.0, pitch=0.0, yaw=0.1): e_angles = torch.tensor([roll, pitch, yaw]).reshape(-1, 3) R = rotation_from_euler(e_angles).reshape(3, 3) T_voxel_object = PoseTW.from_Rt(R, torch.tensor(position)) bb3 = [ -sz[0] / 2.0, sz[0] / 2.0, -sz[1] / 2.0, sz[1] / 2.0, -sz[2] / 2.0, sz[2] / 2.0, ] return ObbTW.from_lmc( bb3_object=torch.tensor(bb3), prob=[prob], T_world_object=T_voxel_object, ) # =====> Main function for 3D IoU computation. <======= def obb_iou3d(obb1: ObbTW, obb2: ObbTW, samp_per_dim=32): """ Computes the intersection of two boxes by sampling points uniformly in x,y,z dims. samp_per_dim: int, number of samples per dimension, e.g. if 8, then 8x8x8 increase for more accuracy but less speed 8: fast but not so accurate 32: medium 128: most accurate but slow """ assert obb1.ndim == 2 assert obb2.ndim == 2 B1 = obb1.shape[0] B2 = obb2.shape[0] vol1 = obb1.bb3_volumes vol2 = obb2.bb3_volumes dim = samp_per_dim points1_w = obb1.voxel_grid(vD=dim, vH=dim, vW=dim) points2_w = obb2.voxel_grid(vD=dim, vH=dim, vW=dim) num_samples = points1_w.shape[1] isin21 = is_point_inside_box(points2_w, obb1.bb3corners_world, verbose=True) num21 = isin21.sum(dim=-1) isin12 = is_point_inside_box(points1_w, obb2.bb3corners_world, verbose=True) num12 = isin12.sum(dim=-1) inters12 = vol1.view(B1, 1) * num12.view(B1, B2) inters21 = vol2.view(B2, 1) * num21.view(B2, B1) inters = (inters12 + inters21.transpose(1, 0)) / 2.0 union = (vol1.view(B1, 1) * num_samples) + (vol2.view(1, B2) * num_samples) - inters iou = inters / union return iou def is_point_inside_box(points: torch.Tensor, box: torch.Tensor, verbose=False): """ Determines whether points are inside the boxes Args: points: tensor of shape (B1, P, 3) of the points box: tensor of shape (B2, 8, 3) of the corners of the boxes Returns: inside: bool tensor of whether point (row) is in box (col) shape (B1, B2, P) """ device = box.device B1 = points.shape[0] B2 = box.shape[0] P = points.shape[1] normals = box_planar_dir(box) # (B2, 6, 3) box_planes = get_plane_verts(box) # (B2, 6, 4, 3) NP = box_planes.shape[1] # = 6 # a point p is inside the box if it "inside" all planes of the box # so we run the checks ins = torch.zeros((B1, B2, P, NP), device=device, dtype=torch.bool) # ins = [] for i in range(NP): is_in = is_inside(points, box_planes[:, i], normals[:, i]) ins[:, :, :, i] = is_in # ins.append(is_in) # ins = torch.stack(ins, dim=-1) ins = ins.all(dim=-1) return ins def box_planar_dir( box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS ) -> torch.Tensor: """ Finds the unit vector n which is perpendicular to each plane in the box and points towards the inside of the box. The planes are defined by `_box_planes`. Since the shape is convex, we define the interior to be the direction pointing to the center of the shape. Args: box: tensor of shape (B, 8, 3) of the vertices of the 3D box Returns: n: tensor of shape (B, 6) of the unit vector orthogonal to the face pointing towards the interior of the shape """ assert box.shape[1] == 8 and box.shape[2] == 3 # center point of each box box_ctr = box.mean(dim=1).view(-1, 1, 3) # box planes plane_verts = get_plane_verts(box) # (B, 6, 4, 3) v0, v1, v2, v3 = plane_verts.unbind(2) plane_ctr, n = get_plane_center_normal(plane_verts) # Check all verts are coplanar normv = F.normalize(v3 - v0, dim=-1).unsqueeze(2).reshape(-1, 1, 3) nn = n.unsqueeze(3).reshape(-1, 3, 1) dists = normv @ nn if not (dists.abs() < dot_eps).all().item(): msg = "Plane vertices are not coplanar" raise ValueError(msg) # Check all faces have non zero area area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 if (area1 < area_eps).any().item() or (area2 < area_eps).any().item(): msg = "Planes have zero areas" raise ValueError(msg) # We can write: `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1). # With = 0 and = 0, where <.,.> refers to the dot product, # since that e0 is orthogonal to n. Same for e1. """ # Below is how one would solve for (a, b, c) # Solving for (a, b) numF = verts.shape[0] A = torch.ones((numF, 2, 2), dtype=torch.float32, device=device) B = torch.ones((numF, 2), dtype=torch.float32, device=device) A[:, 0, 1] = (e0 * e1).sum(-1) A[:, 1, 0] = (e0 * e1).sum(-1) B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1) B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1) ab = torch.linalg.solve(A, B) # (numF, 2) a, b = ab.unbind(1) # solving for c c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1) """ # Since we know that = 0 and = 0 (e0 and e1 are orthogonal to n), # the above solution is equivalent to direc = F.normalize(box_ctr - plane_ctr, dim=-1) # (6, 3) c = (direc * n).sum(-1) # If c is negative, then we revert the direction of n such that n points "inside" negc = c < 0.0 n[negc] *= -1.0 # c[negc] *= -1.0 # Now (a, b, c) is the solution to (1) return n def get_plane_verts(box: torch.Tensor) -> torch.Tensor: """ Return the vertex coordinates forming the planes of the box. The computation here resembles the Meshes data structure. But since we only want this tiny functionality, we abstract it out. Args: box: tensor of shape (B, 8, 3) Returns: plane_verts: tensor of shape (B, 6, 4, 3) """ device = box.device B = box.shape[0] faces = torch.tensor(_box_planes, device=device, dtype=torch.int64) # (6, 4) plane_verts = torch.stack([box[b, faces] for b in range(B)]) # (B, 6, 4, 3) return plane_verts def is_inside( points: torch.Tensor, plane: torch.Tensor, normal: torch.Tensor, return_proj: bool = True, ): """ Computes whether point is "inside" the plane. The definition of "inside" means that the point has a positive component in the direction of the plane normal defined by n. For example, plane | | . (A) |--> n | .(B) | Point (A) is "inside" the plane, while point (B) is "outside" the plane. Args: points: tensor of shape (B1, P, 3) of coordinates of a point plane: tensor of shape (B2, 4, 3) of vertices of a box plane normal: tensor of shape (B2, 3) of the unit "inside" direction on the plane return_proj: bool whether to return the projected point on the plane Returns: is_inside: bool of shape (B2, P) of whether point is inside """ device = plane.device assert plane.ndim == 3 assert normal.ndim == 2 assert points.ndim == 3 assert points.shape[2] == 3 B1 = points.shape[0] B2 = plane.shape[0] P = points.shape[1] v0, v1, v2, v3 = plane.unbind(dim=1) plane_ctr = plane.mean(dim=1) e0 = F.normalize(v0 - plane_ctr, dim=1) e1 = F.normalize(v1 - plane_ctr, dim=1) dot1 = (e0.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2) if not torch.allclose(dot1, torch.zeros((B2,), device=device), atol=1e-2): raise ValueError("Input n is not perpendicular to the plane") dot2 = (e1.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2) if not torch.allclose(dot2, torch.zeros((B2,), device=device), atol=1e-2): raise ValueError("Input n is not perpendicular to the plane") # Every point p can be written as p = ctr + a e0 + b e1 + c n # solving for c # c = (point - ctr - a * e0 - b * e1).dot(n) pts = points.view(B1, 1, P, 3) ctr = plane_ctr.view(1, B2, 1, 3) e0 = e0.view(1, B2, 1, 3) e1 = e1.view(1, B2, 1, 3) normal = normal.view(1, B2, 1, 3) direc = torch.sum((pts - ctr) * normal, dim=-1) ins = direc >= 0.0 return ins def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor: """ Returns the center and normal of planes Args: planes: tensor of shape (B, P, 4, 3) Returns: center: tensor of shape (B, P, 3) normal: tensor of shape (B, P, 3) """ B = planes.shape[0] add_dim1 = False if planes.ndim == 3: planes = planes.unsqueeze(1) add_dim1 = True ctr = planes.mean(dim=2) # (B, P, 3) normals = torch.zeros_like(ctr) v0, v1, v2, v3 = planes.unbind(dim=2) # 4 x (B, P, 3) P = planes.shape[1] for t in range(P): ns = torch.zeros((B, 6, 3), device=planes.device) ns[:, 0] = torch.cross(v0[:, t] - ctr[:, t], v1[:, t] - ctr[:, t], dim=-1) ns[:, 1] = torch.cross(v0[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1) ns[:, 2] = torch.cross(v0[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1) ns[:, 3] = torch.cross(v1[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1) ns[:, 4] = torch.cross(v1[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1) ns[:, 5] = torch.cross(v2[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1) ii = torch.argmax(torch.norm(ns, dim=-1), dim=-1) normals[:, t] = ns[torch.arange(B), ii] if add_dim1: ctr = ctr[:, 0] normals = normals[:, 0] normals = F.normalize(normals, dim=-1) return ctr, normals ================================================ FILE: efm3d/aria/pose.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math from typing import Dict, List, Tuple, Union import numpy as np import torch import torch.nn.functional as F from .tensor_wrapper import autocast, autoinit, smart_stack, TensorWrapper logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) IdentityPose = torch.tensor( [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] ).reshape(12) PAD_VAL = -1 def get_T_rot_z(angle: float): T_rot_z = np.array( [ [np.cos(angle), -np.sin(angle), 0.0, 0.0], [np.sin(angle), np.cos(angle), 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], ] ) return torch.from_numpy(T_rot_z).float() def skew_symmetric(v): """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).""" z = torch.zeros_like(v[..., 0]) M = torch.stack( [ z, -v[..., 2], v[..., 1], v[..., 2], z, -v[..., 0], -v[..., 1], v[..., 0], z, ], dim=-1, ).reshape(v.shape[:-1] + (3, 3)) return M def inv_skew_symmetric(V): """Create a (batched) vector from a skew-symmetric matrix of size (..., 3, 3).""" # average lower and uper triangular entries in case skew symmetric matrix # has numeric errors. VVT = 0.5 * (V - V.transpose(-2, -1)) return torch.stack( [ -VVT[..., 1, 2], VVT[..., 0, 2], -VVT[..., 0, 1], ], -1, ) def so3exp_map(w, eps: float = 1e-7): """Compute rotation matrices from batched twists. Args: w: batched 3D axis-angle vectors of size (..., 3). Returns: A batch of rotation matrices of size (..., 3, 3). """ theta = w.norm(p=2, dim=-1, keepdim=True) small = theta < eps div = torch.where(small, torch.ones_like(theta), theta) W = skew_symmetric(w / div) theta = theta[..., None] # ... x 1 x 1 res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta)) res = torch.where(small[..., None], W, res) # first-order Taylor approx return torch.eye(3).to(W) + res def so3log_map(R, eps: float = 1e-7): trace = torch.diagonal(R, dim1=-1, dim2=-2).sum(-1) cos = torch.clamp((trace - 1.0) * 0.5, -1, 1) theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1) ones = torch.ones_like(theta) small = theta < eps # compute factors and approximate them around 0 using second order # taylor expansion (from WolframAlpha) theta_over_sin_theta = torch.where( small, ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0, theta / torch.sin(theta), ) # compute log-map W of rotation R first W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2)) omega = inv_skew_symmetric(W) return omega def interpolation_boundaries_alphas(times: torch.Tensor, interp_times: torch.Tensor): """ find the ids in times tensor that bound each of the interp_times timestamps from below (lower_ids) and above (upper_ids). If interp_times are outside the interval spanned by times, upper and lower ids will both point to the boundary timestamps and the returned good boolean tensor will be False at those interpolation timestamps. Also return the alphas needed to interpolate a value as: interp_value = alpha * value[lower_id] + (1-alpha)* value[upper_id] Note that because the upper and lower ids are pointing to the boundary timestamps when the interpolation time is outside the time interval, applying the formula above will yield the values at the boundaries as a reasonable "interpolation". No extrapolation will be performed. Again the good values can be used to check which values are at the boundaries and which ones are interpolated. """ times = times.unsqueeze(-2) interp_times = interp_times.unsqueeze(-1) dt = times - interp_times if dt.dtype == torch.long: dt_max = torch.iinfo(type=dt.dtype).max else: dt_max = torch.finfo(type=dt.dtype).max dt_upper = torch.where(dt < 0.0, torch.ones_like(dt) * dt_max, dt) dt_lower = torch.where(dt > 0.0, torch.ones_like(dt) * dt_max, -dt) upper_alpha, upper_ids = torch.min(dt_upper, dim=-1) lower_alpha, lower_ids = torch.min(dt_lower, dim=-1) good = torch.logical_and(lower_alpha < dt_max, upper_alpha < dt_max) upper_ids = torch.where(good, upper_ids, torch.maximum(lower_ids, upper_ids)) lower_ids = torch.where(good, lower_ids, torch.minimum(lower_ids, upper_ids)) assert (lower_ids <= upper_ids).all() # nan_to_num handles the case where time and interpolation time are the same # and hence this is a 0/0 # okay to go to floats now since the critical bit is the computation of the time difference alpha = torch.nan_to_num(lower_alpha.float() / (lower_alpha + upper_alpha).float()) alpha = torch.where(good, alpha, torch.zeros_like(alpha)) return lower_ids, upper_ids, alpha, good def quaternion_to_matrix(quaternions_wxyz: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to rotation matrices. Input quaternions should be in wxyz format, with real part first, imaginary part last. The function is copied from `quaternion_to_matrix` in Pytorch3d: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions_wxyz, -1) two_s = 2.0 / (quaternions_wxyz * quaternions_wxyz).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions_wxyz.shape[:-1] + (3, 3)) class PoseTW(TensorWrapper): @autocast @autoinit def __init__(self, data: torch.Tensor = IdentityPose): assert isinstance(data, torch.Tensor) assert data.shape[-1] == 12 super().__init__(data) @classmethod @autocast def from_Rt(cls, R: torch.Tensor, t: torch.Tensor): """Pose from a rotation matrix and translation vector. Accepts numpy arrays or PyTorch tensors. Args: R: rotation matrix with shape (..., 3, 3). t: translation vector with shape (..., 3). """ assert R.shape[-2:] == (3, 3) assert t.shape[-1] == 3 assert R.shape[:-2] == t.shape[:-1] data = torch.cat([R.flatten(start_dim=-2), t], -1) return cls(data) @classmethod @autocast def from_qt(cls, quaternion_wxyz: torch.Tensor, t: torch.Tensor): """Pose from quaternion and translation vectors. Quaternion should be wxyz format, with real part first, and imaginary part last. Args: quaternion: quaternion with shape (..., 4). t: translation vector with shape (..., 3). """ assert quaternion_wxyz.shape[:-1] == t.shape[:-1], ( f"quaternion shape {quaternion_wxyz.shape[:-1]} must match translation shape {t.shape[:-1]} expect the last dim" ) assert quaternion_wxyz.shape[-1] == 4, "quaternion must be of shape (..., 4)" assert t.shape[-1] == 3, "translation must be of shape (..., 3)" R = quaternion_to_matrix(quaternion_wxyz) data = torch.cat([R.flatten(start_dim=-2), t], -1) return cls(data) @classmethod @autocast def from_aa(cls, aa: torch.Tensor, t: torch.Tensor): """Pose from an axis-angle rotation vector and translation vector. Accepts numpy arrays or PyTorch tensors. Args: aa: axis-angle rotation vector with shape (..., 3). t: translation vector with shape (..., 3). """ assert aa.shape[-1] == 3 assert t.shape[-1] == 3 assert aa.shape[:-1] == t.shape[:-1] return cls.from_Rt(so3exp_map(aa), t) @classmethod @autocast def from_matrix(cls, T: torch.Tensor): """Pose from an SE(3) transformation matrix. Args: T: transformation matrix with shape (..., 4, 4). """ assert T.shape[-2:] == (4, 4) R, t = T[..., :3, :3], T[..., :3, 3] return cls.from_Rt(R, t) @classmethod @autocast def from_matrix3x4(cls, T_3x4: torch.Tensor): """Pose from an SE(3) transformation matrix. Args: T: transformation matrix with shape (..., 3, 4). """ assert T_3x4.shape[-2:] == (3, 4) R, t = T_3x4[..., :3, :3], T_3x4[..., :3, 3] return cls.from_Rt(R, t) @classmethod @autocast def exp(cls, u_omega: torch.Tensor, eps: float = 1e-7): """ Compute the SE3 exponential map from input se3 vectors u_omega [....,6] where the last 3 entries are the so3 entires omega and the first 3 the entries for translation. """ # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf u = u_omega[..., :3] omega = u_omega[..., 3:] theta = omega.norm(p=2, dim=-1, keepdim=True).unsqueeze(-1) small = theta < eps R = so3exp_map(omega, eps) # compute V shape = [1] * len(omega.shape[:-1]) ones = torch.ones_like(theta) # compute factors and approximate them around 0 using second order # taylor expansion (from WolframAlpha) b = torch.where( small, 0.5 * ones - theta**2 / 24.0 + theta**4 / 720.0, (ones - torch.cos(theta)) / theta**2, ) c = torch.where( small, 1.0 / 6.0 * ones - theta**2 / 120.0 + theta**4 / 5040.0, (theta - torch.sin(theta)) / theta**3, ) Identity = ( torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(u_omega) ) W = skew_symmetric(omega) V = Identity + b * W + c * W @ W # compute t t = (V @ u.unsqueeze(-1)).squeeze(-1) return cls.from_Rt(R, t) # @classmethod # def from_colmap(cls, image: NamedTuple): # '''Pose from a COLMAP Image.''' # return cls.from_Rt(image.qvec2rotmat(), image.tvec) @property def R(self) -> torch.Tensor: """Underlying rotation matrix with shape (..., 3, 3).""" rvec = self._data[..., :9] return rvec.reshape(rvec.shape[:-1] + (3, 3)) @property def t(self) -> torch.Tensor: """Underlying translation vector with shape (..., 3).""" return self._data[..., -3:] @property def q(self) -> torch.Tensor: """ Convert rotations of shape (..., 3, 3) to a quaternion (..., 4). The returned quaternions have real part first, as wxyz. The function is adapted from `matrix_to_quaternion` in Pytorch3d: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py The major difference to the original pytorch3d function is that the returned quaternions are normalized and have positive real part. """ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret matrix = self.R batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) # we produce the desired quaternion multiplied by each of r, i, j, k quat_by_wxyz = torch.stack( [ torch.stack( [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1 ), torch.stack( [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1 ), torch.stack( [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1 ), torch.stack( [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1 ), ], dim=-2, ) # We floor here at 0.1 but the exact level is not important; if q_abs is small, # the candidate won't be picked. flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_wxyz / (2.0 * q_abs[..., None].max(flr)) # if not for numerical problems, quat_candidates[i] should be same (up to a sign), # forall i; we pick the best-conditioned one (with the largest denominator) best_quat = quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : ].reshape(batch_dim + (4,)) # normalize quaternions and make the real part to be positive for all quaternions best_quat = best_quat.reshape(-1, 4) neg_ind = torch.nonzero(best_quat[:, 0] < 0).squeeze() best_quat[neg_ind, :] *= -1 best_quat = best_quat.reshape(batch_dim + (4,)) best_quat_normalized = F.normalize(best_quat, p=2, dim=-1) return best_quat_normalized @property def q_xyzw(self) -> torch.Tensor: """ Get the quaternion representation similar to self.q, but the real part of the quaternion comes last rather than first. This is a handy function to increase interoperability, e.g. lietorch requires xyzw quaternions. """ quat_wxyz = self.q return torch.concat([quat_wxyz[..., 1:4], quat_wxyz[..., 0:1]], dim=-1) @property def matrix3x4(self) -> torch.Tensor: """Underlying transformation matrix with shape (..., 3, 4).""" rvec = self._data[..., :9] rmat = rvec.reshape(rvec.shape[:-1] + (3, 3)) tvec = self._data[..., -3:].unsqueeze(-1) T = torch.cat([rmat, tvec], dim=-1) return T @property def matrix(self) -> torch.Tensor: """Underlying transformation matrix with shape (..., 4, 4).""" T_3x4 = self.matrix3x4 bot_row = T_3x4.new_zeros(T_3x4.shape[:-2] + (1, 4)) bot_row[..., 0, 3] = 1 return torch.cat([T_3x4, bot_row], dim=-2) def to_euler(self, rad=True) -> torch.Tensor: """Convert the rotation matrix to Euler angles using ZYX convention.""" """Reference: http://eecs.qmul.ac.uk/~gslabaugh/publications/euler.pdf""" # Test gimbal lock (ignore rotations that are all PAD_VAL). is_pad = torch.all(torch.all(self.R == PAD_VAL, dim=-1), dim=-1) assert (~torch.abs(self.R[~is_pad][..., 2, 0]).isclose(torch.tensor(1.0))).all() Y_angle = -torch.asin(self.R[..., 2, 0]) euler_angles = ( torch.atan2(self.R[..., 2, 1], self.R[..., 2, 2]), Y_angle, torch.atan2(self.R[..., 1, 0], self.R[..., 0, 0]), ) if not rad: # return degree return torch.stack(euler_angles, -1) * 180.0 / torch.pi return torch.stack(euler_angles, -1) def to_ypr(self, rad=True) -> torch.Tensor: # yaw, pitch, roll from rotation matrix: http://lavalle.pl/planning/node103.html R = self.R yaw = torch.atan(R[..., 1, 0] / R[..., 0, 0]) pitch = torch.atan( -R[..., 2, 0] / torch.sqrt(R[..., 2, 1] * R[..., 2, 1] + R[..., 2, 2] * R[..., 2, 2]) ) roll = torch.atan(R[..., 2, 1] / R[..., 2, 2]) return yaw, pitch, roll def inverse(self) -> "PoseTW": """Invert an SE(3) pose.""" R = self.R.transpose(-1, -2) t = -(R @ self.t.unsqueeze(-1)).squeeze(-1) return self.__class__.from_Rt(R, t) def compose(self, other: "PoseTW") -> "PoseTW": """Chain two SE(3) poses: T_C_B.compose(T_B_A) -> T_C_A.""" R = self.R @ other.R t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1) return self.__class__.from_Rt(R, t) @autocast def transform(self, p3d: torch.Tensor) -> torch.Tensor: """Transform a set of 3D points. Args: p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). """ assert p3d.shape[-1] == 3 # use more efficient right multiply that avoids transpose of the points # according to the equality: # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T # where p^T = p3d, R = self.R and t = self.t return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2) @autocast def batch_transform(self, p3d: torch.Tensor) -> torch.Tensor: """Transform a set of 3D points each by the associated (in batch dimensions) transform in this PoseTW. Args: p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). """ assert p3d.shape == self.t.shape, f"shapes of p3d {p3d.shape}, t {self.t.shape}" # bmm assumes one batch dimension assert p3d.dim() == 2, f"{p3d.shape}" assert self.ndim == 2, f"{self.shape}" # use more efficient right multiply that avoids transpose of the points # according to the equality: # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T # where p^T = p3d, R = self.R and t = self.t return ( torch.bmm(p3d.unsqueeze(-2), self.R.transpose(-1, -2)).squeeze(-2) + self.t ) @autocast def rotate(self, p3d: torch.Tensor) -> torch.Tensor: """Rotate a set of 3D points. Useful for directional vectors which should not be translated. Args: p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3). """ assert p3d.shape[-1] == 3 # use more efficient right multiply that avoids transpose of the points # according to the equality: # (Rp)^T = p^T R^T where p3d = p^T and self.R = R return p3d @ self.R.transpose(-1, -2) def __mul__(self, p3D: torch.Tensor) -> torch.Tensor: """Transform a set of 3D points: T_B_A * p3D_A -> p3D_B""" return self.transform(p3D) def __matmul__(self, other: "PoseTW") -> "PoseTW": """Chain two SE(3) poses: T_C_B @ T_B_A -> T_C_A.""" return self.compose(other) def numpy(self) -> Tuple[np.ndarray]: return self.R.numpy(), self.t.numpy() def magnitude(self, deg=True, eps=0) -> Tuple[torch.Tensor]: """Magnitude of the SE(3) transformation. The `eps` has to be positive if you want to use this function as part of a training loop. Returns: dr: rotation angle in degrees (if deg=True) or in radians. dt: translation distance in meters. """ trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1) cos = torch.clamp((trace - 1) / 2, min=-1.0 + eps, max=1.0 - eps) dr = torch.acos(cos) if deg: dr = dr * 180.0 / math.pi dt = torch.norm(self.t, dim=-1) return dr, dt def so3_geodesic(self, other: "PoseTW", deg=False) -> "PoseTW": """Compute the geodesic distance for rotation between this pose and another pose""" pose_e = self.compose(other.inverse()) dr, _ = pose_e.magnitude(deg=deg, eps=1e-6) return dr def log(self, eps: float = 1e-6) -> torch.Tensor: """ Compute the SE3 log map for these poses. Returns [...,6] where the last 3 entries are the so3 entires omega and the first 3 the entries for translation. """ # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf R, t = self.R, self.t trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1) cos = torch.clamp((trace - 1.0) * 0.5, -1, 1) theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1) ones = torch.ones_like(theta) small = theta < eps # compute factors and approximate them around 0 using second order # taylor expansion (from WolframAlpha) theta_over_sin_theta = torch.where( small, ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0, theta / torch.sin(theta), ) c = torch.where( small, 0.08333333 + 0.001388889 * theta**2 + 0.0000330688 * theta**4, (ones - ((0.5 * theta * torch.sin(theta)) / (ones - torch.cos(theta)))) / theta**2, ) # compute log-map W of rotation R first W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2)) # compute V_inv to be able to get u shape = [1] * len(R.shape[:-2]) Identity = ( torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(self._data) ) V_inv = Identity - 0.5 * W + c * W @ W u = (V_inv @ t.unsqueeze(-1)).squeeze(-1) omega = inv_skew_symmetric(W) return torch.cat([u, omega], -1) def interpolate(self, times: torch.Tensor, interp_times: torch.Tensor): """ Return poses at the given interpolation times interp_times based on the poses in this object and the provided associated timestamps times. If interpolation timestamps are outside the interval of times, the poses at the interval boundaries will be returned and the good boolean tensor will indicate those boundary values with a False. """ assert times.shape == self._data.shape[:-1], ( f"time stamps for the poses do not match poses shape {times.shape} vs {self._data.shape}" ) assert times.dim() <= 2, ( "The shape of the input times should be either BxT or T." ) times = times.to(self.device) interp_times = interp_times.to(self.device) # find the closest timestamps above and below for each interp_times in times lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas( times, interp_times ) # get the bounding poses upper_ids = upper_ids.unsqueeze(-1) upper_ids = upper_ids.expand(*upper_ids.shape[0:-1], self._data.shape[-1]) lower_ids = lower_ids.unsqueeze(-1) lower_ids = lower_ids.expand(*lower_ids.shape[0:-1], self._data.shape[-1]) T_upper = self.__class__(self._data.gather(times.dim() - 1, upper_ids)) T_lower = self.__class__(self._data.gather(times.dim() - 1, lower_ids)) # get se3 element connecting the lower and upper poses dT = T_lower.inverse() @ T_upper dx = dT.log() # interpolate on se3 dT = self.exp(dx * alpha.unsqueeze(-1)) return T_lower @ dT, good def align(self, other, self_times=None, other_times=None): """Align two trajectories using the method of Horn (closed-form). Input: other -- second PoseTW (Nx12) trajectory to align to Output: T_self_other -- relative SE3 transform (Nx12) trans_error -- translational error per point (Nx1) code inspired by: https://github.com/symao/vio_evaluation/blob/master/align.py#L6-L38 """ if self.t.ndim != 2: raise ValueError( "Only Nx12 Pose supported in alignment, given {self.shape}" ) if other.t.ndim != 2: raise ValueError( "Only Nx12 Pose supported in alignment, given {other.shape}" ) dtype = torch.promote_types(self.dtype, other.dtype) # Optionally interpolate other to match the size of self. if self.shape[0] != other.shape[0]: if self_times is None or other_times is None: raise ValueError( "Got different length PoseTW (self {self.shape} and other {other.shape}). Must provide timestamps to support interpolation" ) # Do interpolation on temporal intersection. other, goods = other.interpolate(other_times, self_times) self2 = self.clone()[goods].to(dtype) other2 = other.clone()[goods].to(dtype) else: self2 = self.clone().to(dtype) other2 = other.clone().to(dtype) P = self2.t.transpose(0, 1) Q = other2.t.transpose(0, 1) if P.shape != Q.shape: raise ValueError("Matrices P and Q must be of the same dimensionality") centroids_P = torch.mean(P, dim=1) centroids_Q = torch.mean(Q, dim=1) A = P - torch.outer(centroids_P, torch.ones(P.shape[1], dtype=dtype)) B = Q - torch.outer(centroids_Q, torch.ones(Q.shape[1], dtype=dtype)) C = A @ B.transpose(0, 1) U, S, V = torch.linalg.svd(C) R = V.transpose(0, 1) @ U.transpose(0, 1) L = torch.eye(3, dtype=dtype) if torch.linalg.det(R) < 0: L[2][2] *= -1 R = V.transpose(0, 1) @ (L @ U.transpose(0, 1)) t = (-R @ centroids_P) + centroids_Q T_self_other = PoseTW.from_Rt(R, t).inverse().to(dtype) other_aligned = T_self_other @ other2 error = torch.linalg.norm(other_aligned.t - self2.t, dim=-2) mean_error = error.mean(dim=-1) return T_self_other, mean_error def fit_to_SO3(self): # Math used from quora post and this berkeley pdf. # https://qr.ae/pKQaG5 # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf assert self._data.ndim == 1 Q = fit_to_SO3(self.R) return PoseTW.from_Rt(Q, self.t) def __repr__(self): return f"PoseTW: {self.shape} {self.dtype} {self.device}" def interpolate_timed_poses( timed_poses: Dict[ Union[float, int], Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]], ], time: Union[float, int], ): """ interpolate timed poses given as a dict[time:container[PoseTW]] to given time. The poses container indexed by time can be given plain as poses, or a list or dict of poses. If a list or dict of poses is given, the output will also be a list or dict of the interpolated poses. This allows batched interpolation. """ ts_list = list(timed_poses.keys()) ts = torch.from_numpy(np.array(ts_list)) interp_time = torch.from_numpy(np.array([time])) lower_ids, upper_ids, _, _ = interpolation_boundaries_alphas(ts, interp_time) t_lower = ts_list[lower_ids[0]] t_upper = ts_list[upper_ids[0]] poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper] poses_interp = None times = torch.from_numpy(np.array([t_lower, t_upper])).float() if isinstance(poses_lower, PoseTW): poses = PoseTW(smart_stack([poses_lower, poses_upper])) if poses.dim() == 3: times = times.unsqueeze(-1).repeat(1, poses.shape[1]) poses_interp = poses.interpolate(times, interp_time)[0].squeeze() elif isinstance(poses_lower, dict): keys_lower = set(poses_lower.keys()) keys_upper = set(poses_upper.keys()) keys = keys_lower & keys_upper poses_interp = {} for key in keys: poses = PoseTW(smart_stack([poses_lower[key], poses_upper[key]])) if poses.dim() == 3 and times.dim() == 1: times = times.unsqueeze(-1).repeat(1, poses.shape[1]) poses_interp[key] = poses.interpolate(times, interp_time)[0].squeeze() elif isinstance(poses_lower, list): assert len(poses_lower) == len(poses_upper) poses_interp = [] for i in range(len(poses_lower)): poses = PoseTW(smart_stack([poses_lower[i], poses_upper[i]])) if poses.dim() == 3 and times.dim() == 1: times = times.unsqueeze(-1).repeat(1, poses.shape[1]) poses_interp.append(poses.interpolate(times, interp_time)[0].squeeze()) return poses_interp def lower_timed_poses( timed_poses: Dict[ Union[float, int], Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]], ], time: Union[float, int], ): """ interpolate timed poses given as a dict[time:container[PoseTW]] to given time. The poses container indexed by time can be given plain as poses, or a list or dict of poses. If a list or dict of poses is given, the output will also be a list or dict of the interpolated poses. This allows batched interpolation. """ ts_list = list(timed_poses.keys()) ts = torch.from_numpy(np.array(ts_list)) interp_time = torch.from_numpy(np.array([time])) lower_ids, _, alpha, good = interpolation_boundaries_alphas(ts, interp_time) t_lower = ts_list[lower_ids[0]] poses_lower = timed_poses[t_lower] return poses_lower, t_lower - time def closest_timed_poses( timed_poses: Dict[ Union[float, int], Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]], ], time: Union[float, int], ): """ interpolate timed poses given as a dict[time:container[PoseTW]] to given time. The poses container indexed by time can be given plain as poses, or a list or dict of poses. If a list or dict of poses is given, the output will also be a list or dict of the interpolated poses. This allows batched interpolation. """ ts_list = list(timed_poses.keys()) ts = torch.from_numpy(np.array(ts_list)) interp_time = torch.from_numpy(np.array([time])) lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas(ts, interp_time) t_lower = ts_list[lower_ids[0]] t_upper = ts_list[upper_ids[0]] poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper] if time - t_lower < t_upper - time: return poses_lower, time - t_lower else: return poses_upper, t_upper - time def all_rot90(): # construct all possible 90 degree rotations dirs = torch.cat([torch.eye(3), -torch.eye(3)], dim=0) ids = torch.arange(0, 6).long() jds = torch.arange(0, 6).long() ids, jds = torch.meshgrid(ids, jds) ids, jds = ids.reshape(-1), jds.reshape(-1) a, b = dirs[ids, :], dirs[jds, :] c = torch.cross(a, b, -1) Rs = torch.cat([a.unsqueeze(2), b.unsqueeze(2), c.unsqueeze(2)], dim=2) # filter to valid rotations det = torch.linalg.det(Rs) Rs = Rs[det > 0.99] return Rs def find_r90(Ta, Tb, R90s): N = None if Tb.ndim == 2: N = Tb.shape[0] # 24xNx3x3 R90s = R90s.unsqueeze(1).repeat(1, N, 1, 1) Ra_inv, Rb = Ta.inverse().R.unsqueeze(0), Tb.R # 24x(Nx)3x3 dR = Ra_inv @ Rb.unsqueeze(0) @ R90s w = so3log_map(dR) ang = torch.linalg.norm(w, 2, dim=-1) ang_min, id_min = torch.min(ang, dim=0) if N is None: R90min = R90s[id_min] else: R90min = R90s[id_min, torch.arange(N)] Rb = Rb @ R90min Tb = PoseTW.from_Rt(Rb, Tb.t) return Tb, R90min def stereographic_unproject(a, axis=None): """ Inverse of stereographic projection: https://en.wikipedia.org/wiki/Stereographic_projection This is from the paper "On the Continuity of Rotation Representations in Neural Networks" https://arxiv.org/pdf/1812.07035.pdf, equation [8,9], used in rotation_from_ortho_5d. """ batch = a.shape[0] if axis is None: axis = a.shape[1] s2 = torch.pow(a, 2).sum(1) ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1] + 1).to(a)) unproj = 2 * a / (s2 + 1).reshape(batch, 1).repeat(1, a.shape[1]) if axis > 0: ans[:, :axis] = unproj[:, :axis] ans[:, axis] = (s2 - 1) / (s2 + 1) ans[:, axis + 1 :] = unproj[:, axis:] return ans def rotation_from_ortho_6d(ortho6d): """ Convert a 6-d rotation representation to rotation matrix From the paper "On the Continuity of Rotation Representations in Neural Networks" https://arxiv.org/pdf/1812.07035.pdf """ x_raw = ortho6d[..., 0:3] y_raw = ortho6d[..., 3:6] x = F.normalize(x_raw, dim=-1, eps=1e-6) y = F.normalize(y_raw, dim=-1, eps=1e-6) z = torch.cross(x, y, -1) z = F.normalize(z, dim=-1, eps=1e-6) y = torch.cross(z, x, -1) x = x.reshape(-1, 3, 1) y = y.reshape(-1, 3, 1) z = z.reshape(-1, 3, 1) matrix = torch.cat((x, y, z), 2) return matrix def rotation_from_ortho_5d(ortho5d): """ Convert a 5-d rotation representation to rotation matrix From the paper "On the Continuity of Rotation Representations in Neural Networks" https://arxiv.org/pdf/1812.07035.pdf """ batch = ortho5d.shape[0] proj_scale_np = np.array([np.sqrt(2) + 1, np.sqrt(2) + 1, np.sqrt(2)]) proj_scale = ( torch.autograd.Variable(torch.FloatTensor(proj_scale_np).to(ortho5d)) .reshape(1, 3) .repeat(batch, 1) ) u = stereographic_unproject(ortho5d[:, 2:5] * proj_scale, axis=0) norm = torch.sqrt(torch.pow(u[:, 1:], 2).sum(1)) u = u / norm.reshape(batch, 1).repeat(1, u.shape[1]) b = torch.cat((ortho5d[:, 0:2], u), 1) matrix = rotation_from_ortho_6d(b) return matrix def rotation_from_euler(euler): """ Convert a 3-d Euler angle representation to rotation matrix """ batch = euler.shape[0] c1 = torch.cos(euler[:, 0]).reshape(batch, 1) s1 = torch.sin(euler[:, 0]).reshape(batch, 1) c2 = torch.cos(euler[:, 2]).reshape(batch, 1) s2 = torch.sin(euler[:, 2]).reshape(batch, 1) c3 = torch.cos(euler[:, 1]).reshape(batch, 1) s3 = torch.sin(euler[:, 1]).reshape(batch, 1) row1 = torch.cat((c2 * c3, -s2, c2 * s3), 1).reshape(-1, 1, 3) row2 = torch.cat( (c1 * s2 * c3 + s1 * s3, c1 * c2, c1 * s2 * s3 - s1 * c3), 1 ).reshape(-1, 1, 3) row3 = torch.cat( (s1 * s2 * c3 - c1 * s3, s1 * c2, s1 * s2 * s3 + c1 * c3), 1 ).reshape(-1, 1, 3) matrix = torch.cat((row1, row2, row3), 1) return matrix def fit_to_SO3(R): # Math used from quora post and this berkeley pdf. # https://qr.ae/pKQaG5 # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf # # Input: # R - torch 3x3 rotation matrix that is not quite orthogonal # Output: # Q - torch 3x3 nearest valid rotation matrix assert R.ndim == 2 assert R.shape[0] == 3 and R.shape[1] == 3 B = R I = torch.eye(3) Y = B.transpose(-2, -1) @ B - I Q = B - B @ Y @ ( I / 2.0 - (3.0 * Y) / 8.0 + (5 * Y @ Y) / 16 - (35 * Y @ Y @ Y @ Y) / 128 ) return Q ================================================ FILE: efm3d/aria/projection_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def sign_plus(x): """ return +1 for positive and for 0.0 in x. This is important for our handling of z values that should never be 0.0 """ sgn = torch.ones_like(x) sgn[sgn < 0.0] = -1.0 return sgn @torch.jit.script def fisheye624_project(xyz, params): """ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera model project() function. Inputs: xyz: Bx(T)xNx3 tensor of 3D points to be projected params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this: [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] or Bx(T)x15 tensor of Fisheye624 parameters formatted like this: [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] Outputs: uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane Model for fisheye cameras with radial, tangential, and thin-prism distortion. This model allows fu != fv. Specifically, the model is: uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion [y_r] proj = diag(fu,fv) * uvDistorted + [cu;cv]; where: a = x/z, b = y/z, r = (a^2+b^2)^(1/2) th = atan(r) cosPhi = a/r, sinPhi = b/r [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi] [y_r] [sinPhi] the number of terms in the series is determined by the template parameter numK. tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1] [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] where rd^2 = x_r^2 + y_r^2 thinPrismDistortion = [s0 * rd^2 + s1 rd^4] [s2 * rd^2 + s3 rd^4] """ assert (xyz.ndim == 3 and params.ndim == 2) or ( xyz.ndim == 4 and params.ndim == 3 ), f"point dim {xyz.shape} does not match cam parameter dim {params}" assert xyz.shape[-1] == 3 assert params.shape[-1] == 16 or params.shape[-1] == 15, ( "This model allows fx != fy" ) assert xyz.dtype == params.dtype, "data type must match" eps = 1e-9 T = -1 if xyz.ndim == 4: # has T dim T, N = xyz.shape[1], xyz.shape[2] xyz = xyz.reshape(-1, N, 3) # (BxT)xNx3 params = params.reshape(-1, params.shape[-1]) # (BxT)x16 B, N = xyz.shape[0], xyz.shape[1] # Radial correction. z = xyz[:, :, 2].reshape(B, N, 1) # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a # nan when we compute xy/z z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z) ab = xyz[:, :, :2] / z # make sure abs are not too small or 0 otherwise gradients are nan ab = torch.where(torch.abs(ab) < eps, eps * sign_plus(ab), ab) r = torch.norm(ab, dim=-1, p=2, keepdim=True) th = torch.atan(r) th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) th_k = th.reshape(B, N, 1).clone() for i in range(6): th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2) xr_yr = th_k * th_divr uv_dist = xr_yr # Tangential correction. p0 = params[:, -6].reshape(B, 1) p1 = params[:, -5].reshape(B, 1) xr = xr_yr[:, :, 0].reshape(B, N) yr = xr_yr[:, :, 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) rd_sq = xr_sq + yr_sq uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) uv_dist = torch.stack( [uv_dist_tu, uv_dist_tv], dim=-1 ) # Avoids in-place complaint. # Thin Prism correction. s0 = params[:, -4].reshape(B, 1) s1 = params[:, -3].reshape(B, 1) s2 = params[:, -2].reshape(B, 1) s3 = params[:, -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Finally, apply standard terms: focal length and camera centers. if params.shape[-1] == 15: fx_fy = params[:, 0].reshape(B, 1, 1) cx_cy = params[:, 1:3].reshape(B, 1, 2) else: fx_fy = params[:, 0:2].reshape(B, 1, 2) cx_cy = params[:, 2:4].reshape(B, 1, 2) result = uv_dist * fx_fy + cx_cy if T > 0: result = result.reshape(B // T, T, N, 2) assert result.ndim == 4 or result.ndim == 3 assert result.shape[-1] == 2 return result @torch.jit.script def fisheye624_unproject(uv, params, max_iters: int = 5): """ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera model. There is no analytical solution for the inverse of the project() function so this solves an optimization problem using Newton's method to get the inverse. Inputs: uv: Bx(T)xNx2 tensor of 2D pixels to be projected params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this: [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] or Bx(T)x15 tensor of Fisheye624 parameters formatted like this: [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] Outputs: xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1. Model for fisheye cameras with radial, tangential, and thin-prism distortion. This model assumes fu=fv. This unproject function holds that: X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0] and x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2] """ assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx2" assert uv.shape[-1] == 2 assert params.ndim == 2 or params.ndim == 3, ( "Expected batched input shaped Bx(T)x16 or Bx(T)x15" ) assert params.shape[-1] == 16 or params.shape[-1] == 15, ( "This model allows fx != fy" ) assert uv.dtype == params.dtype, "data type must match" eps = 1e-6 T = -1 if uv.ndim == 4: # has T dim T, N = uv.shape[1], uv.shape[2] uv = uv.reshape(-1, N, 2) # (BxT)xNx2 params = params.reshape(-1, params.shape[-1]) # (BxT)x16 B, N = uv.shape[0], uv.shape[1] if params.shape[-1] == 15: fx_fy = params[:, 0].reshape(B, 1, 1) cx_cy = params[:, 1:3].reshape(B, 1, 2) else: fx_fy = params[:, 0:2].reshape(B, 1, 2) cx_cy = params[:, 2:4].reshape(B, 1, 2) uv_dist = (uv - cx_cy) / fx_fy # Compute xr_yr using Newton's method. xr_yr = uv_dist.clone() # Initial guess. for _ in range(max_iters): uv_dist_est = xr_yr.clone() # Tangential terms. p0 = params[:, -6].reshape(B, 1) p1 = params[:, -5].reshape(B, 1) xr = xr_yr[:, :, 0].reshape(B, N) yr = xr_yr[:, :, 1].reshape(B, N) xr_yr_sq = torch.square(xr_yr) xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) rd_sq = xr_sq + yr_sq uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ( (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 ) uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ( (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 ) # Thin Prism terms. s0 = params[:, -4].reshape(B, 1) s1 = params[:, -3].reshape(B, 1) s2 = params[:, -2].reshape(B, 1) s3 = params[:, -1].reshape(B, 1) rd_4 = torch.square(rd_sq) uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) # Compute the derivative of uv_dist w.r.t. xr_yr. duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) duv_dist_dxr_yr[:, :, 0, 0] = ( 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1 ) offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0) duv_dist_dxr_yr[:, :, 0, 1] = offdiag duv_dist_dxr_yr[:, :, 1, 0] = offdiag duv_dist_dxr_yr[:, :, 1, 1] = ( 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0 ) xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1] temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + ( xr_yr[:, :, 0] * temp1 ) duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + ( xr_yr[:, :, 1] * temp1 ) temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + ( xr_yr[:, :, 0] * temp2 ) duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + ( xr_yr[:, :, 1] * temp2 ) # Compute 2x2 inverse manually here since torch.inverse() is very slow. # Because this is slow: inv = duv_dist_dxr_yr.inverse() # About a 10x reduction in speed with above line. mat = duv_dist_dxr_yr.reshape(-1, 2, 2) a = mat[:, 0, 0].reshape(-1, 1, 1) b = mat[:, 0, 1].reshape(-1, 1, 1) c = mat[:, 1, 0].reshape(-1, 1, 1) d = mat[:, 1, 1].reshape(-1, 1, 1) det = 1.0 / ((a * d) - (b * c)) top = torch.cat([d, -b], dim=2) bot = torch.cat([-c, a], dim=2) inv = det * torch.cat([top, bot], dim=1) inv = inv.reshape(B, N, 2, 2) # Manually compute 2x2 @ 2x1 matrix multiply. # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0] diff = uv_dist - uv_dist_est a = inv[:, :, 0, 0] b = inv[:, :, 0, 1] c = inv[:, :, 1, 0] d = inv[:, :, 1, 1] e = diff[:, :, 0] f = diff[:, :, 1] step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) # Newton step. xr_yr = xr_yr + step # Compute theta using Newton's method. xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) th = xr_yr_norm.clone() for _ in range(max_iters): th_radial = uv.new_ones(B, N, 1) dthd_th = uv.new_ones(B, N, 1) for k in range(6): r_k = params[:, -12 + k].reshape(B, 1, 1) th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2)) dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2)) th_radial = th_radial * th step = (xr_yr_norm - th_radial) / dthd_th # handle dthd_th close to 0. step = torch.where(dthd_th.abs() > eps, step, sign_plus(step) * eps * 10.0) th = th + step # Compute the ray direction using theta and xr_yr. close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) assert ray.shape[-1] == 3 if T > 0: ray = ray.reshape(B // T, T, N, 3) return ray def pinhole_project(xyz, params): """ Batched implementation of the Pinhole (aka Linear) camera model project() function. Inputs: xyz: Bx(T)xNx3 tensor of 3D points to be projected params: Bx(T)x4 tensor of Pinhole parameters formatted like this: [f_u f_v c_u c_v] Outputs: uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane """ assert (xyz.ndim == 3 and params.ndim == 2) or (xyz.ndim == 4 and params.ndim == 3) assert params.shape[-1] == 4 eps = 1e-9 # Focal length and principal point fx_fy = params[..., 0:2].reshape(*xyz.shape[:-2], 1, 2) cx_cy = params[..., 2:4].reshape(*xyz.shape[:-2], 1, 2) # Make sure depth is not too close to zero. z = xyz[..., 2:] # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a # nan when we compute xy/z z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z) uv = (xyz[..., :2] / z) * fx_fy + cx_cy return uv def pinhole_unproject(uv, params, max_iters: int = 5): """ Batched implementation of the Pinhole (aka Linear) camera model. Inputs: uv: Bx(T)xNx2 tensor of 2D pixels to be projected params: Bx(T)x4 tensor of Pinhole parameters formatted like this: [f_u f_v c_u c_v] Outputs: xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1. """ assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx3" assert params.ndim == 2 or params.ndim == 3 assert params.shape[-1] == 4 assert uv.shape[-1] == 2 # Focal length and principal point fx_fy = params[..., 0:2].reshape(*uv.shape[:-2], 1, 2) cx_cy = params[..., 2:4].reshape(*uv.shape[:-2], 1, 2) uv_dist = (uv - cx_cy) / fx_fy ray = torch.cat([uv_dist, uv.new_ones(*uv.shape[:-1], 1)], dim=-1) return ray ================================================ FILE: efm3d/aria/tensor_wrapper.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import inspect import logging from typing import List import numpy as np import torch from torch.utils.data._utils.collate import ( collate, collate_tensor_fn, default_collate_fn_map, ) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) def smart_cat(inp_arr, dim=-1): devices = set() for i, inp in enumerate(inp_arr): if isinstance(inp, TensorWrapper): inp_arr[i] = inp._data else: inp_arr[i] = inp devices.add(inp_arr[i].device) if len(devices) > 1: raise RuntimeError(f"More than one device found! {devices}") return torch.cat(inp_arr, dim=dim) def smart_stack(inp_arr, dim: int = 0): devices = set() for i, inp in enumerate(inp_arr): if isinstance(inp, TensorWrapper): inp_arr[i] = inp._data else: inp_arr[i] = inp devices.add(inp_arr[i].device) if len(devices) > 1: raise RuntimeError(f"More than one device found! {devices}") return torch.stack(inp_arr, dim=dim) def get_default_args(func): signature = inspect.signature(func) return { k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty } def get_nonempty_arg_names(func): spec = inspect.getfullargspec(func) signature = inspect.signature(func) return [ k for k in spec.args if signature.parameters[k].default is not inspect.Parameter.empty ] def autocast(func): """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays. Use the device and dtype of the wrapper. """ @functools.wraps(func) def wrap(self, *args): device = torch.device("cpu") dtype = None if isinstance(self, TensorWrapper): if self._data is not None: device = self.device dtype = self.dtype elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): raise ValueError(self) cast_args = [] for arg in args: if isinstance(arg, np.ndarray): arg = torch.from_numpy(arg) arg = arg.to(device=device, dtype=dtype) cast_args.append(arg) return func(self, *cast_args) return wrap def autoinit(func): """ Helps with initialization. Will auto-reshape and auto-expand input arguments to match the first argument, as well as check shapes based on default tensor sizes. """ @functools.wraps(func) def wrap(self, *args, **kwargs): # Combine args and kwargs. arg_names = get_nonempty_arg_names(func) all_args = {} for i, arg in enumerate(args): all_args[arg_names[i]] = arg for arg_name in kwargs: all_args[arg_name] = kwargs[arg_name] # Add default values to all_args if unspecified inputs. default_args = get_default_args(func) extra_args = {} for arg_name in default_args: default_arg = default_args[arg_name] if not isinstance(default_arg, (TensorWrapper, torch.Tensor)): # If not TW or torch tensor, pass it through unperturbed. extra_args[arg_name] = all_args.pop(arg_name) else: if arg_name not in all_args or all_args[arg_name] is None: all_args[arg_name] = default_arg # Auto convert numpy,lists,floats to torch, check that shapes are good. for arg_name in all_args: arg = all_args[arg_name] if isinstance(arg, (torch.Tensor, TensorWrapper)): pass elif isinstance(arg, (int, float)): arg = torch.tensor(arg).reshape(1) elif isinstance(arg, List): arg = torch.tensor(arg) elif isinstance(arg, np.ndarray): arg = torch.from_numpy(arg) else: raise ValueError("Unsupported initialization type") assert isinstance(arg, (torch.Tensor, TensorWrapper)) default_arg = default_args[arg_name] if isinstance(default_arg, TensorWrapper): # Convert list of torch.Size to tuple of ints. default_dims = tuple([da[0] for da in default_arg.shape]) else: default_dims = (default_arg.shape[-1],) if arg.shape[-1] not in default_dims: # probably need a more general solution here to handle single dim inputs. if default_dims[0] == 1: arg = arg.unsqueeze(-1) if arg.shape[-1] not in default_dims: raise ValueError( "Bad shape of %d for %s, should be in %s" % (arg.shape[-1], arg_name, default_dims) ) all_args[arg_name] = arg # Shape of all inputs is determined by first arg. first_arg_name = arg_names[0] batch_shape = all_args[first_arg_name].shape[:-1] has_cuda_tensor = False for arg_name in all_args: arg = all_args[arg_name] # Try to trim any extra dimensions at the beginning of arg shape. while True: if arg.ndim > len(batch_shape) and arg.shape[0] == 1 and arg.ndim > 1: arg = arg.squeeze(0) else: break arg = arg.expand(*batch_shape, arg.shape[-1]) all_args[arg_name] = arg if ( isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper)) and all_args[arg_name].is_cuda ): has_cuda_tensor = True if has_cuda_tensor: for arg_name in all_args: if ( isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper)) and not all_args[arg_name].is_cuda ): all_args[arg_name] = all_args[arg_name].cuda() # Add the unperturbed args back to all args. all_args.update(extra_args) return func(self, **all_args) return wrap def tensor_wrapper_collate(batch, *, collate_fn_map=None): """Simply call stack for TensorWrapper""" return torch.stack(batch, 0) def float_collate(batch, *, collate_fn_map=None): """Auto convert float to float32""" return torch.tensor(batch, dtype=torch.float32) def list_dict_collate(batch, *, collate_fn_map=None): """collate lists; handles the case where the lists in the batch are expressing a dict via List[Tuple[key, value]] and returns a Dict[key, value] in that case.""" if len(batch) > 0: list_0 = batch[0] if len(list_0) > 0: elem_0 = list_0[0] if isinstance(elem_0, tuple) and len(elem_0) == 2: # the lists in each batch sample are (key, value) pairs and we hence return a dictionary for i in range(len(batch)): batch[i] = {k: v for k, v in batch[i]} return batch def tensor_wrapper_collate_cat(batch, *, collate_fn_map=None): """Simply call cat for TensorWrapper""" return torch.cat(batch, 0) def tensor_collate_cat(batch, *, collate_fn_map=None): """identical to "collate_tensor_fn" but replace torch.stack with torch.cat""" elem = batch[0] out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) # Note: pytorch 1.12 doesn't have the _typed_storage() interface. Need to use storage() instead. # storage = elem._typed_storage()._new_shared(numel, device=elem.device) storage = elem.storage()._new_shared(numel, device=elem.device) # since we are using torch.cat, we don't need to add a new dimension here dims_from_one = list(elem.size())[1:] out = elem.new(storage).resize_(len(batch), *dims_from_one) return torch.cat(batch, 0, out=out) # concatenate instead of stack def custom_collate_fn(batch): # Get the common keys between samples. This is required when we train with # multiple datasets with samples having different keys. if isinstance(batch, list) and isinstance(batch[0], dict): common_keys = set(batch[0].keys()) for sample in batch[1:]: common_keys &= set(sample.keys()) # update the batch with new samples with only the common keys new_batch = [] for sample in batch: new_sample = {k: v for k, v in sample.items() if k in common_keys} new_batch.append(new_sample) batch = new_batch """Custom collate function for tensor wrapper""" default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate default_collate_fn_map[float] = float_collate default_collate_fn_map[list] = list_dict_collate default_collate_fn_map[torch.Tensor] = collate_tensor_fn if "already_collated" in batch[0]: # Use torch.cat instead of torch.stack default_collate_fn_map[torch.Tensor] = tensor_collate_cat default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate_cat batch = collate(batch, collate_fn_map=default_collate_fn_map) return batch class TensorWrapper: """Base class for making "smart" tensor objects that behave like pytorch tensors Inpired by Paul-Edouard Sarlin's code here in pixloc: https://github.com/cvg/pixloc/blob/master/pixloc/pixlib/geometry/wrappers.py Adopted and modified by Daniel DeTone. """ _data = None @autocast def __init__(self, data: torch.Tensor): self._data = data @property def shape(self): return self._data.shape @property def device(self): return self._data.device @property def dtype(self): return self._data.dtype @property def ndim(self): return self._data.ndim def dim(self): return self._data.dim() def nelement(self): return self._data.nelement() def numel(self): return self._data.numel() @property def collate_fn(self): return custom_collate_fn @property def is_cuda(self): return self._data.is_cuda @property def is_contiguous(self): return self._data.is_contiguous @property def requires_grad(self): return self._data.requires_grad @property def grad(self): return self._data.grad @property def grad_fn(self): return self._data.grad_fn def requires_grad_(self, requires_grad: bool = True): self._data.requires_grad_(requires_grad) def __getitem__(self, index): return self.__class__(self._data[index]) def __setitem__(self, index, item): self._data[index] = item._data def to(self, *args, **kwargs): return self.__class__(self._data.to(*args, **kwargs)) def reshape(self, *args, **kwargs): return self.__class__(self._data.reshape(*args, **kwargs)) def repeat(self, *args, **kwargs): return self.__class__(self._data.repeat(*args, **kwargs)) def expand(self, *args, **kwargs): return self.__class__(self._data.expand(*args, **kwargs)) def clone(self): return self.__class__(self._data.clone()) def cpu(self): return self.__class__(self._data.cpu()) def cuda(self, gpu_id=0): return self.__class__(self._data.cuda(gpu_id)) def contiguous(self): return self.__class__(self._data.contiguous()) def pin_memory(self): return self.__class__(self._data.pin_memory()) def float(self): return self.__class__(self._data.float()) def double(self): return self.__class__(self._data.double()) def detach(self): return self.__class__(self._data.detach()) def numpy(self): return self._data.numpy() def tensor(self): return self._data def tolist(self): return self._data.tolist() def squeeze(self, dim=None): assert dim != -1 and dim != self._data.dim() - 1 if dim is None: return self.__class__(self._data.squeeze()) return self.__class__(self._data.squeeze(dim=dim)) def unsqueeze(self, dim=None): assert dim != -1 and dim != self._data.dim() return self.__class__(self._data.unsqueeze(dim=dim)) def view(self, *shape): assert shape[-1] == -1 or shape[-1] == self._data.shape[-1] return self.__class__(self._data.view(*shape)) def __len__(self): return self._data.shape[0] @classmethod def stack(cls, objects: List, dim=0, *, out=None): data = torch.stack([obj._data for obj in objects], dim=dim, out=out) return cls(data) @classmethod def cat(cls, objects: List, dim=0, *, out=None): data = torch.cat([obj._data for obj in objects], dim=dim, out=out) return cls(data) @classmethod def allclose( cls, input: torch.Tensor, other: torch.Tensor, rtol=1e-5, atol=1e-8, equal_nan=False, ): return torch.allclose( input._data, other._data, rtol=rtol, atol=atol, equal_nan=equal_nan ) @classmethod def take_along_dim(cls, obj, indices, dim, *, out=None): data = torch.take_along_dim(obj._data, indices, dim, out=out) return cls(data) @classmethod def flatten(cls, obj, start_dim=0, end_dim=-1): data = torch.flatten(obj._data, start_dim=start_dim, end_dim=end_dim) return cls(data) @classmethod def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func is torch.stack: return self.stack(*args, **kwargs) elif func is torch.cat: return self.cat(*args, **kwargs) elif func is torch.allclose: return self.allclose(*args, **kwargs) elif func is torch.take_along_dim: return self.take_along_dim(*args, **kwargs) elif func is torch.flatten: return self.flatten(*args, **kwargs) else: return NotImplemented ================================================ FILE: efm3d/config/efm_preprocessing_conf.yaml ================================================ atek_config_name: "efm" camera_temporal_subsampler: main_camera_label: "camera-rgb" time_domain: "DEVICE_TIME" main_camera_target_freq_hz: 10.0 sample_length_in_num_frames: 20 stride_length_in_num_frames: 10 processors: rgb: selected: true sensor_label: "camera-rgb" time_domain: "DEVICE_TIME" tolerance_ns: 10_000_000 undistort_to_linear_cam: false # if set, undistort to a linear camera model target_camera_resolution: [240, 240] # if set, rescale to [image_width, image_height] rescale_antialias: false # to be consistent with cv2 rotate_image_cw90deg: false # if set, rotate image by 90 degrees clockwise slam_left: selected: true sensor_label: "camera-slam-left" tolerance_ns: 10_000_000 time_domain: "DEVICE_TIME" target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height] rescale_antialias: false # to be consistent with cv2 slam_right: selected: true sensor_label: "camera-slam-right" tolerance_ns: 10_000_000 time_domain: "DEVICE_TIME" target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height] rescale_antialias: false # to be consistent with cv2 mps_traj: selected: true tolerance_ns: 10_000_000 mps_semidense: selected: true tolerance_ns: 10_000_000 rgb_depth: selected: true sensor_stream_id: "345-1" # 345-1 for ADT data, 214-8 for ASE data tolerance_ns: 10_000_000 time_domain: "DEVICE_TIME" convert_zdepth_to_dist: false efm_gt: selected: true tolerance_ns : 10_000_000 category_mapping_field_name: category # {prototype_name, category} wds_writer: prefix_string: "" max_samples_per_shard: 8 ================================================ FILE: efm3d/config/evl_inf.yaml ================================================ _target_: efm3d.model.evl.EVL neck_hidden_dims: [128, 256, 512] head_hidden_dim: 256 head_layers: 2 taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv video_backbone: _target_: efm3d.model.video_backbone.VideoBackboneDinov2 freeze_encoder: true image_tokenizer: _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens dinov2_name: vit_base_v25 freeze: true handle_rotated_data: true dim_out: 768 add_lin_layer: false multilayer_output: true ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth video_streams: [rgb] correct_vignette: false optimize_vignette: false video_backbone3d: _target_: efm3d.model.lifter.Lifter in_dim: 768 out_dim: 64 patch_size: 16 voxel_size: [96,96,96] voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0] head_type: dpt_ori streams: [rgb] joint_slam_streams: false joint_streams: false ================================================ FILE: efm3d/config/evl_inf_desktop.yaml ================================================ _target_: efm3d.model.evl.EVL neck_hidden_dims: [32, 64, 128] head_hidden_dim: 256 head_layers: 2 taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv video_backbone: _target_: efm3d.model.video_backbone.VideoBackboneDinov2 freeze_encoder: true image_tokenizer: _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens dinov2_name: vit_base_v25 freeze: true handle_rotated_data: true dim_out: 768 add_lin_layer: false multilayer_output: true ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth video_streams: [rgb] correct_vignette: false optimize_vignette: false video_backbone3d: _target_: efm3d.model.lifter.Lifter in_dim: 768 out_dim: 32 patch_size: 16 voxel_size: [48,48,48] voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0] head_type: dpt_ori streams: [rgb] joint_slam_streams: false joint_streams: false ================================================ FILE: efm3d/config/evl_train.yaml ================================================ _target_: efm3d.model.evl_train.EVLTrain neck_hidden_dims: [128, 256, 512] head_hidden_dim: 256 head_layers: 2 taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv video_backbone: _target_: efm3d.model.video_backbone.VideoBackboneDinov2 freeze_encoder: true image_tokenizer: _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens dinov2_name: vit_base_v25 freeze: true handle_rotated_data: true dim_out: 768 add_lin_layer: false multilayer_output: true ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth video_streams: [rgb] correct_vignette: false optimize_vignette: false video_backbone3d: _target_: efm3d.model.lifter.Lifter in_dim: 768 out_dim: 64 patch_size: 16 voxel_size: [96,96,96] voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0] head_type: dpt_ori streams: [rgb] joint_slam_streams: false joint_streams: false ================================================ FILE: efm3d/config/taxonomy/aeo_to_efm.csv ================================================ AEO Category Name,EFM Category Name,EFM Category Id Chair,chair,3 Couch,sofa,1 Table,table,0 Bed,bed,4 WallArt,picture_frame,21 Plant,flower_pot,13 Window,window,28 Mirror,mirror,22 Lamp,lamp,16 ================================================ FILE: efm3d/config/taxonomy/ase_sem_name_to_id.csv ================================================ sem_name,sem_id table,0 sofa,1 shelf,2 chair,3 bed,4 floor_mat,5 exercise_weight,6 cutlery,7 container,8 clock,9 cart,10 vase,11 tent,12 flower_pot,13 pillow,14 mount,15 lamp,16 ladder,17 fan,18 cabinet,19 jar,20 picture_frame,21 mirror,22 electronic_device,23 dresser,24 clothes_rack,25 battery_charger,26 air_conditioner,27 window,28 ================================================ FILE: efm3d/config/taxonomy/atek_to_efm.csv ================================================ ATEK Category Name,EFM version ASE Category Name,EFM version ASE Category Id table,table,0 sofa,sofa,1 shelves,shelf,2 chair,chair,3 bed,bed,4 floor mat,floor_mat,5 exercise_weight,exercise_weight,6 cutlery,cutlery,7 container,container,8 clock,clock,9 cart,cart,10 vase,vase,11 tent,tent,12 plant,flower_pot,13 pillow,pillow,14 mount,mount,15 lamp,lamp,16 ladder,ladder,17 fan,fan,18 cabinet,cabinet,19 jar,jar,20 picture,picture_frame,21 mirror,mirror,22 electronic_device,electronic_device,23 dresser,dresser,24 clothes_rack,clothes_rack,25 battery_charger,battery_charger,26 air_conditioner,air_conditioner,27 window,window,28 ================================================ FILE: efm3d/dataset/atek_vrs_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pyre-strict import logging import os from typing import Dict, List, Optional from atek.data_loaders.atek_wds_dataloader import select_and_remap_dict_keys from atek.data_preprocess.atek_data_sample import AtekDataSample from atek.data_preprocess.sample_builders.atek_data_paths_provider import ( AtekDataPathsProvider, ) from atek.data_preprocess.sample_builders.efm_sample_builder import EfmSampleBuilder from atek.data_preprocess.subsampling_lib.temporal_subsampler import ( CameraTemporalSubsampler, ) from efm3d.dataset.efm_model_adaptor import EfmModelAdaptor from omegaconf.omegaconf import DictConfig, OmegaConf logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class AtekRawDataloaderAsEfm: def __init__( self, vrs_file: str, mps_files: Dict[str, str], gt_files: Dict[str, str], conf: DictConfig, freq_hz: int, snippet_length_s: float, semidense_points_pad_to_num: int = 50000, max_snippets=9999, ) -> None: self.max_snippets = max_snippets # initialize the sample builder self.sample_builder = EfmSampleBuilder( conf=conf.processors, vrs_file=vrs_file, mps_files=mps_files, gt_files=gt_files, depth_vrs_file="", sequence_name=os.path.basename(vrs_file), ) self.subsampler = CameraTemporalSubsampler( vrs_file=vrs_file, conf=conf.camera_temporal_subsampler, ) # Create a EFM model adaptor self.model_adaptor = EfmModelAdaptor( freq=freq_hz, snippet_length_s=snippet_length_s, semidense_points_pad_to_num=semidense_points_pad_to_num, atek_to_efm_taxonomy_mapping_file=f"{os.path.dirname(__file__)}/../config/taxonomy/atek_to_efm.csv", ) def __len__(self): return min(self.subsampler.get_total_num_samples(), self.max_snippets) def get_timestamps_by_sample_index(self, index: int) -> List[int]: return self.subsampler.get_timestamps_by_sample_index(index) def get_atek_sample_at_timestamps_ns( self, timestamps_ns: List[int] ) -> Optional[AtekDataSample]: return self.sample_builder.get_sample_by_timestamps_ns(timestamps_ns) def get_model_specific_sample_at_timestamps_ns( self, timestamps_ns: List[int] ) -> Optional[Dict]: atek_sample = self.get_atek_sample_at_timestamps_ns(timestamps_ns) if atek_sample is None: logger.warning( f"Cannot retrieve valid atek sample at timestamp {timestamps_ns}" ) return None # Flatten to dict atek_sample_dict = atek_sample.to_flatten_dict() # key remapping remapped_data_dict = select_and_remap_dict_keys( sample_dict=atek_sample_dict, key_mapping=self.model_adaptor.get_dict_key_mapping_all(), ) # transform model_specific_sample_gen = self.model_adaptor.atek_to_efm([remapped_data_dict]) # Obtain a dict from a generator object model_specific_sample = next(model_specific_sample_gen) return model_specific_sample def __getitem__(self, index): if index >= self.max_snippets: raise StopIteration timestamps = self.get_timestamps_by_sample_index(index) maybe_sample = self.get_model_specific_sample_at_timestamps_ns(timestamps) return maybe_sample def create_atek_raw_data_loader_from_vrs_path( vrs_path: str, freq_hz: int, snippet_length_s, stride_length_s, skip_begin_seconds: float = 0.0, skip_end_seconds: float = 0.0, semidense_points_pad_to_num=50000, max_snippets=9999, ): vrs_dir = os.path.dirname(vrs_path) data_path_provider = AtekDataPathsProvider(data_root_path=vrs_dir) atek_data_paths = data_path_provider.get_data_paths() conf = OmegaConf.load("efm3d/config/efm_preprocessing_conf.yaml") # Update snippet / stride length conf.camera_temporal_subsampler.main_camera_target_freq_hz = float(freq_hz) conf.camera_temporal_subsampler.sample_length_in_num_frames = int( freq_hz * snippet_length_s ) conf.camera_temporal_subsampler.stride_length_in_num_frames = int( freq_hz * stride_length_s ) conf.camera_temporal_subsampler.update( { "skip_begin_seconds": skip_begin_seconds, "skip_end_seconds": skip_end_seconds, } ) data_loader = AtekRawDataloaderAsEfm( vrs_file=atek_data_paths["video_vrs_file"], mps_files={ "mps_closedloop_traj_file": atek_data_paths["mps_closedloop_traj_file"], "mps_semidense_points_file": atek_data_paths["mps_semidense_points_file"], "mps_semidense_observations_file": atek_data_paths[ "mps_semidense_observations_file" ], }, gt_files={ "obb3_file": atek_data_paths["gt_obb3_file"], "obb3_traj_file": atek_data_paths["gt_obb3_traj_file"], "obb2_file": atek_data_paths["gt_obb2_file"], "instance_json_file": atek_data_paths["gt_instance_json_file"], }, conf=conf, freq_hz=freq_hz, snippet_length_s=snippet_length_s, semidense_points_pad_to_num=semidense_points_pad_to_num, max_snippets=max_snippets, ) return data_loader ================================================ FILE: efm3d/dataset/atek_wds_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import tarfile import numpy as np import torch import webdataset as wds from efm3d.aria import CameraTW, DEFAULT_CAM_DATA_SIZE, ObbTW, PoseTW, TensorWrapper from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_OBB_PADDED, ARIA_POINTS_VOL_MAX, ARIA_POINTS_VOL_MIN, ARIA_POINTS_WORLD, ARIA_POSE_T_SNIPPET_RIG, ARIA_POSE_T_WORLD_RIG, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.dataset.efm_model_adaptor import load_atek_wds_dataset_as_efm def batchify(datum, device=None): # Add batch dimension for key in datum: if isinstance(datum[key], (torch.Tensor, TensorWrapper)): datum[key] = datum[key][None, ...].to(device) if device is not None: datum[key] = datum[key].to(device) else: datum[key] = [datum[key]] return datum def unbatchify(datum): # Remove batch dimension for key in datum: if isinstance(datum[key], (torch.Tensor, TensorWrapper, list)): datum[key] = datum[key][0] return datum class AtekWdsStreamDataset: """Sample 2s/1s WDS dataset to specified snippet length and stride""" def __init__( self, data_path, atek_to_efm_taxonomy, snippet_length_s=1.0, stride_length_s=0.1, wds_length_s=2.0, fps=10, max_snip=99999999, ): self.snippet_length_s = snippet_length_s self.stride_length_s = stride_length_s self.wds_length_s = wds_length_s # wds snippets should always be generated half overlapped self.wds_stride_s = wds_length_s // 2 self.fps = fps self.max_snip = max_snip tar_list = sorted(glob.glob(f"{data_path}/*.tar")) sn = set() with tarfile.TarFile(tar_list[0], "r") as tar: for member in tar.getmembers(): sn.add(member.name.split(".")[0]) self.samples_per_tar = len(sn) self.num_tars = len(tar_list) self.dataset = load_atek_wds_dataset_as_efm( urls=tar_list, freq=fps, snippet_length_s=wds_length_s, # Need to use `wds_length` for model adaptor! atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy, ) self.dataloader = iter(self.dataset) self.frames_wds = int(self.fps * self.wds_length_s) self.frames_out = int(self.fps * self.snippet_length_s) self.frames_stride_wds = int(self.fps * self.wds_stride_s) self.frames_stride_out = int(self.fps * self.stride_length_s) self.num_rest = int( (self.wds_length_s - self.snippet_length_s) / self.stride_length_s ) self.num_first = int(1 + self.num_rest) self.num_snippets = ( self.num_first + (self.samples_per_tar * self.num_tars - 1) * self.num_rest ) # for iteration self.first = True self.wds_snippet = None self.snip_idx = 0 self.global_idx = 0 def __len__(self): return min(self.num_snippets, self.max_snip) def sample_snippet_(self, snippet, start, end): # time crop sample = snippet.copy() for k in sample: if isinstance(sample[k], (torch.Tensor, TensorWrapper)): if k not in [ ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_POINTS_VOL_MIN, ARIA_POINTS_VOL_MAX, ]: sample[k] = sample[k][start:end, ...] return sample def __iter__(self): return self def if_get_next_(self): if self.wds_snippet is None: return True if self.first: return self.snip_idx >= self.num_first else: return self.snip_idx >= self.num_rest def __next__(self): if self.global_idx >= self.max_snip: raise StopIteration if self.if_get_next_(): if self.first and self.wds_snippet is not None: self.first = False self.wds_snippet = next(self.dataloader) self.snip_idx = 0 if self.first: start = self.snip_idx * self.frames_stride_out else: start = (self.snip_idx + 1) * self.frames_stride_out end = start + self.frames_out sample = self.sample_snippet_(self.wds_snippet, start, end) self.snip_idx += 1 self.global_idx += 1 return sample ================================================ FILE: efm3d/dataset/augmentation.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from functools import partial from typing import Dict, List, Optional, Tuple, Union import torch import torchvision from efm3d.aria.aria_constants import ( ARIA_IMG, ARIA_POINTS_DIST_STD, ARIA_POINTS_INV_DIST_STD, ARIA_POINTS_WORLD, ) from torchvision.transforms.v2._color import RandomAdjustSharpness from webdataset import WebDataset logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class ColorJitter: """ Applies photometric jitter to the images in the video sequence. """ def __init__( self, brightness: Union[Tuple[float], float] = 0.5, contrast: Union[Tuple[float], float] = 0.3, saturation: Union[Tuple[float], float] = 0.3, hue: Union[Tuple[float], float] = 0.05, sharpness: Union[Tuple[float], float] = 2.0, snippet_jitter: bool = False, ): """ Calls torchvision on the images independently in a video using: https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html brightness: how much to jitter brightness in range [0,val] contrast: how much to jitter contrast in range [0,val] saturation: how much to jitter contrast in range [0,val] hue: how much to jitter hue in range [-val,val] snippet_jitter: if true, jitter equally across the snippet """ self.transform = torchvision.transforms.ColorJitter( brightness=brightness, contrast=contrast, saturation=saturation, hue=hue ) self.snippet_jitter = snippet_jitter self.sharpness = sharpness def rnd_sharpen(self, im): factor = float(self.sharpness * torch.rand(1)) sharp_fn = RandomAdjustSharpness(sharpness_factor=factor, p=1.0) return sharp_fn(im) def apply(self, im): im = self.transform(im) im = self.rnd_sharpen(im) return im def __call__(self, batch: Dict): for name in ARIA_IMG: if name in batch: batch[name] = batch[name].clone().detach() if self.snippet_jitter: batch[name] = self.apply(batch[name]) else: for t in range(len(batch[name])): batch[name][t] = self.apply(batch[name][t]) return batch class PointDrop: """ Applies point drop augmentation based on the standard deviations of the points. A standard deviation is sampled within the provided range, and points exceeding the sampled standard deviation are dropped. Attributes: dropout_all_rate (float): The rate at which all points are dropped. inv_dist_std (List[float]): Range [min, max] of inverse distance standard deviations. dist_std (List[float]): Range [min, max] of distance standard deviations. """ def __init__( self, dropout_all_rate: float = 0.2, inv_dist_std: Optional[List[float]] = None, dist_std: Optional[List[float]] = None, ): if inv_dist_std is None: inv_dist_std = [0.001, 0.03] if dist_std is None: dist_std = [0.01, 0.3] self.dropout_all_rate = dropout_all_rate self.inv_dist_std = inv_dist_std self.dist_std = dist_std assert inv_dist_std[1] >= inv_dist_std[0] assert dist_std[1] >= dist_std[0] def __call__(self, batch: Dict): if ARIA_POINTS_WORLD not in batch: return batch p_drop_all = torch.rand(1).item() if p_drop_all < self.dropout_all_rate: # drop all points batch[ARIA_POINTS_WORLD][:, :, :] = torch.nan else: # drop based on stds. p_w = batch[ARIA_POINTS_WORLD] T, N = p_w.shape[:2] # sample inv_dist_std rand_inv_dist_thres = torch.rand(1).item() rand_inv_dist_thres = ( rand_inv_dist_thres * (self.inv_dist_std[1] - self.inv_dist_std[0]) + self.inv_dist_std[0] ) # sample dist_std rand_dist_thres = torch.rand(1).item() rand_dist_thres = ( rand_dist_thres * (self.dist_std[1] - self.dist_std[0]) + self.dist_std[0] ) dropped = torch.zeros(T, N, dtype=torch.bool) if ARIA_POINTS_INV_DIST_STD in batch: drop_inv_dist_std = ( batch[ARIA_POINTS_INV_DIST_STD] > rand_inv_dist_thres ) dropped |= drop_inv_dist_std logger.debug(f"drop points with max inv_dist_std {rand_inv_dist_thres}") logger.debug(f"drop {dropped.sum()} points.") if ARIA_POINTS_DIST_STD in batch: drop_dist_std = batch[ARIA_POINTS_DIST_STD] > rand_dist_thres dropped |= drop_dist_std logger.debug(f"drop points with max dist_std {rand_dist_thres}") logger.debug(f"drop {dropped.sum()} points.") p_w[dropped, :] = torch.nan batch[ARIA_POINTS_WORLD] = p_w return batch class PointDropSimple: """ simple point drop augmentation. """ def __init__( self, max_dropout_rate: float = 0.8, ): self.max_dropout_rate = max_dropout_rate assert self.max_dropout_rate < 1.0 and self.max_dropout_rate > 0.0 def __call__(self, batch: Dict): if ARIA_POINTS_WORLD not in batch: return batch dropout_rate = torch.rand(1).item() if dropout_rate > self.max_dropout_rate: return batch else: p_w = batch[ARIA_POINTS_WORLD] # B, T, 3 T, N = p_w.shape[:2] mask = torch.rand((T, N)) < dropout_rate p_w[mask, :] = torch.nan batch[ARIA_POINTS_WORLD] = p_w return batch class PointJitter: """ Applies point jitter augmentation. """ def __init__( self, depth_std_scale_min: float = 1.0, depth_std_scale_max: float = 3.0, ): """ Args: depth_std_scale_min: min scale factor for depth jitter based on depth_std depth_std_scale_max: max scale factor for depth jitter based on depth_std """ self.depth_std_scale_max = depth_std_scale_max self.depth_std_scale_min = depth_std_scale_min def __call__(self, batch: Dict): if ARIA_POINTS_WORLD in batch and ARIA_POINTS_DIST_STD in batch: p_w = batch[ARIA_POINTS_WORLD] scale = ( torch.rand(1).item() * (self.depth_std_scale_max - self.depth_std_scale_min) + self.depth_std_scale_min ) std = batch[ARIA_POINTS_DIST_STD] * scale noise = torch.randn_like(p_w) * std.unsqueeze(-1) batch[ARIA_POINTS_WORLD] = p_w + noise return batch ================================================ FILE: efm3d/dataset/efm_model_adaptor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import csv import logging from functools import partial from typing import Callable, Dict, List, Optional import torch import webdataset as wds from atek.data_loaders.atek_wds_dataloader import ( load_atek_wds_dataset, process_wds_sample, select_and_remap_dict_keys, ) from atek.util.tensor_utils import fill_or_trim_tensor from efm3d.aria import CameraTW, ObbTW, PoseTW, TensorWrapper from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_CALIB_SNIPPET_TIME_S, ARIA_OBB_PADDED, ARIA_POINTS_VOL_MAX, ARIA_POINTS_VOL_MIN, ARIA_POSE_SNIPPET_TIME_S, ARIA_POSE_T_SNIPPET_RIG, ARIA_SNIPPET_LENGTH_S, ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_SNIPPET_TIME_NS, ) from efm3d.aria.obb import transform_obbs from efm3d.aria.tensor_wrapper import smart_stack from webdataset.filters import pipelinefilter logger = logging.getLogger(__name__) def get_local_pose_helper(snippet_origin_time_s, batch, local_coordinate): """ get the local coordinate system of the snippet as the pose at the snippet_origin_time_s under the specified coordinate system conventions (rig, or cam_rgb) """ assert ( ARIA_POSE_T_SNIPPET_RIG in batch.keys() and ARIA_POSE_SNIPPET_TIME_S in batch.keys() and ARIA_SNIPPET_T_WORLD_SNIPPET in batch.keys() ), f"keys not in batch keys {batch.keys()}" T_world_snippet = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] Ts_world_rig = T_world_snippet @ batch[ARIA_POSE_T_SNIPPET_RIG] time_s = batch[ARIA_POSE_SNIPPET_TIME_S] assert Ts_world_rig.dim() in [2, 3], f"{Ts_world_rig.shape} should be (B)xTx12" if local_coordinate == "rig": T_world_local = get_snippet_cosy_from_rig( Ts_world_rig=Ts_world_rig, time=time_s, snippet_origin_time=snippet_origin_time_s, ) elif local_coordinate == "cam_rgb": T_world_local = get_snippet_cosy_from_cam_rgb( Ts_world_rig=Ts_world_rig, time=time_s, snippet_origin_time=snippet_origin_time_s, cam_rgb=batch[ARIA_CALIB[0]], cam_rgb_time_s=batch[ARIA_CALIB_SNIPPET_TIME_S[0]], ) else: raise NotImplementedError( f"{local_coordinate} is not a valid coordinate option" ) return T_world_local def run_local_cosy( batch, origin_ratio=0.5, local_coordinate="cam_rgb", align_to_gravity=False, snippet_origin_time_s=None, ): new_batch = {} if snippet_origin_time_s is None: assert ARIA_SNIPPET_LENGTH_S in batch.keys() # get new snippet time origin snippet_length_s = batch[ARIA_SNIPPET_LENGTH_S] snippet_origin_time_s = snippet_length_s * origin_ratio # New origin time in ns. snippet_origin_time_ns = (snippet_origin_time_s * 1e9).long() # modify all time stamps to account for snippet origin change new_batch[ARIA_SNIPPET_TIME_NS] = ( batch[ARIA_SNIPPET_TIME_NS] + snippet_origin_time_ns ) # modify all snippet_time_s timestamps to account for snippet origin change keys_time_s = [key for key in batch.keys() if key.endswith("/snippet_time_s")] for key in keys_time_s: new_batch[key] = batch[key] - snippet_origin_time_s # get new snippet pose origin if ( ARIA_POSE_T_SNIPPET_RIG in batch and ARIA_POSE_SNIPPET_TIME_S in batch and ARIA_SNIPPET_TIME_NS in batch and ARIA_SNIPPET_T_WORLD_SNIPPET in batch ): T_world_snippet = get_local_pose_helper( snippet_origin_time_s, batch, local_coordinate, ) # apply change of coordinates to snippet coordinate system T_snippet_new_old = ( T_world_snippet.inverse() @ batch[ARIA_SNIPPET_T_WORLD_SNIPPET] ) new_batch[ARIA_SNIPPET_T_WORLD_SNIPPET] = T_world_snippet # apply the coordinate change to t_snippet_rigs keys_t_snippet_rig = [ key for key in batch.keys() if key.endswith("t_snippet_rig") ] for key in keys_t_snippet_rig: new_batch[key] = T_snippet_new_old @ batch[key] # transform obbs into the new snippet coordinate system as well if ARIA_OBB_PADDED in batch.keys(): new_batch[ARIA_OBB_PADDED] = transform_obbs( batch[ARIA_OBB_PADDED], T_snippet_new_old ) return new_batch def get_snippet_cosy_from_rig( snippet_origin_time: torch.Tensor, Ts_world_rig: PoseTW, time: torch.Tensor, ): """ simply interpolate the T_world_rig using the given time at the snippet_origin_time to get T_world_rig_origin """ T_world_rig_origin, good = Ts_world_rig.interpolate(time, snippet_origin_time) T = T_world_rig_origin.shape[-1] if T > 1 and not good.all(): logger.warn( f"WARNING some interpolated poses were not good: {good} time_s {time} snippet_time {snippet_origin_time}" ) return T_world_rig_origin def get_snippet_cosy_from_cam_rgb( snippet_origin_time: torch.Tensor, Ts_world_rig: PoseTW, time: torch.Tensor, cam_rgb: torch.Tensor, cam_rgb_time_s: torch.Tensor, ): """ interpolate T_world_rig and T_camera_rig using the given time_s at the snippet_origin_time and then compose the interpolated centers to get T_world_camera_origin """ # interpolate T_camera_rig Ts_camera_rig = cam_rgb.T_camera_rig T_camera_rig_origin, good = Ts_camera_rig.interpolate( cam_rgb_time_s, snippet_origin_time ) T = Ts_camera_rig.shape[-1] if T > 1 and not good.all(): logger.warn("WARNING: some interpolated camera extrinsics were not good:") logger.debug( f"Good: {good}\n time_s {cam_rgb_time_s}\n snip_center {snippet_origin_time}" ) T_world_rig_origin = get_snippet_cosy_from_rig( Ts_world_rig=Ts_world_rig, time=time, snippet_origin_time=snippet_origin_time ) return T_world_rig_origin @ T_camera_rig_origin.inverse() class EfmModelAdaptor: ATEK_CAM_LABEL_TO_EFM_CAM_LABEL: Dict[str, str] = { "camera-rgb": "rgb", "camera-slam-left": "slaml", "camera-slam-right": "slamr", } EFM_CAM_LABELS = ["rgb", "slaml", "slamr"] EFM_GRAVITY_IN_WORLD = [0, 0, -9.81] def __init__( self, freq: int, snippet_length_s: float = 2.0, semidense_points_pad_to_num: int = 50000, atek_to_efm_taxonomy_mapping_file: Optional[str] = None, ): self.freq = torch.tensor([freq], dtype=torch.int32) # EFM samples have fields padded to a fixed shape. # Obtain the fixed shape dimentions self.fixed_num_frames = int(snippet_length_s * freq) self.fixed_semidense_num_points = semidense_points_pad_to_num # Load optional taxonomy mapping file if atek_to_efm_taxonomy_mapping_file is not None: self.atek_to_efm_category_mapping = self._load_taxonomy_mapping_file( atek_to_efm_taxonomy_mapping_file ) else: self.atek_to_efm_category_mapping = None @staticmethod def get_dict_key_mapping_for_camera(atek_camera_label: str, efm_camera_label: str): return { f"mfcd#{atek_camera_label}+images": f"{efm_camera_label}/img", f"mfcd#{atek_camera_label}+projection_params": f"{efm_camera_label}/calib/projection_params", f"mfcd#{atek_camera_label}+frame_ids": f"{efm_camera_label}/frame_id_in_sequence", f"mfcd#{atek_camera_label}+capture_timestamps_ns": f"{efm_camera_label}/img/time_ns", f"mfcd#{atek_camera_label}+camera_model_name": f"{efm_camera_label}/calib/camera_model_name", f"mfcd#{atek_camera_label}+camera_valid_radius": f"{efm_camera_label}/calib/valid_radius", f"mfcd#{atek_camera_label}+exposure_durations_s": f"{efm_camera_label}/calib/exposure", f"mfcd#{atek_camera_label}+gains": f"{efm_camera_label}/calib/gain", f"mfcd#{atek_camera_label}+t_device_camera": f"{efm_camera_label}/calib/t_device_camera", } @staticmethod def get_dict_key_mapping_all(): dict_key_mapping = { # mps data mappings "mtd#ts_world_device": "pose/t_world_rig", "mtd#capture_timestamps_ns": "pose/time_ns", "mtd#gravity_in_world": "pose/gravity_in_world", "msdpd#points_world": "points/p3s_world", "msdpd#points_inv_dist_std": "points/inv_dist_std", "msdpd#points_dist_std": "points/dist_std", "msdpd#capture_timestamps_ns": "points/time_ns", "msdpd#points_volumn_min": ARIA_POINTS_VOL_MIN, "msdpd#points_volumn_max": ARIA_POINTS_VOL_MAX, "msdpd#points": "points/time_ns", "mfcd#camera-rgb-depth+images": "rgb/distance_m", # gt mappings "gt_data": "gt_data", } # camera data related mappings for ( atek_cam_label, efm_cam_label, ) in EfmModelAdaptor.ATEK_CAM_LABEL_TO_EFM_CAM_LABEL.items(): dict_key_mapping.update( EfmModelAdaptor.get_dict_key_mapping_for_camera( atek_camera_label=atek_cam_label, efm_camera_label=efm_cam_label ) ) return dict_key_mapping def _get_pose_to_align_gravity(self, sample_dict: Dict) -> Optional[PoseTW]: """ A helper function to return a T_newWorld_oldWorld transformation to align world gravity to the EFM convention. This pose needs to be later applied to all poses that include world. """ efm_gravity_in_world = torch.tensor( self.EFM_GRAVITY_IN_WORLD, dtype=torch.float32 ) current_gravity_in_world = sample_dict["pose/gravity_in_world"] if torch.allclose(efm_gravity_in_world, current_gravity_in_world, atol=1e-3): # print("gravity convention is already aligned.") return None else: if torch.allclose(current_gravity_in_world, torch.tensor([0, -9.81, 0])): return PoseTW.from_Rt( torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float32 ), torch.tensor([0, 0, 0], dtype=torch.float32), ) else: raise ValueError( f"unsupported gravity direction to align: {current_gravity_in_world}" ) def _load_taxonomy_mapping_file(self, filename: str) -> Dict: """ Load a taxonomy mapping csv file in the format of: ATEK_category_name, efm_category_name, efm_category_id returns a dict of {atek_cat_name -> (efm_cat_name, efm_cat_id)} """ atek_to_efm_category_mapping = {} with open(filename, "r") as f: csv_reader = csv.reader(f) next(csv_reader) for row in csv_reader: atek_name = row[0] value = (row[1], int(row[2])) # Convert category id to an integer atek_to_efm_category_mapping[atek_name] = value return atek_to_efm_category_mapping def _fill_dict_with_freq(self, sample_dict: Dict) -> Dict: fields_to_fill = [ "pose/hz", "points/hz", "rgb/img/hz", "slaml/img/hz", "slamr/img/hz", ] # only fill obb frequency if GT exists if self.gt_exists_flag: fields_to_fill += ["obbs/hz"] for field in fields_to_fill: sample_dict[field] = self.freq return sample_dict def _convert_to_batched_camera_tw( self, sample_dict: Dict, cam_label: str ) -> CameraTW: """ A helper function to convert ATEK camera calibration to EFM camera tensor wrapper, where calibration params are replicated x `num_frames`. """ # calibrations are replicated by `fixed_num_frames` batched_size = torch.Size((self.fixed_num_frames, 1)) camera_tw = CameraTW.from_surreal( width=torch.full( size=batched_size, fill_value=sample_dict[f"{cam_label}/img"].shape[3] ), height=torch.full( size=batched_size, fill_value=sample_dict[f"{cam_label}/img"].shape[2] ), type_str=sample_dict[f"{cam_label}/calib/camera_model_name"], params=sample_dict[f"{cam_label}/calib/projection_params"].unsqueeze( 0 ), # make tensor shape [1, 15], so that it can be expanded to [num_frames, 15] gain=fill_or_trim_tensor( tensor=sample_dict[f"{cam_label}/calib/gain"], dim_size=self.fixed_num_frames, dim=0, ), exposure_s=fill_or_trim_tensor( tensor=sample_dict[f"{cam_label}/calib/exposure"], dim_size=self.fixed_num_frames, dim=0, ), valid_radius=sample_dict[f"{cam_label}/calib/valid_radius"], T_camera_rig=PoseTW.from_matrix3x4( sample_dict[f"{cam_label}/calib/t_device_camera"] ).inverse(), ) return camera_tw.float() def _update_efm_obb_gt(self, atek_gt_dict: Dict) -> Dict: """ Helper function to convert ATEK obb gt to EFM obb gt. """ efm_sub_dict = {} # loop over all timestamps timestamp_list = [] efm_obb_all_timestamps = [] semantic_id_to_name = {} for timestamp, obb3_dict in atek_gt_dict["efm_gt"].items(): timestamp_list.append(int(timestamp)) # Create a hash map to query which instance is visible in which camera. # The resulting map will look like: { # "instance_1": { # "cam_0": index_in_cam0, # "cam_1": index_in_cam1, # ... # }, # "instance_2": { # ... # } # ... instance_visible_map = {} for camera_label, per_cam_dict in obb3_dict.items(): for i in range(len(per_cam_dict["instance_ids"])): instance_id = per_cam_dict["instance_ids"][i].item() if instance_id not in instance_visible_map: instance_visible_map[instance_id] = {} instance_visible_map[instance_id][camera_label] = i efm_obb_tw_list = [] # Loop over all instances from all cameras for instance_id, instance_mapping_info in instance_visible_map.items(): # Create a ObbTW for this instance # get obb3 info from any visible camera cam_label_0, cam_index_0 = next(iter(instance_mapping_info.items())) atek_single_bb3_dict = obb3_dict[cam_label_0] bb3_dim = atek_single_bb3_dict["object_dimensions"][ cam_index_0 ] # tensor [3] object_half_sizes = bb3_dim / 2.0 bb3_object = torch.tensor( [ -object_half_sizes[0], object_half_sizes[0], -object_half_sizes[1], object_half_sizes[1], -object_half_sizes[2], object_half_sizes[2], ], dtype=torch.float32, ) T_world_object = PoseTW.from_matrix3x4( atek_single_bb3_dict["ts_world_object"][cam_index_0] ) inst_id = atek_single_bb3_dict["instance_ids"][cam_index_0] # perform taxonomy remapping if needed, but skip "other" sem_id = atek_single_bb3_dict["category_ids"][cam_index_0].item() category_name = atek_single_bb3_dict["category_names"][cam_index_0] if category_name == "other": continue if self.atek_to_efm_category_mapping is not None: category_name, sem_id = self.atek_to_efm_category_mapping[ category_name ] # Also keep track of a sem_id_to_name mapping if sem_id not in semantic_id_to_name: semantic_id_to_name[sem_id] = category_name bb2_rgb = -1 * torch.ones(4) bb2_slaml = -1 * torch.ones(4) bb2_slamr = -1 * torch.ones(4) # Commenting off because obb2 are not needed """ if "camera-rgb" in instance_mapping_info: cam_label = "camera-rgb" cam_index = instance_mapping_info[cam_label] bb2_rgb = atek_gt_dict["obb2"][cam_label]["bbox_ranges"][cam_index] if "camera-slam-left" in instance_mapping_info: cam_label = "camera-slam-left" cam_index = instance_mapping_info[cam_label] bb2_slaml = atek_gt_dict["obb2"][cam_label]["bbox_ranges"][cam_index] if "camera-slam-right" in instance_mapping_info: cam_label = "camera-slam-right" cam_index = instance_mapping_info[cam_label] bb2_slamr = atek_gt_dict["obb2"][cam_label]["bbox_ranges"][cam_index] """ # Fill in padded obbs in EFM format efm_obb_tw_list.append( ObbTW.from_lmc( bb3_object=bb3_object, bb2_rgb=bb2_rgb, bb2_slaml=bb2_slaml, bb2_slamr=bb2_slamr, T_world_object=T_world_object, sem_id=torch.tensor([sem_id], dtype=torch.int64), inst_id=torch.tensor([inst_id], dtype=torch.int64), ) ) # end for instance_id if len(efm_obb_tw_list) == 0: efm_obb_tw = ObbTW() else: efm_obb_tw = ObbTW(smart_stack(efm_obb_tw_list, dim=0)) efm_obb_tw = efm_obb_tw.add_padding(max_elts=128) efm_obb_all_timestamps.append(efm_obb_tw) efm_sub_dict["obbs/padded_snippet"] = ObbTW( smart_stack(efm_obb_all_timestamps, dim=0) ) efm_sub_dict["obbs/time_ns"] = torch.tensor(timestamp_list, dtype=torch.int64) efm_sub_dict["obbs/sem_id_to_name"] = semantic_id_to_name return efm_sub_dict def _pad_semidense_data(self, sample_dict: Dict) -> Dict: """ A helper function to pad semidense data from List[Tensor, (K, 3 or 1)] to fixed shape of [numFrames, num_semidense_points, 3 or 1] """ result_dict = {} fields_to_pad = ["points/p3s_world", "points/dist_std", "points/inv_dist_std"] for field in fields_to_pad: tensor_list = sample_dict[field] for i in range(len(tensor_list)): # First, pad each tensor in the list to fixed num points tensor_list[i] = fill_or_trim_tensor( tensor=tensor_list[i], dim_size=self.fixed_semidense_num_points, dim=0, fill_value=float("nan"), ) # then stack stacked_tensor = torch.stack(tensor_list, dim=0) # then pad over frames result_dict[field] = fill_or_trim_tensor( tensor=stacked_tensor, dim_size=self.fixed_num_frames, dim=0 ) return result_dict def _pad_over_frames(self, sample_dict: Dict, fields_to_pad: List[str]) -> Dict: """ A helper function to pad data over frames, by repeating the last element over frames. """ result_dict = {} for field in fields_to_pad: result_dict[field] = fill_or_trim_tensor( tensor=sample_dict[field], dim_size=self.fixed_num_frames, dim=0, ) return result_dict def _split_pose_over_snippet(self, sample_dict: Dict) -> Dict: """ A helper function to split T_world_rig into T_world_snippet and T_snippet_rig. In the meantime, Align gravity to [0, 0, -9.81] """ result_dict = {} # check if world coordinates needs to be re-aligned maybe_T_newWorld_oldWorld = self._get_pose_to_align_gravity(sample_dict) # maybe_T_newWorld_oldWorld = None Ts_world_rig = PoseTW.from_matrix3x4(sample_dict["pose/t_world_rig"]) if maybe_T_newWorld_oldWorld: Ts_world_rig = maybe_T_newWorld_oldWorld @ Ts_world_rig result_dict["pose/t_world_rig"] = Ts_world_rig T_world_snippet = Ts_world_rig.clone()[0] T_world_snippet = T_world_snippet.unsqueeze(0) result_dict["snippet/t_world_snippet"] = T_world_snippet.clone() result_dict["pose/t_snippet_rig"] = Ts_world_rig[0].inverse() @ Ts_world_rig for camera_label in EfmModelAdaptor.EFM_CAM_LABELS: result_dict[f"{camera_label}/t_snippet_rig"] = result_dict[ "pose/t_snippet_rig" ].clone() # Transform obbs poses, from old_world -> new_world -> snippet if ARIA_OBB_PADDED in sample_dict: if maybe_T_newWorld_oldWorld: T_snippet_world = T_world_snippet.inverse() @ maybe_T_newWorld_oldWorld else: T_snippet_world = T_world_snippet.inverse() result_dict[ARIA_OBB_PADDED] = transform_obbs( sample_dict[ARIA_OBB_PADDED], T_snippet_world ) # Also transform semidense points if maybe_T_newWorld_oldWorld: result_dict["points/p3s_world"] = ( maybe_T_newWorld_oldWorld * sample_dict["points/p3s_world"] ) return result_dict def _split_timestamps_over_snippet(self, sample_dict: Dict) -> Dict: """ A helper function to split capture_timestamps_ns into snippet/time_ns and */snippet_time_s """ dict_keys_to_split_timestamps = [ "pose/", "points/", ] + [f"{label}/img/" for label in EfmModelAdaptor.EFM_CAM_LABELS] # Also split obbs timestamps, if gt exists if self.gt_exists_flag: dict_keys_to_split_timestamps += ["obbs/"] result_dict = {} result_dict["snippet/time_ns"] = sample_dict["rgb/img/time_ns"][0].unsqueeze(0) for key in dict_keys_to_split_timestamps: result_dict[key + "snippet_time_s"] = ( sample_dict[key + "time_ns"] - result_dict["snippet/time_ns"] ) / torch.tensor(1e9, dtype=torch.float32) return result_dict def atek_to_efm(self, data, train=False): """ A helper data transform function to convert a ATEK webdataset data sample built by EfmSampleBuilder to EFM unbatched samples. Yield one unbatched sample a time to use the collation and batching mechanism in the webdataset properly. """ for atek_wds_sample in data: efm_sample = atek_wds_sample # Check if GT exists in the sample. If not, all obb related operations will be skipped self.gt_exists_flag = ( "gt_data" in atek_wds_sample and len(atek_wds_sample["gt_data"]) > 0 ) # Fill frequenze data from conf efm_sample = self._fill_dict_with_freq(efm_sample) # Pad semidense data, which requires 2-dim padding padded_dict = self._pad_semidense_data(efm_sample) efm_sample.update(padded_dict) # Convert ATEK calibration to EFM camera calibration, where calibration params are replicated x `num_frames`, # except gains and exposure_s which is per-frame. for cam_label in EfmModelAdaptor.EFM_CAM_LABELS: efm_sample[f"{cam_label}/calib"] = self._convert_to_batched_camera_tw( efm_sample, cam_label ) # Convert ATEK GT to EFM GT if self.gt_exists_flag: result_dict = self._update_efm_obb_gt(atek_wds_sample["gt_data"]) efm_sample.update(result_dict) # split T_world_rig into T_world_snippet and T_snippet_rig result_dict = self._split_pose_over_snippet(efm_sample) efm_sample.update(result_dict) # split capture_timestamps_ns into snippet/time_ns and */snippet_time_s result_dict = self._split_timestamps_over_snippet(efm_sample) efm_sample.update(result_dict) # Pad some data over frames by repeating last element fields_to_pad = [] fields_to_skip_padding = ["snippet/t_world_snippet"] for key, value in efm_sample.items(): if key in fields_to_skip_padding: continue if isinstance(value, torch.Tensor) or isinstance(value, TensorWrapper): if value.shape[0] < self.fixed_num_frames: # pad timestamp tensors, but not other 1-dim tensors if ( key.endswith("time_ns") or key.endswith("time_s") or value.ndim > 1 ): fields_to_pad.append(key) result_dict = self._pad_over_frames(efm_sample, fields_to_pad=fields_to_pad) efm_sample.update(result_dict) # Duplicate `camera/img/time` to `camera/calib/time` for camera_name in EfmModelAdaptor.EFM_CAM_LABELS: efm_sample[f"{camera_name}/calib/time_ns"] = efm_sample[ f"{camera_name}/img/time_ns" ] efm_sample[f"{camera_name}/calib/snippet_time_s"] = efm_sample[ f"{camera_name}/img/snippet_time_s" ] # Convert data types from int to float32 fields_to_conv2float32 = [ f"{label}/img" for label in EfmModelAdaptor.EFM_CAM_LABELS ] + [ f"{label}/frame_id_in_sequence" for label in EfmModelAdaptor.EFM_CAM_LABELS ] for field in fields_to_conv2float32: efm_sample[field] = efm_sample[field].to(torch.float32) if field.endswith("img"): # normalize efm_sample[field] = efm_sample[field] / 255.0 if field == "rgb/img": # swap channels from [RGB] -> [BGR] # efm_sample[field] = efm_sample[field][:, [2, 1, 0], :, :] pass # Run local cosy to shift the origin # For testing only: patch snippet lenths efm_sample[ARIA_SNIPPET_LENGTH_S] = torch.tensor([2.0], dtype=torch.float32) result = run_local_cosy(batch=efm_sample, origin_ratio=0.5) efm_sample.update(result) # delete useless data if train: # keep only tensors remove_keys = [] for key in efm_sample: if not isinstance(efm_sample[key], (torch.Tensor, TensorWrapper)): remove_keys.append(key) for k in remove_keys: efm_sample.pop(k) yield efm_sample def load_atek_wds_dataset_as_efm( urls: List, freq=10, snippet_length_s=2.0, semidense_points_pad_to_num=50000, atek_to_efm_taxonomy_mapping_file: Optional[str] = None, batch_size: Optional[int] = None, collation_fn: Optional[Callable] = None, ): efm_model_adaptor = EfmModelAdaptor( freq=freq, snippet_length_s=snippet_length_s, semidense_points_pad_to_num=semidense_points_pad_to_num, atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy_mapping_file, ) return load_atek_wds_dataset( urls, dict_key_mapping=EfmModelAdaptor.get_dict_key_mapping_all(), data_transform_fn=pipelinefilter(efm_model_adaptor.atek_to_efm)( train=collation_fn is not None ), batch_size=batch_size, collation_fn=collation_fn, ) def load_atek_wds_dataset_as_efm_train( urls: List, freq=10, snippet_length_s=2.0, semidense_points_pad_to_num=50000, atek_to_efm_taxonomy_mapping_file: Optional[str] = None, batch_size: Optional[int] = None, collation_fn: Optional[Callable] = None, ): efm_model_adaptor = EfmModelAdaptor( freq=freq, snippet_length_s=snippet_length_s, semidense_points_pad_to_num=semidense_points_pad_to_num, atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy_mapping_file, ) wds_dataset = ( wds.WebDataset(urls, nodesplitter=None, resampled=True, repeat=True) .decode(wds.imagehandler("torchrgb8")) .map(process_wds_sample) ) wds_dataset = wds_dataset.map( partial( select_and_remap_dict_keys, key_mapping=EfmModelAdaptor.get_dict_key_mapping_all(), ) ) wds_dataset = wds_dataset.compose( pipelinefilter(efm_model_adaptor.atek_to_efm)(train=collation_fn is not None) ) wds_dataset = wds_dataset.batched(batch_size, collation_fn=collation_fn) return wds_dataset ================================================ FILE: efm3d/dataset/vrs_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import random from typing import Callable, List, Optional, Union import numpy as np import pyvrs import torch import torch.nn.functional as F from efm3d.aria import CameraTW, ObbTW, PoseTW, smart_stack, transform_obbs from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_CALIB_SNIPPET_TIME_S, ARIA_CALIB_TIME_NS, ARIA_CAM_INFO, ARIA_FRAME_ID, ARIA_IMG, ARIA_IMG_SNIPPET_TIME_S, ARIA_IMG_T_SNIPPET_RIG, ARIA_IMG_TIME_NS, ARIA_OBB_BB2, ARIA_OBB_PADDED, ARIA_OBB_SEM_ID_TO_NAME, ARIA_OBB_SNIPPET_TIME_S, ARIA_OBB_TIME_NS, ARIA_POINTS_SNIPPET_TIME_S, ARIA_POINTS_TIME_NS, ARIA_POINTS_VOL_MAX, ARIA_POINTS_VOL_MIN, ARIA_POINTS_WORLD, ARIA_POSE_SNIPPET_TIME_S, ARIA_POSE_T_SNIPPET_RIG, ARIA_POSE_T_WORLD_RIG, ARIA_POSE_TIME_NS, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.utils.file_utils import ( exists_nonzero_path, get_timestamp_list_ns, load_factory_calib, load_global_points_csv, load_obbs_gt, load_semidense_observations, load_trajectory, load_trajectory_adt, load_trajectory_aeo, read_image_snippet_from_vrs, sample_from_range, sample_times, ) from efm3d.utils.obb_io import get_instance_id_in_frameset, next_obb_observations from efm3d.utils.rescale import rescale_obb_tw from torch.utils.data import Dataset # gravity direction in ADT conventions GRAVITY_DIRECTION_ADT = np.array([0.0, -1.0, 0.0], np.float32) def is_adt(vrs_path): # get folder name if vrs_path.endswith(".vrs"): vrs_path = os.path.split(vrs_path)[0] if os.path.exists(os.path.join(vrs_path, "aria_trajectory.csv")): return True folder_name = os.path.basename(vrs_path) return "optitrack_release_work_seq" in folder_name def is_aeo(vrs_path): return "aeo_" in vrs_path def get_transform_to_vio_gravity_convention(gravity_direction: np.array): """ Get transformation to map gravity_direction to (0,0,-1) as per our (and VIO/Temple) convention. """ # gravity_direction = (d1, d2, d3) (0,0,-1)^T; d1, d2, d3 column vectors of rotation matrix R_gravity_vio # -d3 = gravity_direction d3 = -gravity_direction.copy() # now construct an orthonormal basis for the rotation matrix # d1 is a vector thats orthogonal to gravity_direction by construction d1 = np.array( [ gravity_direction[2] - gravity_direction[1], gravity_direction[0], -gravity_direction[0], ] ) # get d2 via orthogonal direction vector to d3 and d1 d2 = np.cross(d3, d1) # get rotation matrix R_gravity_vio = np.concatenate( [d1[:, np.newaxis], d2[:, np.newaxis], d3[:, np.newaxis]], 1 ) assert (np.linalg.det(R_gravity_vio) - 1.0) < 1e-5 assert (((R_gravity_vio @ R_gravity_vio.transpose()) - np.eye(3)) < 1e-5).all() R_gravity_vio = torch.from_numpy(R_gravity_vio) # normalize to unit length R_gravity_vio = F.normalize(R_gravity_vio, p=2, dim=-2) R_vio_gravity = R_gravity_vio.transpose(1, 0) T_vio_gravity = PoseTW.from_Rt(R_vio_gravity, torch.zeros(3)) return T_vio_gravity def compute_time_intersection(time_lists): min_time = -math.inf max_time = math.inf for ts in time_lists: ts = np.array(ts) min_time = max(min_time, ts.min()) max_time = min(max_time, ts.max()) # add an offset to the timestamp safety_margin = 3_000_000 # 3ms min_time = min_time - safety_margin max_time = max_time - safety_margin return min_time, max_time def preprocess_inference(batch): # tensor wrapper for k in batch: if not isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)): continue if k in [ ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_POSE_T_WORLD_RIG, ARIA_POSE_T_SNIPPET_RIG, ] + ARIA_IMG_T_SNIPPET_RIG and not isinstance(batch[k], PoseTW): batch[k] = PoseTW(batch[k]) elif k in ARIA_CALIB and not isinstance(batch[k], CameraTW): batch[k] = CameraTW(batch[k]) elif k == ARIA_OBB_PADDED and not isinstance(batch[k], ObbTW): batch[k] = ObbTW(batch[k]) return batch def preprocess( batch, device, subsample: int = 10, aug_funcs: Optional[Union[Callable, List[Callable]]] = None, ): # tensor wrapper for k in batch: if not isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)): continue if k in [ ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_POSE_T_WORLD_RIG, ARIA_POSE_T_SNIPPET_RIG, ] + ARIA_IMG_T_SNIPPET_RIG and not isinstance(batch[k], PoseTW): batch[k] = PoseTW(batch[k]) elif k in ARIA_CALIB and not isinstance(batch[k], CameraTW): batch[k] = CameraTW(batch[k]) elif k == ARIA_OBB_PADDED and not isinstance(batch[k], ObbTW): batch[k] = ObbTW(batch[k]) # time crop T = batch[ARIA_IMG[0]].shape[1] if subsample != T: s = random.randint(0, T - subsample - 1) for k in batch: if ( isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)) and batch[k].shape[1] == T ): batch[k] = batch[k][:, s : s + subsample, ...] # move to GPU for k in batch: if isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)): batch[k] = batch[k].to(device) # data augmentations if aug_funcs is not None: if isinstance(aug_funcs, Callable): aug_funcs = [aug_funcs] for aug in aug_funcs: batch = aug(batch) return batch def tensor_unify(tensor, dim_size: int, dim: int = 0): """Fill or trim a torch or numpy tensor to the given `dim_size`, along the given dim Inputs: tensor (torch or np.array): input tensor dim_size (int): the size to fill or trim to (e.g. predefined batch size) dim (int): the dimension to fill or trim Returns: tensor2 (a torch or np.array): output tensor with the dim size = `dim_size`. """ assert tensor.shape[dim] > 0, "Input tensor must have at least 1 element" if isinstance(tensor, list): tensor = np.array(tensor) if isinstance(tensor, np.ndarray): np_tensor = tensor tensor_bs = np_tensor.shape[dim] if tensor_bs > dim_size: tensor2 = np.take(np_tensor, indices=np.arange(dim_size), axis=dim) elif tensor_bs < dim_size: last = np.take(np_tensor, tensor_bs - 1, dim) fill = np.expand_dims(last, axis=dim) # fill with last element fill = np.repeat(fill, dim_size - tensor_bs, axis=dim) tensor2 = np.concatenate([np_tensor, fill], axis=dim) else: tensor2 = tensor else: tensor_bs = tensor.shape[dim] if tensor_bs > dim_size: indices = torch.arange(dim_size) for i in range(tensor.ndim): if i != dim: indices = indices.unsqueeze(i) tensor2 = torch.take_along_dim(tensor, indices, dim) elif tensor_bs < dim_size: shape = [1 for _ in range(tensor.ndim)] indices = torch.ones(shape).long() indices[0] = tensor_bs - 1 last = torch.take_along_dim(tensor, indices, dim) fill_shape = shape fill_shape[dim] = dim_size - tensor_bs fill = last.repeat(fill_shape) tensor2 = torch.cat([tensor, fill], dim=dim) else: tensor2 = tensor return tensor2 def run_sensor_poses(batch, num_notified=-1, max_notified=10): if ( ARIA_POSE_T_SNIPPET_RIG in batch.keys() and ARIA_POSE_SNIPPET_TIME_S in batch.keys() ): new_batch = {} Ts_snippet_rig = batch[ARIA_POSE_T_SNIPPET_RIG] ts = batch[ARIA_POSE_SNIPPET_TIME_S] assert Ts_snippet_rig.dim() in [ 2, 3, ], f"need to be of shape (B) x T x 12 but are {Ts_snippet_rig.shape}" for i, img_time_key in enumerate(ARIA_IMG_SNIPPET_TIME_S): if ( img_time_key in batch.keys() and ARIA_IMG_T_SNIPPET_RIG[i] not in batch.keys() ): ts_interp = batch[img_time_key] Ts_world_rig_i, good = Ts_snippet_rig.interpolate(ts, ts_interp) new_batch[ARIA_IMG_T_SNIPPET_RIG[i]] = Ts_world_rig_i if not good.all(): counts = good.sum(dim=-1).squeeze() if num_notified > 0 and num_notified < max_notified: print( f"some interpolated poses were bad (fraction good per batch: {counts / good.shape[-1]}); likely because tried to interpolated past given input timed poses." ) return new_batch class VrsSequenceDataset(Dataset): def __init__( self, vrs_path, frame_rate, sdi, snippet_length_s, stride_length_s, max_snippets=9999, skip_snippets=0, preprocess=None, ): self.frame_rate = frame_rate self.vrs_path = vrs_path self.vrs_folder = os.path.split(vrs_path)[0] self.reader = pyvrs.SyncVRSReader( vrs_path, auto_read_configuration_records=True ) self.max_snippets = max_snippets self.sdi = sdi self.preprocess = preprocess self.cam_calib = load_factory_calib(self.reader) self.is_adt = is_adt(vrs_path) self.is_aeo = is_aeo(vrs_path) self.max_objects_per_frameset = 128 fps = self.cam_calib["fps"] self.fps = [fps["rgb"], fps["slaml"], fps["slamr"]] ts_lists = [] # Add images for idx in range(3): img_ts_list = get_timestamp_list_ns(self.reader, ARIA_CAM_INFO["id"][idx]) ts_lists.append(img_ts_list) # Add poses timed_Ts_world_rig = self.load_poses(self.vrs_folder, subsample=1) pose_times_ns = list(timed_Ts_world_rig.keys()) pose_freq = int(1.0 / (1e-9 * (pose_times_ns[1] - pose_times_ns[0]))) pose_subsample = int(pose_freq / frame_rate) pose_times_ns = pose_times_ns[::pose_subsample] self.T_world_rig_time_ns = pose_times_ns self.Ts_world_rig = torch.stack( [timed_Ts_world_rig[key] for key in pose_times_ns] ) ts_lists.append(pose_times_ns) # Add obbs GT if available self.obs = None self.obs = self.load_objects() if self.obs is not None: obb_freq = int(1.0 / (1e-9 * (self.obb_times[1] - self.obb_times[0]))) obb_subsample = max(1, int(obb_freq / frame_rate)) self.obb_times = self.obb_times[::obb_subsample] # Add points self.load_semidense(self.vrs_folder) # intersect all data modalities min_time, max_time = compute_time_intersection(ts_lists) play_times_ns = get_timestamp_list_ns(self.reader, ARIA_CAM_INFO["id"][idx]) play_times_ns = [ ts for ts in play_times_ns if (ts > min_time and ts < max_time) ] play_times_ns = np.unique(play_times_ns).tolist() # compute snippets start and end time seq_start_time = play_times_ns[0] seq_end_time = play_times_ns[-1] snip_start = seq_start_time snip_end = snip_start + snippet_length_s * 1e9 self.snippet_times = [] while snip_end < seq_end_time: self.snippet_times.append((snip_start, snip_end)) snip_start += stride_length_s * 1e9 snip_end = snip_start + snippet_length_s * 1e9 if skip_snippets > 0: self.snippet_times = self.snippet_times[skip_snippets:] def load_objects(self): self.obs = load_obbs_gt( self.vrs_folder, load_2d_bbs=True, filter_outside_2d_bbs=True, rgb_only=False, ) if len(self.obs) == 0: return None # inverse map from proto to a linear id and filter the interested objects if given. instance2proto = self.obs["inst2proto"] unique_proto_names = np.unique(list(instance2proto.values())).tolist() self.obs["proto2id"] = {name: i for i, name in enumerate(unique_proto_names)} if self.is_aeo: aeo_to_efm = ( f"{os.path.dirname(__file__)}/../config/taxonomy/aeo_to_efm.csv" ) self.global_name_to_id = {} with open(aeo_to_efm, "r") as f: lines = f.readlines() for li in lines[1:]: ori_name, class_name, class_id = li.strip().split(",") self.global_name_to_id[str(ori_name)] = (str(class_name), int(class_id)) filtered_proto_names = set(self.global_name_to_id.keys()).intersection( set(unique_proto_names) ) # remap the proto names and semantic ids given the taxonomy mapping self.obs["proto2id"] = { self.global_name_to_id[name][0]: self.global_name_to_id[name][1] for name in filtered_proto_names } self.obs["inst2proto"] = { inst: self.global_name_to_id[name][0] for inst, name in instance2proto.items() if name in filtered_proto_names } else: # use the class name to id mapping in the sequence self.obs["proto2id"] = { name: i for i, name in enumerate(unique_proto_names) } # compute inverse map self.obs["id2proto"] = {id: name for name, id in self.obs["proto2id"].items()} timedTs_world_object = self.obs["timedTs_world_object"] static_Ts_world_object = {} assert len(timedTs_world_object) != 0, ( "Warning: no observations found for entire sequence" ) # timedTs_world_object captures static object at the -1 timestamp if -1 in timedTs_world_object.keys(): static_Ts_world_object = timedTs_world_object[-1] self.obs["static_Ts_world_object"] = static_Ts_world_object self.obb_times = sorted(set(self.obs[ARIA_OBB_BB2[0]].keys())) if self.is_adt: T_vio_gravity = get_transform_to_vio_gravity_convention( GRAVITY_DIRECTION_ADT ) for time, idT_wo in self.obs["timedTs_world_object"].items(): for inst, T_wo in idT_wo.items(): # we go from gravity world coordinate system to the new one that follows vio conventions self.obs["timedTs_world_object"][time][inst] = ( T_vio_gravity @ T_wo.float() ) return self.obs def load_semidense(self, vrs_path, max_inv_depth_std=0.005, max_depth_std=0.05): possible_global_points_paths = [ os.path.join(vrs_path, "multi_global_points.csv.gz"), os.path.join(vrs_path, "multi_global_points.csv"), os.path.join(vrs_path, "global_points.csv.gz"), os.path.join(vrs_path, "global_points.csv"), os.path.join(vrs_path, "semidense_points.csv.gz"), os.path.join(vrs_path, "maps/maps_v1/globalcloud_GT.csv"), # ASE os.path.join(vrs_path, "mps/slam/semidense_points.csv.gz"), # ADT ] possible_obs_paths = [ os.path.join(vrs_path, "semidense_observations.csv.gz"), os.path.join(vrs_path, "semidense_observations.csv"), os.path.join(vrs_path, "maps/maps_v1/observations.csv"), # ASE os.path.join(vrs_path, "semidense_points.csv"), os.path.join(vrs_path, "mps/slam/semidense_observations.csv.gz"), # ADT ] global_points_path = exists_nonzero_path(possible_global_points_paths) self.uid_to_p3, self.uid_to_inv_dist_std, self.uid_to_dist_std = ( load_global_points_csv(global_points_path, max_inv_depth_std, max_depth_std) ) if self.is_adt: T_vio_gravity = get_transform_to_vio_gravity_convention( GRAVITY_DIRECTION_ADT ).double() for uid, p3 in self.uid_to_p3.items(): self.uid_to_p3[uid] = (T_vio_gravity * p3).reshape(-1) semidense_obs_path = exists_nonzero_path(possible_obs_paths) self.time_to_uids, self.uid_to_times = load_semidense_observations( semidense_obs_path ) if self.time_to_uids is not None: self.pts_times_ns = sorted(self.time_to_uids.keys()) ( self.time_to_pc, self.time_to_dist_std, self.time_to_inv_dist_std, no_points_times, ) = ({}, {}, {}, []) for time in self.pts_times_ns: uids = self.time_to_uids[time] p3s = [self.uid_to_p3[uid] for uid in uids if uid in self.uid_to_p3] if len(p3s) > 0: # sort by inv dist std to make any cropping later use the best points inv_dist_std = [ self.uid_to_inv_dist_std[uid] for uid in uids if uid in self.uid_to_inv_dist_std ] inv_dist_std = np.array(inv_dist_std) dist_std = [ self.uid_to_dist_std[uid] for uid in uids if uid in self.uid_to_dist_std ] dist_std = np.array(dist_std) ids = np.argsort(inv_dist_std) p3s = [p3s[i] for i in ids] p3s = torch.stack(p3s) inv_dist_std = torch.from_numpy(inv_dist_std[ids]) dist_std = torch.from_numpy(dist_std[ids]) else: no_points_times.append(time) p3s = torch.zeros((0, 3), dtype=torch.float32) inv_dist_std = torch.zeros((0), dtype=torch.float32) dist_std = torch.zeros((0), dtype=torch.float32) self.time_to_pc[time] = p3s self.time_to_dist_std[time] = dist_std self.time_to_inv_dist_std[time] = inv_dist_std print( f"Found {len(self.uid_to_p3)} semidense points; time range {min(self.pts_times_ns) / 1e9}s-{max(self.pts_times_ns) / 1e9}s" ) # aggregate all the points all_p3s = [self.uid_to_p3[uid] for uid in self.uid_to_p3] all_inv_dist_std = [ self.uid_to_inv_dist_std[uid] for uid in self.uid_to_inv_dist_std ] ids = np.argsort(all_inv_dist_std) # ranked by inverse depth std self.all_p3s = torch.stack([all_p3s[i] for i in ids]) # [N, 3] assert self.all_p3s.shape[0] > 0, "no points loaded" # compute a [q, 1-q] percentile as the global range q = 0.001 self.vol_min = torch.quantile(self.all_p3s, q, dim=0) self.vol_max = torch.quantile(self.all_p3s, 1 - q, dim=0) self.vol_min = self.vol_min.detach() self.vol_max = self.vol_max.detach() def load_poses(self, vrs_path, subsample): timed_Ts_world_rig = None # ADT sequences timed_Ts_world_rig = load_trajectory_adt(vrs_path, subsample=subsample) if timed_Ts_world_rig is not None: # handle ADT sequence gravity rotation T_vio_gravity = get_transform_to_vio_gravity_convention( GRAVITY_DIRECTION_ADT ).double() for k, T_wr in timed_Ts_world_rig.items(): timed_Ts_world_rig[k] = T_vio_gravity @ T_wr return timed_Ts_world_rig # AEO sequences timed_Ts_world_rig = load_trajectory_aeo( vrs_path, time_in_secs=False, load_torch=True, subsample=subsample, ) if timed_Ts_world_rig is not None: if self.is_adt: T_vio_gravity = get_transform_to_vio_gravity_convention( GRAVITY_DIRECTION_ADT ).double() for k, T_wr in timed_Ts_world_rig.items(): timed_Ts_world_rig[k] = T_vio_gravity @ T_wr return timed_Ts_world_rig # Other sequences timed_Ts_world_rig = load_trajectory( vrs_path, time_in_secs=False, load_torch=True, subsample=subsample, ) return timed_Ts_world_rig def load_snippet_pose(self, start, end): idx_i, idx_j = sample_times(self.T_world_rig_time_ns, start, end) Ts_wr = self.Ts_world_rig[idx_i:idx_j, :] pose_times_ns = torch.LongTensor(self.T_world_rig_time_ns[idx_i:idx_j]) T_ws = Ts_wr[0].clone().unsqueeze(0) Ts_sr = T_ws.inverse() @ Ts_wr pose_times_s = ( pose_times_ns - torch.tensor(start, dtype=torch.long) ).float() * 1e-9 return T_ws, Ts_wr, Ts_sr, pose_times_ns, pose_times_s def load_snippet_semidense(self, start, end, max_size=20000): idx_i, idx_j = sample_times(self.pts_times_ns, start, end) points_times_ns = self.pts_times_ns[idx_i:idx_j] points_world = [self.time_to_pc[time] for time in points_times_ns] for idx, ps in enumerate(points_world): ps = ps[:max_size, :] pad_num = max_size - ps.shape[0] assert pad_num >= 0, f"padding must be non-negative, but got {pad_num}" points_world[idx] = F.pad( ps, (0, 0, 0, pad_num), "constant", float("nan"), ) points_world = torch.stack(points_world) points_times_ns = torch.LongTensor(points_times_ns) points_times_s = ( points_times_ns - torch.tensor(start, dtype=torch.long) ).float() * 1e-9 return points_world, points_times_ns, points_times_s def load_snippet_objects(self, start, end): def get_obbs_for_time(t: int, inst_ids: List): ( bb2s_rgb, bb2s_slaml, bb2s_slamr, bb3s, Ts_world_object, sem_ids, inst_ids, ) = next_obb_observations( obs=self.obs, time=t, inst_ids=inst_ids, cam_names=["rgb", "slaml", "slamr"], load_dynamic_objects=True, interpolate_poses=True, dt_threshold_ns=10_000_000, ) obbs = ObbTW.from_lmc( bb3s, bb2s_rgb, bb2s_slaml, bb2s_slamr, Ts_world_object, sem_ids, inst_ids, ) # scale 2d bbs to image size obbs = rescale_obb_tw( obbs, cam_size_before_rgb=[1408, 1408, 3], # Aria rgb size cam_size_before_slam=[480, 640, 1], # Aria slam size down_scale=self.sdi, wh_multiple_of=16, ) # center object bounding box in the object coordinate system # T_world_object so that origin is the center of the object obbs = obbs.center() # get object sem_id to name mapping sem_id_to_name = { self.obs["proto2id"][self.obs["inst2proto"][iid.item()]]: self.obs[ "inst2proto" ][iid.item()] for iid in inst_ids } return obbs, sem_id_to_name obbs_snippet, sem_id_to_name, snippet_times = [], {}, [] probably_snippet_times = [t for t in self.obb_times if start < t and t <= end] for t in probably_snippet_times: # we get only the instances that are visibile as indicated by them having 2d bb annotations inst_ids = get_instance_id_in_frameset( self.obs, t, load_dynamic_objects=True, interpolate_poses=True, dt_threshold_ns=10_000_000, ) snippet_times.append(t) if len(inst_ids) == 0: obbs_snippet.append( ObbTW(-1 * torch.ones(self.max_objects_per_frameset, 34)) ) continue obbs, sem2names = get_obbs_for_time(t, inst_ids) obbs_snippet.append(obbs.add_padding(self.max_objects_per_frameset)) sem_id_to_name.update(sem2names) if len(obbs_snippet) > 0: obbs_padded = ObbTW(smart_stack(obbs_snippet)) else: obbs_padded = ObbTW(-1 * torch.ones((0, self.max_objects_per_frameset, 34))) print(f"could not find obbs for snippet times {snippet_times}") obbs_time_ns = torch.LongTensor(snippet_times) obbs_time_s = ( obbs_time_ns - torch.tensor(start, dtype=torch.long) ).float() * 1e-9 # subsample obj_idxs = sample_from_range( 0, len(obbs_padded), sample_rate=1, add_random=False ) obbs_padded = obbs_padded[obj_idxs].contiguous() obbs_time_ns = obbs_time_ns[obj_idxs].contiguous() obbs_time_s = obbs_time_s[obj_idxs].contiguous() return obbs_padded, sem_id_to_name, obbs_time_ns, obbs_time_s def __len__(self): return min(len(self.snippet_times), self.max_snippets) def __getitem__(self, index): if index >= self.max_snippets: raise StopIteration sample = {} start, end = self.snippet_times[index] rgb_calib = {key: self.cam_calib[key]["rgb"] for key in self.cam_calib} # img for i in range(3): subsample = int(self.fps[i] / self.frame_rate) imgs, img_times_ns, cam_tws, frame_ids = read_image_snippet_from_vrs( self.reader, ARIA_CAM_INFO["id"][i], start, end, rgb_calib, subsample=subsample, scale_down_images=self.sdi, ) img_times_s = ( img_times_ns - torch.tensor(start, dtype=torch.long).float() ) * 1e-9 sample.update( { ARIA_IMG[i]: imgs, ARIA_IMG_TIME_NS[i]: img_times_ns, ARIA_IMG_SNIPPET_TIME_S[i]: img_times_s, ARIA_FRAME_ID[i]: frame_ids, ARIA_CALIB[i]: cam_tws, ARIA_CALIB_TIME_NS[i]: img_times_ns, ARIA_CALIB_SNIPPET_TIME_S[i]: img_times_s, } ) # pose T_ws, Ts_wr, Ts_sr, pose_times_ns, pose_times_s = self.load_snippet_pose( start, end ) sample.update( { ARIA_SNIPPET_T_WORLD_SNIPPET: T_ws, ARIA_POSE_T_WORLD_RIG: Ts_wr, ARIA_POSE_T_SNIPPET_RIG: Ts_sr, ARIA_POSE_TIME_NS: pose_times_ns, ARIA_POSE_SNIPPET_TIME_S: pose_times_s, } ) # interpolate slam poses to get img poses sample.update(run_sensor_poses(sample)) # semidense points pts_world, pts_times_ns, pts_times_s = self.load_snippet_semidense(start, end) sample.update( { ARIA_POINTS_WORLD: pts_world, ARIA_POINTS_TIME_NS: pts_times_ns, ARIA_POINTS_SNIPPET_TIME_S: pts_times_s, ARIA_POINTS_VOL_MIN: self.vol_min, ARIA_POINTS_VOL_MAX: self.vol_max, } ) # objects if self.obs: obbs_padded, sem_id_to_name, obbs_time_ns, obbs_time_s = ( self.load_snippet_objects(start, end) ) # transform obbs into snippet coordinate system obbs_padded = transform_obbs(obbs_padded, T_ws.float().inverse()) sample.update( { ARIA_OBB_PADDED: obbs_padded, ARIA_OBB_SEM_ID_TO_NAME: sem_id_to_name, ARIA_OBB_TIME_NS: obbs_time_ns, ARIA_OBB_SNIPPET_TIME_S: obbs_time_s, } ) for key in sample: if isinstance(sample[key], (PoseTW, CameraTW, ObbTW)): sample[key] = sample[key].tensor() if isinstance(sample[key], torch.Tensor): sample[key] = sample[key].float() if key not in [ ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_POINTS_VOL_MIN, ARIA_POINTS_VOL_MAX, ARIA_OBB_SEM_ID_TO_NAME, ]: if isinstance(sample[key], torch.Tensor) and sample[key].shape[0] == 0: continue sample[key] = tensor_unify(sample[key], self.frame_rate) if self.preprocess: sample = self.preprocess(sample) return sample ================================================ FILE: efm3d/dataset/wds_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import tarfile import numpy as np import torch import webdataset as wds from efm3d.aria import CameraTW, DEFAULT_CAM_DATA_SIZE, ObbTW, PoseTW, TensorWrapper from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_OBB_PADDED, ARIA_POINTS_VOL_MAX, ARIA_POINTS_VOL_MIN, ARIA_POINTS_WORLD, ARIA_POSE_T_SNIPPET_RIG, ARIA_POSE_T_WORLD_RIG, ARIA_SNIPPET_T_WORLD_SNIPPET, ) def convert_to_aria_multimodal_dataset(sample): """ Convert a data sample from Aria multimodal data in webdataset format to training/validation sample format. """ def to_mm_key(k, end_separator="."): k = k[: k.rfind(end_separator)] # remove suffix # move keys back to the "/" convention from the "-" separator needed for webdataset paths. k = k.replace("-", "/") return k image_snippets = {} mm_sample = {} for k, v in sample.items(): # Compose images to image snippet if k.endswith(".jpg"): img_key = to_mm_key(k, "_") if img_key not in image_snippets: image_snippets[img_key] = [v] else: image_snippets[img_key].append(v) # np.float32 tensors elif k.endswith(".pyd"): k = to_mm_key(k, ".") if k in [ ARIA_POSE_T_SNIPPET_RIG, ARIA_POSE_T_WORLD_RIG, ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_IMG_T_SNIPPET_RIG[0], ARIA_IMG_T_SNIPPET_RIG[1], ARIA_IMG_T_SNIPPET_RIG[2], ]: mm_sample[k] = PoseTW.from_matrix3x4(v.float()) elif k in ARIA_CALIB: assert v.shape[-1] == DEFAULT_CAM_DATA_SIZE, ( "only allow Fisheye624 cameras" ) mm_sample[k] = CameraTW(v) elif k == ARIA_OBB_PADDED: mm_sample[ARIA_OBB_PADDED] = ObbTW(v) elif k == ARIA_POINTS_WORLD: # load as float32 mm_sample[ARIA_POINTS_WORLD] = v.float() elif isinstance(v, dict): # store dicts as (key, datum) lists in order to be able to collate them mm_sample[k] = [(kv, vv) for kv, vv in v.items()] else: mm_sample[k] = v # str elif k.endswith(".txt"): k = to_mm_key(k, ".") mm_sample[k] = v # int elif k.endswith(".cls"): k = to_mm_key(k, ".") mm_sample[k] = v else: pass # silently ignore data field not used for training # images to image snippets for k, v in image_snippets.items(): mm_sample[k] = np.transpose(np.stack(v, axis=0), (0, 3, 1, 2)) # convert to one-channel for SLAM images if k == ARIA_IMG[1] or k == ARIA_IMG[2]: mm_sample[k] = mm_sample[k][:, :1, :, :] mm_sample[k] = torch.from_numpy(mm_sample[k]) for key in mm_sample: if "time_s" in key: if isinstance(mm_sample[key], np.ndarray): mm_sample[key] = torch.from_numpy(mm_sample[key]) assert mm_sample[key].dtype == torch.float32 mm_sample[key] = mm_sample[key] if "time_ns" in key: if isinstance(mm_sample[key], np.ndarray): mm_sample[key] = torch.from_numpy(mm_sample[key]) assert mm_sample[key].dtype == torch.int64 mm_sample[key] = mm_sample[key] return mm_sample def batchify(datum, device=None): # Add batch dimension for key in datum: if isinstance(datum[key], (torch.Tensor, TensorWrapper)): datum[key] = datum[key][None, ...].to(device) if device is not None: datum[key] = datum[key].to(device) else: datum[key] = [datum[key]] return datum def unbatchify(datum): # Remove batch dimension for key in datum: if isinstance(datum[key], (torch.Tensor, TensorWrapper, list)): datum[key] = datum[key][0] return datum def get_tar_sample_num(tar_file): sn = set() with tarfile.TarFile(tar_file, "r") as tar: for member in tar.getmembers(): sn.add(member.name.split(".")[0]) return len(sn) class WdsStreamDataset: """Sample 2s/1s WDS dataset to specified snippet length and stride""" def __init__( self, data_path, snippet_length_s=1.0, stride_length_s=0.1, wds_length_s=2.0, fps=10, max_snip=99999999, ): self.snippet_length_s = snippet_length_s self.stride_length_s = stride_length_s self.wds_length_s = wds_length_s # wds snippets should always be generated half overlapped self.wds_stride_s = wds_length_s // 2 self.fps = fps self.max_snip = max_snip tar_list = sorted(glob.glob(f"{data_path}/*.tar")) self.samples_per_tar = get_tar_sample_num(tar_list[0]) self.num_tars = len(tar_list) self.dataset = wds.DataPipeline( wds.SimpleShardList(tar_list), wds.tarfile_to_samples(), wds.decode("rgb"), wds.map(convert_to_aria_multimodal_dataset), ) self.dataloader = iter(self.dataset) self.frames_wds = int(self.fps * self.wds_length_s) self.frames_out = int(self.fps * self.snippet_length_s) self.frames_stride_wds = int(self.fps * self.wds_stride_s) self.frames_stride_out = int(self.fps * self.stride_length_s) self.num_rest = int( (self.wds_length_s - self.snippet_length_s) / self.stride_length_s ) self.num_first = int(1 + self.num_rest) self.num_snippets = ( self.num_first + (self.samples_per_tar * self.num_tars - 1) * self.num_rest ) # for iteration self.first = True self.wds_snippet = None self.snip_idx = 0 self.global_idx = 0 def __len__(self): return min(self.num_snippets, self.max_snip) def sample_snippet_(self, snippet, start, end): # time crop sample = snippet.copy() for k in sample: if isinstance(sample[k], (torch.Tensor, TensorWrapper)): if k not in [ ARIA_SNIPPET_T_WORLD_SNIPPET, ARIA_POINTS_VOL_MIN, ARIA_POINTS_VOL_MAX, ]: sample[k] = sample[k][start:end, ...] return sample def __iter__(self): return self def if_get_next_(self): if self.wds_snippet is None: return True if self.first: return self.snip_idx >= self.num_first else: return self.snip_idx >= self.num_rest def __next__(self): if self.global_idx >= self.max_snip: raise StopIteration if self.if_get_next_(): if self.first and self.wds_snippet is not None: self.first = False self.wds_snippet = next(self.dataloader) self.snip_idx = 0 if self.first: start = self.snip_idx * self.frames_stride_out else: start = (self.snip_idx + 1) * self.frames_stride_out end = start + self.frames_out sample = self.sample_snippet_(self.wds_snippet, start, end) self.snip_idx += 1 self.global_idx += 1 return sample ================================================ FILE: efm3d/inference/__init__.py ================================================ ================================================ FILE: efm3d/inference/eval.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import numpy as np import torch from efm3d.utils.obb_csv_writer import ObbCsvReader from efm3d.utils.obb_metrics import ObbMetrics from efm3d.utils.obb_utils import ( draw_prec_recall_curve, prec_recall_bb3, prec_recall_curve, ) def check_sem_id_conflict(ids_pred, ids_gt): all_sem_ids = set(list(ids_pred.keys()) + list(ids_gt.keys())) for sem_id in all_sem_ids: if sem_id in ids_pred and sem_id in ids_gt: assert ids_pred[sem_id] == ids_gt[sem_id], ( f"Mismatch id to name for sem id {sem_id}, {ids_pred[sem_id]} in pred but {ids_gt[sem_id]} in GT" ) elif sem_id not in ids_pred: print(f"sem_id {sem_id} not found in pred") else: print(f"sem_id {sem_id} not found in GT") def evaluate_obb_csv( pred_csv: str, gt_csv: str, iou: float = 0.2, pr_curve: bool = False, ): pred_reader = ObbCsvReader(pred_csv) gt_reader = ObbCsvReader(gt_csv) pred_obbs = pred_reader.obbs gt_obbs = gt_reader.obbs sem_id_to_name = pred_reader.sem_ids_to_names.copy() sem_id_to_name_gt = gt_reader.sem_ids_to_names.copy() check_sem_id_conflict(sem_id_to_name, sem_id_to_name_gt) sem_id_to_name.update(sem_id_to_name_gt) result = {} mAP = ObbMetrics( cam_ids=[0], cam_names=["rgb"], class_metrics=True, eval_2d=False, eval_3d=True, global_name_to_id={ name: int(sem_id) for sem_id, name in sem_id_to_name.items() }, ) ts = list(pred_obbs.keys()) + list(gt_obbs.keys()) ts = list(set(ts)) ts.sort() gt_ts_miss = 0 pred_ts_miss = 0 for t in ts: if t not in pred_obbs: print(f"pred obbs not found for {t}") pred_ts_miss += 1 continue if t not in gt_obbs: print(f"gt obbs not found for {t}") gt_ts_miss += 1 continue # we should not have any paddings assert pred_obbs[t].shape[0] == pred_obbs[t].remove_padding().shape[0] assert gt_obbs[t].shape[0] == gt_obbs[t].remove_padding().shape[0] # always do precision recall calculation prec, rec, match_mat, ious, per_class_results = prec_recall_bb3( pred_obbs[t], gt_obbs[t], iou_thres=iou, return_ious=True, per_class=True, ) tps = match_mat.any(-1) fps = (~match_mat).all(-1) result[f"precision@IoU{iou}"] = float(prec) result[f"recall@IoU{iou}"] = float(rec) result[f"num_true_positives@IoU{iou}"] = int(tps.sum()) result["num_dets"] = match_mat.shape[0] result["num_gts"] = match_mat.shape[1] for sem_id, per_class_result in per_class_results.items(): result[f"precision@IoU{iou}@Class_{sem_id_to_name[sem_id.item()]}"] = float( per_class_result["precision"] ) result[f"recall@IoU{iou}@Class_{sem_id_to_name[sem_id.item()]}"] = float( per_class_result["recall"] ) # check if the preds contain probabilities prob = pred_obbs[t].prob.squeeze() assert not torch.all(prob.eq(-1.0)), ( "the obbs don't contain valid probabilities for mAP calculation." ) # add pred/gt pair to mAP calculator. mAP.update(pred_obbs[t], gt_obbs[t]) output_dir = os.path.dirname(pred_csv) if pr_curve and len(ts) == 1: precs, recalls, probs = prec_recall_curve([(pred_obbs[t], gt_obbs[t])]) draw_prec_recall_curve( precs, recalls, save_folder=output_dir, iou_thres=iou ) result["num_timestamps"] = len(ts) result["num_timestamp_miss_pred"] = pred_ts_miss result["num_timestamp_miss_gt"] = gt_ts_miss result_map = mAP.compute() # ignore average recall result_map = { k: v.item() for k, v in result_map.items() if not k.startswith("rgb/mar_") } result.update(result_map) return result def obb_eval_dataset(input_folder: str, iou: float = 0.2): """ Obb eval at dataset-level """ GT_OBB_FILENAME = "gt_scene_obbs.csv" PRED_OBB_FILENAME = "tracked_scene_obbs.csv" # get all the pred and gt csv files pred_csv_paths, gt_csv_paths = [], [] filenames = os.listdir(input_folder) dirs = [os.path.join(input_folder, f) for f in filenames] dirs = [d for d in dirs if os.path.isdir(d)] for d in dirs: pred_csv = os.path.join(d, PRED_OBB_FILENAME) gt_csv = os.path.join(d, GT_OBB_FILENAME) if os.path.exists(gt_csv) and os.path.exists(pred_csv): pred_csv_paths.append(pred_csv) gt_csv_paths.append(gt_csv) result = {} result["num_seqs"] = len(pred_csv_paths) if len(pred_csv_paths) == 0 or len(gt_csv_paths) == 0: return result pred_obbs, gt_obbs = [], [] sem_id_to_name = {} for pred_csv, gt_csv in zip(pred_csv_paths, gt_csv_paths): pred_reader = ObbCsvReader(pred_csv) gt_reader = ObbCsvReader(gt_csv) p_obbs = pred_reader.obbs g_obbs = gt_reader.obbs # p_obbs, g_obbs are single-item dicts p_obbs = next(iter(p_obbs.values())) g_obbs = next(iter(g_obbs.values())) pred_obbs.append(p_obbs) gt_obbs.append(g_obbs) sem_id_to_name_pred = pred_reader.sem_ids_to_names.copy() sem_id_to_name_gt = gt_reader.sem_ids_to_names.copy() check_sem_id_conflict(sem_id_to_name_pred, sem_id_to_name_gt) sem_id_to_name.update(sem_id_to_name_gt) mAP = ObbMetrics( cam_ids=[0], cam_names=["rgb"], class_metrics=True, eval_2d=False, eval_3d=True, global_name_to_id={ name: int(sem_id) for sem_id, name in sem_id_to_name.items() }, ) precs, recs = [], [] for p_obbs, g_obbs in zip(pred_obbs, gt_obbs): prec, rec, match_mat, ious, per_class_results = prec_recall_bb3( p_obbs, g_obbs, iou_thres=iou, return_ious=True, per_class=True, ) precs.append(prec) recs.append(rec) mAP.update(p_obbs, g_obbs) result[f"precision@IoU{iou}"] = np.mean(precs) result[f"recall@IoU{iou}"] = np.mean(recs) precs, recalls, probs = prec_recall_curve( [(p_obbs, g_obbs) for p_obbs, g_obbs in zip(pred_obbs, gt_obbs)] ) # save precision-recall curve to png save_dir = input_folder draw_prec_recall_curve(precs, recalls, save_folder=save_dir, iou_thres=iou) result_map = mAP.compute() # ignore average recall (e.g. "rgb/mar_220_3D") result_map = { k: v.item() for k, v in result_map.items() if not k.startswith("rgb/mar_") } result.update(result_map) return result def main(): import argparse import json parser = argparse.ArgumentParser(description="Run EFM eval pipeline") parser.add_argument( "--input_folder", type=str, help="The input folder that contains the gt and pred obbs csv files. If this is provided, the eval will be done at dataset-level", default=None, ) parser.add_argument( "--pred_csv", type=str, help="The prediction obbs csv file, can be snippet-level snippet_obbs.csv or scene-level tracked_scene_obbs.csv", default=None, ) parser.add_argument( "--gt_csv", type=str, help="The ground truth obbs csv file, can be snippet-level gt_obbs.csv or scene-level gt_scene_obbs.csv", default=None, ) parser.add_argument( "--iou", type=float, default=0.2, ) parser.add_argument( "--pr_curve", action="store_true", help="Whether to draw precision recall curve", ) args = parser.parse_args() if args.input_folder: metrics = obb_eval_dataset(args.input_folder) print(json.dumps(metrics, indent=2, sort_keys=True)) else: assert args.pred_csv is not None, "pred_csv is required" assert args.gt_csv is not None, "gt_csv is required" metrics = evaluate_obb_csv(args.pred_csv, args.gt_csv, args.iou, args.pr_curve) output_dir = os.path.dirname(args.pred_csv) print(json.dumps(metrics, indent=2, sort_keys=True)) with open(os.path.join(output_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2, sort_keys=True) if __name__ == "__main__": main() ================================================ FILE: efm3d/inference/fuse.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import logging import os from typing import List import numpy as np import torch import tqdm import trimesh from efm3d.aria.pose import PoseTW from efm3d.utils.marching_cubes import marching_cubes_scaled from efm3d.utils.reconstruction import pc_to_vox, sample_voxels from efm3d.utils.voxel import create_voxel_grid logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) def set_boundary_value(x, val, thickness): if thickness == 0: return x x[..., :thickness, :, :] = val x[..., -thickness:, :, :] = val x[..., :, :thickness, :] = val x[..., :, -thickness:, :] = val x[..., :, :, :thickness] = val x[..., :, :, -thickness:] = val return x def load_tensor(fname, device): data = torch.load(fname, map_location=device) if "_8b" in fname: data = data.dequantize() return data class VolumeFusion: def __init__( self, voxel_size: List[float], voxel_extent: List[float], device: str = "cuda", dtype=torch.float32, w_min: float = 5.0, w_max: float = 100.0, init_value: float = 0.0, surface_thres: float = 0.99, boundary_thres: int = 1, ): self.voxel_size = voxel_size # D x H x W self.voxel_extent = voxel_extent # W x H x D self.vD, self.vH, self.vW = self.voxel_size self.vD = int(self.vD) self.vH = int(self.vH) self.vW = int(self.vW) self.w_max = w_max self.w_min = w_min self.surface_thres = surface_thres self.boundary_thres = boundary_thres self.global_volume = torch.ones( self.vD, self.vH, self.vW, device=device, dtype=dtype ) # D H W self.global_volume = self.global_volume * init_value self.global_volume_weights = torch.zeros_like(self.global_volume) # D H W self.global_volume_points = create_voxel_grid( self.vW, self.vH, self.vD, self.voxel_extent, device ).to(dtype=dtype) # W, H, D, 3 # reshaping self.global_volume_points = self.global_volume_points.permute( 2, 1, 0, 3 ) # D H W 3 self.global_volume_points = self.global_volume_points.reshape( -1, 3 ) # (D*H*W) x 3 self.global_volume_weights = self.global_volume_weights.reshape(-1) # D*H*W self.global_volume = self.global_volume.reshape(-1) # D*H*W self.device = device def set_boundary_mask(self, mask): thickness = self.boundary_thres mask[:thickness] = False # Set the first 'thickness' layers in Height to zero mask[-thickness:] = False # Set the last 'thickness' layers in Height to zero mask[:, :thickness] = False # Set the first 'thickness' layers in Width to zero mask[:, -thickness:] = False # Set the last 'thickness' layers in Width to zero mask[:, :, :thickness] = ( False # Set the first 'thickness' layers in Depth to zero ) mask[:, :, -thickness:] = ( False # Set the last 'thickness' layers in Depth to zero ) return mask def fuse( self, local_volume: torch.Tensor, local_extent: List[float], T_l_w: PoseTW, new_obs_w=1.5, visiblity_mask=None, ): local_volume = local_volume.to(self.global_volume.device) T_l_w = T_l_w.to(self.global_volume.device) vD, vH, vW = local_volume.shape # transform global_volume to local volume global_volume_l = T_l_w * self.global_volume_points global_volume_l_coord, valid_global_points = pc_to_vox( global_volume_l, vW, vH, vD, local_extent ) local_samples, valid_samples = sample_voxels( local_volume.unsqueeze(0).unsqueeze(0).float(), global_volume_l_coord.view(1, -1, 3).float(), ) local_samples = ( local_samples.squeeze(0).squeeze(0).to(dtype=self.global_volume.dtype) ) valid_samples = valid_samples.squeeze(0) # making a mask surface_mask = local_volume < self.surface_thres if visiblity_mask is not None: surface_mask &= visiblity_mask.to(surface_mask) # we don't trust the boundary voxels from CNNS if self.boundary_thres > 0: surface_mask = self.set_boundary_mask(surface_mask) surface_mask_f = surface_mask.float() surface_mask_f[~surface_mask] = torch.nan # sample the mask surface_mask_samples, _ = sample_voxels( surface_mask_f.unsqueeze(0).unsqueeze(0).float(), global_volume_l_coord.view(1, -1, 3).float(), ) surface_mask = ~surface_mask_samples.isnan() valid_samples = valid_samples & surface_mask mask = valid_samples & valid_global_points mask = mask.squeeze() w = self.global_volume_weights[mask] self.global_volume[mask] = ( self.global_volume[mask] * w + local_samples[mask] * 2.0 ) / (w + 2.0) # update weights self.global_volume_weights[mask] = w + new_obs_w self.global_volume_weights[mask] = self.global_volume_weights[mask].clamp( max=self.w_max ) def get_volume(self, reshape=True): if reshape: return self.global_volume.reshape(self.vD, self.vH, self.vW) else: return self.global_volume def get_weights(self, reshape=True): if reshape: return self.global_volume_weights.reshape(self.vD, self.vH, self.vW) else: self.global_volume_weights def get_mask(self, reshape=True): mask = self.global_volume_weights >= self.w_min if reshape: return mask.reshape(self.vD, self.vH, self.vW) else: mask def get_trimesh(self, iso_level=0.5): global_vol = self.get_volume() mask = self.get_mask() verts_w, faces, _ = marching_cubes_scaled( global_vol.cpu().detach().float(), iso_level, self.voxel_extent, mask, ) sem_rgb = None mesh = trimesh.Trimesh(verts_w, faces, vertex_colors=sem_rgb) return mesh class VolumetricFusion: def __init__( self, input_folder, w_min=5.0, w_max=9999999.0, voxel_res=0.04, device="cuda", ): self.input_folder = input_folder self.per_snip_folder = os.path.join(input_folder, "per_snip") f_vol_min = os.path.join(self.per_snip_folder, "scene_vol_min.pt") f_vol_max = os.path.join(self.per_snip_folder, "scene_vol_max.pt") assert os.path.exists(f_vol_min) and os.path.exists(f_vol_max), ( "missing scene volume info" ) self.vol_min = load_tensor(f_vol_min, "cpu").numpy() self.vol_max = load_tensor(f_vol_max, "cpu").numpy() self.w_min = w_min self.w_max = w_max self.voxel_res = voxel_res self.device = device self.vis_norm_grad_occ_thr = 0.2 # we remove a 1 voxel wide boundary on the volumes to remove cnn artifacts self.boundary_thresh = 1 self.f_occ_preds = sorted( glob.glob(os.path.join(self.per_snip_folder, "occ_pr*.pt")) ) Ts_wv_pt = os.path.join(self.per_snip_folder, "Ts_wv.pt") self.Ts_wv = torch.load(Ts_wv_pt, map_location="cpu") # need to be on cpu assert len(self.f_occ_preds) == self.Ts_wv.shape[0], ( f"occ snippets {len(self.f_occ_preds)} should match with Ts_wv {self.Ts_wv.shape[0]}" ) # load voxel extent for initialization ve_path = os.path.join(self.per_snip_folder, "voxel_extent.pt") self.local_extent = torch.load(ve_path).cpu() if self.local_extent.ndim == 2: self.local_extent = self.local_extent.squeeze(0) self.local_extent = self.local_extent.tolist() self.global_vol = None self.init_from_range(self.vol_min, self.vol_max) def reinit(self): # reinit with the same voxel extent if self.global_vol is not None: del self.global_vol self.init_from_range(self.vol_min, self.vol_max) def init_from_range(self, xyz_min, xyz_max): # Add a little buffer around the bounds. xyz_min -= 2 * self.voxel_res xyz_max += 2 * self.voxel_res if xyz_min.ndim == 2: xyz_min = xyz_min[0] if xyz_max.ndim == 2: xyz_max = xyz_max[0] global_extent = [ xyz_min[0], xyz_max[0], xyz_min[1], xyz_max[1], xyz_min[2], xyz_max[2], ] voxel_size = np.ceil((xyz_max - xyz_min) / self.voxel_res).tolist() voxel_size.reverse() # change to DxHxW self.global_vol = VolumeFusion( voxel_size, global_extent, device=self.device, w_min=self.w_min, w_max=self.w_max, init_value=1.0, surface_thres=0.99, ) def get_trimesh(self): return self.global_vol.get_trimesh() def run_step(self, i): # run one step of volume fusion if i >= len(self.f_occ_preds): logger.info( f"{i}-th snippet exceeding the number of snippets {len(self.f_occ_preds)}" ) return T_wv = self.Ts_wv[i] occ_pred = load_tensor(self.f_occ_preds[i], self.device) # [1, 1, D, H, W] occ_pred = occ_pred[0][0] # [D, H, W] self.global_vol.fuse( local_volume=occ_pred, local_extent=self.local_extent, T_l_w=T_wv.inverse(), ) def run(self): logger.info("Fusing voxel occupancy using volume fusion...") for i, _ in tqdm.tqdm(enumerate(self.f_occ_preds), total=len(self.f_occ_preds)): self.run_step(i) ================================================ FILE: efm3d/inference/model.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import shutil import time import torch import tqdm from efm3d.aria.aria_constants import ( ARIA_IMG_TIME_NS, ARIA_OBB_PADDED, ARIA_OBB_PRED_SEM_ID_TO_NAME, ARIA_OBB_PRED_VIZ, ARIA_OBB_SEM_ID_TO_NAME, ARIA_POINTS_VOL_MAX, ARIA_POINTS_VOL_MIN, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.aria.obb import obb_time_union from efm3d.dataset.wds_dataset import batchify from efm3d.utils.obb_csv_writer import ObbCsvWriter logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class EfmInference: def __init__(self, streamer, model, output_dir, device, zip, obb_only=False): self.streamer = streamer self.model = model self.output_dir = output_dir self.device = device self.zip = zip self.obb_only = obb_only self.metadata_saved = False self.Ts_wv = [] # all T_world_voxel as one tensor self.obb_csv_path = os.path.join(output_dir, "snippet_obbs.csv") self.obb_writer = None self.per_snip_dir = os.path.join(output_dir, "per_snip") shutil.rmtree(self.per_snip_dir, ignore_errors=True) os.makedirs(self.per_snip_dir, exist_ok=True) # obb GT self.gt_obb_csv_path = os.path.join(output_dir, "gt_obbs.csv") self.gt_obb_writer = None self.scene_gt_obbs_w = [] def __del__(self): if not self.zip: return # compress the output folder if os.path.exists(self.output_dir) and os.listdir(self.output_dir): logger.info(f"zipping file to {self.output_dir}.zip") shutil.make_archive( self.output_dir.rstrip("/"), "zip", self.output_dir, verbose=True ) logger.info(f"zip file saved to {self.output_dir}.zip") def save_tensor(self, tensor, key, idx=None, output_dir=""): if idx is not None: pt_name = os.path.join(output_dir, f"{key}_{idx:06}.pt") else: pt_name = os.path.join(output_dir, f"{key}.pt") torch.save(tensor.cpu(), pt_name) def save_output(self, data, idx, output_dir): """ Save per-snippet 3D obb output and occupancy tensor to disk. """ # assuming single sample batch bid = 0 # 3d obb predictions if ARIA_OBB_PRED_VIZ in data: obb_preds_s = data[ARIA_OBB_PRED_VIZ][bid].remove_padding() T_ws = data[ARIA_SNIPPET_T_WORLD_SNIPPET][bid] obb_preds_w = obb_preds_s.transform(T_ws) first_rgb_time_ns = data[ARIA_IMG_TIME_NS[0]][bid, 0].item() if self.obb_writer is None: self.obb_writer = ObbCsvWriter(self.obb_csv_path) self.obb_writer.write( obb_preds_w, first_rgb_time_ns, data[ARIA_OBB_PRED_SEM_ID_TO_NAME] ) if ARIA_OBB_PADDED in data and ARIA_OBB_SEM_ID_TO_NAME in data: gt_obbs_s = obb_time_union(data[ARIA_OBB_PADDED])[bid].remove_padding() gt_obbs_w = gt_obbs_s.transform(T_ws) self.scene_gt_obbs_w.append(gt_obbs_w.add_padding(128)) if self.gt_obb_writer is None: self.gt_obb_writer = ObbCsvWriter(self.gt_obb_csv_path) gt_sem_id_to_name = {} gt_sem_id_to_name.update(data[ARIA_OBB_SEM_ID_TO_NAME][bid]) self.gt_obb_writer.write( gt_obbs_w, first_rgb_time_ns, sem_id_to_name=gt_sem_id_to_name, ) # occupancy predictions (skipped in obb_only mode) if ( not self.obb_only and "occ_pr" in data and ARIA_POINTS_VOL_MIN in data and ARIA_POINTS_VOL_MAX in data ): if not self.metadata_saved: self.save_tensor( data["voxel_extent"], "voxel_extent", idx=None, output_dir=output_dir, ) self.metadata_saved = True self.save_tensor( data[ARIA_POINTS_VOL_MIN][0], # tensor(3) "scene_vol_min", idx=None, output_dir=output_dir, ) self.save_tensor( data[ARIA_POINTS_VOL_MAX][0], # tensor(3) "scene_vol_max", idx=None, output_dir=output_dir, ) self.save_tensor(data["occ_pr"], "occ_pr", idx, output_dir) self.Ts_wv.append(data["voxel/T_world_voxel"][0]) def run(self): # feed the per-snippet data to the model gt_sem_id = {} idx = 0 start = time.time() for batch in tqdm.tqdm(self.streamer, total=len(self.streamer)): # convert single sample to batch and move to GPU batchify(batch, device=self.device) with torch.no_grad(): output = self.model(batch, obb_only=self.obb_only) batch.update(output) self.save_output(batch, idx, self.per_snip_dir) if ARIA_OBB_SEM_ID_TO_NAME in batch: gt_sem_id.update(batch[ARIA_OBB_SEM_ID_TO_NAME][0]) idx += 1 print(f"\ninference speed {idx / (time.time() - start):.02f} sample/s") # save all T_wv as one tensor to avoid writing small files if len(self.Ts_wv) > 0: Ts_wv = torch.stack(self.Ts_wv, dim=0) self.save_tensor(Ts_wv, "Ts_wv", None, self.per_snip_dir) # write scene-level obbs if len(self.scene_gt_obbs_w) > 0: max_obbs = 512 merged_gts = torch.stack(self.scene_gt_obbs_w, dim=0) merged_gts = obb_time_union(merged_gts.unsqueeze(0), pad_size=max_obbs) merged_gts = merged_gts[0].remove_padding() gt_scene_obb_csv_path = os.path.join(self.output_dir, "gt_scene_obbs.csv") gt_scene_obb_writer = ObbCsvWriter(gt_scene_obb_csv_path) gt_scene_obb_writer.write(merged_gts, -1, gt_sem_id) ================================================ FILE: efm3d/inference/pipeline.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import math import os import shutil import hydra import numpy as np import omegaconf import torch import trimesh from efm3d.dataset.vrs_dataset import preprocess_inference, VrsSequenceDataset from efm3d.inference.fuse import VolumetricFusion from efm3d.inference.model import EfmInference from efm3d.inference.viz import generate_video from efm3d.utils.gravity import correct_adt_mesh_gravity from efm3d.utils.mesh_utils import eval_mesh_to_mesh def get_gt_mesh_ply(data_path): """ Return ASE or ADT GT mesh path. If not exist, return empty str. """ if data_path.endswith(".vrs"): seq_name = os.path.basename(os.path.dirname(data_path)) else: seq_name = os.path.basename(data_path.strip("/")) adt_mesh_ply = f"./data/adt_mesh/{seq_name}/gt_mesh.ply" ase_mesh_ply = f"./data/ase_mesh/scene_ply_{seq_name}.ply" if os.path.exists(adt_mesh_ply): return adt_mesh_ply elif os.path.exists(ase_mesh_ply): return ase_mesh_ply return "" def compute_avg_metrics(paths): """ Given metrics path list, compute the average metrics Note that simply averaging is not a good way to compute mAP metrics. """ avg_ret = {} for path in paths: with open(path, "r") as f: metrics = json.load(f) for k, v in metrics.items(): if k not in avg_ret: avg_ret[k] = [v] else: avg_ret[k].append(v) for k, v in avg_ret.items(): avg_ret[k] = np.mean(v) return avg_ret def create_streamer( data_path, snippet_length_s, stride_length_s, max_snip, skip_snips=0 ): # infer data type def is_atek_wds_input(data_path): ATEK_WDS_TAR = "shards-0000.tar" first_tar = os.path.join(data_path, ATEK_WDS_TAR) return os.path.exists(first_tar) if is_atek_wds_input(data_path): from efm3d.dataset.atek_wds_dataset import AtekWdsStreamDataset streamer = AtekWdsStreamDataset( data_path, atek_to_efm_taxonomy=f"{os.path.dirname(__file__)}/../config/taxonomy/atek_to_efm.csv", snippet_length_s=snippet_length_s, stride_length_s=stride_length_s, max_snip=max_snip, ) elif data_path.endswith(".vrs"): # Use the native vrs sequence processor streamer = VrsSequenceDataset( data_path, frame_rate=10, sdi=2, snippet_length_s=snippet_length_s, stride_length_s=stride_length_s, max_snippets=max_snip, skip_snippets=skip_snips, preprocess=preprocess_inference, ) # (optional) use the ATEK data loader If it is installed # from efm3d.dataset.atek_vrs_dataset import create_atek_raw_data_loader_from_vrs_path # streamer = create_atek_raw_data_loader_from_vrs_path( # vrs_path=data_path, # freq_hz=10, # snippet_length_s=snippet_length_s, # stride_length_s=stride_length_s, # skip_begin_seconds=20.0, # skip_end_seconds=5.0, # max_snippets=max_snip, # ) else: print( f"Input error {data_path}, expect the input to be a folder to WDS tars or a .vrs file" ) exit(-1) return streamer def create_output_dir(output_dir, model_ckpt, data_path): # create output path from model ckpt and data path # e.g. result will be output to // model_name = os.path.basename(os.path.splitext(model_ckpt)[0]) seq_name = data_path if data_path.endswith(".vrs"): seq_name = os.path.basename(os.path.dirname(data_path)) else: seq_name = os.path.basename(data_path.strip("/")) output_dir = os.path.join(output_dir, f"{model_name}", f"{seq_name}") if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) return output_dir def run_one( data_path, model_ckpt, model_cfg, max_snip=9999, snip_stride=0.1, voxel_res=0.04, output_dir="./output", obb_only=False, skip_video=False, skip_snips=0, ): output_dir = create_output_dir(output_dir, model_ckpt, data_path) # create model if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" checkpoint = torch.load(model_ckpt, weights_only=True, map_location=device) model_config = omegaconf.OmegaConf.load(model_cfg) model = hydra.utils.instantiate(model_config) model.load_state_dict(checkpoint["state_dict"], strict=True) model.to(device) model.eval() print("model init done") # create dataset streamer = create_streamer( data_path, snippet_length_s=1.0, stride_length_s=snip_stride, max_snip=max_snip, skip_snips=skip_snips, ) # per-snippet inference efm_inf = EfmInference( streamer, model, output_dir, device=device, zip=False, obb_only=obb_only ) efm_inf.run() del efm_inf del model if torch.cuda.is_available(): torch.cuda.empty_cache() # track obbs try: from efm3d.inference.track import track_obbs track_obbs(output_dir) except: print(f"Skip tracking obb due to missing dependency, please see INSTALL.md") # eval obb metrics = {} pred_csv = os.path.join(output_dir, "tracked_scene_obbs.csv") gt_csv = os.path.join(output_dir, "gt_scene_obbs.csv") if os.path.exists(pred_csv) and os.path.exists(gt_csv): try: from efm3d.inference.eval import evaluate_obb_csv obb_metrics = evaluate_obb_csv(pred_csv=pred_csv, gt_csv=gt_csv, iou=0.2) metrics.update(obb_metrics) except: print( f"Skip obb evaluation due to missing dependency, please see INSTALL.md" ) vol_fusion = None if not obb_only: # fuse mesh vol_fusion = VolumetricFusion(output_dir, voxel_res=voxel_res, device=device) vol_fusion.run() fused_mesh = vol_fusion.get_trimesh() pred_mesh_ply = os.path.join(output_dir, "fused_mesh.ply") if fused_mesh.vertices.shape[0] > 0 and fused_mesh.faces.shape[0] > 0: fused_mesh.export(pred_mesh_ply) # eval mesh gt_mesh_ply = get_gt_mesh_ply(data_path) if os.path.exists(pred_mesh_ply) and os.path.exists(gt_mesh_ply): pred_trimesh = trimesh.load(pred_mesh_ply) gt_trimesh = trimesh.load(gt_mesh_ply) if "adt" in gt_mesh_ply: gt_trimesh = correct_adt_mesh_gravity(gt_trimesh) mesh_metrics, _, _ = eval_mesh_to_mesh( pred=pred_trimesh, gt=gt_trimesh, sample_num=1000, ) metrics.update(mesh_metrics) else: print("Skipping volume fusion (--obb_only)") # write metrics if len(metrics) > 0: with open(os.path.join(output_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2, sort_keys=True) print(json.dumps(metrics, indent=2, sort_keys=True)) # viz if not skip_video: streamer = create_streamer( data_path=data_path, snippet_length_s=1.0, stride_length_s=1.0, max_snip=math.ceil((max_snip - 1) * snip_stride), skip_snips=int(skip_snips * snip_stride), ) if vol_fusion is not None: vol_fusion.reinit() viz_path = generate_video( streamer, output_dir=output_dir, vol_fusion=vol_fusion, stride_s=snip_stride ) print(f"output viz file to {os.path.abspath(viz_path)}") # rm per-snippet occupancy tensors per_snip_dir = os.path.join(output_dir, "per_snip") if os.path.exists(per_snip_dir): shutil.rmtree(per_snip_dir) ================================================ FILE: efm3d/inference/track.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import torch from efm3d.utils.obb_csv_writer import ObbCsvReader, ObbCsvWriter from efm3d.utils.obb_trackers import ObbTracker logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) def track_obbs(input_path, prob_inst_thr=0.3, prob_assoc_thr=0.25): """ Run ObbTracker on input csv file. input_path: path to input folder or obbs csv file. if folder, will look for 'snippet_obbs.csv' as the input obbs csv file. prob_inst_thr: minimum probability threshold for instantiating a new world obb prob_assoc_thr: minimum probability threshold for associating a new obb with existing world obbs """ if not os.path.exists(input_path): logger.error(f"Input folder {input_path} does not exist") return if input_path.endswith(".csv"): obb_csv_path = input_path obb_folder = os.path.dirname(input_path) else: obb_csv_path = os.path.join(input_path, "snippet_obbs.csv") obb_folder = input_path assert os.path.exists(obb_csv_path), f"No obb csv file found {obb_csv_path}" tracked_obbs_path = os.path.join(obb_folder, "tracked_obbs.csv") reader = ObbCsvReader(obb_csv_path) writer = ObbCsvWriter(tracked_obbs_path) tracker = ObbTracker( track_best=False, track_running_average=True, max_assoc_dist=0.1, max_assoc_iou2=0.0, # disabled max_assoc_iou3=0.2, prob_inst_thr=prob_inst_thr, prob_assoc_thr=prob_assoc_thr, nms_iou3_thr=0.1, nms_iou2_thr=0.0, # disabled w_max=30, w_min=5, dt_max_inst=1.0, dt_max_occ=999999.0, # never delete ) # write snippet-level tracked obbs for t_ns, obbs in reader: tracked_obbs, unviz_obbs = tracker.track(obbs) # seq_obb_eval use both tracked and unviz obbs all_tracked_obbs = torch.cat([tracked_obbs, unviz_obbs], dim=-2) writer.write(all_tracked_obbs, t_ns, reader.sem_ids_to_names) # write scene-level tracked obbs tracked_scene_obbs_path = os.path.join(obb_folder, "tracked_scene_obbs.csv") scene_writer = ObbCsvWriter(tracked_scene_obbs_path) final_scene_obbs, unviz_obbs = tracker.obbs_world final_scene_obbs_all = torch.cat([final_scene_obbs, unviz_obbs], dim=-2) scene_writer.write(final_scene_obbs_all, -1, reader.sem_ids_to_names) logger.info(f"Wrote scene-level tracked obbs to {tracked_scene_obbs_path}") def main(): import argparse parser = argparse.ArgumentParser(description="Run Obb tracker on obbs csv file.") parser.add_argument( "--input", type=str, help="The input folder to look for the per-snippet obbs csv file", required=True, ) parser.add_argument( "--prob_inst_thr", type=float, default=0.3, help="minimum probability threshold for instantiating a new world obb", ) parser.add_argument( "--prob_assoc_thr", type=float, default=0.25, help="minimum probability threshold for associating a new obb with existing world obbs", ) args = parser.parse_args() track_obbs(input_path=args.input) if __name__ == "__main__": main() ================================================ FILE: efm3d/inference/viz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from bisect import bisect_left from typing import Optional import cv2 import numpy as np import torch import tqdm import trimesh from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_IMG_TIME_NS, ARIA_MESH_FACES, ARIA_MESH_VERT_NORMS_W, ARIA_MESH_VERTS_W, ARIA_OBB_PADDED, ARIA_OBB_PRED_VIZ, ARIA_OBB_TRACKED, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.aria.obb import ObbTW from efm3d.inference.fuse import VolumetricFusion from efm3d.utils.image import put_text, smart_resize, torch2cv2 from efm3d.utils.obb_csv_writer import ObbCsvReader from efm3d.utils.render import draw_obbs_snippet from efm3d.utils.viz import draw_snippet_scene_3d, SceneView VIZ_RGB = "RGB/GT" VIZ_SLAM = "SLAM" VIZ_PRED_OBB = "Snippet Prediction" VIZ_TRACKED_OBB = "Tracked Prediction" VIZ_GT_OBB = "Ground Truth" def find_nearest(array, value): """Find the index of the nearest value in an array.""" idx = bisect_left(array, value) if idx == len(array): return idx - 1 if idx == 0: return 0 before = array[idx - 1] after = array[idx] if after - value < value - before: return idx return idx - 1 def fill_obbs_to_snippet(obbs, rgb_ts, T_ws): obbs_out = [] obbs_ts = sorted(obbs.keys()) for ts in rgb_ts: if ts in obbs: obbs_out.append(obbs[ts].add_padding(128)) elif len(obbs_ts) == 0: obbs_out.append(ObbTW().add_padding(128)) else: # find the nearest timestamp within 1s nidx = find_nearest(obbs_ts, ts) if abs(obbs_ts[nidx] - ts) / 1e9 < 1: obbs_out.append(obbs[obbs_ts[nidx]].add_padding(128)) else: obbs_out.append(ObbTW().add_padding(128)) obbs_w = torch.stack(obbs_out, dim=0) obbs_s = obbs_w.transform(T_ws.inverse()) return obbs_s def compose_views(view_dict, keys, vertical=True): """stack snippet images into a single image, vertical or horizontal""" keys = [k for k in keys if k in view_dict] if len(keys) == 0: return None if len(keys) == 1: return view_dict[keys[0]] output_imgs = [] T = len(view_dict[keys[0]]) for i in range(T): img_list = [view_dict[key][i] for key in keys] axis = 0 if vertical else 1 combine_img = np.concatenate(img_list, axis=axis) output_imgs.append(combine_img) return output_imgs def draw_scene_with_mesh_and_obbs( snippet, w, h, scene, snip_obbs=None, tracked_obbs=None, gt_obbs=None, mesh=None, sem_ids_to_names=None, ): """ Draw 3d scene view of a snippet, with optionally obbs and mesh. """ # put pred obbs into the snippet rgb_ts = snippet[ARIA_IMG_TIME_NS[0]] rgb_ts = [ts.item() for ts in rgb_ts] T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET] if snip_obbs is not None: snippet[ARIA_OBB_PRED_VIZ] = fill_obbs_to_snippet(snip_obbs, rgb_ts, T_ws) if tracked_obbs is not None: snippet[ARIA_OBB_TRACKED] = fill_obbs_to_snippet(tracked_obbs, rgb_ts, T_ws) if gt_obbs is not None: snippet[ARIA_OBB_PADDED] = fill_obbs_to_snippet(gt_obbs, rgb_ts, T_ws) if mesh is not None and mesh.vertices.shape[0] > 0 and mesh.faces.shape[0] > 0: snippet[ARIA_MESH_VERTS_W] = torch.tensor(mesh.vertices) snippet[ARIA_MESH_FACES] = torch.tensor(mesh.faces) # normals for pred should be minus due to marching cube snippet[ARIA_MESH_VERT_NORMS_W] = -torch.tensor(mesh.vertex_normals) scene_imgs = draw_snippet_scene_3d( snippet, sem_ids_to_names=sem_ids_to_names, width=w, height=h, scene=scene ) return scene_imgs def render_views(snippet, h, w, pred_sem_ids_to_names, gt_sem_ids_to_names): Ts_sr = snippet[ARIA_IMG_T_SNIPPET_RIG[0]] cams = snippet[ARIA_CALIB[0]] T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET] Ts_wr = T_ws @ Ts_sr rgb_ts = snippet[ARIA_IMG_TIME_NS[0]] time_s = [f"{ts.item() * 1e-9:.02f}s" for ts in rgb_ts] imgs = {} # RGB and SLAM rgb_imgs = snippet[ARIA_IMG[0]].clone().numpy() rgb_imgs = [ torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False) for im in rgb_imgs ] imgs[VIZ_RGB] = rgb_imgs if ARIA_IMG[1] in snippet and ARIA_IMG[2] in snippet: slaml_imgs = snippet[ARIA_IMG[1]].clone().numpy() slamr_imgs = snippet[ARIA_IMG[2]].clone().numpy() slaml_imgs = [ torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False) for im in slaml_imgs ] slamr_imgs = [ torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False) for im in slamr_imgs ] imgs[VIZ_SLAM] = [] for iml, imr in zip(slaml_imgs, slamr_imgs): imgs[VIZ_SLAM].append(np.concatenate([iml, imr], axis=1)) if ARIA_OBB_PRED_VIZ in snippet: imgs[VIZ_PRED_OBB] = draw_obbs_snippet( snippet[ARIA_IMG[0]].clone(), snippet[ARIA_OBB_PRED_VIZ].transform(T_ws), Ts_wr, cams, rgb2bgr=False, draw_cosy=False, white_backing_line=False, draw_bb2=False, sem_id_to_name_mapping=pred_sem_ids_to_names, draw_label=True, draw_score=True, prob_threshold=0.001, # keep this very low, obbs are already thresholded. ) if ARIA_OBB_TRACKED in snippet: imgs[VIZ_TRACKED_OBB] = draw_obbs_snippet( snippet[ARIA_IMG[0]].clone(), snippet[ARIA_OBB_TRACKED].transform(T_ws), Ts_wr, cams, rgb2bgr=False, draw_cosy=False, white_backing_line=False, draw_bb2=False, sem_id_to_name_mapping=pred_sem_ids_to_names, draw_label=True, draw_score=True, prob_threshold=0.001, # keep this very low, obbs are already thresholded. ) if ARIA_OBB_PADDED in snippet: # if gt obb (VIZ_GT_OBB) is present, overlay it on top of the RGB view imgs[VIZ_RGB] = draw_obbs_snippet( snippet[ARIA_IMG[0]].clone(), snippet[ARIA_OBB_PADDED].transform(T_ws), Ts_wr, cams, rgb2bgr=False, draw_cosy=False, white_backing_line=False, draw_bb2=False, sem_id_to_name_mapping=gt_sem_ids_to_names, draw_label=True, draw_inst_id=True, draw_score=True, ) # add text to the images for text, grid_imgs in imgs.items(): for i, img in enumerate(grid_imgs): img = smart_resize(img, h, w, pad_image=True) img = put_text(img, text) imgs[text][i] = put_text(img, time_s[i], line=-1) return imgs def generate_video( streamer, output_dir, fps=10, vol_fusion: Optional[VolumetricFusion] = None, stride_s: float = 0.1, ): """ streamer: the data iterator, assuming input snippets are 1s at 10 FPS. output_dir: the output folder for the video, will also load obbs and per_snip artifacts from the same folder fps: the output video fps vol_fusion: A volumetric fusion class instance. If not None, will use it to show the incremental mesh, updated as 1s frame rate. """ # read snippet obbs snip_obbs_csv = os.path.join(output_dir, "snippet_obbs.csv") snip_obbs = None sem_ids_to_names = None if os.path.exists(snip_obbs_csv): snip_obb_reader = ObbCsvReader(snip_obbs_csv) snip_obbs = snip_obb_reader.obbs sem_ids_to_names = snip_obb_reader.sem_ids_to_names # read tracked obbs tracked_obbs_csv = os.path.join(output_dir, "tracked_obbs.csv") tracked_obbs = None if os.path.exists(tracked_obbs_csv): tracked_obb_reader = ObbCsvReader(tracked_obbs_csv) tracked_obbs = tracked_obb_reader.obbs # read GT obbs gt_obbs_csv = os.path.join(output_dir, "gt_obbs.csv") gt_obbs = None gt_sem_ids_to_names = None if os.path.exists(gt_obbs_csv): gt_obb_reader = ObbCsvReader(gt_obbs_csv) gt_obbs = gt_obb_reader.obbs gt_sem_ids_to_names = gt_obb_reader.sem_ids_to_names # read fused mesh fused_mesh = os.path.join(output_dir, "fused_mesh.ply") pred_mesh = None if os.path.exists(fused_mesh): pred_mesh = trimesh.load(fused_mesh) # write video fourcc = cv2.VideoWriter_fourcc(*"mp4v") output_path = os.path.join(output_dir, "video.mp4") # two columns for 2d views (RGB+SLAM, output), 1 column for 3d scene gW, gH = 360, 360 # 2d grid size sH = 2 * gH sW = sH W = sW + 2 * gW H = sH out = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) scene = SceneView(width=sW, height=sH) num_snip_per_s = int(1.0 / stride_s) for idx, snippet in tqdm.tqdm(enumerate(streamer), total=len(streamer)): # show incremental fusion if vol_fusion is given if vol_fusion is not None: for i in range(num_snip_per_s): vol_fusion.run_step(idx * num_snip_per_s + i) pred_mesh = vol_fusion.get_trimesh() scene_imgs = draw_scene_with_mesh_and_obbs( snippet, w=sW, h=sH, scene=scene, snip_obbs=snip_obbs, tracked_obbs=tracked_obbs, gt_obbs=gt_obbs, mesh=pred_mesh, sem_ids_to_names=sem_ids_to_names, ) view_imgs = render_views( snippet, gH, gW, pred_sem_ids_to_names=sem_ids_to_names, gt_sem_ids_to_names=gt_sem_ids_to_names, ) input_col = compose_views(view_imgs, [VIZ_RGB, VIZ_SLAM]) output_col = compose_views(view_imgs, [VIZ_PRED_OBB, VIZ_TRACKED_OBB]) for i, scene_img in enumerate(scene_imgs): final_img = np.zeros((H, W, 3), dtype=np.uint8) # black background h, w = input_col[i].shape[:2] final_img[:h, :w] = input_col[i] final_img[:sH, gW : gW + sW, :] = scene_img if output_col is not None: h, w = output_col[i].shape[:2] final_img[:h, gW + sW : gW + sW + w] = output_col[i] out.write(final_img[:, :, ::-1]) # convert rgb to bgr before writing out.release() return output_path ================================================ FILE: efm3d/model/__init__.py ================================================ ================================================ FILE: efm3d/model/cnn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import einops import torch import torch.nn as nn import torch.nn.functional as F def cnn_weight_initialization(modules): for m in modules: if isinstance(m, nn.Conv2d): nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias.data, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight.data, 1) nn.init.constant_(m.bias.data, 0) elif isinstance(m, nn.Linear): nn.init.kaiming_uniform_(m.weight.data) nn.init.constant_(m.bias.data, 0) class GELU(nn.Module): def forward(self, x): return F.gelu(x) class LayerNorm2d(nn.LayerNorm): """LayerNorm for channels of '2D' spatial NCHW tensors, taken from https://github.com/huggingface/pytorch-image-models/blob/d7b55a9429f3d56a991e604cbc2e9fdf1901612f/timm/models/layers/norm.py#L26 """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps, ).permute(0, 3, 1, 2) class UpsampleCNN(nn.Module): def __init__( self, input_dim: int = 3, first_hidden_dim: int = 32, final_dim: int = 1, upsample_power: int = 4, fix_hidden_dim: bool = True, ): """ Upsample a feature map by a given factor = 2^upsample_power Args: input_dim (int): number of input channels first_hidden_dim (int): the first hidden layer output dimension. If set to -1, we use the input dimension. final_dim (int): number of output channels upsample_power (int): 2^upsample_power is the factor of image resolution upsampling fix_hidden_dim (bool): if True, all layers have the same hidden dims. Otherwise, hidden dims are subsequently halved by 2x starting from first_hidden_dim """ super(UpsampleCNN, self).__init__() assert upsample_power <= 4, "only upsampling power <= 4 is supported" if fix_hidden_dim: # all layers have the same hidden dims c = [first_hidden_dim] * (upsample_power + 1) else: first_hidden_dim = first_hidden_dim if first_hidden_dim > 0 else input_dim assert first_hidden_dim // 2 ** (upsample_power) >= 1, ( f"first_hidden_dim must be at least {2 ** (upsample_power)}, but got {first_hidden_dim}." ) # subsequently halve the hidden dim by 2x c = [first_hidden_dim] + [ first_hidden_dim // 2 ** (i + 1) for i in range(upsample_power) ] self.conv1 = nn.Conv2d(input_dim, c[0], kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(c[0]) if upsample_power >= 1: self.conv1u = nn.Conv2d(c[0], c[1], kernel_size=3, stride=1, padding=1) self.bn1u = nn.BatchNorm2d(c[1]) if upsample_power >= 2: self.conv2u = nn.Conv2d(c[1], c[2], kernel_size=3, stride=1, padding=1) self.bn2u = nn.BatchNorm2d(c[2]) if upsample_power >= 3: self.conv3u = nn.Conv2d(c[2], c[3], kernel_size=3, stride=1, padding=1) self.bn3u = nn.BatchNorm2d(c[3]) if upsample_power >= 4: self.conv4u = nn.Conv2d(c[3], c[4], kernel_size=3, stride=1, padding=1) self.bn4u = nn.BatchNorm2d(c[4]) self.conv_final = nn.Conv2d( c[-1], final_dim, kernel_size=1, stride=1, padding=0 ) self.relu = nn.ReLU(inplace=True) self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.upsample_power = upsample_power cnn_weight_initialization(self.modules()) print(f"==> [UpsampleCNN]: intialized with hidden layers: {c}") def forward(self, x, force_hw=None): """ Inputs: x : torch.Tensor : Bx(T)xCxhxw tensor force_hw: (int, int) : tuple of ints of height and width to be forced upsampled to Returns: x : torch.Tensor: Upsampled to Bx(T)xCxHxW, where H = h*(upsample_power**2) and W = w*(upsample_power**2) """ ndim = x.ndim if ndim == 5: T = x.shape[1] x = einops.rearrange(x, "b t c h w -> (b t) c h w") x = self.relu(self.bn1(self.conv1(x))) if self.upsample_power >= 1: x = self.upsample(x) x = self.relu(self.bn1u(self.conv1u(x))) if self.upsample_power >= 2: x = self.upsample(x) x = self.relu(self.bn2u(self.conv2u(x))) if self.upsample_power >= 3: x = self.upsample(x) x = self.relu(self.bn3u(self.conv3u(x))) if self.upsample_power >= 4: x = self.upsample(x) x = self.relu(self.bn4u(self.conv4u(x))) # Force upsampling, useful for patch_size=14 ViTs for example. if force_hw is not None and ( x.shape[-2] != force_hw[0] or x.shape[-1] != force_hw[1] ): x = torch.nn.functional.interpolate(x, size=force_hw, mode="bilinear") x = self.conv_final(x) if ndim == 5: x = einops.rearrange(x, "(b t) c h w -> b t c h w", t=T) return x class LayerNorm3d(nn.LayerNorm): """LayerNorm for channels of '3D' spatial NCDHW tensors, taken from https://github.com/huggingface/pytorch-image-models/blob/d7b55a9429f3d56a991e604cbc2e9fdf1901612f/timm/models/layers/norm.py#L26 """ def __init__(self, num_channels, eps=1e-6, affine=True): super().__init__(num_channels, eps=eps, elementwise_affine=affine) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 4, 1), # NCHW self.normalized_shape, self.weight, self.bias, self.eps, ).permute(0, 4, 1, 2, 3) class UpConv3d(torch.nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.upsample = torch.nn.Upsample( scale_factor=2, mode="trilinear", align_corners=True ) self.cnn_up = torch.nn.Conv3d(dim_in, dim_out, 3, stride=1, padding=1) self.norm = torch.nn.BatchNorm3d(dim_out) cnn_weight_initialization(self.modules()) def forward(self, x_up): assert x_up.shape[1] == self.dim_in, f"{x_up.shape}, {self.dim_in}" x_up = self.upsample(x_up) return self.norm(self.cnn_up(x_up)) class FpnUpConv3d(torch.nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.dim_in = dim_in self.dim_out = dim_out self.upsample = torch.nn.Upsample( scale_factor=2, mode="trilinear", align_corners=True ) self.cnn_up = torch.nn.Conv3d(dim_in, dim_out, 3, stride=1, padding=1) self.cnn_lat = torch.nn.Conv3d(dim_out, dim_in, 1) self.norm = torch.nn.BatchNorm3d(dim_out) cnn_weight_initialization(self.modules()) def forward(self, x_up, x_lat): assert x_up.shape[1] == self.dim_in, f"{x_up.shape}, {self.dim_in}" assert x_lat.shape[1] == self.dim_out, f"{x_lat.shape}, {self.dim_out}" x_up = self.upsample(x_up) x_lat = self.cnn_lat(x_lat) return self.norm(self.cnn_up(x_up + x_lat)) class InvBottleNeck3d(torch.nn.Module): def __init__(self, dim_in, dim_out, stride: int = 1, expansion: float = 1.0): super().__init__() self.dim_hidden = int(math.floor(dim_in * expansion)) self.stride = stride self.dim_in = dim_in self.dim_out = dim_out self.relu = torch.nn.ReLU() self.norm = torch.nn.BatchNorm3d(self.dim_out) self.cnn1 = torch.nn.Conv3d(dim_in, self.dim_hidden, 1) self.cnn2 = torch.nn.Conv3d( self.dim_hidden, self.dim_hidden, 3, stride=stride, padding=1 ) self.cnn3 = torch.nn.Conv3d(self.dim_hidden, dim_out, 1) cnn_weight_initialization(self.modules()) def forward(self, x): y = self.relu(self.cnn1(x)) y = self.relu(self.cnn2(y)) y = self.cnn3(y) if self.stride != 1 or self.dim_in != self.dim_out: return self.norm(y) return self.norm(y + x) class InvResnetBlock3d(torch.nn.Module): def __init__( self, dim_in, dim_out, num_bottles, in_stride: int = 1, expansion: float = 1.0 ): super().__init__() self.inv_bottles = torch.nn.ModuleList( [InvBottleNeck3d(dim_in, dim_out, in_stride, expansion)] ) for _ in range(1, num_bottles): self.inv_bottles.append(InvBottleNeck3d(dim_out, dim_out, 1, expansion)) self.num_bottles = num_bottles def forward(self, x): for i in range(self.num_bottles): x = self.inv_bottles[i](x) return x class InvResnetFpn3d(torch.nn.Module): def __init__(self, dims, num_bottles, strides, expansions, freeze=False): super().__init__() assert len(dims) == len(num_bottles) + 1 assert len(dims) == len(strides) + 1 assert len(dims) == len(expansions) + 1 assert strides[0] == 1 assert all([s == 2 for s in strides[1:]]) self.block1 = InvResnetBlock3d( dims[0], dims[1], num_bottles[0], strides[0], expansions[0] ) self.block2 = InvResnetBlock3d( dims[1], dims[2], num_bottles[1], strides[1], expansions[1] ) self.block3 = InvResnetBlock3d( dims[2], dims[3], num_bottles[2], strides[2], expansions[2] ) self.block4 = InvResnetBlock3d( dims[3], dims[4], num_bottles[3], strides[3], expansions[3] ) self.fpn1 = FpnUpConv3d(dims[2], dims[1]) self.fpn2 = FpnUpConv3d(dims[3], dims[2]) self.fpn3 = FpnUpConv3d(dims[4], dims[3]) if freeze: for param in self.parameters(): param.requires_grad = False self.eval() def forward(self, x): x1 = self.block1(x) x2 = self.block2(x1) x3 = self.block3(x2) x = self.block4(x3) x = self.fpn3(x, x3) del x3 x = self.fpn2(x, x2) del x2 x = self.fpn1(x, x1) del x1 return x class VolumeCNN(nn.Module): """A 3d UNet structure with take in a `hidden_dims` vector (e.g. [c0, c1, c2, c3], c0 <= c1 <= c2 <= c3). It outputs a shared feature layer with ReLU and BN applied. The shape on the channel dimension looks like c0->c1->c2->c3->c2->c1->c0. """ def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False): super(VolumeCNN, self).__init__() self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool3d(kernel_size=2, stride=2) self.upsample = nn.Upsample( scale_factor=2, mode="trilinear", align_corners=True ) c0, c1, c2, c3 = tuple(hidden_dims) self.conv1 = conv3(c0, c1, kernel_size=3, stride=1, padding=1) self.conv2 = conv3(c1, c2, kernel_size=3, stride=1, padding=1) self.conv3 = conv3(c2, c3, kernel_size=3, stride=1, padding=1) self.conv2u = conv3(c2 + c3, c2, kernel_size=3, stride=1, padding=1) self.conv1u = conv3(c1 + c2, c1, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm3d(c1) self.bn2 = nn.BatchNorm3d(c2) self.bn3 = nn.BatchNorm3d(c3) self.bn2u = nn.BatchNorm3d(c2) self.bn1u = nn.BatchNorm3d(c1) cnn_weight_initialization(self.modules()) self.out_dim = c1 if freeze: for param in self.parameters(): param.requires_grad = False self.eval() def forward(self, x): # Simple U-Net like structure. conv1 = self.relu(self.bn1(self.conv1(x))) x = self.pool(conv1) conv2 = self.relu(self.bn2(self.conv2(x))) x = self.pool(conv2) x = self.relu(self.bn3(self.conv3(x))) x = self.upsample(x) x = torch.cat([x, conv2], dim=1) x = self.relu(self.bn2u(self.conv2u(x))) x = self.upsample(x) x = torch.cat([x, conv1], dim=1) x = self.relu(self.bn1u(self.conv1u(x))) return x class VolumeCNNHead(nn.Module): def __init__( self, input_dim, hidden_dim, final_dim, num_layers=2, name="", bias=None, freeze=False, ): super(VolumeCNNHead, self).__init__() self.num_layers = num_layers self.relu = nn.ReLU(inplace=True) assert num_layers in [2, 3, 4], f"num_layers {num_layers} must be 2, 3, or 4" # first conv layer is the same for all num_layers = {2,3,4} self.conv1 = torch.nn.Conv3d( input_dim, hidden_dim, kernel_size=3, stride=1, padding=1 ) self.bn1 = nn.BatchNorm3d(hidden_dim) if num_layers == 2: self.conv2 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1) elif num_layers == 3: self.conv2 = torch.nn.Conv3d( hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1 ) self.conv3 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1) self.bn2 = nn.BatchNorm3d(hidden_dim) elif num_layers == 4: self.conv2 = torch.nn.Conv3d( hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1 ) self.conv3 = torch.nn.Conv3d( hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1 ) self.conv4 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1) self.bn2 = nn.BatchNorm3d(hidden_dim) self.bn3 = nn.BatchNorm3d(hidden_dim) cnn_weight_initialization(self.modules()) model_msg = f"==> Init {num_layers}-layer 3DCNN with {hidden_dim} hidden dims and {final_dim} outputs" if name: model_msg += f", with name {name}" print(model_msg) if bias: print("overwriting bias to %f" % bias) self.conv2.bias.data.fill_(bias) if freeze: for param in self.parameters(): param.requires_grad = False self.eval() def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) if self.num_layers == 2: x = self.conv2(x) elif self.num_layers == 3: x = self.relu(self.bn2(self.conv2(x))) x = self.conv3(x) elif self.num_layers == 4: x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) x = self.conv4(x) return x class ResidualConvUnit3d(nn.Module): # From "Vision Transformers for Dense Prediction": https://arxiv.org/abs/2103.13413 # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py def __init__(self, features, kernel_size): super().__init__() assert kernel_size % 1 == 0, "Kernel size needs to be odd" padding = kernel_size // 2 self.conv = nn.Sequential( nn.Conv3d(features, features, kernel_size, padding=padding), nn.ReLU(True), nn.Conv3d(features, features, kernel_size, padding=padding), nn.ReLU(True), ) def forward(self, x): return self.conv(x) + x class FeatureFusionBlock3d(nn.Module): # Fro "Vision Transformers for Dense Prediction": https://arxiv.org/abs/2103.13413 # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py def __init__(self, features, kernel_size, with_skip=True): super().__init__() self.with_skip = with_skip if self.with_skip: self.resConfUnit1 = ResidualConvUnit3d(features, kernel_size) self.resConfUnit2 = ResidualConvUnit3d(features, kernel_size) def forward(self, x, skip_x=None): if skip_x is not None: assert self.with_skip and skip_x.shape == x.shape x = self.resConfUnit1(x) + skip_x x = self.resConfUnit2(x) return x class VolumeResnet(nn.Module): def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False): super(VolumeResnet, self).__init__() self.relu = nn.ReLU(inplace=True) self.pool = nn.MaxPool3d(kernel_size=2, stride=2) self.upsample = nn.Upsample( scale_factor=2, mode="trilinear", align_corners=True ) c0, c1, c2, c3 = tuple(hidden_dims) self.resconv1 = ResidualConvUnit3d(c0, kernel_size=3) self.conv1 = conv3(c0, c1, kernel_size=3, stride=1, padding=1) self.resconv2 = ResidualConvUnit3d(c1, kernel_size=3) self.conv2 = conv3(c1, c2, kernel_size=3, stride=1, padding=1) self.resconv3 = ResidualConvUnit3d(c2, kernel_size=3) self.conv3 = conv3(c2, c3, kernel_size=3, stride=1, padding=1) self.conv2u = conv3(c2 + c3, c2, kernel_size=3, stride=1, padding=1) self.conv1u = conv3(c1 + c2, c1, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm3d(c1) self.bn2 = nn.BatchNorm3d(c2) self.bn3 = nn.BatchNorm3d(c3) self.bn2u = nn.BatchNorm3d(c2) self.bn1u = nn.BatchNorm3d(c1) cnn_weight_initialization(self.modules()) self.out_dim = c1 if freeze: for param in self.parameters(): param.requires_grad = False self.eval() def forward(self, x): # Simple U-Net like structure. conv1 = self.relu(self.bn1(self.conv1(self.resconv1(x)))) x = self.pool(conv1) conv2 = self.relu(self.bn2(self.conv2(self.resconv2(x)))) x = self.pool(conv2) x = self.relu(self.bn3(self.conv3(self.resconv3(x)))) x = self.upsample(x) x = torch.cat([x, conv2], dim=1) x = self.relu(self.bn2u(self.conv2u(x))) x = self.upsample(x) x = torch.cat([x, conv1], dim=1) x = self.relu(self.bn1u(self.conv1u(x))) return x ================================================ FILE: efm3d/model/dinov2_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py # https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py import logging import math import os from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import torchvision.transforms as T from torch import Tensor from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm logger = logging.getLogger("dinov2") try: from xformers.ops import ( fmha, index_select_cat, memory_efficient_attention, scaled_index_add, unbind, ) XFORMERS_AVAILABLE = True except ImportError: logger.warning("xFormers not available") XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class CrossAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor: B, N, C = x_kv.shape kv = ( self.kv(x_kv) .reshape(B, N, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) B, N, C = x_q.shape q = ( self.q(x_q) .reshape(B, N, 1, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = q[0] * self.scale, kv[0], kv[1] attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffCrossAttention(CrossAttention): def forward(self, x_kv: Tensor, x_q: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: assert attn_bias is None, "xFormers is required for nested tensors usage" return super().forward(x_kv, x_q) B, N, C = x_kv.shape kv = self.kv(x_kv).reshape(B, N, 2, self.num_heads, C // self.num_heads) k, v = unbind(kv, 2) B, N, C = x_q.shape q = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: Tensor) -> Tensor: def attn_residual_func(x: Tensor) -> Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x) x = x + ffn_residual_func(x) return x class CrossBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() self.norm1_q = norm_layer(dim) self.norm1_kv = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor: def attn_residual_func(x_kv: Tensor, x_q: Tensor) -> Tensor: return self.ls1(self.attn(self.norm1_kv(x_kv), self.norm1_q(x_q))) def ffn_residual_func(x_q: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x_q))) if self.training and self.sample_drop_ratio > 0.0: x_q = x_q + self.drop_path1(attn_residual_func(x_kv, x_q)) x_q = x_q + self.drop_path2(ffn_residual_func(x_q)) else: x_q = x_q + attn_residual_func(x_kv, x_q) x_q = x_q + ffn_residual_func(x_q) return x_q def drop_add_residual_stochastic_depth( x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, ) -> Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add( x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor ) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add( x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor ) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor, ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = ( [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] ) all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( 1, -1, x_list[0].shape[-1] ) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[Tensor], residual_func: Callable[[Tensor, Any], Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list ] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip( x_list, branges, residual_list, residual_scale_factors ): outputs.append( add_residual( x, brange, residual, residual_scale_factor, scaling_vector ).view_as(x) ) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=( self.ls1.gamma if isinstance(self.ls1, LayerScale) else None ), ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=( self.ls2.gamma if isinstance(self.ls1, LayerScale) else None ), ) return x_list else: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): assert XFORMERS_AVAILABLE, ( "Please install xFormers for nested tensors usage" ) return self.forward_nested(x_or_x_list) else: raise AssertionError class DINOHead(nn.Module): def __init__( self, in_dim, out_dim, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256, mlp_bias=True, ): super().__init__() nlayers = max(nlayers, 1) self.mlp = _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias, ) self.apply(self._init_weights) self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) eps = 1e-6 if x.dtype == torch.float16 else 1e-12 x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) x = self.last_layer(x) return x def _build_mlp( nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True ): if nlayers == 1: return nn.Linear(in_dim, bottleneck_dim, bias=bias) else: layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) return nn.Sequential(*layers) def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size if H % patch_H > 0 or W % patch_W > 0: H_new = math.ceil(H / patch_H) * patch_H W_new = math.ceil(W / patch_W) * patch_W x = F.interpolate( x, size=(H_new, W_new), mode="bilinear", align_corners=False ) x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = ( Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) ) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) try: from xformers.ops import SwiGLU XFORMERS_AVAILABLE = True except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) def named_apply( fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = ( embed_dim # num_features for consistency with other models ) self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_tokens, embed_dim) ) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append( [nn.Identity()] * i + blocks_list[i : i + chunksize] ) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset sqrt_N = math.sqrt(N) sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute( 0, 3, 1, 2 ), scale_factor=(sx, sy), mode="bicubic", antialias=self.interpolate_antialias, ) assert int(w0) == patch_pos_embed.shape[-2] assert int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( previous_dtype ) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where( masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x ) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat( ( x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:], ), dim=1, ) return x def forward_features_list(self, x_list, masks_list): x = [ self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list) ] for blk in self.blocks: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) for blk in self.blocks: x = blk(x) x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } def forward_features_multi(self, x, masks=None): """ Extract multilayer features from the model for dense prediction. Fixing the number of layers to 4 following "Vision Transformers for Dense Prediction" https://arxiv.org/abs/2103.13413. """ if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) feats = [] num_layers = len(self.blocks) feat_layers = [ num_layers // 4 - 1, num_layers // 2 - 1, num_layers // 4 * 3 - 1, num_layers - 1, ] for i, blk in enumerate(self.blocks): x = blk(x) if i in feat_layers: feats.append(self.norm(x)) return { "x_norm_clstoken": [feat[:, 0] for feat in feats], "x_norm_regtokens": [ feat[:, 1 : self.num_register_tokens + 1] for feat in feats ], "x_norm_patchtokens": [ feat[:, self.num_register_tokens + 1 :] for feat in feats ], "masks": masks, } def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = ( range(total_block_len - n, total_block_len) if isinstance(n, int) else n ) for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), ( f"only {len(output)} / {len(blocks_to_take)} blocks found" ) return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = ( range(total_block_len - n, total_block_len) if isinstance(n, int) else n ) for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), ( f"only {len(output)} / {len(blocks_to_take)} blocks found" ) return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1) .permute(0, 3, 1, 2) .contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return self.head(ret["x_norm_clstoken"]) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, # block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=Attention), **kwargs, ) return model def vit_small_reg(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, # block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=Attention), num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) return model def vit_base(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), **kwargs, ) return model def vit_base_reg(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) return model def vit_large(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), **kwargs, ) return model def vit_large_reg(patch_size, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) return model def vit_giant2(patch_size, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), ffn_layer="swiglufused", **kwargs, ) return model dino_name_mappings = { "vit_small": { "weights": "dinov2_vits14_pretrain.pth", "feats": 384, "func": vit_small, }, "vit_base": { "weights": "dinov2_vitb14_pretrain.pth", "feats": 768, "func": vit_base, }, "vit_large": { "weights": "dinov2_vitl14_pretrain.pth", "feats": 1024, "func": vit_large, }, "vit_giant2": { "weights": "dinov2_vitg14_pretrain.pth", "feats": 1536, "func": vit_giant2, }, # v2.5 models "vit_small_v25": { "weights": "dinov2_vits14_reg4_pretrain.pth", "feats": 384, "func": vit_small_reg, }, "vit_base_v25": { "weights": "dinov2_vitb14_reg4_pretrain.pth", "feats": 768, "func": vit_base_reg, }, "vit_large_v25": { "weights": "dinov2_vitl14_reg4_pretrain.pth", "feats": 1024, "func": vit_large_reg, }, } class DinoV2Wrapper(torch.nn.Module): """ runs DinoV2 on input images """ def __init__( self, name: str = "vit_small", img_size: Optional[Union[int, Tuple[int, int]]] = None, multilayer_output: bool = False, ckpt_path: str = "", ): super().__init__() assert name in dino_name_mappings.keys(), ( f"Dino model name should be one of {dino_name_mappings.keys()}" ) assert os.path.exists(ckpt_path), f"Missing DinoV2 checkpoint path {ckpt_path}" print(f"Use the provided DinoV2 checkpoint path {ckpt_path}") # If no image size is provided, use the recommended image size from DinoV2 models. if img_size is None: img_size = 518 patch_size = 14 feat_dim = dino_name_mappings[name]["feats"] model_constructor = dino_name_mappings[name]["func"] # reference: https://github.com/facebookresearch/dinov2/blob/9a4564ce5ebfe66a37fd16c6a233fb04ffb0a752/dinov2/hub/backbones.py#L21 model = model_constructor( patch_size=patch_size, img_size=img_size, block_chunks=0, init_values=1.0, ) print(f"Contructed DinoV2 model {name}") checkpoint = torch.load(ckpt_path, weights_only=True) # Dino models should all be loaded with strict=True. model.load_state_dict(checkpoint, strict=True) normalize_fn = T.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) self.model = model self.feat_dim = feat_dim self.img_size = img_size self.patch_size = patch_size self.normalize_fn = normalize_fn if multilayer_output: self.forward_fn = self.model.forward_features_multi else: self.forward_fn = self.model.forward_features def forward(self, img): """ Input: img : torch.Tensor : batch of images shaped BxCxHxW of type float32 in range [0,1], where C can be 1 or 3, they get resized to a fixed size Returns: feats : torch.Tensor : shaped BxDxpHxpW where D is feature dim, pH & pW = img_size / patch_size, """ B, C, H, W = img.shape assert C in [1, 3], "must be either 1 or 3 channel input (BxCxHxW)" if C == 1: # Fake RGB by repeating gray channel. img = img.repeat(1, 3, 1, 1) img = self.normalize_fn(img) # Apply imagenet normalization. feats = self.forward_fn(img)["x_norm_patchtokens"] H, W = img.shape[-2:] assert H % self.patch_size == 0 and W % self.patch_size == 0, ( "Resize the images to a multiple of patch size" ) pH = H // self.patch_size pW = W // self.patch_size if isinstance(feats, List): feats = [f.reshape(-1, pH, pW, self.feat_dim) for f in feats] feats = [f.permute(0, 3, 1, 2) for f in feats] # BxHxWxD => BxDxHxW else: feats = feats.reshape(-1, pH, pW, self.feat_dim) feats = feats.permute(0, 3, 1, 2) # BxHxWxD => BxDxHxW return feats ================================================ FILE: efm3d/model/dpt.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import einops import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.functional import interpolate class ResidualConvUnit(nn.Module): # From "Vision Transformers for Dense Prediction": https://arxiv.org/abs/2103.13413 # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py def __init__(self, features, kernel_size): super().__init__() assert kernel_size % 1 == 0, "Kernel size needs to be odd" padding = kernel_size // 2 self.conv = nn.Sequential( nn.Conv2d(features, features, kernel_size, padding=padding), nn.ReLU(True), nn.Conv2d(features, features, kernel_size, padding=padding), nn.ReLU(True), ) def forward(self, x): return self.conv(x) + x class FeatureFusionBlock(nn.Module): # Fro "Vision Transformers for Dense Prediction": https://arxiv.org/abs/2103.13413 # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py def __init__(self, features, kernel_size, with_skip=True): super().__init__() self.with_skip = with_skip if self.with_skip: self.resConfUnit1 = ResidualConvUnit(features, kernel_size) self.resConfUnit2 = ResidualConvUnit(features, kernel_size) def forward(self, x, skip_x=None): if skip_x is not None: assert self.with_skip, "Must init with with_skip=True" assert skip_x.shape == x.shape, ( f"skip {skip_x.shape} and x {x.shape} shape mismatch" ) x = self.resConfUnit1(x) + skip_x x = self.resConfUnit2(x) return x class Interpolate(nn.Module): """ Interpolation module. https://github.com/isl-org/DPT/blob/main/dpt/blocks.py#L138 """ def __init__(self, scale_factor, mode, align_corners=False): """Init. Args: scale_factor (float): scaling mode (str): interpolation mode """ super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners def forward(self, x): """Forward pass. Args: x (tensor): input Returns: tensor: interpolated data """ x = self.interp( x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, ) return x class DPTOri(nn.Module): """ Implementation of DPT according to the paper description https://arxiv.org/pdf/2103.13413 """ def __init__(self, input_dim, hidden_dim=256, output_dim=256, depth=False): """ input_dim: dimension of the DinoV2 tokens (384/768/...) hidden_dim: dense feature dimension(=256, D^{hat} in the paper) in DPT output_dim: final output feature dimension """ super().__init__() self.depth = depth if self.depth: # DPT depth head https://github.com/isl-org/DPT/blob/main/dpt/models.py#L89 self.depth_head = nn.Sequential( nn.Conv2d( hidden_dim, hidden_dim // 2, kernel_size=3, stride=1, padding=1 ), Interpolate(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(hidden_dim // 2, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True), # require depth to be non-negative nn.Identity(), ) output_dim = output_dim - 1 # last dim is depth # 1x1 convs to map (H/p x W/p x D) -> (H/s x W/s x D^{hat}) self.conv_0 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0) self.conv_1 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0) self.conv_2 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0) self.conv_3 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0) # (strided) convs for upsampling (feat_0/1/2) and downsample (feat_3) # image - WxW, padding - P, kernel - FxF, stride - S # conv size - (W-F+2P) / S + 1 # transpose conv size - (H-1)*S+F-2P self.resample_conv0 = nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=4, padding=0 ) self.resample_conv1 = nn.ConvTranspose2d( hidden_dim, hidden_dim, 3, stride=2, padding=1 ) self.resample_conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1) self.resample_conv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1) # fusion blocks self.ref_0 = FeatureFusionBlock(hidden_dim, 3) self.ref_1 = FeatureFusionBlock(hidden_dim, 3) self.ref_2 = FeatureFusionBlock(hidden_dim, 3) self.ref_3 = FeatureFusionBlock(hidden_dim, 3, with_skip=False) # final upsample head self.conv_up1 = nn.Conv2d( hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1 ) self.conv_final = nn.Conv2d( hidden_dim, output_dim, kernel_size=1, stride=1, padding=0 ) self.relu = nn.ReLU(inplace=True) def forward(self, feats): """ feats: tokens from multi-layers, for ViT-base these are [3,6,9,12] (starting from 1, not 0) """ assert len(feats) == 4, ( "feats must be multi-level as a list of 4 tensors, probably set model.video_backbone.image_tokenizer.multilayer_output=True" ) ndim = feats[0].ndim if ndim == 5: T = feats[0].shape[1] feats = [einops.rearrange(f, "b t c h w -> (b t) c h w") for f in feats] # [T, D, H/p, W/p] feats[0] = self.conv_0(feats[0]) feats[1] = self.conv_1(feats[1]) feats[2] = self.conv_2(feats[2]) feats[3] = self.conv_3(feats[3]) # add single-side padding here after feat0 and feat1 to make upsampling 4x and 2x the token map size padding = (0, 1, 0, 1) # left, right, top, bottom feats[0] = self.resample_conv0(feats[0]) feats[0] = F.pad(feats[0], padding, mode="constant", value=0) feats[1] = self.resample_conv1(feats[1]) feats[1] = F.pad(feats[1], padding, mode="constant", value=0) feats[2] = self.resample_conv2(feats[2]) feats[3] = self.resample_conv3(feats[3]) out = self.ref_3(feats[3], None) out = interpolate( out, size=feats[2].shape[-2:], mode="bilinear", align_corners=True ) out = self.ref_2(feats[2], out) out = interpolate( out, size=feats[1].shape[-2:], mode="bilinear", align_corners=True ) out = self.ref_1(feats[1], out) out = interpolate( out, size=feats[0].shape[-2:], mode="bilinear", align_corners=True ) out = self.ref_0(feats[0], out) h, w = feats[0].shape[-2:] feat = interpolate( out, size=(h * 2, w * 2), mode="bilinear", align_corners=True ) # upsample by 2x (In the paper DPT outputs 1/2 original size feature maps) out = self.relu(self.conv_up1(feat)) h, w = out.shape[-2:] out = interpolate(out, size=(h * 2, w * 2), mode="bilinear", align_corners=True) out = self.conv_final(out) if self.depth: inv_depth = self.depth_head(feat) + 1e-3 # predict inv depth, add epsilon out = torch.cat([out, inv_depth], dim=1) if ndim == 5: out = einops.rearrange(out, "(b t) c h w -> b t c h w", t=T) return out ================================================ FILE: efm3d/model/evl.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from efm3d.aria.aria_constants import ( ARIA_OBB_PRED, ARIA_OBB_PRED_PROBS_FULL, ARIA_OBB_PRED_PROBS_FULL_VIZ, ARIA_OBB_PRED_SEM_ID_TO_NAME, ARIA_OBB_PRED_VIZ, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.model.cnn import InvResnetFpn3d, VolumeCNNHead from efm3d.model.lifter import VideoBackbone3d from efm3d.model.video_backbone import VideoBackbone from efm3d.utils.detection_utils import simple_nms3d, voxel2obb from efm3d.utils.file_utils import parse_global_name_to_id_csv from hydra.utils import instantiate from omegaconf import DictConfig class EVL(torch.nn.Module): def __init__( self, video_backbone: Union[VideoBackbone, DictConfig], video_backbone3d: Union[VideoBackbone3d, DictConfig], neck_hidden_dims: Optional[List] = None, head_hidden_dim: int = 128, head_layers: int = 2, taxonomy_file: Optional[str] = None, det_thresh: float = 0.2, yaw_max: float = 1.6, ): """ Args: video_backbone: 2D backbone to extract features from images. video_backbone3d: 3D backbone to lift 2d to 3d voxels. neck_hidden_dims: hidden dims of the 3D CNN neck. head_hidden_dim: hidden dim of the 3D CNN head. # obb params det_thresh: Detection threshold for NMS. yaw_max: Maximum yaw angle for object orientation. """ super().__init__() if neck_hidden_dims is None: neck_hidden_dims = [64, 128, 256] self.backbone2d = video_backbone self.backbone3d = video_backbone3d self.head_layers = head_layers if isinstance(video_backbone, DictConfig): self.backbone2d = instantiate(video_backbone) if isinstance(video_backbone3d, DictConfig): self.backbone3d = instantiate(video_backbone3d) backbone3d_out_dim = self.backbone3d.output_dim() # 3d U-Net c = backbone3d_out_dim # c = 66 (64 + 1 + 1) dims = [c, 64, 96, 128, 160] neck_final = dims[1] print(f"==> Init 3D InvResnetFpn3d neck with hidden layers: {dims}") self.neck = InvResnetFpn3d( dims=dims, num_bottles=[2, 2, 2, 2], strides=[1, 2, 2, 2], expansions=[2.0, 2.0, 2.0, 2.0], ) print( f"==> Init 3D CNN Head with final dim = {neck_final}, hidden dim = {head_hidden_dim}" ) # occpuancy head self.occ_head = VolumeCNNHead( neck_final, head_hidden_dim, final_dim=1, num_layers=self.head_layers, name="Occupancy", ) # obb part if taxonomy_file is not None: taxonomy = parse_global_name_to_id_csv(taxonomy_file) self.sem2name = {int(sem_id): name for name, sem_id in taxonomy.items()} self.num_class = len(self.sem2name) # Centerness head (center of the bounding box). self.cent_head = VolumeCNNHead( neck_final, head_hidden_dim, final_dim=1, name="Centerness", bias=-5, ) # Box size head (height, width, depth, offset_h, offset_w, offset_d, yaw rotation of box). self.bbox_head = VolumeCNNHead( neck_final, head_hidden_dim, final_dim=7, name="BoundingBox", ) self.clas_head = VolumeCNNHead( neck_final, head_hidden_dim, final_dim=self.num_class, name="Class", ) self.det_thresh = det_thresh self.bbox_min = 0.1 # Min bbox dim self.bbox_max = 6.0 # Max bbox dim # Scale the bbox offset max based on voxel size in meters. self.offset_max = 2 * self.backbone3d.voxel_meters self.splat_sigma = max(1, int(0.12 / self.backbone3d.voxel_meters)) self.iou_thres = 0.2 self.ve = self.backbone3d.voxel_extent # voxel extent self.yaw_max = yaw_max self.scene = None def post_process(self, batch, out): cent_pr = out["cent_pr"] bbox_pr = out["bbox_pr"] clas_pr = out["clas_pr"] # Run NMS + convert voxel outputs to ObbTW. with torch.no_grad(): # First NMS is a simple heatmap suppression. cent_pr_nms = simple_nms3d(cent_pr, nms_radius=self.splat_sigma + 1) vD, vH, vW = cent_pr.shape[-3:] # Convert dense predicitions to sparse ObbTW predictions. obbs_pr_nms, _, clas_prob_nms = voxel2obb( cent_pr_nms, bbox_pr, clas_pr, self.ve, top_k=128, thresh=self.det_thresh, return_full_prob=True, ) out["obbs_pr_nms"] = obbs_pr_nms out["cent_pr_nms"] = cent_pr_nms # obb tracker expects ARIA_OBB_PRED and ARIA_OBB_PRED_VIZ to be in snippet coord system obbs_pr_nms_s = obbs_pr_nms.clone() T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] # B x 1 x 12 T_wv = out["voxel/T_world_voxel"].unsqueeze(1) # B x 1 x 12 T_sv = T_ws.inverse() @ T_wv obbs_pr_nms_s = obbs_pr_nms_s.transform(T_sv) # transform to snippet coords out[ARIA_OBB_PRED_SEM_ID_TO_NAME] = self.sem2name out[ARIA_OBB_PRED] = obbs_pr_nms_s out[ARIA_OBB_PRED_VIZ] = obbs_pr_nms_s out[ARIA_OBB_PRED_PROBS_FULL] = [item for item in clas_prob_nms] out[ARIA_OBB_PRED_PROBS_FULL_VIZ] = out[ARIA_OBB_PRED_PROBS_FULL] return out def forward(self, batch, obb_only=False): out = {} # Run 2D backbone on images to get on 2D feature map per image. backbone2d_out_all = self.backbone2d(batch) for stream in ["rgb", "slaml", "slamr"]: if stream in backbone2d_out_all: # add to batch for lifter batch[f"{stream}/feat"] = backbone2d_out_all[stream] # Run explicit 3D backbone to lift 2D features to a 3D voxel grid. backbone3d_out = self.backbone3d(batch) voxel_feats = backbone3d_out["voxel/feat"] # Run 3D encoder-decoder CNN, acting as a "neck" to the heads. neck_feats1 = self.neck(voxel_feats) neck_feats2 = neck_feats1 # ---------- Run the occ head ------------ if not obb_only: occ_logits = self.occ_head(neck_feats1) occ_pr = torch.sigmoid(occ_logits) # logits => prob. out["occ_pr"] = occ_pr out["voxel_extent"] = torch.tensor(self.ve).to(neck_feats1) # ---------- Run the obb head ------------ # Run the centerness head. cent_logits = self.cent_head(neck_feats2) cent_pr = torch.sigmoid(cent_logits) # logits => prob. # Run the box size head. bbox_pr = self.bbox_head(neck_feats2) bbox_pr[:, 0:3] = (self.bbox_max - self.bbox_min) * torch.sigmoid( bbox_pr[:, :3] ) + self.bbox_min bbox_pr[:, 3:6] = self.offset_max * torch.tanh(bbox_pr[:, 3:6]) bbox_pr[:, 6] = self.yaw_max * torch.tanh(bbox_pr[:, 6]) # Run the classification head. clas_pr = self.clas_head(neck_feats2) clas_pr = torch.nn.functional.softmax(clas_pr, dim=1) out.update(backbone3d_out) out["neck/occ_feat"] = neck_feats1 out["neck/obb_feat"] = neck_feats2 # Copy data from head outputs. out["cent_pr"] = cent_pr out["bbox_pr"] = bbox_pr out["clas_pr"] = clas_pr out = self.post_process(batch, out) return out ================================================ FILE: efm3d/model/evl_train.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_OBB_PADDED, ARIA_OBB_PRED, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.aria.obb import ObbTW from efm3d.aria.pose import PoseTW from efm3d.model.evl import EVL from efm3d.model.lifter import VideoBackbone3d from efm3d.model.video_backbone import VideoBackbone from efm3d.utils.evl_loss import compute_obb_losses, compute_occ_losses, get_gt_obbs from efm3d.utils.image import put_text from efm3d.utils.marching_cubes import marching_cubes_scaled from efm3d.utils.obb_utils import prec_recall_bb3 from efm3d.utils.pointcloud import ( get_points_world, pointcloud_occupancy_samples, pointcloud_to_occupancy_snippet, ) from efm3d.utils.render import draw_obbs_snippet, get_colors_from_sem_map from efm3d.utils.viz import ( render_cosy, render_frustum, render_linestrip, render_obb_line, render_obbs_line, render_points, render_rgb_tri_mesh, render_scalar_field_points, render_tri_mesh, SceneView, ) from efm3d.utils.voxel import erode_voxel_mask from efm3d.utils.voxel_sampling import pc_to_vox, sample_voxels from omegaconf import DictConfig class EVLTrain(EVL): def __init__( self, video_backbone: Union[VideoBackbone, DictConfig], video_backbone3d: Union[VideoBackbone3d, DictConfig], neck_hidden_dims: Optional[List] = None, head_hidden_dim: int = 128, head_layers: int = 2, taxonomy_file: Optional[str] = None, det_thresh: float = 0.2, yaw_max: float = 1.6, ): super().__init__( video_backbone, video_backbone3d, neck_hidden_dims, head_hidden_dim, head_layers, taxonomy_file, det_thresh, yaw_max, ) def compute_losses(self, outputs, batch): total_loss = 0 losses = {"rgb": {}} self.occ_weight = 10.0 self.tv_weight = 0.01 occ_losses, occ_total_loss = compute_occ_losses( outputs, batch, self.ve, occ_weight=self.occ_weight, tv_weight=self.tv_weight, ) for k in losses: # for ['rgb', 'slaml', 'slamr'] losses[k].update(occ_losses[k]) total_loss += occ_total_loss self.cent_weight = 10.0 self.bbox_weight = 0.0 self.cham_weight = 0.0 self.clas_weight = 0.1 self.iou_weight = 0.5 obb_losses, obb_total_loss = compute_obb_losses( outputs, batch, self.ve, self.num_class, self.splat_sigma, cent_weight=self.cent_weight, clas_weight=self.clas_weight, iou_weight=self.iou_weight, bbox_weight=self.bbox_weight, cham_weight=self.cham_weight, ) for k in losses: # for ['rgb', 'slaml', 'slamr'] losses[k].update(obb_losses[k]) total_loss += obb_total_loss return losses, total_loss def render2d(self, imgs, obbs, Ts_wr, cams): """Render a 2D visualization overlaid on the RGB image of the given obbs.""" # Draw the 3D bb overlaid on the image. obb_img = draw_obbs_snippet( imgs.clone(), obbs, Ts_wr, cams, rgb2bgr=False, draw_cosy=True, white_backing_line=False, draw_bb2=False, sem_id_to_name_mapping=self.sem2name, draw_label=True, draw_score=True, prob_threshold=0.001, # keep this very low, obbs are already thresholded. ) return np.array(obb_img) def log_single_obb(self, batch, outputs, batch_idx): """Log a single element from the batch based on "batch_idx".""" log_ims = {} # Get stuff. rgb_img = batch[ARIA_IMG[0]][batch_idx].cpu().detach() T = rgb_img.shape[0] cams = batch[ARIA_CALIB[0]][batch_idx].cpu().detach() Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][batch_idx].cpu().detach() T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][batch_idx].cpu().detach() voxel_w = outputs["voxel/pts_world"][batch_idx].cpu().detach() T_wv = outputs["voxel/T_world_voxel"][batch_idx] obbs_gt = get_gt_obbs(batch, self.ve, T_wv) obbs_gt = obbs_gt[batch_idx].cpu() T_wv = T_wv.cpu().detach() cent_pr = outputs["cent_pr"][batch_idx].cpu().detach() obbs_pr = outputs["obbs_pr_nms"][batch_idx].cpu().detach() occ_input = None # Get some convenience transforms. Ts_wr = T_ws @ Ts_sr # Transform Objects to world coords. obbs_pr = obbs_pr.transform(T_wv) # T_wo = T_wv @ T_vo # Transform lifter volume obb to world coordinates. extent = torch.tensor(self.ve).to(T_wv._data) voxel_obb = ObbTW()[0] voxel_obb.set_bb3_object(extent, use_mask=False) voxel_obb.set_T_world_object(T_wv) occ_input = outputs["voxel/occ_input"][batch_idx].cpu().detach().reshape(-1) mask = occ_input > 1e-4 log_ims["voxel/occ_input"] = self.render3d_obb( occ_input[mask], obbs_pr, Ts_wr, T_ws, cams, voxel_w[mask], voxel_obb, view="follow", alpha_min=0.1, ) # compute precision and recall and add text to the pred 2d log_ims["rgb_pred"] = self.render2d(rgb_img, obbs_pr, Ts_wr, cams) if ARIA_OBB_PADDED in batch: obbs_pr_nms = outputs[ARIA_OBB_PRED][batch_idx].cpu() prec, rec, match_mat = prec_recall_bb3( obbs_pr_nms.remove_padding(), obbs_gt.remove_padding(), iou_thres=self.iou_thres, ) if match_mat is not None: num_tp = match_mat.any(-1).sum().item() num_pred = match_mat.shape[0] num_gt = match_mat.shape[1] precision = f"Prec@{self.iou_thres}: {prec:.2f} ({num_tp}/{num_pred})" recall = f"Recall@{self.iou_thres}: {rec:.2f} ({num_tp}/{num_gt})" else: precision = f"Prec@{self.iou_thres}: {prec:.2f}" recall = f"Recall@{self.iou_thres}: {rec:.2f}" imgs_pred = log_ims["rgb_pred"] imgs_pred = [put_text(img, precision, line=-2) for img in imgs_pred] imgs_pred = [put_text(img, recall, line=-1) for img in imgs_pred] log_ims["rgb_pred"] = np.array(imgs_pred) log_ims["3D_pred"] = self.render3d_obb( cent_pr, obbs_pr, Ts_wr, T_ws, cams, voxel_w, voxel_obb, view="follow", alpha_min=0.1, ) if "cent_gt" in outputs: self.compute_losses(outputs, batch) obbs_gt = obbs_gt[~obbs_gt.get_padding_mask()] obbs_gt = obbs_gt.transform(T_ws) # T_wo = T_ws @ T_so log_ims["rgb_gt"] = self.render2d(rgb_img, obbs_gt, Ts_wr, cams) cent_gt = outputs["cent_gt"][batch_idx].cpu().reshape(-1) log_ims["3D_gt"] = self.render3d_obb( cent_gt, obbs_gt, Ts_wr, T_ws, cams, voxel_w, voxel_obb, alpha_min=0.1, ) return log_ims def log_single(self, batch, outputs, batch_idx): """Log a single element from the batch based on "batch_idx".""" log_ims = self.log_single_obb(batch, outputs, batch_idx) cams = batch[ARIA_CALIB[0]][batch_idx].cpu().detach() Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][batch_idx].cpu().detach() T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][batch_idx].cpu().detach() voxel_w = outputs["voxel/pts_world"][batch_idx].cpu().detach() T_wv = outputs["voxel/T_world_voxel"][batch_idx].cpu().detach() Ts_wr = T_ws @ Ts_sr T = cams.shape[0] occ = outputs["occ_pr"].squeeze(1) voxel_counts = outputs["voxel/counts"][batch_idx].cpu().detach() B, D, H, W = occ.shape Df, Hf, Wf = voxel_counts.shape if D != Df or H != Hf or W != Wf: resize = torch.nn.Upsample(size=(D, H, W)) voxel_w = voxel_w.view(Df, Hf, Wf, 3).permute(3, 0, 1, 2) voxel_w = resize(voxel_w.unsqueeze(0)).squeeze(0) voxel_w = voxel_w.permute(1, 2, 3, 0).view(-1, 3) voxel_counts = resize(voxel_counts.unsqueeze(0).unsqueeze(0).float()) voxel_counts = voxel_counts.squeeze(0).squeeze(0) visible = voxel_counts > 0 visible = erode_voxel_mask(visible.unsqueeze(0)).squeeze(0) # Get some convenience transforms. Ts_wr = T_ws @ Ts_sr Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() # Transform lifter volume obb to world coordinates. extent = torch.tensor(self.ve).to(T_wv._data) voxel_obb = ObbTW()[0] voxel_obb.set_bb3_object(extent, use_mask=False) voxel_obb.set_T_world_object(T_wv) # -------------------- draw occ ----------------------- occ_pr = outputs["occ_pr"][batch_idx].cpu().detach().squeeze(0) alpha_min = 0.5 if self.occ_weight > 0.0 else 0.04 log_ims["occ/mesh_pred"] = self.render3d_mesh( occ_pr, Ts_wr, T_ws, cams, voxel_obb, view="follow", alpha_min=alpha_min, T_wv=T_wv, voxel_mask=visible, ) log_ims["occ/occ_pred"] = self.render3d_occ( occ_pr, Ts_wr, T_ws, cams, voxel_w, voxel_obb, view="follow", alpha_min=alpha_min, voxel_mask=visible, ) vD, vH, vW = occ_pr.shape pc_w = get_points_world(batch, batch_idx)[0].cpu().detach() ( p3s_occ_w, p3s_surf_w, p3s_free_w, valid, ) = pointcloud_occupancy_samples( pc_w.unsqueeze(0), Ts_wc.unsqueeze(0), cams.unsqueeze(0), vW, vH, vD, self.ve, S=1, T_wv=T_wv, ) p3s_occ_w[~valid] = float("nan") p3s_surf_w[~valid] = float("nan") p3s_free_w[~valid] = float("nan") log_ims["occ/occ_gt_samples"] = self.render3d_points( p3s_surf_w.squeeze(0), Ts_wr, T_ws, cams, voxel_obb, view="follow", more_p3s_w=p3s_free_w.squeeze(0), more2_p3s_w=p3s_occ_w.squeeze(0), ) # get occ gt occ_gt, mask = pointcloud_to_occupancy_snippet( pc_w, Ts_wc, cams, T_wv, vW, vH, vD, self.ve, S=1, ) mask = torch.logical_and(mask.bool(), visible) log_ims["occ/mesh_gt"] = self.render3d_mesh( occ_gt, Ts_wr, T_ws, cams, voxel_obb, view="follow", alpha_min=alpha_min, T_wv=T_wv, voxel_mask=mask, ) return log_ims @torch.no_grad() def render3d_mesh( self, voxel_vals, Ts_wr, T_ws, cams, voxel_obb, view="follow", alpha_min=0.5, T_wv=None, voxel_mask=None, volume_feat=None, ): """Render a 3D visualization of the given voxel values and obbs.""" if self.scene is None: self.scene = SceneView(width=320, height=320) scene = self.scene lifter_imgs = [] verts_v, faces, normals_v = marching_cubes_scaled( voxel_vals.cpu().detach().float(), alpha_min, self.ve, voxel_mask, ) feats = torch.tensor([]) if volume_feat is not None and len(verts_v) > 0: vD, vH, vW = voxel_vals.shape p3s_surf_vox, _ = pc_to_vox(verts_v, vW, vH, vD, self.ve) feats, _ = sample_voxels( volume_feat.unsqueeze(0), p3s_surf_vox.unsqueeze(0) ) feats = feats.squeeze(0).permute(1, 0) print("[WARN] No PCA compressor provided. Take the first 3 channels.") rgb = feats[:, :3] maxs = rgb.max(dim=-1, keepdim=True)[0] mins = rgb.min(dim=-1, keepdim=True)[0] rgb = (rgb - mins) / (maxs - mins + 1e-4) black = (0.0, 0.0, 0.0, 1.0) green = (0.0, 1.0, 0.0, 1.0) Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() for T_wr, T_wc, cam in zip(Ts_wr, Ts_wc, cams): scene.clear() if view == "follow": scene.set_follow_view(T_wc, zoom_factor=4) elif view == "bird": scene.set_birds_eye_view(T_wc, zoom_factor=8) else: raise ValueError("bad option for 3d view style") if len(verts_v) > 0: verts_w = T_wv * verts_v.to(T_wv.device) normals_w = T_wv.rotate(normals_v.to(T_wv.device)) if volume_feat is not None: render_rgb_tri_mesh( verts_w, -normals_w, faces, rgb=rgb, prog=scene.prog_mesh_rgb, ctx=scene.ctx, ) else: render_tri_mesh( verts_w, normals_w, faces, prog=scene.prog_mesh, ctx=scene.ctx, ) # draw voxel bounding volume render_obb_line( voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True ) # draw trajectory. render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black) render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black) # Render snippet origin. render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3) # Render world origin. render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) img = scene.finish() lifter_imgs.append(np.array(img)) lifter_imgs = np.array(lifter_imgs) if volume_feat is None: return lifter_imgs else: return lifter_imgs, feats, verts_v, faces, normals_v @torch.no_grad() def render3d_points( self, p3s_w, Ts_wr, T_ws, cams, voxel_obb, view="follow", values=None, alpha_min=0.01, mask=None, more_p3s_w=None, more2_p3s_w=None, ): """Render a 3D visualization of the given voxel values and obbs.""" if self.scene is None: self.scene = SceneView(width=320, height=320) scene = self.scene lifter_imgs = [] black = (0.0, 0.0, 0.0, 1.0) Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() for t, (T_wr, T_wc, cam) in enumerate(zip(Ts_wr, Ts_wc, cams)): scene.clear() if view == "follow": scene.set_follow_view(T_wc, zoom_factor=8) elif view == "bird": scene.set_birds_eye_view(T_wc, zoom_factor=12) else: raise ValueError("bad option for 3d view style") if values is not None: alphas = torch.ones_like(values[t]) if alpha_min is not None: alphas[values[t] < alpha_min] = 0 else: alpha_min = 0.0 if mask is not None: p3s_wt = p3s_w[t][mask[t]] if values is not None: values_t = values[t][mask[t]] alphas_t = alphas[mask[t]] else: p3s_wt = p3s_w[t] if values is not None: values_t = values[t] alphas_t = alphas if values is not None: render_scalar_field_points( p3s_wt, values_t.float(), prog=scene.prog_scalar_field, ctx=scene.ctx, point_size=1.0, alphas=alphas_t, val_min=alpha_min, ) else: render_points(p3s_wt, (1.0, 0, 0, 1.0), scene.prog, scene.ctx, 1.0) if more_p3s_w is not None: render_points( more_p3s_w[t], (0.0, 1.0, 0, 0.5), scene.prog, scene.ctx, 1.0 ) if more2_p3s_w is not None: render_points( more2_p3s_w[t], (0.0, 0.0, 1.0, 1.0), scene.prog, scene.ctx, 1.0 ) # draw voxel bounding volume render_obb_line( voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True ) # draw trajectory. render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black) render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black) # Render snippet origin. render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3) # Render world origin. render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) img = scene.finish() lifter_imgs.append(np.array(img)) lifter_imgs = np.array(lifter_imgs) return lifter_imgs @torch.no_grad() def render3d_obb( self, voxel_vals, obb, Ts_wr, T_ws, cams, voxel_w, voxel_obb, view="follow", alpha_min=None, ): """Render a 3D visualization of the given voxel values and obbs.""" if self.scene is None: self.scene = SceneView(width=320, height=320) scene = self.scene lifter_imgs = [] alphas = torch.ones_like(voxel_vals) if alpha_min is not None: alphas[voxel_vals < alpha_min] = 0 blue = (0.1, 0.1, 1.0, 1.0) black = (0.0, 0.0, 0.0, 1.0) Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() for T_wr, T_wc, cam in zip(Ts_wr, Ts_wc, cams): scene.clear() if view == "follow": scene.set_follow_view(T_wc, zoom_factor=8) elif view == "bird": scene.set_birds_eye_view(T_wc, zoom_factor=8) else: raise ValueError("bad option for 3d view style") # draw obbs if obb: colors = get_colors_from_sem_map(self.sem2name, scale_to_255=False) render_obbs_line( obb, scene.prog, scene.ctx, line_width=2, colors=colors, draw_cosy=True, ) # draw voxel bounding volume render_obb_line( voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True ) # draw trajectory. render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black) render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black) # Render snippet origin. render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3) # Render world origin. render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) # "scalar_field_points" supports colored point cloud, will rescale based on min/max. render_scalar_field_points( voxel_w, voxel_vals, prog=scene.prog_scalar_field, ctx=scene.ctx, point_size=3, alphas=alphas, ) img = scene.finish() lifter_imgs.append(np.array(img)) lifter_imgs = np.array(lifter_imgs) return lifter_imgs @torch.no_grad() def render3d_occ( self, voxel_vals, Ts_wr, T_ws, cams, voxel_w, voxel_obb, view="follow", alpha_min=None, voxel_mask=None, ): """Render a 3D visualization of the given voxel values and obbs.""" if self.scene is None: self.scene = SceneView(width=320, height=320) scene = self.scene lifter_imgs = [] black = (0.0, 0.0, 0.0, 1.0) Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() for t, (T_wr, T_wc, cam) in enumerate(zip(Ts_wr, Ts_wc, cams)): if voxel_vals.ndim == 4: v_vals = voxel_vals[t] else: v_vals = voxel_vals alp = torch.ones_like(v_vals) if alpha_min is not None: if isinstance(alpha_min, torch.Tensor): alp_min = alpha_min[t] else: alp_min = alpha_min alp[v_vals < alp_min] = 0 else: alp_min = 0.0 scene.clear() if view == "follow": scene.set_follow_view(T_wc, zoom_factor=4) elif view == "bird": scene.set_birds_eye_view(T_wc, zoom_factor=8) else: raise ValueError("bad option for 3d view style") # "scalar_field_points" supports colored point cloud, will rescale based on min/max. if voxel_mask is not None: render_scalar_field_points( voxel_w[voxel_mask.view(-1)], v_vals[voxel_mask].float(), prog=scene.prog_scalar_field, ctx=scene.ctx, point_size=3, alphas=alp[voxel_mask].float(), val_min=alp_min, ) else: render_scalar_field_points( voxel_w, v_vals.float(), prog=scene.prog_scalar_field, ctx=scene.ctx, point_size=3, alphas=alp.float(), val_min=alp_min, ) # draw voxel bounding volume render_obb_line( voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True ) # draw trajectory. render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black) render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black) # Render snippet origin. render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3) # Render world origin. render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) img = scene.finish() lifter_imgs.append(np.array(img)) lifter_imgs = np.array(lifter_imgs) return lifter_imgs def get_log_images(self, batch, outputs): B = len(batch["rgb/img"]) with torch.no_grad(): # Visualize one random element from the batch. batch_idx = torch.randint(low=0, high=B, size=(1,)).item() log_ims = self.log_single(batch, outputs, batch_idx=batch_idx) return log_ims def reset_metrics(self): self.metrics = {} # obb self.metrics[f"precision@{self.iou_thres}"] = [] self.metrics[f"recall@{self.iou_thres}"] = [] # occ self.metrics["mesh/acc"] = [] self.metrics["mesh/comp"] = [] self.metrics["mesh/prec"] = [] self.metrics["mesh/recall"] = [] def update_metrics(self, outputs, batch): # don't compute metrics on training since it takes long to compute. if self.training: return obbs_pred = outputs[ARIA_OBB_PRED] T_wv = outputs["voxel/T_world_voxel"] obbs_gt = get_gt_obbs(batch, self.ve, T_wv) precs, recs = [], [] for obbs_pred_s, obbs_gt_s in zip(obbs_pred, obbs_gt): prec, rec, _ = prec_recall_bb3( obbs_pred_s.remove_padding(), obbs_gt_s.remove_padding(), iou_thres=self.iou_thres, ) if prec != -1.0 and rec != -1.0: precs.append(prec) recs.append(rec) self.metrics[f"precision@{self.iou_thres}"].extend(precs) self.metrics[f"recall@{self.iou_thres}"].extend(recs) def compute_metrics(self): metrics = {} if self.training: return metrics metrics["rgb"] = {} metrics["rgb"]["metrics"] = {} for key in self.metrics: val = torch.tensor(self.metrics[key]).mean() metrics["rgb"]["metrics"][key] = val return metrics ================================================ FILE: efm3d/model/image_tokenizer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math from typing import List import einops import torch import torch.nn.functional as F from efm3d.model.dinov2_utils import dino_name_mappings, DinoV2Wrapper logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class ImageToDinoV2Tokens(torch.nn.Module): """ Tokenize an image snippet using DinoV2. """ def __init__( self, dinov2_name: str = "vit_small", freeze: bool = False, handle_rotated_data: bool = True, dim_out: int = 768, # ignored if add_linear_layer = False add_lin_layer: bool = False, # add a linear layer to get to any output dim out_patch_size: int = 14, # 14 is default but can set to 16 to get resampled into a more compatible feature size multilayer_output: bool = False, # if True, return a list of features ckpt_path: str = "", # if not empty, load the pretrained weights from the given path ): super().__init__() assert dinov2_name in dino_name_mappings.keys() self.freeze = freeze self.handle_rotated_data = handle_rotated_data self.model = DinoV2Wrapper( dinov2_name, multilayer_output=multilayer_output, ckpt_path=ckpt_path ) if self.freeze: for param in self.model.parameters(): param.requires_grad_(False) self.model.eval() self.lin = None if not add_lin_layer: assert dim_out == self.model.feat_dim, ( f"dim_out must match DinoV2 feature dim if not adding linear layer, but get dim_out: {dim_out} and feat_dim: {self.model.feat_dim}." ) else: self.lin = torch.nn.Linear(self.model.feat_dim, dim_out) print( f"Add linear layer to project features from {self.model.feat_dim} to {dim_out}" ) self.dim_out = dim_out logger.info( f"DinoV2 InputTokenizer {dinov2_name}, is frozen {freeze}, dim_out of {self.dim_out}" ) self.out_patch_size = out_patch_size def feat_dim(self): return self.dim_out def patch_size(self): return self.out_patch_size def post_process(self, feats, B, T, out_size=None): """ Post processing to convert Dino features, e.g. feature interpolation to the desired size, handling Aria image rotation, the linear mapping to increase feature dimension. Args: feats: [B x T x C x H x W] B: batch size T: number of frames out_size: (h, w) token feature map output size, if None, don't resize the feature map size. """ if out_size is not None: # resize to desired size feats = F.interpolate(feats, out_size, mode="bilinear") if self.handle_rotated_data: feats = torch.rot90(feats, 1, [-2, -1]) # to token sequence BxNxC feats = einops.rearrange(feats, "(b t) c h w -> b t h w c", b=B, t=T) if self.lin is not None: # increase feature dimension to desired output dimension feats = self.lin(feats) return feats def forward_resize(self, img: torch.Tensor) -> torch.Tensor: """ Return the round-up image size to match a multiple of patch size, which will be used as the input size to the DinoV2 model. Args: img: [..., H, W] image tensor """ H_ori, W_ori = img.shape[-2:] # Dino models have a fixed patch size of 14 H_new = math.ceil(H_ori / 14) * 14 W_new = math.ceil(W_ori / 14) * 14 if H_new != H_ori or W_new != W_ori: img = F.interpolate( img, size=(H_new, W_new), mode="bilinear", align_corners=False ) return img def forward(self, img: torch.Tensor) -> torch.Tensor: """ Args: img: [B x T x C x H x W] A sequence / snippet of Image Frames (typically used for Pose Regression) """ assert img.dim() == 5, f"expecting BxTxCxHxW but got {img.shape}" B, T, C, H, W = img.shape if self.handle_rotated_data: # rotate image 90 degrees clockwise to give it expected upright # orientation for pretrained resnet img = torch.rot90(img, 1, [-1, -2]) # get batch image for resnet img = einops.rearrange(img, "b t c h w -> (b t) c h w") H_ori, W_ori = img.shape[-2:] img = self.forward_resize(img) feats = self.model.forward(img) out_size = None # if output_patch_size is not 14, then we need to resize the feature map to the desired size if self.patch_size() != 14: out_size = H_ori // self.patch_size(), W_ori // self.patch_size() if ( out_size[0] * self.patch_size() != H_ori or out_size[1] * self.patch_size() != W_ori ): logger.warning( f"Image size {(H_ori, W_ori)} not divisible by output patch size {self.patch_size()}" ) if isinstance(feats, List): feats = [self.post_process(f, B, T, out_size) for f in feats] else: feats = self.post_process(feats, B, T, out_size) return feats ================================================ FILE: efm3d/model/lifter.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import math from abc import ABC from typing import List, Literal, Optional import numpy as np import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_POINTS_WORLD, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.model.cnn import UpsampleCNN from efm3d.model.dpt import DPTOri from efm3d.utils.gravity import gravity_align_T_world_cam, GRAVITY_DIRECTION_VIO from efm3d.utils.image_sampling import sample_images from efm3d.utils.pointcloud import pointcloud_to_voxel_counts from efm3d.utils.ray import sample_depths_in_grid, transform_rays from efm3d.utils.voxel import create_voxel_grid from torch.nn import functional as F logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class VideoBackbone3d(torch.nn.Module, ABC): """ Abstract Video Backbone that creates an explicit 3D feature volume from a video stream. """ def __init__( self, feat_dim: int, ): """ Args: feat_dim: number of channels in voxel grid, the C in BxCxDxHxW """ super().__init__() self._feat_dim = feat_dim @property def feat_dim(self): return self._feat_dim def forward_impl(self, batch): pass def forward(self, batch): out = {} assert "rgb/feat" in batch, "must run 2d backbone to get rgb feature maps first" out.update(self.forward_impl(batch)) # Shaped B x C x D x H x W assert "voxel/feat" in out, "3d backbone must output voxel features" # Shaped B x N x W (where N=D*H*W) assert "voxel/pts_world" in out, "3d backbone must output voxel positions" # Shaped B x 12 (PoseTW object) assert "voxel/T_world_voxel" in out, "3d backbone must output voxel coord frame" return out class Lifter(VideoBackbone3d): """ Abstract Video Backbone that creates an explicit 3D feature volume from a set of 2D features. """ def __init__( self, in_dim: int, out_dim: int, patch_size: int, voxel_size: List[float], voxel_extent: List[float], head_type: Literal["none", "dpt_ori", "cnn"] = "cnn", streams: Optional[List[str]] = None, # default is just rgb stream joint_slam_streams: bool = False, joint_streams: bool = False, # joint all streams ): """ Args: in_dim: input feature dimension (the 2d image or feature image channel dim) out_dim: output feature dimension (in 3d volume - FPN in 2D is used to get to that dim) patch_size: size of the patch to use for upsampling voxel_size: size of the voxel grid (D H W) voxel_extent: extent of the voxel grid (x_min, x_max y_min, y_max, z_min, z_max) streams: list of streams to use for the 2D features ("rgb", "slaml", "slamr"). Lifting gets run per stream (unless joint_slam_streams is True) and then concatenated in 3d. joint_slam_streams: if True, use the slaml and slamr streams as a single stream (i.e. dont concatenate lifted volumes from the two slam streams - lift them as if they were one camera) """ super().__init__(in_dim) self.streams = streams if streams is None: self.streams = ["rgb"] # default is just rgb stream self.stream2id = {"rgb": 0, "slaml": 1, "slamr": 2} # feature map upsampling network final_dim = out_dim if head_type == "none": self.head = None self.out_dim = in_dim elif head_type == "cnn": assert patch_size > 0, f"{patch_size} should be > 0 for UpsampleCNN" upsample_power = np.sqrt(patch_size) logger.info("True upsample_power: %f" % upsample_power) upsample_power = int(round(upsample_power)) logger.info("Rounded upsample_power: %d" % upsample_power) self.head = UpsampleCNN( input_dim=in_dim, first_hidden_dim=-1, final_dim=final_dim, upsample_power=upsample_power, fix_hidden_dim=False, ) self.out_dim = out_dim elif head_type == "dpt_ori": self.head = DPTOri( input_dim=in_dim, output_dim=final_dim, depth=False, ) self.out_dim = out_dim else: raise ValueError(f"{head_type} is not supported") self.voxel_size = voxel_size # D x H x W self.voxel_extent = list(voxel_extent) # W x H x D self.joint_streams = joint_streams self.joint_slam_streams = ( joint_slam_streams and "slaml" in self.streams and "slamr" in self.streams ) x_meters = (voxel_extent[1] - voxel_extent[0]) / self.voxel_size[2] y_meters = (voxel_extent[3] - voxel_extent[2]) / self.voxel_size[1] z_meters = (voxel_extent[5] - voxel_extent[4]) / self.voxel_size[0] assert abs(x_meters - y_meters) < 1e-5 and abs(x_meters - z_meters) < 1e-5, ( f"Voxels should be cubes {x_meters}x{y_meters}x{z_meters}" ) self.voxel_meters = x_meters self.num_free_samples = 16 def output_dim(self): num_streams = len(self.streams) if self.joint_slam_streams: num_streams -= 1 if self.joint_streams: num_streams = 1 out_dim = 0 out_dim = self.out_dim * num_streams out_dim += 1 # point mask out_dim += 1 # freespace token return out_dim def get_freespace_world(self, batch, batch_idx, T_wv, vW, vH, vD, S=1): """ Get points (semi-dense or GT points) of a snippet in the batch. """ cams = batch[ARIA_CALIB[0]][batch_idx] T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][ batch_idx ] # T_world_rig (one per snippet) Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][ batch_idx ] # Ts_snippet_rig (T per snippet) Ts_wr = T_ws @ Ts_sr Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() # Ts_world_cam # compute rays and max depths p_w = batch[ARIA_POINTS_WORLD][batch_idx] # TxNx3 T, N = p_w.shape[:2] p0_w = Ts_wc.t.unsqueeze(1) # Tx1x3 diff_w = p_w - p0_w ds = torch.norm(diff_w, 2.0, dim=-1) dir_w = F.normalize(diff_w, 2.0, dim=-1) # filter out nans good = ~p_w.isnan().any(dim=-1) p0_w = p0_w.repeat(1, N, 1)[good] ds = ds[good] dir_w = dir_w[good] rays_w = torch.cat([p0_w, dir_w], dim=-1) rays_v = transform_rays(rays_w, T_wv.inverse()) x_min, x_max, y_min, y_max, z_min, z_max = self.voxel_extent dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD diag = math.sqrt(dW**2 + dH**2 + dD**2) # subtract diagonal of voxel size to not label the occupied voxel as free ds = ds - diag # sample depths that lie within the feature volume grid (same function as used for nerf3d!) depths, _, _ = sample_depths_in_grid( rays_v.view(1, 1, -1, 6), ds.view(1, 1, -1), self.voxel_extent, vW, vH, vD, S, ) depths = depths.view(-1, S) rays_v = rays_v.view(-1, 1, 6) pts_v = rays_v[..., :3] + depths.unsqueeze(-1) * rays_v[..., 3:] pts_v = pts_v.view(-1, 3) return T_wv * pts_v def get_points_world(self, batch, batch_idx, keep_T=False): """ Get points (semi-dense or GT points) of a snippet in the batch. """ def filter_points(p_w): p_w = p_w.reshape(-1, 3) # filter out nans bad = p_w.isnan().any(dim=-1) p_w = p_w[~bad] # filter out duplicates from the collapsing of the time dimension p_w = torch.unique(p_w, dim=0) return p_w p_w_Ts = [] p_w = batch[ARIA_POINTS_WORLD][batch_idx] if not keep_T: p_w = filter_points(p_w) else: T = p_w.shape[0] for t in range(T): p_w_t = p_w[t, ...] p_w_t = filter_points(p_w_t) p_w_Ts.append(p_w_t) if keep_T: return p_w_Ts else: return p_w def get_freespace_counts( self, batch, T_wv, vW, vH, vD, MAX_NUM_POINTS_VOXEL=50, return_mask=False, ): """ Get points as voxel grid where each voxel is assigned a count of how many points are inside it. If return_mask is trued the function returns the binary occupancy instead of point counts. """ B, T, _, H, W = batch[ARIA_IMG[0]].shape point_counts = [] for b in range(B): p_w = self.get_freespace_world( batch, b, T_wv[b], vW, vH, vD, self.num_free_samples ) # transform points into voxel coordinate. p_v = T_wv[b].inverse() * p_w point_count = pointcloud_to_voxel_counts(p_v, self.voxel_extent, vW, vH, vD) point_counts.append(point_count) point_counts = torch.stack(point_counts, dim=0) # B x 1 x vD, vH, vW # Normalize point_counts = ( point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL ) if return_mask: # Only use as a mask. Comment out if want to use real point counts. point_counts[point_counts > 1e-4] = 1.0 return point_counts def get_points_counts( self, batch, T_wv, vW, vH, vD, MAX_NUM_POINTS_VOXEL=50, return_mask=False, keep_T=False, ): """ Get points as voxel grid where each voxel is assigned a count of how many points are inside it. If return_mask is trued the function returns the binary occupancy instead of point counts. """ B, T, _, H, W = batch[ARIA_IMG[0]].shape point_counts = [] for b in range(B): p_w = self.get_points_world(batch, b, keep_T) if not keep_T: assert isinstance(p_w, torch.Tensor) # transform points into voxel coordinate. p_v = T_wv[b].inverse() * p_w point_count = pointcloud_to_voxel_counts( p_v, self.voxel_extent, vW, vH, vD ) else: assert isinstance(p_w, list) point_count = [] for p_w_t in p_w: p_v_t = T_wv[b].inverse() * p_w_t point_count_t = pointcloud_to_voxel_counts( p_v_t, self.voxel_extent, vW, vH, vD ) point_count.append(point_count_t) point_count = torch.cat(point_count, dim=0) point_counts.append(point_count) point_counts = torch.stack(point_counts, dim=0) # B x 1 x vD, vH, vW # Normalize point_counts = ( point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL ) if return_mask: # Only use as a mask. Comment out if want to use real point counts. point_counts[point_counts > 1e-4] = 1.0 return point_counts def get_voxelgrid_pose(self, cams, T_ws, Ts_sr): B, T = cams.shape[:2] Ts_wr = T_ws @ Ts_sr Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() # Ts_world_cam # Select last frame in snippet. selectT = torch.tensor(T - 1).repeat(B).long() # Create the voxel grid by aligning selected frame with gravity. T_wc_select = Ts_wc[torch.arange(B), selectT, :] T_wv = gravity_align_T_world_cam( T_wc_select, gravity_w=GRAVITY_DIRECTION_VIO, z_grav=True ) # T_wv should only have yaw value rpy = T_wv.to_euler() assert torch.allclose(torch.tensor(0.0), rpy[:, :2], atol=1e-4) return T_wv, selectT def lift(self, feats2d, vox_w, cam, Ts_wr, vD, vH, vW): B, T = cam.shape[:2] F = feats2d.shape[2] Ts_wc = Ts_wr @ cam.T_camera_rig.inverse() # Ts_world_cam vox_w = torch.flatten(vox_w, 0, 1) cam = torch.flatten(cam, 0, 1) Ts_wc = torch.flatten(Ts_wc, 0, 1) feats2d = torch.flatten(feats2d, 0, 1) vox_cam = Ts_wc.inverse() * vox_w vox_feats, vox_valid = sample_images( feats2d, vox_cam, cam, n_by_c=False, warn=False, single_channel_mask=True ) vox_feats = vox_feats.reshape(B, T, F, vD, vH, vW) vox_valid = vox_valid.reshape(B, T, 1, vD, vH, vW) return vox_feats, vox_valid def aggregate(self, vox_feats, vox_valid): def basic_mean(x, dim, valid, keepdim=False): count = torch.sum(valid, dim=dim, keepdim=True) # B 1 C D H W invalid = (~valid).expand_as(x) x[invalid] = 0.0 x_sum = torch.sum(x, dim=dim, keepdim=True) count[count == 0] = 1.0 # just so we dont divide by zero mean = x_sum / count del x_sum mean[count.expand_as(mean) < 1] = 0.0 if not keepdim: return mean.squeeze(dim), count.squeeze(dim) return mean, count vox_feats, count_feats_m = basic_mean( vox_feats, 1, valid=vox_valid, keepdim=False ) return vox_feats, count_feats_m[:, [0]] def lift_aggregate_centers(self, batch, feats2d, vox_w, Ts_wr, T_wv=None): vD, vH, vW = self.voxel_size B, T = batch[ARIA_IMG[0]].shape[:2] # Lift to 3D. Project 3D voxel centers into each image and sample. vox_w = vox_w.reshape(B, 1, -1, 3).repeat(1, T, 1, 1) vox_feats, vox_valid, stream2pos = [], [], {} for stream in self.streams: stream_id = self.stream2id[stream] cam = batch[ARIA_CALIB[stream_id]] _vox_feats, _vox_valid = self.lift( feats2d[stream], vox_w, cam, Ts_wr, vD, vH, vW ) stream2pos[stream] = len(vox_feats) vox_feats.append(_vox_feats) vox_valid.append(_vox_valid) if self.joint_slam_streams: vox_feats_rgb, vox_valid_rgb = None, None if "rgb" in stream2pos: i = stream2pos["rgb"] vox_feats_rgb, vox_valid_rgb = vox_feats[i], vox_valid[i] vox_feats_slam = [ vox_feats[stream2pos[stream]] for stream in ["slaml", "slamr"] ] vox_valid_slam = [ vox_valid[stream2pos[stream]] for stream in ["slaml", "slamr"] ] vox_feats_slam = torch.cat(vox_feats_slam, 1) vox_valid_slam = torch.cat(vox_valid_slam, 1) count_feats = torch.sum(vox_valid_slam, dim=1, keepdim=True) # B 1 C D H W if vox_valid_rgb is not None: count_feats = count_feats + torch.sum( vox_valid_slam, dim=1, keepdim=True ) vox_feats_m, vox_valid_m = vox_feats_slam, vox_valid_slam vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m) if vox_valid_rgb is not None: vox_feats_m, vox_valid_m = vox_feats_rgb, vox_valid_rgb vox_feats_rgb, count_feats_rgb_m = self.aggregate( vox_feats_m, vox_valid_m ) vox_feats = torch.cat([vox_feats, vox_feats_rgb], 1) count_feats_m = count_feats_m + count_feats_rgb_m elif self.joint_streams: vox_feats = torch.cat(vox_feats, 1) vox_valid = torch.cat(vox_valid, 1) # Sum up number of valid projections into each camera for each voxel. count_feats = torch.sum(vox_valid, dim=1, keepdim=True) # B 1 C D H W vox_feats_m, vox_valid_m = vox_feats, vox_valid vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m) else: # concat lifted volumes for all selected video streams vox_feats = torch.cat(vox_feats, 2) vox_valid = torch.cat(vox_valid, 2) # B T C D H W # Sum up number of valid projections into each camera for each voxel. count_feats = torch.sum(vox_valid, dim=1, keepdim=True) # B 1 C D H W vox_feats_m, vox_valid_m = vox_feats, vox_valid vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m) count_feats = count_feats[:, :, 0] assert count_feats.shape == (B, 1, vD, vH, vW), f"{count_feats.shape}" assert count_feats_m.shape == (B, 1, vD, vH, vW), f"{count_feats_m.shape}" return vox_feats, count_feats, count_feats_m def forward(self, batch): B, T, _, H, W = batch[ARIA_IMG[0]].shape # Run CNN on EFM features to features back up to full resolution. feats2d = {} tokens2d = {} for stream in self.streams: feats2d[stream] = batch[f"{stream}/feat"] # for visualizations if not isinstance(feats2d[stream], list): tokens2d[stream] = feats2d[stream].detach().cpu() else: # multi-layer 2d features. Needed by DPT head in Lifter tokens2d[stream] = [f.detach().cpu() for f in feats2d[stream]] if self.head: feats2d[stream] = self.head.forward(feats2d[stream]) # Compute voxel grid pose. cams = batch[ARIA_CALIB[0]] device = cams.device T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] # T_world_rig (one per snippet) Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]] # Ts_snippet_rig (T per snippet) Ts_wr = T_ws @ Ts_sr T_wv, selectT = self.get_voxelgrid_pose(cams, T_ws, Ts_sr) # Generate voxel grid. vD, vH, vW = self.voxel_size point_info = [] point_masks = self.get_points_counts(batch, T_wv, vW, vH, vD, return_mask=True) point_info.append(point_masks) free_masks = self.get_freespace_counts( batch, T_wv, vW, vH, vD, return_mask=True ) point_info.append(free_masks) vox_v_orig = create_voxel_grid(vW, vH, vD, self.voxel_extent, device) vox_v_orig = vox_v_orig.permute(2, 1, 0, 3) # D H W 3 vox_v = vox_v_orig.reshape(-1, 3) vox_v = vox_v.unsqueeze(0).repeat(B, 1, 1) vox_w = T_wv * vox_v vox_w = vox_w.reshape(B, vD, vH, vW, 3) vox_w = vox_w.reshape(B, -1, 3) # B DHW 3 if len(feats2d) > 0: # Lift image features to 3D. Project 3D voxel centers into each # image and sample. vox_feats, count_feats, count_feats_m = self.lift_aggregate_centers( batch, feats2d, vox_w, Ts_wr, T_wv, ) vox_feats = torch.concatenate([vox_feats] + point_info, dim=1) else: vox_feats = torch.concatenate(point_info, dim=1) count_feats = torch.ones(B, 1, vD, vH, vW, device=device) count_feats_m = torch.ones(B, 1, vD, vH, vW, device=device) out = {} # Don't use the masked out versions (_m) because loss functions later on need these. for stream, feat2d in feats2d.items(): out[f"{stream}/feat2d_upsampled"] = feat2d for stream, token2d in tokens2d.items(): out[f"{stream}/token2d"] = token2d out["voxel/feat"] = vox_feats # B x F x D x H x W out["voxel/counts"] = count_feats[:, 0] # B x D x H x W # Pass the masked version of counts for debugging. out["voxel/counts_m"] = count_feats_m[:, 0] # B x D x H x W # We don't need the repeat across time anymore. vox_w = vox_w.reshape(B, vD * vH * vW, 3) out["voxel/pts_world"] = vox_w # B x N x 3 (N=D*H*W) out["voxel/T_world_voxel"] = T_wv # B x 12 out["voxel/selectT"] = selectT # B x 1 (frame that voxel grid is anchored to) out["voxel/occ_input"] = point_info[0] return out ================================================ FILE: efm3d/model/video_backbone.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from abc import ABC, abstractproperty from typing import Dict, List, Optional import einops import torch import torch.nn as nn from efm3d.aria.aria_constants import ARIA_IMG from hydra.utils import instantiate from omegaconf import DictConfig logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class VideoBackbone(torch.nn.Module, ABC): """ Snippet Feature Backbone runs image feature extractors for video snippets. This lets us easily try out various different backbones. """ def __init__( self, video_streams: Optional[List[str]] = None, pass_batch: bool = True, feat_dim: Optional[int] = None, correct_vignette: bool = False, optimize_vignette: bool = False, ensure_rgb: bool = False, ): """ Args: video_streams: a list of video streams to extract features for. Supported is "rgb", "slaml", "slamr". pass_batch: pass whole batch dict to the forward_impl if set to true. Otherwise passing the image tensors associated with the stream, instead of passing a dictionary of batch. correct_vignette: correct vignette for the image streams. optimize_vignette: optimize vignette correction for the image streams. This enables backpropagating into the vignettes. ensure_rgb: if set to true, will ensure that the output streams are all 3 channels. """ super().__init__() self.ensure_rgb = ensure_rgb self._feat_dim = -1 if feat_dim is not None: # Note that FPN will be constructed if feat_dim is passed in by construction (and fpn_levels > 0). self.feat_dim = feat_dim self.video_streams = video_streams if self.video_streams is None: self.video_streams = ["rgb"] self.pass_batch = pass_batch self.stream_to_id = {"rgb": 0, "slaml": 1, "slamr": 2} assert set(self.video_streams).issubset(set(self.stream_to_id.keys())), ( f"{self.video_streams} are not all valid (need to be a subset of {self.stream_to_id.keys()})" ) self.vignette_correction = {} self.vignette_correction = nn.ModuleDict(self.vignette_correction) @property def feat_dim(self): return self._feat_dim @feat_dim.setter def feat_dim(self, _feat_dim: int): self._feat_dim = _feat_dim @abstractproperty def patch_size(self): pass def forward_impl(self, img, stream) -> Dict[str, torch.Tensor]: """ forward_impl should return a dict with keys of the desired streams mapping to the extracted feature images. Other additional outputs can be added as well as needed. A suggested way to return additional outputs is to nest their keys under the corresponding streams such as: "rgb/feature_scale2" for additional feature outputs for the rgb stream. """ pass def forward(self, batch): out = {} if self.pass_batch: out = self.forward_impl(batch, self.video_streams) else: for stream in self.video_streams: # if we have a batch dictionary retrieve the corresponding video. If not assume that we are just key = ARIA_IMG[self.stream_to_id[stream]] if isinstance(batch, dict) and key in batch: im = batch[key] elif isinstance(batch, torch.Tensor) and len(self.video_streams) == 1: im = batch else: raise ValueError( f"batch not passed correctly {type(batch)} for video streams {self.video_streams}, {key}" ) if self.ensure_rgb and stream in ["slaml", "slamr"]: # greyscale -> rgb im = torch.cat([im, im, im], 2) # correct vignette if desired if stream in self.vignette_correction: im = self.vignette_correction[stream](im) # accumulate updates into one flat dict out.update(self.forward_impl(im, stream)) assert isinstance(out, dict), ( f"Output of forward must be of type dict, got {type(out)}" ) assert set(self.video_streams).issubset(set(out.keys())) return out class VideoBackboneDinov2(VideoBackbone): """ Get a snippet feature extractor from Dino v2. """ def __init__( self, image_tokenizer: DictConfig, video_streams: Optional[List[str]] = None, freeze_encoder: bool = False, correct_vignette: bool = False, optimize_vignette: bool = False, ): super().__init__( video_streams=video_streams, pass_batch=False, correct_vignette=correct_vignette, optimize_vignette=optimize_vignette, ) self.image_tokenizer = image_tokenizer if isinstance(image_tokenizer, DictConfig): self.image_tokenizer = instantiate(self.image_tokenizer) # assert freeze_encoder == self.image_tokenizer.freeze # get feature dimension self.feat_dim = self.image_tokenizer.feat_dim() self._patch_size = self.image_tokenizer.patch_size() logging.info("feature dim is %d" % self.feat_dim) logging.info("down_scale factor is %d" % self.patch_size) @property def patch_size(self): return self._patch_size def forward_impl(self, img, stream): # Run tokenizer. handles SLAM images internally. img_tokens = self.image_tokenizer.forward(img) # BxTxHxWxC -> B, T, C, H, W if isinstance(img_tokens, List): return { stream: [ einops.rearrange(t, "b t h w c -> b t c h w") for t in img_tokens ] } return {stream: einops.rearrange(img_tokens, "b t h w c -> b t c h w")} ================================================ FILE: efm3d/thirdparty/__init__.py ================================================ ================================================ FILE: efm3d/thirdparty/mmdetection3d/LICENSE ================================================ Copyright 2018-2019 Open-MMLab. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2018-2019 Open-MMLab. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: efm3d/thirdparty/mmdetection3d/__init__.py ================================================ ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/cuda_utils.h ================================================ // @lint-ignore-every LICENSELINT #ifndef _CUDA_UTILS_H #define _CUDA_UTILS_H #include #include #include #include #include #include #define TOTAL_THREADS 512 inline int opt_n_thread(int work_size) { const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); return max(min(1 << pow_2, TOTAL_THREADS), 1); } inline dim3 opt_block_config(int x, int y) { const int x_thread = opt_n_thread(x); const int y_thread = max(min(opt_n_thread(y), TOTAL_THREADS / x_thread), 1); dim3 block_config(x_thread, y_thread, 1); return block_config; } #define CUDA_CHECK_ERRORS() \ do { \ cudaError_t err = cudaGetLastError(); \ if (cudaSuccess != err) { \ fprintf( \ stderr, \ "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ cudaGetErrorString(err), \ __PRETTY_FUNCTION__, \ __LINE__, \ __FILE__); \ exit(-1); \ } \ } while (0) #endif ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/iou3d.cpp ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp // License under Apache 2.0 // https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define CHECK_ERROR(ans) \ { gpuAssert((ans), __FILE__, __LINE__); } inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf( stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); if (abort) exit(code); } } const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; void boxesoverlapLauncher( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_overlap); void boxesioubevLauncher( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_iou); void nmsLauncher( const float* boxes, unsigned long long* mask, int boxes_num, float nms_overlap_thresh); void nmsNormalLauncher( const float* boxes, unsigned long long* mask, int boxes_num, float nms_overlap_thresh); int boxes_overlap_bev_gpu( at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap) { // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] // params boxes_b: (M, 5) // params ans_overlap: (N, M) CHECK_INPUT(boxes_a); CHECK_INPUT(boxes_b); CHECK_INPUT(ans_overlap); int num_a = boxes_a.size(0); int num_b = boxes_b.size(0); const float* boxes_a_data = boxes_a.data_ptr(); const float* boxes_b_data = boxes_b.data_ptr(); float* ans_overlap_data = ans_overlap.data_ptr(); boxesoverlapLauncher( num_a, boxes_a_data, num_b, boxes_b_data, ans_overlap_data); return 1; } int boxes_iou_bev_gpu( at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou) { // params boxes_a: (N, 5) [x1, y1, x2, y2, ry] // params boxes_b: (M, 5) // params ans_overlap: (N, M) CHECK_INPUT(boxes_a); CHECK_INPUT(boxes_b); CHECK_INPUT(ans_iou); int num_a = boxes_a.size(0); int num_b = boxes_b.size(0); const float* boxes_a_data = boxes_a.data_ptr(); const float* boxes_b_data = boxes_b.data_ptr(); float* ans_iou_data = ans_iou.data_ptr(); boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data); return 1; } int nms_gpu( at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, int device_id) { // params boxes: (N, 5) [x1, y1, x2, y2, ry] // params keep: (N) CHECK_INPUT(boxes); CHECK_CONTIGUOUS(keep); cudaSetDevice(device_id); int boxes_num = boxes.size(0); const float* boxes_data = boxes.data_ptr(); long* keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); unsigned long long* mask_data = NULL; CHECK_ERROR(cudaMalloc( (void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long))); nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); // unsigned long long mask_cpu[boxes_num * col_blocks]; // unsigned long long *mask_cpu = new unsigned long long [boxes_num * // col_blocks]; std::vector mask_cpu(boxes_num * col_blocks); // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); CHECK_ERROR(cudaMemcpy( &mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long), cudaMemcpyDeviceToHost)); cudaFree(mask_data); unsigned long long remv_cpu[col_blocks]; memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; for (int i = 0; i < boxes_num; i++) { int nblock = i / THREADS_PER_BLOCK_NMS; int inblock = i % THREADS_PER_BLOCK_NMS; if (!(remv_cpu[nblock] & (1ULL << inblock))) { keep_data[num_to_keep++] = i; unsigned long long* p = &mask_cpu[0] + i * col_blocks; for (int j = nblock; j < col_blocks; j++) { remv_cpu[j] |= p[j]; } } } if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; } int nms_normal_gpu( at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, int device_id) { // params boxes: (N, 5) [x1, y1, x2, y2, ry] // params keep: (N) CHECK_INPUT(boxes); CHECK_CONTIGUOUS(keep); cudaSetDevice(device_id); int boxes_num = boxes.size(0); const float* boxes_data = boxes.data_ptr(); long* keep_data = keep.data_ptr(); const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); unsigned long long* mask_data = NULL; CHECK_ERROR(cudaMalloc( (void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long))); nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh); // unsigned long long mask_cpu[boxes_num * col_blocks]; // unsigned long long *mask_cpu = new unsigned long long [boxes_num * // col_blocks]; std::vector mask_cpu(boxes_num * col_blocks); // printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks); CHECK_ERROR(cudaMemcpy( &mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long), cudaMemcpyDeviceToHost)); cudaFree(mask_data); unsigned long long remv_cpu[col_blocks]; memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long)); int num_to_keep = 0; for (int i = 0; i < boxes_num; i++) { int nblock = i / THREADS_PER_BLOCK_NMS; int inblock = i % THREADS_PER_BLOCK_NMS; if (!(remv_cpu[nblock] & (1ULL << inblock))) { keep_data[num_to_keep++] = i; unsigned long long* p = &mask_cpu[0] + i * col_blocks; for (int j = nblock; j < col_blocks; j++) { remv_cpu[j] |= p[j]; } } } if (cudaSuccess != cudaGetLastError()) printf("Error!\n"); return num_to_keep; } ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/iou3d.h ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.h // License under Apache 2.0 // https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE #pragma once #include // @manual=//caffe2:torch-cpp int boxes_overlap_bev_gpu( at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap); int boxes_iou_bev_gpu( at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou); int nms_gpu( at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, int device_id); int nms_normal_gpu( at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh, int device_id); ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/iou3d_kernel.cu ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu // License under Apache 2.0 // https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE #include #define THREADS_PER_BLOCK 16 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) // #define DEBUG const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8; const float EPS = 1e-8; struct Point { float x, y; __device__ Point() {} __device__ Point(double _x, double _y) { x = _x, y = _y; } __device__ void set(float _x, float _y) { x = _x; y = _y; } __device__ Point operator+(const Point& b) const { return Point(x + b.x, y + b.y); } __device__ Point operator-(const Point& b) const { return Point(x - b.x, y - b.y); } }; __device__ inline float cross(const Point& a, const Point& b) { return a.x * b.y - a.y * b.x; } __device__ inline float cross(const Point& p1, const Point& p2, const Point& p0) { return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); } __device__ int check_rect_cross( const Point& p1, const Point& p2, const Point& q1, const Point& q2) { int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && min(q1.x, q2.x) <= max(p1.x, p2.x) && min(p1.y, p2.y) <= max(q1.y, q2.y) && min(q1.y, q2.y) <= max(p1.y, p2.y); return ret; } __device__ inline int check_in_box2d(const float* box, const Point& p) { // params: box (5) [x1, y1, x2, y2, angle] const float MARGIN = 1e-5; float center_x = (box[0] + box[2]) / 2; float center_y = (box[1] + box[3]) / 2; float angle_cos = cos(-box[4]), angle_sin = sin(-box[4]); // rotate the point in the opposite direction of box float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x; float rot_y = -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y; #ifdef DEBUG printf( "box: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", box[0], box[1], box[2], box[3], box[4]); printf( "center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, " "%.3f)\n", center_x, center_y, angle_cos, angle_sin, p.x, p.y, rot_x, rot_y); #endif return ( rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN && rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN); } __device__ inline int intersection( const Point& p1, const Point& p0, const Point& q1, const Point& q0, Point& ans) { // fast exclusion if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; // check cross standing float s1 = cross(q0, p1, p0); float s2 = cross(p1, q1, p0); float s3 = cross(p0, q1, q0); float s4 = cross(q1, p1, q0); if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0; // calculate intersection of two lines float s5 = cross(q1, p1, p0); if (fabs(s5 - s1) > EPS) { ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); } else { float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; float D = a0 * b1 - a1 * b0; ans.x = (b0 * c1 - b1 * c0) / D; ans.y = (a1 * c0 - a0 * c1) / D; } return 1; } __device__ inline void rotate_around_center( const Point& center, const float angle_cos, const float angle_sin, Point& p) { float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x; float new_y = -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; p.set(new_x, new_y); } __device__ inline int point_cmp(const Point& a, const Point& b, const Point& center) { return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x); } __device__ inline float box_overlap(const float* box_a, const float* box_b) { // params: box_a (5) [x1, y1, x2, y2, angle] // params: box_b (5) [x1, y1, x2, y2, angle] float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], a_angle = box_a[4]; float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], b_angle = box_b[4]; Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2); Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2); #ifdef DEBUG printf( "a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle); printf( "center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y, center_b.x, center_b.y); #endif Point box_a_corners[5]; box_a_corners[0].set(a_x1, a_y1); box_a_corners[1].set(a_x2, a_y1); box_a_corners[2].set(a_x2, a_y2); box_a_corners[3].set(a_x1, a_y2); Point box_b_corners[5]; box_b_corners[0].set(b_x1, b_y1); box_b_corners[1].set(b_x2, b_y1); box_b_corners[2].set(b_x2, b_y2); box_b_corners[3].set(b_x1, b_y2); // get oriented corners float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle); float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); for (int k = 0; k < 4; k++) { #ifdef DEBUG printf( "before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); #endif rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); #ifdef DEBUG printf( "corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y); #endif } box_a_corners[4] = box_a_corners[0]; box_b_corners[4] = box_b_corners[0]; // get intersection of lines Point cross_points[16]; Point poly_center; int cnt = 0, flag = 0; poly_center.set(0, 0); for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { flag = intersection( box_a_corners[i + 1], box_a_corners[i], box_b_corners[j + 1], box_b_corners[j], cross_points[cnt]); if (flag) { poly_center = poly_center + cross_points[cnt]; cnt++; } } } // check corners for (int k = 0; k < 4; k++) { if (check_in_box2d(box_a, box_b_corners[k])) { poly_center = poly_center + box_b_corners[k]; cross_points[cnt] = box_b_corners[k]; cnt++; } if (check_in_box2d(box_b, box_a_corners[k])) { poly_center = poly_center + box_a_corners[k]; cross_points[cnt] = box_a_corners[k]; cnt++; } } poly_center.x /= cnt; poly_center.y /= cnt; // sort the points of polygon Point temp; for (int j = 0; j < cnt - 1; j++) { for (int i = 0; i < cnt - j - 1; i++) { if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { temp = cross_points[i]; cross_points[i] = cross_points[i + 1]; cross_points[i + 1] = temp; } } } #ifdef DEBUG printf("cnt=%d\n", cnt); for (int i = 0; i < cnt; i++) { printf( "All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x, cross_points[i].y); } #endif // get the overlap areas float area = 0; for (int k = 0; k < cnt - 1; k++) { area += cross( cross_points[k] - cross_points[0], cross_points[k + 1] - cross_points[0]); } return fabs(area) / 2.0; } __device__ inline float iou_bev(const float* box_a, const float* box_b) { // params: box_a (5) [x1, y1, x2, y2, angle] // params: box_b (5) [x1, y1, x2, y2, angle] float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]); float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]); float s_overlap = box_overlap(box_a, box_b); return s_overlap / fmaxf(sa + sb - s_overlap, EPS); } __global__ void boxes_overlap_kernel( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_overlap) { const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; if (a_idx >= num_a || b_idx >= num_b) { return; } const float* cur_box_a = boxes_a + a_idx * 5; const float* cur_box_b = boxes_b + b_idx * 5; float s_overlap = box_overlap(cur_box_a, cur_box_b); ans_overlap[a_idx * num_b + b_idx] = s_overlap; } __global__ void boxes_iou_bev_kernel( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_iou) { const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y; const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x; if (a_idx >= num_a || b_idx >= num_b) { return; } const float* cur_box_a = boxes_a + a_idx * 5; const float* cur_box_b = boxes_b + b_idx * 5; float cur_iou_bev = iou_bev(cur_box_a, cur_box_b); ans_iou[a_idx * num_b + b_idx] = cur_iou_bev; } __global__ void nms_kernel( const int boxes_num, const float nms_overlap_thresh, const float* boxes, unsigned long long* mask) { // params: boxes (N, 5) [x1, y1, x2, y2, ry] // params: mask (N, N/THREADS_PER_BLOCK_NMS) const int row_start = blockIdx.y; const int col_start = blockIdx.x; // if (row_start > col_start) return; const int row_size = fminf( boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); const int col_size = fminf( boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; if (threadIdx.x < col_size) { block_boxes[threadIdx.x * 5 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; block_boxes[threadIdx.x * 5 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; block_boxes[threadIdx.x * 5 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; block_boxes[threadIdx.x * 5 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; block_boxes[threadIdx.x * 5 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; const float* cur_box = boxes + cur_box_idx * 5; int i = 0; unsigned long long t = 0; int start = 0; if (row_start == col_start) { start = threadIdx.x + 1; } for (i = start; i < col_size; i++) { if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { t |= 1ULL << i; } } const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); mask[cur_box_idx * col_blocks + col_start] = t; } } __device__ inline float iou_normal(float const* const a, float const* const b) { float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f); float interS = width * height; float Sa = (a[2] - a[0]) * (a[3] - a[1]); float Sb = (b[2] - b[0]) * (b[3] - b[1]); return interS / fmaxf(Sa + Sb - interS, EPS); } __global__ void nms_normal_kernel( const int boxes_num, const float nms_overlap_thresh, const float* boxes, unsigned long long* mask) { // params: boxes (N, 5) [x1, y1, x2, y2, ry] // params: mask (N, N/THREADS_PER_BLOCK_NMS) const int row_start = blockIdx.y; const int col_start = blockIdx.x; // if (row_start > col_start) return; const int row_size = fminf( boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); const int col_size = fminf( boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS); __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5]; if (threadIdx.x < col_size) { block_boxes[threadIdx.x * 5 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0]; block_boxes[threadIdx.x * 5 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1]; block_boxes[threadIdx.x * 5 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2]; block_boxes[threadIdx.x * 5 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3]; block_boxes[threadIdx.x * 5 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4]; } __syncthreads(); if (threadIdx.x < row_size) { const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x; const float* cur_box = boxes + cur_box_idx * 5; int i = 0; unsigned long long t = 0; int start = 0; if (row_start == col_start) { start = threadIdx.x + 1; } for (i = start; i < col_size; i++) { if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { t |= 1ULL << i; } } const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS); mask[cur_box_idx * col_blocks + col_start] = t; } } void boxesoverlapLauncher( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_overlap) { dim3 blocks( DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); boxes_overlap_kernel<<>>( num_a, boxes_a, num_b, boxes_b, ans_overlap); #ifdef DEBUG cudaDeviceSynchronize(); // for using printf in kernel function #endif } void boxesioubevLauncher( const int num_a, const float* boxes_a, const int num_b, const float* boxes_b, float* ans_iou) { dim3 blocks( DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK); boxes_iou_bev_kernel<<>>( num_a, boxes_a, num_b, boxes_b, ans_iou); } void nmsLauncher( const float* boxes, unsigned long long* mask, int boxes_num, float nms_overlap_thresh) { dim3 blocks( DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); dim3 threads(THREADS_PER_BLOCK_NMS); nms_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, mask); } void nmsNormalLauncher( const float* boxes, unsigned long long* mask, int boxes_num, float nms_overlap_thresh) { dim3 blocks( DIVUP(boxes_num, THREADS_PER_BLOCK_NMS), DIVUP(boxes_num, THREADS_PER_BLOCK_NMS)); dim3 threads(THREADS_PER_BLOCK_NMS); nms_normal_kernel<<>>( boxes_num, nms_overlap_thresh, boxes, mask); } ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/setup.py ================================================ # @lint-ignore-every LICENSELINT from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name="mmdet_iou3d", ext_modules=[ CUDAExtension( "mmdet_iou3d", [ "iou3d_kernel.cu", "iou3d.cpp", "sort_vert_kernel.cu", "sort_vert.cpp", ], ) ], headers=[ "iou3d.h", "sort_vert.h", "cuda_utils.h", "utils.h", ], cmdclass={"build_ext": BuildExtension}, ) ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/sort_vert.cpp ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert.cpp // License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE #include "sort_vert.h" #include "iou3d.h" #include "utils.h" void sort_vertices_wrapper( int b, int n, int m, const float* vertices, const bool* mask, const int* num_valid, int* idx); at::Tensor sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid) { CHECK_CONTIGUOUS(vertices); CHECK_CONTIGUOUS(mask); CHECK_CONTIGUOUS(num_valid); CHECK_CUDA(vertices); CHECK_CUDA(mask); CHECK_CUDA(num_valid); CHECK_IS_FLOAT(vertices); CHECK_IS_BOOL(mask); CHECK_IS_INT(num_valid); int b = vertices.size(0); int n = vertices.size(1); int m = vertices.size(2); at::Tensor idx = torch::zeros( {b, n, MAX_NUM_VERT_IDX}, at::device(vertices.device()).dtype(at::ScalarType::Int)); // fix issue with multi-gpu (kernel only works for cuda:0) const at::cuda::OptionalCUDAGuard device_guard(device_of(idx)); sort_vertices_wrapper( b, n, m, vertices.data_ptr(), mask.data_ptr(), num_valid.data_ptr(), idx.data_ptr()); return idx; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "sort_vertices_forward", &sort_vertices, "sort vertices of a convex polygon. forward only"); m.def( "boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu, "oriented boxes overlap"); m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou"); m.def("nms_gpu", &nms_gpu, "oriented nms gpu"); m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu"); } ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/sort_vert.h ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert.h // License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE #pragma once #include // @manual=//caffe2:torch-cpp #define MAX_NUM_VERT_IDX 9 at::Tensor sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid); ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/sort_vert_kernel.cu ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert_kernel.cu // License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE #include #include #include #include "cuda_utils.h" #define MAX_NUM_VERT_IDX 9 #define INTERSECTION_OFFSET 8 #define EPSILON 1e-8 /* compare normalized vertices (vertices around (0,0)) if vertex1 < vertex2 return ture. order: minimum at x-aixs, become larger in anti-clockwise direction */ __device__ bool compare_vertices(float x1, float y1, float x2, float y2) { if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON) return false; // if equal, return false if (y1 > 0 && y2 < 0) return true; if (y1 < 0 && y2 > 0) return false; float n1 = x1 * x1 + y1 * y1 + EPSILON; float n2 = x2 * x2 + y2 * y2 + EPSILON; if (y1 > 0 && y2 > 0) { if (fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2 > EPSILON) return true; else return false; } if (y1 < 0 && y2 < 0) { if (fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2 < EPSILON) return true; else return false; } } __global__ void sort_vertices_kernel( int b, int n, int m, const float* __restrict__ vertices, const bool* __restrict__ mask, const int* __restrict__ num_valid, int* __restrict__ idx) { int batch_idx = blockIdx.x; vertices += batch_idx * n * m * 2; mask += batch_idx * n * m; num_valid += batch_idx * n; idx += batch_idx * n * MAX_NUM_VERT_IDX; int index = threadIdx.x; // index of polygon int stride = blockDim.x; for (int i = index; i < n; i += stride) { int pad; // index of arbitrary invalid intersection point (not box corner!) for (int j = INTERSECTION_OFFSET; j < m; ++j) { if (!mask[i * m + j]) { pad = j; break; } } if (num_valid[i] < 3) { // not enough vertices, take an invalid intersection point // (zero padding) for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) { idx[i * MAX_NUM_VERT_IDX + j] = pad; } } else { // sort the valid vertices // note the number of valid vertices is known for (int j = 0; j < num_valid[i]; ++j) { // initilize with a "big" value float x_min = 1; float y_min = -EPSILON; int i_take = 0; for (int k = 0; k < m; ++k) { float x = vertices[i * m * 2 + k * 2 + 0]; float y = vertices[i * m * 2 + k * 2 + 1]; if (j == 0) { if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) { x_min = x; y_min = y; i_take = k; } } else { int i2 = idx[i * MAX_NUM_VERT_IDX + j - 1]; float x2 = vertices[i * m * 2 + i2 * 2 + 0]; float y2 = vertices[i * m * 2 + i2 * 2 + 1]; if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min) && compare_vertices(x2, y2, x, y)) { x_min = x; y_min = y; i_take = k; } } idx[i * MAX_NUM_VERT_IDX + j] = i_take; } } // duplicate the first idx idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0]; // pad zeros for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) { idx[i * MAX_NUM_VERT_IDX + j] = pad; } // for corner case: the two boxes are exactly the same. // in this case, idx would have duplicate elements, which makes the // shoelace formula broken because of the definition, the duplicate // elements only appear in the first 8 positions (they are "corners in // box", not "intersection of edges") if (num_valid[i] == 8) { int counter = 0; for (int j = 0; j < 4; ++j) { int check = idx[i * MAX_NUM_VERT_IDX + j]; for (int k = 4; k < INTERSECTION_OFFSET; ++k) { if (idx[i * MAX_NUM_VERT_IDX + k] == check) counter++; } } if (counter == 4) { idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0]; for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) { idx[i * MAX_NUM_VERT_IDX + j] = pad; } } } } } } void sort_vertices_wrapper( int b, int n, int m, const float* vertices, const bool* mask, const int* num_valid, int* idx) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); sort_vertices_kernel<<>>( b, n, m, vertices, mask, num_valid, idx); CUDA_CHECK_ERRORS(); } ================================================ FILE: efm3d/thirdparty/mmdetection3d/cuda/utils.h ================================================ // @lint-ignore-every LICENSELINT // Modified from // https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/utils.h // License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE #pragma once #include #include #include // @manual=//caffe2:torch-cpp #define CHECK_CUDA(x) \ do { \ TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \ } while (0) #define CHECK_CONTIGUOUS(x) \ do { \ TORCH_CHECK(x.is_contiguous(), #x " must ne a contiguous tensor"); \ } while (0) #define CHECK_IS_INT(x) \ do { \ TORCH_CHECK( \ x.scalar_type() == at::ScalarType::Int, #x " must be a int tensor"); \ } while (0) #define CHECK_IS_FLOAT(x) \ do { \ TORCH_CHECK( \ x.scalar_type() == at::ScalarType::Float, \ #x " must be a float tensor"); \ } while (0) #define CHECK_IS_BOOL(x) \ do { \ TORCH_CHECK( \ x.scalar_type() == at::ScalarType::Bool, #x " must be a bool tensor"); \ } while (0) ================================================ FILE: efm3d/thirdparty/mmdetection3d/iou3d.py ================================================ # Copyright (c) OpenMMLab. All rights reserved. # Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa # Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa # License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE from typing import Tuple import mmdet_iou3d import torch from torch import Tensor from torch.autograd import Function EPSILON = 1e-8 class SortVertices(Function): @staticmethod def forward(ctx, vertices, mask, num_valid): idx = mmdet_iou3d.sort_vertices_forward(vertices, mask, num_valid) ctx.mark_non_differentiable(idx) return idx @staticmethod def backward(ctx, gradout): return () def box_intersection(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]: """Find intersection points of rectangles. Convention: if two edges are collinear, there is no intersection point. Args: corners1 (Tensor): (B, N, 4, 2) First batch of boxes. corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: - Tensor: (B, N, 4, 4, 2) Intersections. - Tensor: (B, N, 4, 4) Valid intersections mask. """ # build edges from corners # B, N, 4, 4: Batch, Box, edge, point line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3) line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) # duplicate data to pair each edges from the boxes # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point line1_ext = line1.unsqueeze(3) line2_ext = line2.unsqueeze(2) x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1) x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1) # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) t = denumerator_t / numerator t[numerator == 0.0] = -1.0 mask_t = (t > 0) & (t < 1) # intersection on line segment 1 denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) u = -denumerator_u / numerator u[numerator == 0.0] = -1.0 mask_u = (u > 0) & (u < 1) # intersection on line segment 2 mask = mask_t * mask_u # overwrite with EPSILON. otherwise numerically unstable t = denumerator_t / (numerator + EPSILON) intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], dim=-1) intersections = intersections * mask.float().unsqueeze(-1) return intersections, mask def box1_in_box2(corners1: Tensor, corners2: Tensor) -> Tensor: """Check if corners of box1 lie in box2. Convention: if a corner is exactly on the edge of the other box, it's also a valid point. Args: corners1 (Tensor): (B, N, 4, 2) First batch of boxes. corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tensor: (B, N, 4) Intersection. """ # a, b, c, d - 4 vertices of box2 a = corners2[:, :, 0:1, :] # (B, N, 1, 2) b = corners2[:, :, 1:2, :] # (B, N, 1, 2) d = corners2[:, :, 3:4, :] # (B, N, 1, 2) # ab, am, ad - vectors between corresponding vertices ab = b - a # (B, N, 1, 2) am = corners1 - a # (B, N, 4, 2) ad = d - a # (B, N, 1, 2) prod_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1) prod_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1) # NOTE: the expression looks ugly but is stable if the two boxes # are exactly the same also stable with different scale of bboxes cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6) # (B, N, 4) cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6) # (B, N, 4) return cond1 * cond2 def box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]: """Check if corners of two boxes lie in each other. Args: corners1 (Tensor): (B, N, 4, 2) First batch of boxes. corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: - Tensor: (B, N, 4) True if i-th corner of box1 is in box2. - Tensor: (B, N, 4) True if i-th corner of box2 is in box1. """ c1_in_2 = box1_in_box2(corners1, corners2) c2_in_1 = box1_in_box2(corners2, corners1) return c1_in_2, c2_in_1 def build_vertices( corners1: Tensor, corners2: Tensor, c1_in_2: Tensor, c2_in_1: Tensor, intersections: Tensor, valid_mask: Tensor, ) -> Tuple[Tensor, Tensor]: """Find vertices of intersection area. Args: corners1 (Tensor): (B, N, 4, 2) First batch of boxes. corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. intersections (Tensor): (B, N, 4, 4, 2) Intersections. valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask. Returns: Tuple: - Tensor: (B, N, 24, 2) Vertices of intersection area; only some elements are valid. - Tensor: (B, N, 24) Mask of valid elements in vertices. """ # NOTE: inter has elements equals zero and has zeros gradient # (masked by multiplying with 0); can be used as trick B = corners1.size()[0] N = corners1.size()[1] # (B, N, 4 + 4 + 16, 2) vertices = torch.cat([corners1, corners2, intersections.view([B, N, -1, 2])], dim=2) # Bool (B, N, 4 + 4 + 16) mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2) return vertices, mask def sort_indices(vertices: Tensor, mask: Tensor) -> Tensor: """Sort indices. Note: why 9? the polygon has maximal 8 vertices. +1 to duplicate the first element. the index should have following structure: (A, B, C, ... , A, X, X, X) and X indicates the index of arbitrary elements in the last 16 (intersections not corners) with value 0 and mask False. (cause they have zero value and zero gradient) Args: vertices (Tensor): (B, N, 24, 2) Box vertices. mask (Tensor): (B, N, 24) Mask. Returns: Tensor: (B, N, 9) Sorted indices. """ num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) mean = torch.sum( vertices * mask.float().unsqueeze(-1), dim=2, keepdim=True ) / num_valid.unsqueeze(-1).unsqueeze(-1) vertices_normalized = vertices - mean # normalization makes sorting easier return SortVertices.apply(vertices_normalized, mask, num_valid).long() def calculate_area(idx_sorted: Tensor, vertices: Tensor) -> Tuple[Tensor, Tensor]: """Calculate area of intersection. Args: idx_sorted (Tensor): (B, N, 9) Sorted vertex ids. vertices (Tensor): (B, N, 24, 2) Vertices. Returns: Tuple: - Tensor (B, N): Area of intersection. - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. """ idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2]) selected = torch.gather(vertices, 2, idx_ext) total = ( selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0] ) total = torch.sum(total, dim=2) area = torch.abs(total) / 2 return area, selected def oriented_box_intersection_2d( corners1: Tensor, corners2: Tensor ) -> Tuple[Tensor, Tensor]: """Calculate intersection area of 2d rotated boxes. Args: corners1 (Tensor): (B, N, 4, 2) First batch of boxes. corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: - Tensor (B, N): Area of intersection. - Tensor (B, N, 9, 2): Vertices of polygon with zero padding. """ intersections, valid_mask = box_intersection(corners1, corners2) c12, c21 = box_in_box(corners1, corners2) vertices, mask = build_vertices( corners1, corners2, c12, c21, intersections, valid_mask ) sorted_indices = sort_indices(vertices, mask) return calculate_area(sorted_indices, vertices) def box2corners(box: Tensor) -> Tensor: """Convert rotated 2d box coordinate to corners. Args: box (Tensor): (B, N, 5) with x, y, w, h, alpha. Returns: Tensor: (B, N, 4, 2) Corners. """ B = box.size()[0] x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) x4 = x4 * w # (B, N, 4) y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device) y4 = y4 * h # (B, N, 4) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) sin = torch.sin(alpha) cos = torch.cos(alpha) row1 = torch.cat([cos, sin], dim=-1) row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2) rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2) rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2])) rotated = rotated.view([B, -1, 4, 2]) # (B * N, 4, 2) -> (B, N, 4, 2) rotated[..., 0] += x rotated[..., 1] += y return rotated def diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor: """Calculate differentiable iou of rotated 2d boxes. Args: box1 (Tensor): (B, N, 5) First box. box2 (Tensor): (B, N, 5) Second box. Returns: Tensor: (B, N) IoU. """ corners1 = box2corners(box1) corners2 = box2corners(box2) intersection, _ = oriented_box_intersection_2d(corners1, corners2) # (B, N) area1 = box1[:, :, 2] * box1[:, :, 3] area2 = box2[:, :, 2] * box2[:, :, 3] union = area1 + area2 - intersection iou = intersection / union return iou def diff_iou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor: """Calculate differentiable iou of rotated 3d boxes. Args: box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha). box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha). Returns: Tensor: (B, N) IoU. """ box1 = box3d1[..., [0, 1, 3, 4, 6]] # 2d box box2 = box3d2[..., [0, 1, 3, 4, 6]] corners1 = box2corners(box1) corners2 = box2corners(box2) intersection, _ = oriented_box_intersection_2d(corners1, corners2) zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 z_overlap = (torch.min(zmax1, zmax2) - torch.max(zmin1, zmin2)).clamp_(min=0.0) intersection_3d = intersection * z_overlap volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] union_3d = volume1 + volume2 - intersection_3d return intersection_3d / union_3d def rotated_iou_3d_loss(pred, target): """Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes. Note that predictions and targets are one-to-one corresponded. Args: pred (torch.Tensor): Bbox predictions with shape [N, 7] (x, y, z, w, l, h, alpha). target (torch.Tensor): Bbox targets (gt) with shape [N, 7] (x, y, z, w, l, h, alpha). Returns: torch.Tensor: IoU loss between predictions and targets. """ iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0), target.unsqueeze(0))[0] return iou_loss class RotatedIoU3DLoss(torch.nn.Module): """Calculate the IoU loss (1-IoU) of rotated bounding boxes. Args: loss_weight (float, optional): Weight of loss. Defaults to 1.0. """ def __init__(self, loss_weight=1.0): super().__init__() self.loss_weight = loss_weight def forward( self, pred, target, ): """Forward function of loss calculation. Args: pred (torch.Tensor): Bbox predictions with shape [..., 7] (x, y, z, w, l, h, alpha). target (torch.Tensor): Bbox targets (gt) with shape [..., 7] (x, y, z, w, l, h, alpha). Returns: torch.Tensor: IoU loss between predictions and targets. """ # print(pred.shape, target.shape) if pred.shape[0] == 0 or target.shape[0] == 0: return 0.0 * pred.sum() loss = self.loss_weight * rotated_iou_3d_loss(pred, target) return loss def boxes_iou_bev(boxes_a, boxes_b): """Calculate boxes IoU in the bird view. Args: boxes_a (torch.Tensor): Input boxes a with shape (M, 5). boxes_b (torch.Tensor): Input boxes b with shape (N, 5). Returns: ans_iou (torch.Tensor): IoU result with shape (M, N). """ ans_iou = boxes_a.new_zeros(torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) mmdet_iou3d.boxes_iou_bev_gpu(boxes_a.contiguous(), boxes_b.contiguous(), ans_iou) return ans_iou def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): """Nms function with gpu implementation. Args: boxes (torch.Tensor): Input boxes with the shape of [N, 5] ([x1, y1, x2, y2, ry]). scores (torch.Tensor): Scores of boxes with the shape of [N]. thresh (int): Threshold. pre_maxsize (int): Max size of boxes before nms. Default: None. post_maxsize (int): Max size of boxes after nms. Default: None. Returns: torch.Tensor: Indexes after nms. """ order = scores.sort(0, descending=True)[1] if pre_maxsize is not None: order = order[:pre_maxsize] boxes = boxes[order].contiguous() keep = torch.zeros(boxes.size(0), dtype=torch.long) num_out = mmdet_iou3d.nms_gpu(boxes, keep, thresh, boxes.device.index) keep = order[keep[:num_out].cuda(boxes.device)].contiguous() if post_max_size is not None: keep = keep[:post_max_size] return keep def nms_normal_gpu(boxes, scores, thresh): """Normal non maximum suppression on GPU. Args: boxes (torch.Tensor): Input boxes with shape (N, 5). scores (torch.Tensor): Scores of predicted boxes with shape (N). thresh (torch.Tensor): Threshold of non maximum suppression. Returns: torch.Tensor: Remaining indices with scores in descending order. """ order = scores.sort(0, descending=True)[1] boxes = boxes[order].contiguous() keep = torch.zeros(boxes.size(0), dtype=torch.long) num_out = mmdet_iou3d.nms_normal_gpu(boxes, keep, thresh, boxes.device.index) return order[keep[:num_out].cuda(boxes.device)].contiguous() ================================================ FILE: efm3d/utils/__init__.py ================================================ ================================================ FILE: efm3d/utils/common.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch def sample_nearest(value_a, value_b, array_b): array_b_at_a = [] for v_a in value_a: idx = find_nearest(value_b, v_a, return_index=True) array_b_at_a.append(array_b[idx]) return torch.stack(array_b_at_a) def find_nearest(array, value, return_index=False): array = np.asarray(array) idx = (np.abs(array - value)).argmin() if return_index: return idx else: return array[idx] ================================================ FILE: efm3d/utils/depth.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from efm3d.utils.ray import ray_grid def dist_im_to_point_cloud_im(dist_m, cams): B, T = None, None if cams.ndim == 3: B, T, _ = cams.shape cams = cams.view(B * T, -1) dist_m = dist_m.flatten(0, 1) elif cams.ndim == 2: B, _ = cams.shape elif cams.ndim == 1: cams = cams.view(1, -1) H, W = dist_m.shape dist_m = dist_m.view(1, H, W) BT, H, W = dist_m.shape rays_rig, valids = ray_grid(cams) p3s_rig = rays_rig[..., :3] + rays_rig[..., 3:] * dist_m.unsqueeze(-1) p3s_c = cams.T_camera_rig * p3s_rig.view(BT, -1, 3) # distances > 0.0 are valid valids = torch.logical_and(valids, dist_m > 0.0) if T is not None: p3s_c = p3s_c.view(B, T, H, W, 3) valids = valids.view(B, T, H, W) elif B is not None: p3s_c = p3s_c.view(B, H, W, 3) valids = valids.view(B, H, W) else: p3s_c = p3s_c.view(H, W, 3) valids = valids.view(H, W) return p3s_c, valids ================================================ FILE: efm3d/utils/detection_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import torchvision from efm3d.aria.obb import ObbTW from efm3d.aria.pose import PAD_VAL, PoseTW, rotation_from_euler def norm2ind(norm_xyz, vD, vH, vW): """Converts normalized xyz coords [-1,1] to DxHxW indices.""" if isinstance(norm_xyz, np.ndarray): inds_dhw = norm_xyz.copy() else: inds_dhw = norm_xyz.clone() inds_dhw[..., 0] = torch.ceil((norm_xyz[..., 2] + 1.0) * vD / 2.0) - 1 inds_dhw[..., 1] = torch.ceil((norm_xyz[..., 1] + 1.0) * vH / 2.0) - 1 inds_dhw[..., 2] = torch.ceil((norm_xyz[..., 0] + 1.0) * vW / 2.0) - 1 inds_dhw = inds_dhw.round() outside = ( (inds_dhw[..., 0] <= 0) | (inds_dhw[..., 0] >= (vD - 1)) | (inds_dhw[..., 1] <= 0) | (inds_dhw[..., 1] >= (vH - 1)) | (inds_dhw[..., 2] <= 0) | (inds_dhw[..., 2] >= (vW - 1)) ) inside = ~outside if isinstance(inds_dhw, np.ndarray): inds_dhw = inds_dhw.astype(int) else: inds_dhw = inds_dhw.int() return inds_dhw, inside def ind2norm(inds_dhw, vD, vH, vW): """Converts DxHxW indices to normalized xyz coords [-1,1].""" if isinstance(inds_dhw, np.ndarray): norm_xyz = inds_dhw.copy().astype(float) else: norm_xyz = inds_dhw.clone().float() norm_xyz[..., 0] = 2.0 * (inds_dhw[..., 2] + 0.5) / vW - 1.0 norm_xyz[..., 1] = 2.0 * (inds_dhw[..., 1] + 0.5) / vH - 1.0 norm_xyz[..., 2] = 2.0 * (inds_dhw[..., 0] + 0.5) / vD - 1.0 return norm_xyz def normalize_coord3d(xyz, extent): if isinstance(xyz, np.ndarray): xyz_n = xyz.copy() else: xyz_n = xyz.clone() x_min, x_max, y_min, y_max, z_min, z_max = extent xyz_n[..., 0] = ((xyz[..., 0] - x_min) / ((x_max - x_min) / 2.0)) - 1.0 xyz_n[..., 1] = ((xyz[..., 1] - y_min) / ((y_max - y_min) / 2.0)) - 1.0 xyz_n[..., 2] = ((xyz[..., 2] - z_min) / ((z_max - z_min) / 2.0)) - 1.0 return xyz_n def unnormalize_coord3d(xyz_n, extent): if isinstance(xyz_n, np.ndarray): xyz = xyz_n.copy() else: xyz = xyz_n.clone() x_min, x_max, y_min, y_max, z_min, z_max = extent xyz[..., 0] = ((xyz_n[..., 0] + 1.0) * ((x_max - x_min) / 2.0)) + x_min xyz[..., 1] = ((xyz_n[..., 1] + 1.0) * ((y_max - y_min) / 2.0)) + y_min xyz[..., 2] = ((xyz_n[..., 2] + 1.0) * ((z_max - z_min) / 2.0)) + z_min return xyz def create_heatmap_gt(mu_xy, H, W, valid=None): """ Inputs: mu_xy : torch.Tensor : shaped BxNx2 of pixel locations in range [0,H-1] and [0,W-1] H : image height W : image width: valid : torch.Tensor : optional boolean mask shaped BxNx2 or whether to use this point or not returns: heat_gt : torch.Tensor : Bx1xHxW tensor of splatted 2D points """ B = mu_xy.shape[0] inside = ( (mu_xy[..., 0] >= 0) & (mu_xy[..., 0] <= (H - 1)) & (mu_xy[..., 1] >= 0) & (mu_xy[..., 1] <= (W - 1)) ) if valid is not None: # if we have additional valid signal, use it inside = inside & valid inds_xy = mu_xy.round().long() inds_xy = inds_xy.reshape(B, -1, 2) inds = ( inds_xy[:, :, 1] * W + inds_xy[:, :, 0] ) # flatten matrix index into vector index inds = torch.clip(inds, min=0, max=(H - 1) * (W - 1)) inside = inside.reshape(B, -1).to(inds) heat_gt = torch.zeros((B, H * W)).to(inds) heat_gt.scatter_(1, inds, inside) heat_gt = heat_gt.reshape(B, H, W).float() blur = torchvision.transforms.functional.gaussian_blur kernel = 25 heat_gt = blur(heat_gt, kernel) if heat_gt.sum() > 0: # Normalize such that peak is ~1. heat_gt = heat_gt * 100 heat_gt = torch.clip(heat_gt, min=0, max=1) return heat_gt def simple_nms(scores, nms_radius: int): """Approximate + Fast Non-maximum suppression to remove nearby points, works by running max pool twice on GPU.""" assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool2d( x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) for _ in range(2): supp_mask = max_pool(max_mask.float()) > 0 supp_scores = torch.where(supp_mask, zeros, scores) new_max_mask = supp_scores == max_pool(supp_scores) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) def simple_nms3d(scores, nms_radius: int): """Approximate + Fast Non-maximum suppression on 3D heatmap to remove nearby points, works by running max pool twice on GPU.""" assert nms_radius >= 0 def max_pool(x): return torch.nn.functional.max_pool3d( x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius ) zeros = torch.zeros_like(scores) max_mask = scores == max_pool(scores) for _ in range(2): supp_mask = max_pool(max_mask.float()) > 0 supp_scores = torch.where(supp_mask, zeros, scores) new_max_mask = supp_scores == max_pool(supp_scores) max_mask = max_mask | (new_max_mask & (~supp_mask)) return torch.where(max_mask, scores, zeros) def heatmap2obb(scores, threshold=0.3, size=20, max_elts=1000): """Runs argmax on a 2D heatmaps to return (x,y) positions in the heatmap in the ObbTW class, above a threshold. Creates fake 2D bounding boxes of size 20x20 by default.""" # Extract keypoints hsize = int(round(size / 2)) obbs = [] dev = scores.device for score in scores: keypoint = torch.nonzero(score > threshold) ymin = keypoint[:, 0] - hsize ymax = keypoint[:, 0] + hsize xmin = keypoint[:, 1] - hsize xmax = keypoint[:, 1] + hsize bb2 = torch.stack([xmin, xmax, ymin, ymax], dim=1).float() obb = ObbTW().repeat(bb2.shape[0], 1).clone().to(dev) # Set bb2_rgb obb.set_bb2(cam_id=0, bb2d=bb2, use_mask=False) # Set to RGB. # Set probability probs = score[tuple(keypoint.t())] obb.set_prob(probs) obbs.append(obb.add_padding(max_elts=max_elts)) return torch.stack(obbs, dim=0) # Centerness loss, binary cross entropy (evaluated densely per voxel position). def compute_focal_loss(pred, gt, focal_gamma=2, focal_alpha=0.25): """focal loss for imbalanced classification https://pytorch.org/vision/stable/_modules/torchvision/ops/focal_loss.html Args: pred (torch.tensor): predicted probabilities gt (torch.tensor): GT probabilities Returns: nll_loss: negative log-likelihood loss """ assert pred.shape == gt.shape gt = gt.double() pred = pred.double() eps = 1e-9 # Simple negative log-likelihood (aka binary cross-entropy). Assume sigmoid already applied. nll = -(torch.log(pred + eps) * gt + torch.log((1.0 - pred) + eps) * (1.0 - gt)) if focal_gamma > 0: p_t = pred * gt + (1 - pred) * (1 - gt) nll = nll * ((1 - p_t) ** focal_gamma) # class-wise balancing if focal_alpha >= 0: alpha_t = focal_alpha * gt + (1 - focal_alpha) * (1.0 - gt) nll = alpha_t * nll return nll.float() def compute_chamfer_loss(vals, target): B = vals.shape[0] xx = vals.view(B, 8, 1, 3) yy = target.view(B, 1, 8, 3) l1_dist = (xx - yy).abs().sum(-1) gt_to_pred = l1_dist.min(1).values.mean(-1) pred_to_gt = l1_dist.min(2).values.mean(-1) l1 = 0.1 * pred_to_gt + gt_to_pred return l1 def obb2voxel(obb_v, vD, vH, vW, voxel_extent, num_class, splat_sigma=2): """ Inputs: obb_v : ObbTW : shaped BxNx34 of obbs in voxel coordinates. vD : voxel depth vH : voxel height vW : voxel width: voxel_extent: size of voxel grid in meters, with order W, H, D num_class: number of classes to detect splat_sigma: how big to splat the Obbs returns: cent_gt : torch.Tensor : Bx1xDxHxW tensor of splatted 2D points bbox_gt : torch.Tensor : Bx7xDxHxW tensor of bounding box params clas_gt : torch.Tensor : Bxnum_classxDxHxW one hot tensor of class valid_gt : torch.Tensor : Bx1xDxHxW bool tensor of where splatting is valid """ B = obb_v.shape[0] device = obb_v.device cent_gt = torch.zeros((B, 1, vD, vH, vW), device=device) bbox_gt = torch.zeros((B, 7, vD, vH, vW), device=device) clas_gt = torch.zeros((B, num_class, vD, vH, vW), device=device) # Where to apply non-centerness losses. valid_gt = torch.zeros((B, 1, vD, vH, vW), device=device, dtype=torch.bool) # Gaussian kernel for splatting. size = 2 * splat_sigma + 1 rng = torch.arange(0, size, 1).to(device) xx, yy, zz = torch.meshgrid(rng, rng, rng, indexing="ij") x0 = y0 = z0 = size // 2 eps = 1e-6 gauss = torch.exp( -((xx - x0) ** 2 + (yy - y0) ** 2 + (zz - z0) ** 2) / (2 * splat_sigma**2 + eps) ) # Convert obb centers to voxel indices. cent_v = obb_v.bb3_center_world cent_vn = normalize_coord3d(cent_v, voxel_extent) inds, inside = norm2ind(cent_vn, vD, vH, vW) # Get index offsets for splatting. if splat_sigma == 0: dd = torch.tensor([0]).reshape(1, 1, 1).to(device) hh = torch.tensor([0]).reshape(1, 1, 1).to(device) ww = torch.tensor([0]).reshape(1, 1, 1).to(device) elif splat_sigma > 0: rng_d = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device) rng_h = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device) rng_w = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device) dd, hh, ww = torch.meshgrid(rng_d, rng_h, rng_w, indexing="ij") else: raise ValueError("splat sigma most be non-negative") offsets_dhw = torch.stack((dd.reshape(-1), hh.reshape(-1), ww.reshape(-1)), dim=-1) offsets_dhw = offsets_dhw.unsqueeze(0).repeat(B, 1, 1) # Use broadcasting to apply the offset indices to the voxel indices. O = offsets_dhw.shape[1] N = inds.shape[1] inds_dhw = inds.reshape(B, N, 1, 3) + offsets_dhw.reshape(B, 1, O, 3) inds_dhw = inds_dhw.reshape(B, N * O, 3) inside = inside.reshape(B, N, 1).repeat(1, 1, O).reshape(B, N * O).float() # Avoid accessing OOB. ones = torch.ones_like(inds_dhw[:, :, 0]) inds_dhw[:, :, 0] = torch.maximum(inds_dhw[:, :, 0], 0 * ones) inds_dhw[:, :, 1] = torch.maximum(inds_dhw[:, :, 1], 0 * ones) inds_dhw[:, :, 2] = torch.maximum(inds_dhw[:, :, 2], 0 * ones) inds_dhw[:, :, 0] = torch.minimum(inds_dhw[:, :, 0], (vD - 1) * ones) inds_dhw[:, :, 1] = torch.minimum(inds_dhw[:, :, 1], (vH - 1) * ones) inds_dhw[:, :, 2] = torch.minimum(inds_dhw[:, :, 2], (vW - 1) * ones) # keep the (d, h, w) indices before flattening inds_dhw_3d = inds_dhw.clone() # Convert D,H,W indices into flat array indices. inds_d = inds_dhw[:, :, 0] inds_h = inds_dhw[:, :, 1] inds_w = inds_dhw[:, :, 2] inds_dhw = inds_d * (vH * vW) + inds_h * vW + inds_w b_inds = torch.arange(B).reshape(-1, 1).repeat(1, N * O) # Set centerness GT. cent_gt = cent_gt.reshape(B, -1) gauss = gauss.reshape(1, 1, -1).repeat(B, N, 1).reshape(B, N * O) cent_gt[b_inds, inds_dhw] = gauss * inside cent_gt = cent_gt.reshape(B, 1, vD, vH, vW) # Semantic class. CL = num_class sem_id = torch.clip(obb_v.sem_id, 0, CL - 1).long() one_hot = torch.nn.functional.one_hot(sem_id, num_classes=CL) one_hot = one_hot.reshape(B, N, CL).permute(0, 2, 1) one_hot = one_hot.reshape(B, CL, N, 1).repeat(1, 1, 1, O) one_hot = one_hot.reshape(B, CL, N * O) val = one_hot * inside.reshape(B, 1, -1) clas_gt = clas_gt.reshape(B, -1, vD * vH * vW) b_inds_rep = b_inds.reshape(B, 1, -1).repeat(1, CL, 1) cl_inds_rep = torch.arange(CL).reshape(1, CL, 1).repeat(B, 1, O * N) inds_dhw_rep = inds_dhw.reshape(B, 1, -1).repeat(1, CL, 1) clas_gt[b_inds_rep, cl_inds_rep, inds_dhw_rep] = val clas_gt = clas_gt.reshape(B, -1, vD, vH, vW) # Get gravity aligned rotation from obb T_voxel_object = obb_v.T_world_object.clone() # HACK to avoid gimbal lock for padded entries. mask = obb_v.get_padding_mask() T_voxel_object.R[mask] = PAD_VAL rpy = T_voxel_object.to_euler() yaw = rpy[:, :, 2] # BBox size (in voxel coordinates.) bb3 = obb_v.bb3_object xsize = bb3[:, :, 1] - bb3[:, :, 0] ysize = bb3[:, :, 3] - bb3[:, :, 2] zsize = bb3[:, :, 5] - bb3[:, :, 4] # Discretized centers. centd_vn = ind2norm(inds_dhw_3d, vD, vH, vW) centd_v = unnormalize_coord3d(centd_vn, voxel_extent) # Compute offset between discretized centers and obb centers. cent_v_rep = cent_v.reshape(B, -1, 1, 3).repeat(1, 1, O, 1).reshape(B, N * O, 3) offsets = centd_v - cent_v_rep xoff = offsets[:, :, 0] yoff = offsets[:, :, 1] zoff = offsets[:, :, 2] # Splat via repeat. xsize = xsize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O) ysize = ysize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O) zsize = zsize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O) yaw = yaw.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O) # Assign bbox parameters into voxel GT. bbox_gt = bbox_gt.reshape(B, 7, -1) BB = bbox_gt.shape[1] bb_inds = torch.arange(BB).reshape(-1, 1).repeat(1, N * O) bbox_gt[b_inds, bb_inds[0, :], inds_dhw] = xsize * inside bbox_gt[b_inds, bb_inds[1, :], inds_dhw] = ysize * inside bbox_gt[b_inds, bb_inds[2, :], inds_dhw] = zsize * inside bbox_gt[b_inds, bb_inds[3, :], inds_dhw] = xoff * inside bbox_gt[b_inds, bb_inds[4, :], inds_dhw] = yoff * inside bbox_gt[b_inds, bb_inds[5, :], inds_dhw] = zoff * inside bbox_gt[b_inds, bb_inds[6, :], inds_dhw] = yaw * inside bbox_gt = bbox_gt.reshape(B, 7, vD, vH, vW) # Set valid mask. valid_gt = valid_gt.reshape(B, -1) valid_gt[b_inds, inds_dhw] = inside.bool() valid_gt = valid_gt.reshape(B, 1, vD, vH, vW) return cent_gt, bbox_gt, clas_gt, valid_gt def voxel2obb( cent_pr, bbox_pr, clas_pr, voxel_extent, top_k=None, thresh=None, return_full_prob=False, ): """Convert 3D centerness, size, rotation voxel grids to ObbTW objects, returning objects in the voxel coordinate frame. Can optionally threshold based on a topK predictions. """ device = cent_pr.device assert cent_pr.ndim == 5 B, _, vD, vH, vW = cent_pr.shape device = cent_pr.device # Get extent. xhalf = bbox_pr[:, 0] / 2.0 yhalf = bbox_pr[:, 1] / 2.0 zhalf = bbox_pr[:, 2] / 2.0 bb3 = torch.stack( [ -xhalf, +xhalf, -yhalf, +yhalf, -zhalf, +zhalf, ], dim=-1, ) # Get rotation to set T_world_object. yaw = bbox_pr[:, 6] zeros = torch.zeros_like(yaw) e_angles = torch.stack([zeros, zeros, yaw], dim=-1) R = rotation_from_euler(e_angles.reshape(-1, 3)) R = R.reshape(B, vD, vH, vW, 3, 3) t_zero = torch.zeros(B, vD, vH, vW, 3).to(device) T_voxel_object = PoseTW.from_Rt(R, t_zero) rngd = torch.arange(vD).to(device) rngh = torch.arange(vH).to(device) rngw = torch.arange(vW).to(device) xx, yy, zz = torch.meshgrid(rngd, rngh, rngw, indexing="ij") inds = torch.stack([xx.reshape(-1), yy.reshape(-1), zz.reshape(-1)], dim=-1) norm_centers = ind2norm(inds, vD, vH, vW) centers_v = unnormalize_coord3d(norm_centers, voxel_extent) centers_v = centers_v.reshape(1, vD, vH, vW, 3).repeat(B, 1, 1, 1, 1) # The center is defined as the voxel center + the offset. xoff = bbox_pr[:, 3] yoff = bbox_pr[:, 4] zoff = bbox_pr[:, 5] t_off = torch.stack([xoff, yoff, zoff], dim=-1) T_voxel_object.t[:] = centers_v - t_off # Get prob. prob = cent_pr.reshape(B, vD, vH, vW, 1) N = inds.shape[0] # Get instance id, use voxel location for this. inst_id = torch.arange(N).reshape(1, vD, vH, vW, 1).repeat(B, 1, 1, 1, 1) # Get semantic id sem_id = torch.argmax(clas_pr, dim=1).unsqueeze(-1) # Construct ObbTW object. obbs = ObbTW.from_lmc( bb3_object=bb3, T_world_object=T_voxel_object, prob=prob, inst_id=inst_id, sem_id=sem_id, ) # Optionally remove detections below threshold. if thresh is not None: below = (obbs.prob < thresh).squeeze(-1) obbs._data[below, :] = PAD_VAL # Optionally subselect top K. if top_k is not None: prob = obbs.prob.reshape(B, N) s_vals, s_inds = torch.sort(prob, dim=1, descending=True) n_inds = s_inds[:, :top_k].reshape(-1) b_inds = torch.arange(B).reshape(B, 1).repeat(1, top_k).to(device).reshape(-1) obbs = obbs.reshape(B, N, -1) # B x K x 34 obbs = obbs[b_inds, n_inds].reshape(B, top_k, -1) # B x K prob = prob[b_inds, n_inds].reshape(B, top_k) # B x K x C clas_pr = clas_pr.reshape(B, -1, N)[b_inds, :, n_inds].reshape(B, top_k, -1) if return_full_prob: return obbs, prob, clas_pr else: return obbs ================================================ FILE: efm3d/utils/evl_loss.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_IMG_T_SNIPPET_RIG, ARIA_OBB_PADDED, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.aria.obb import obb_filter_outside_volume, obb_time_union, ObbTW, PAD_VAL from efm3d.thirdparty.mmdetection3d.iou3d import RotatedIoU3DLoss from efm3d.utils.detection_utils import ( compute_chamfer_loss, compute_focal_loss, obb2voxel, voxel2obb, ) from efm3d.utils.pointcloud import get_points_world from efm3d.utils.reconstruction import compute_occupancy_loss_subvoxel, compute_tv_loss def get_gt_obbs(batch, voxel_extent, T_wv=None): """ Get the GT Obbs from the batch. voxel_extent: used to filter GT Obbs outside of voxel grid. T_wv: if not None, filter GT Obbs outside of voxel grid. """ if ARIA_OBB_PADDED not in batch: B = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].shape[0] return ObbTW().view(1, -1).repeat(B, 1) obbs_gt = batch[ARIA_OBB_PADDED].clone() # Optionally filter GT. if batch[ARIA_OBB_PADDED].ndim == 4: # Filter by time Union. obbs_gt = obb_time_union(obbs_gt) if T_wv is not None: # Filter outside of voxel grid. T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1) obbs_gt = obb_filter_outside_volume( obbs_gt, T_ws, T_wv, voxel_extent=voxel_extent ) return obbs_gt def obbs_to_7d(obbs): obbs_cent = obbs.bb3_center_world # center in voxel coords wlh = obbs.bb3_max_object - obbs.bb3_min_object # Get gravity aligned rotation from obb T_voxel_object = obbs.T_world_object.clone() # HACK to avoid gimbal lock for padded entries. mask = obbs.get_padding_mask() T_voxel_object.R[mask] = PAD_VAL rpy = T_voxel_object.to_euler() yaw = rpy[..., 2].unsqueeze(-1) obbs_7d = torch.concat([obbs_cent, wlh, yaw], dim=-1) return obbs_7d def iou_3d_loss(obbs_pr, obbs_gt, cent_pr, cent_gt, valid_gt): """ obbs_pr: N x 34 obbs_gt: N x 34 """ assert obbs_pr.ndim == 2 and obbs_gt.ndim == 2, "obbs dimension should be Nx34" obbs_pr_7d = obbs_to_7d(obbs_pr) obbs_gt_7d = obbs_to_7d(obbs_gt) iou_loss = RotatedIoU3DLoss(loss_weight=1.0) # weighted by validness and GT centerness obbs_weight = cent_gt.reshape(-1) * valid_gt.reshape(-1) valid_idx = torch.nonzero(obbs_weight > 0).squeeze() obbs_weight = obbs_weight[valid_idx] obbs_pr_7d = obbs_pr_7d[valid_idx, :] obbs_gt_7d = obbs_gt_7d[valid_idx, :] loss = obbs_weight * iou_loss.forward(obbs_pr_7d, obbs_gt_7d) loss = loss.mean() return loss def compute_obb_losses( outputs, batch, voxel_extent, num_class, splat_sigma, cent_weight, clas_weight, iou_weight, bbox_weight, cham_weight, ): B, _, vD, vH, vW = outputs["cent_pr"].shape ve = voxel_extent N = vD * vH * vW T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1) T_wv = outputs["voxel/T_world_voxel"] obb_gt_s = get_gt_obbs(batch, voxel_extent, T_wv) # Put GT in voxel coordinate frame. T_vs = T_wv.inverse() @ T_ws obb_gt_v = obb_gt_s.transform(T_vs.unsqueeze(1)) # Create 3D GT tensors. cent_gt, bbox_gt, clas_gt, valid_gt = obb2voxel( obb_gt_v, vD, vH, vW, ve, num_class, splat_sigma ) outputs["cent_gt"] = cent_gt outputs["bbox_gt"] = bbox_gt outputs["clas_gt"] = clas_gt outputs["valid_gt"] = valid_gt # Get Obbs from densified predictions + GT. cent_pr = outputs["cent_pr"] bbox_pr = outputs["bbox_pr"] clas_pr = outputs["clas_pr"] obbs_pr_dense = voxel2obb(cent_pr, bbox_pr, clas_pr, ve, top_k=None, thresh=None) obbs_gt_dense = voxel2obb(cent_gt, bbox_gt, clas_gt, ve, top_k=None, thresh=None) obbs_pr_dense = obbs_pr_dense.reshape(B * N, -1) obbs_gt_dense = obbs_gt_dense.reshape(B * N, -1) outputs["obbs_gt_dense"] = obbs_gt_dense losses = {"rgb": {}} total_loss = 0.0 # Centerness loss. if cent_weight > 0: cent_pr = outputs["cent_pr"] cent_loss = compute_focal_loss(cent_pr, cent_gt) cent_loss = cent_loss.reshape(B, -1) cent_loss = cent_loss.mean() cent_loss = cent_loss * cent_weight losses["rgb"]["cent"] = cent_loss total_loss += cent_loss # Classification loss. if clas_weight > 0: clas_pr = outputs["clas_pr"] clas_loss = compute_focal_loss(clas_pr, clas_gt) clas_loss = clas_loss.sum(dim=1).reshape(-1) clas_loss[~valid_gt.reshape(-1)] = 0.0 clas_loss = torch.sum(clas_loss) / (valid_gt.sum() + 1) clas_loss = clas_loss * clas_weight losses["rgb"]["clas"] = clas_loss total_loss += clas_loss # 3D IoU loss (gravity aligned 7 DoF loss). if iou_weight > 0: iou_loss = iou_3d_loss(obbs_pr_dense, obbs_gt_dense, cent_pr, cent_gt, valid_gt) iou_loss = iou_loss * iou_weight losses["rgb"]["iou"] = iou_loss total_loss += iou_loss # Supervise directly on D, H, W dimensions with L1 loss. if bbox_weight > 0: dhw_gt = obbs_gt_dense.bb3_diagonal dhw_pr = obbs_pr_dense.bb3_diagonal bbox_loss = torch.mean(torch.abs(dhw_pr - dhw_gt), dim=-1) bbox_loss[~valid_gt.reshape(-1)] = 0.0 bbox_loss = torch.sum(bbox_loss) / (valid_gt.sum() + 1) bbox_loss = bbox_loss * bbox_weight losses["rgb"]["bbox"] = bbox_loss total_loss += bbox_loss # Chamfer loss for rotation. if cham_weight > 0: corners_pr = obbs_pr_dense.bb3corners_world # world is voxel corners_gt = obbs_gt_dense.bb3corners_world # world is voxel cham_loss = compute_chamfer_loss(corners_pr, corners_gt) cham_loss[~valid_gt.reshape(-1)] = 0.0 cham_loss = torch.sum(cham_loss) / (valid_gt.sum() + 1) cham_loss = cham_loss * cham_weight losses["rgb"]["cham"] = cham_loss total_loss += cham_loss return losses, total_loss def compute_occ_losses( outputs, batch, voxel_extent, occ_weight, tv_weight, ): B, T, vD, vH, vW = outputs["occ_pr"].shape T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1) T_wv = outputs["voxel/T_world_voxel"] losses = {"rgb": {}} total_loss = 0.0 p3s_w, dist_stds = get_points_world(batch) # Occupancy loss. cams = batch[ARIA_CALIB[0]] Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]] T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] Ts_wr = T_ws @ Ts_sr Ts_cw = cams.T_camera_rig @ Ts_wr.inverse() Ts_wc = Ts_cw.inverse() occ = outputs["occ_pr"].squeeze(1) voxel_counts = outputs["voxel/counts"] B, D, H, W = occ.shape B, Df, Hf, Wf = voxel_counts.shape if D != Df or H != Hf or W != Wf: resize = torch.nn.Upsample(size=(D, H, W)) voxel_counts = resize(voxel_counts.unsqueeze(1).float()).squeeze(1) visible = voxel_counts > 0 if occ_weight > 0: occ_loss = compute_occupancy_loss_subvoxel( occ, visible, p3s_w, Ts_wc, cams, T_wv, voxel_extent, loss_type="l2", ) occ_loss = occ_loss * occ_weight total_loss += occ_loss losses["rgb"]["occ"] = occ_loss.cpu().detach() if tv_weight > 0.0: tv_loss = compute_tv_loss(occ) tv_loss = tv_loss * tv_weight total_loss += tv_loss losses["rgb"]["tv"] = tv_loss.cpu().detach() return losses, total_loss ================================================ FILE: efm3d/utils/file_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gzip import json import os import pickle import random from bisect import bisect_left from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union import fsspec import numpy as np import pandas as pd import pyvrs import torch import tqdm from efm3d.aria import CameraTW, PoseTW from efm3d.aria.aria_constants import ARIA_CAM_INFO, ARIA_OBB_BB2, ARIA_OBB_BB3 from efm3d.utils.rescale import rescale_camera_tw, rescale_image from pyquaternion import Quaternion from pyvrs import SyncVRSReader from vrsbindings import ImageConversion, RecordType def load_gt_calibration( calib_path: Union[str, dict], load_torch=False, timestamps=None ): """load ground truth calibration json from simulation""" if isinstance(calib_path, str): with fsspec.open(calib_path, "r") as f: calib = json.load(f) elif isinstance(calib_path, dict): calib = calib_path else: raise IOError("calib_path must be str or dict") gt_calib = {} gt_calib["T_rig_views"] = {} gt_calib["intr_type"] = {} gt_calib["intr_params"] = {} cam_names = ARIA_CAM_INFO["name"] # Maps names from the gt_calib.json file to the ARIA_CAM_INFO convention. name_map = { "camera-rgb": cam_names[0], "camera-slam-left": cam_names[1], "camera-slam-right": cam_names[2], } for camera in calib["CameraCalibrations"]: cn = camera["Label"] if cn not in name_map: # Ignore other cameras like eye tracking. continue cam_name = name_map[cn] [tx, ty, tz] = camera["T_Device_Camera"]["Translation"] [qw, [qx, qy, qz]] = camera["T_Device_Camera"]["UnitQuaternion"] rot_mat = Quaternion(qw, qx, qy, qz).rotation_matrix translation = torch.tensor([tx, ty, tz]).view(3, 1) T_rig_view = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_rig_view = PoseTW.from_matrix3x4(T_rig_view) T_rig_view = T_rig_view.fit_to_SO3() if not load_torch: T_rig_view = T_rig_view.numpy() gt_calib["T_rig_views"][cam_name] = T_rig_view intr_type = camera["Projection"]["Name"] # This is the case for Fisheye62 which has 6+2+3=11 parameters, morphed as Fisheye624 # Add zeros to make it 15 params (same as Fisheye624) if intr_type == "Fisheye624": N = 15 - len(camera["Projection"]["Params"]) if N > 0: for _i in range(N): camera["Projection"]["Params"].append(0) intr_params = np.array(camera["Projection"]["Params"]) if load_torch: intr_params = torch.from_numpy(intr_params) gt_calib["intr_type"][cam_name] = intr_type gt_calib["intr_params"][cam_name] = intr_params if timestamps is not None: time2calib = {} for timestamp in timestamps: time2calib[timestamp] = gt_calib return time2calib return gt_calib def get_image_info(image_reader: SyncVRSReader) -> Tuple[Dict, Dict]: """ Get image info such as sizes and frame rate. These fields are not part of calibration so we have to query them through VRSReader. """ image_sizes = {} fps = {} image_config_reader = image_reader.filtered_by_fields( record_types=["configuration"] ) for image_config in image_config_reader: assert image_config.record_type == "configuration" stream_id = image_config.stream_id if stream_id not in ARIA_CAM_INFO["id_to_name"]: continue name = ARIA_CAM_INFO["id_to_name"][stream_id] metadata = image_config.metadata_blocks[0] image_sizes[name] = metadata["image_height"], metadata["image_width"] fps[name] = metadata["nominal_rate"] return image_sizes, fps def load_factory_calib( reader: SyncVRSReader, calib: Optional[str] = None, map_radius_to_cam_height: bool = False, ): """ Augment `load_gt_calibration` by adding `image_sizes`, `camera_tw` (CameraTW objects), and `fps` for each camera. The reader has to be an image VRSReader. video_stream_name is needed for eye tracking images. Unlike slaml and slamr where their vrs ids are 1201-1 and 1201-2, eye tracking vrs id is only 211-1 for both left and right eye images """ image_sizes, fps = get_image_info(reader) if "calib_json" in reader.file_tags: calib = json.loads(reader.file_tags["calib_json"]) elif calib is None: return None cam_calib = load_gt_calibration(calib, load_torch=True, timestamps=None) cam_calib["image_sizes"] = image_sizes cam_calib["fps"] = fps cam_calib["camera_tw"] = {} # Hack to override the camera model instead of using cam_calib["intr_type"][cam_name] which is set to "Fisheye62" for cam_name in image_sizes: if map_radius_to_cam_height: cam_calib["camera_tw"][cam_name] = CameraTW.from_surreal( height=image_sizes[cam_name][0], width=image_sizes[cam_name][1], type_str=cam_calib["intr_type"][cam_name], params=cam_calib["intr_params"][cam_name], T_camera_rig=cam_calib["T_rig_views"][cam_name].inverse(), valid_radius=image_sizes[cam_name][0], ) else: cam_calib["camera_tw"][cam_name] = CameraTW.from_surreal( height=image_sizes[cam_name][0], width=image_sizes[cam_name][1], type_str=cam_calib["intr_type"][cam_name], params=cam_calib["intr_params"][cam_name], T_camera_rig=cam_calib["T_rig_views"][cam_name].inverse(), ) return cam_calib def load_2d_bounding_boxes(bb2d_path, time_in_secs=False): bb2ds = {} try: with fsspec.open(bb2d_path).open() as f: # genfromtxt handles missing values and lets us specify dtypes. # #Object_UID, timestamp [nanoseconds], x_min [pixel], x_max [pixel], y_min [pixel], y_max [pixel] lines = np.genfromtxt( f, dtype=[int] * 2 + [float] * 4, names=True, delimiter=",", usecols=range(6), ) except Exception: try: # sometimes the last row is bad for some reason so we just skip it with fsspec.open(bb2d_path).open() as f: # genfromtxt handles missing values and lets us specify dtypes. # #Object_UID, timestamp [nanoseconds], x_min [pixel], x_max [pixel], y_min [pixel], y_max [pixel] lines = np.genfromtxt( f, dtype=[int] * 2 + [float] * 4, names=True, delimiter=",", usecols=range(6), skip_footer=1, ) except Exception as e: print(f"could not load {bb2d_path}; error {e}") return bb2ds count = 0 for line in lines: object_id = line[0] timestamp_ns = line[1] if time_in_secs: timestamp = timestamp_ns / 1e9 else: timestamp = timestamp_ns x_min = max(0, line[2]) x_max = max(0, line[3]) y_min = max(0, line[4]) y_max = max(0, line[5]) # invalid entries will have nan as fill value; we skip them. if any(x != x for x in [x_min, x_max, y_min, y_max]): continue if timestamp not in bb2ds: bb2ds[timestamp] = [(object_id, x_min, x_max, y_min, y_max)] else: bb2ds[timestamp].append((object_id, x_min, x_max, y_min, y_max)) count += 1 print(f"loaded {count} 2d bbs for {len(bb2ds)} timestamps from {bb2d_path}") return bb2ds def load_2d_bounding_boxes_adt(bb2d_path): bb2ds_rgb = {} bb2ds_slaml = {} bb2ds_slamr = {} with fsspec.open(bb2d_path).open() as f: lines = f.readlines() # expected header: # stream_id,object_uid,timestamp[ns],x_min[pixel],x_max[pixel],y_min[pixel],y_max[pixel],visibility_ratio[%]\n' count = 0 for ii, line in enumerate(lines): if ii == 0: continue # skip header line = line.decode("utf-8").rstrip().split(",") device_id = str(line[0]) object_id = int(line[1]) timestamp = int(line[2]) # ns x_min = max(0, float(line[3])) x_max = max(0, float(line[4])) y_min = max(0, float(line[5])) y_max = max(0, float(line[6])) # invalid entries will have nan as fill value; we skip them. if any(x != x for x in [x_min, x_max, y_min, y_max]): continue if device_id == "214-1": if timestamp not in bb2ds_rgb: bb2ds_rgb[timestamp] = [(object_id, x_min, x_max, y_min, y_max)] else: bb2ds_rgb[timestamp].append((object_id, x_min, x_max, y_min, y_max)) elif device_id == "1201-1": if timestamp not in bb2ds_slaml: bb2ds_slaml[timestamp] = [(object_id, x_min, x_max, y_min, y_max)] else: bb2ds_slaml[timestamp].append((object_id, x_min, x_max, y_min, y_max)) elif device_id == "1201-2": if timestamp not in bb2ds_slamr: bb2ds_slamr[timestamp] = [(object_id, x_min, x_max, y_min, y_max)] else: bb2ds_slamr[timestamp].append((object_id, x_min, x_max, y_min, y_max)) else: raise IOError("unexpected device id {device_id} in 2d observations") count += 1 print( f"loaded {count} 2d bbs for {len(bb2ds_rgb)}[rgb] {len(bb2ds_slaml)}[slaml] {len(bb2ds_slamr)}[slamr] timestamps from {bb2d_path}" ) return bb2ds_rgb, bb2ds_slaml, bb2ds_slamr def remove_invalid_2d_bbs(timed_bb2s, filter_bb2_area=-1): """ remove bbs with x, y <= 0. In some datasets (DlrSim) these 2d bbs indicate object is not visible! """ bb2s_filtered = defaultdict(list) for time, bb2s in timed_bb2s.items(): for bb2 in bb2s: if not ((bb2[1] <= 0 and bb2[2] <= 0) or (bb2[3] <= 0 and bb2[4] <= 0)): if filter_bb2_area > 0: bb2_area = (bb2[2] - bb2[1]) * (bb2[4] - bb2[3]) if bb2_area >= filter_bb2_area: bb2s_filtered[time].append(bb2) else: bb2s_filtered[time].append(bb2) return bb2s_filtered def load_instances(instances_path): instance2proto = {} assert os.path.exists(instances_path), ( f"instances path {instances_path} does not exist" ) with open(instances_path, "r") as f: lines = f.readlines() for line in lines[1:]: # skip first line line = line.rstrip().split(",") instance_uid = int(line[0]) prototype_uid = str(line[1]).strip() instance2proto[instance_uid] = prototype_uid return instance2proto def load_instances_adt(instances_path): instance2proto = {} with fsspec.open(instances_path).open() as f: content = json.load(f) for inst_id in content: instance2proto[int(inst_id)] = content[inst_id]["category"] # lot of other info available, for example: # {'instance_id': 5691266090916432, 'instance_name': 'Hook_4', # 'prototype_name': 'Hook', 'category': 'hook', 'category_uid': 643, # 'motion_type': 'static', 'instance_type': 'object', 'rigidity': 'rigid', # 'rotational_symmetry': {'is_annotated': False}, # 'canonical_pose': {'up_vector': [0, 1, 0], 'front_vector': [0, 0, 1]}} return instance2proto def load_3d_bounding_box_transforms(scene_path, time_in_secs=False, load_torch=False): T_world_object = {} with fsspec.open(scene_path).open() as f: lines = np.genfromtxt( f, dtype=[int] * 2 + [float] * 7, names=True, delimiter=",", usecols=range(9), ) if lines.size == 1: lines = lines[np.newaxis] for line in lines: object_id = line[0] timestamp_ns = line[1] if time_in_secs and timestamp_ns != -1: timestamp = timestamp_ns / 1e9 else: timestamp = timestamp_ns tx = line[2] ty = line[3] tz = line[4] qw = line[5] qx = line[6] qy = line[7] qz = line[8] # invalid entries will have nan as fill value; we skip them. if any(x != x for x in [tx, ty, tz, qw, qx, qy, qz]): continue rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix translation = torch.tensor([tx, ty, tz]).view(3, 1) T_wo = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_wo = PoseTW.from_matrix3x4(T_wo) T_wo = T_wo.fit_to_SO3() if not load_torch: T_wo = T_wo.numpy() if timestamp not in T_world_object: T_world_object[timestamp] = {} T_world_object[timestamp][object_id] = T_wo return T_world_object def load_3d_bounding_box_local_extents(bb3d_path, load_torch=False): bb3ds_local = {} with fsspec.open(bb3d_path).open() as f: # Object UID, Timestamp ( ns ), p_local_obj.xmin, p_local_obj.xmax, p_local_obj.ymin, p_local_obj.ymax, p_local_obj.zmin, p_local_obj.zmax lines = np.genfromtxt( f, dtype=[int] * 2 + [float] * 6, names=True, delimiter=",", usecols=range(8), ) if lines.size == 1: lines = lines[np.newaxis] for line in lines: object_id = line[0] xmin = line[2] xmax = line[3] ymin = line[4] ymax = line[5] zmin = line[6] zmax = line[7] # invalid entries will have nan as fill value; we skip them. if any(x != x for x in [xmin, xmax, ymin, ymax, zmin, zmax]): continue local = np.array([xmin, xmax, ymin, ymax, zmin, zmax]) if load_torch: local = torch.from_numpy(local) bb3ds_local[object_id] = local return bb3ds_local def load_obbs_gt( input_dir, load_2d_bbs=True, filter_outside_2d_bbs: bool = False, rgb_only=False, filter_bb2_area=-1, ): obs = {} if load_2d_bbs: # Load 2d bbs from CSV. bb2s_path_rgb = exists_nonzero_path( [ os.path.join(input_dir, "2d_bounding_box.csv"), os.path.join(input_dir, "2d_bounding_box_rgb.csv"), os.path.join(input_dir, "sensor_0_2d_bounding_box.csv"), ] ) bb2s_path_slaml, bb2s_path_slamr = False, False if not rgb_only: bb2s_path_slaml = exists_nonzero_path( [ os.path.join(input_dir, "2d_bounding_box_2.csv"), os.path.join(input_dir, "2d_bounding_box_left_slam.csv"), os.path.join(input_dir, "sensor_1_2d_bounding_box.csv"), ] ) bb2s_path_slamr = exists_nonzero_path( [ os.path.join(input_dir, "2d_bounding_box_3.csv"), os.path.join(input_dir, "2d_bounding_box_right_slam.csv"), os.path.join(input_dir, "sensor_2_2d_bounding_box.csv"), ] ) bb2_loaded = False if bb2s_path_rgb: # ADT dataset packs all three bb2 observations into one file with fsspec.open(bb2s_path_rgb).open() as f: header = f.readline() header = str(header).split(",") if len(header) == 8: ( obs[ARIA_OBB_BB2[0]], obs[ARIA_OBB_BB2[1]], obs[ARIA_OBB_BB2[2]], ) = load_2d_bounding_boxes_adt(bb2s_path_rgb) bb2_loaded = True if not bb2_loaded and bb2s_path_rgb and bb2s_path_slaml and bb2s_path_slamr: # Load 2d bounding boxes separately for three cameras obs[ARIA_OBB_BB2[0]] = load_2d_bounding_boxes( bb2s_path_rgb, time_in_secs=False ) obs[ARIA_OBB_BB2[1]] = load_2d_bounding_boxes( bb2s_path_slaml, time_in_secs=False ) obs[ARIA_OBB_BB2[2]] = load_2d_bounding_boxes( bb2s_path_slamr, time_in_secs=False ) bb2_loaded = True elif not bb2_loaded and bb2s_path_rgb: # sometimes we only have RGB 2d bounding boxes. obs[ARIA_OBB_BB2[0]] = load_2d_bounding_boxes( bb2s_path_rgb, time_in_secs=False ) obs[ARIA_OBB_BB2[1]] = {} obs[ARIA_OBB_BB2[2]] = {} bb2_loaded = True elif not bb2_loaded: print("Warning: could not find 2d bbs") return {} else: obs[ARIA_OBB_BB2[0]] = {} obs[ARIA_OBB_BB2[1]] = {} obs[ARIA_OBB_BB2[2]] = {} print("not loading 2d bb information") # most of the time bbs with x, y <= 0 indicate object is visible but we dont # know where. In the DlrSim dataset it indicates object not observed! if filter_outside_2d_bbs: for bb2_key in ARIA_OBB_BB2: obs[bb2_key] = remove_invalid_2d_bbs(obs[bb2_key], filter_bb2_area) # Load bounding box local 3D extents. bb3d_path = exists_nonzero_path( [ os.path.join(input_dir, "scene/3d_bounding_box.csv"), os.path.join(input_dir, "3d_bounding_box.csv"), ] ) if bb3d_path: obs[ARIA_OBB_BB3] = load_3d_bounding_box_local_extents(bb3d_path) # Load scene object centers + object_ids from scene_objects.csv scene_path = exists_nonzero_path( [ os.path.join(input_dir, "scene/scene_objects.csv"), os.path.join(input_dir, "scene_objects.csv"), ] ) if scene_path: obs["timedTs_world_object"] = load_3d_bounding_box_transforms( scene_path, time_in_secs=False, load_torch=True ) # Load label mapping from instances to prototypes. instance_path = exists_nonzero_path( [ # fixed some wrong 'rug' labels os.path.join(input_dir, "scene/instances_fix.csv"), os.path.join(input_dir, "scene/instances.csv"), os.path.join(input_dir, "instances.json"), ] ) if instance_path: if instance_path.endswith(".csv"): obs["inst2proto"] = load_instances(instance_path) elif instance_path.endswith(".json"): obs["inst2proto"] = load_instances_adt(instance_path) else: raise IOError("Unknown instances extension") return obs def load_trajectory_adt( traj_path, subsample: Union[float, int] = 1, load_first_n=99999999999, ): print("checking " + traj_path) fs = fsspec.get_mapper(traj_path).fs if not fs.exists(traj_path): return None if not fs.isfile(traj_path): traj_path = exists_nonzero_path( [ os.path.join(traj_path, "aria_trajectory.csv"), # ADT ground truth ] ) if traj_path is None: return None print("loading " + traj_path) T_world_rigs = {} # check for number of columns first with fsspec.open(traj_path, "r").open() as f: header = f.readline() num_cols = len(header.split(",")) if num_cols not in [20]: return None # load data without header with fsspec.open(traj_path, "rb").open() as f: lines = f.readlines() N = min(len(lines), load_first_n) idxs = sample_from_range(0, N, subsample) for ii in idxs: if ii == 0: continue # skip header line = lines[ii] line = str(line).split(",") timestamp_us = int(line[1]) timestamp_ns = timestamp_us * 1000 timestamp = timestamp_ns sub_line = line[3:10] tx, ty, tz, qx, qy, qz, qw = [float(e) for e in sub_line] rot_mat = Quaternion(qw, qx, qy, qz).rotation_matrix translation = torch.tensor([tx, ty, tz]).view(3, 1) T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_world_rig = PoseTW.from_matrix3x4(T_world_rig) T_world_rig = T_world_rig.fit_to_SO3() T_world_rigs[timestamp] = T_world_rig return T_world_rigs def load_trajectory_aeo( csv_path: str, load_torch: bool = False, subsample: Union[float, int] = 1, time_in_secs: bool = False, load_first_n: int = 99999999999, ): assert not time_in_secs, "Only support time in ns for now" vio_filenames = [ "closed_loop_framerate_trajectory.csv", "closed_loop_trajectory.csv", "mps/slam/closed_loop_trajectory.csv", ] lines = None for vio_filename in vio_filenames: traj_csv_path = os.path.join(csv_path, vio_filename) print("checking " + traj_csv_path) if os.path.exists(traj_csv_path): with open(traj_csv_path, "r") as f: lines = f.readlines() print(f"loaded {len(lines)} from " + traj_csv_path) break if lines is None: print(f"No file found in {csv_path}.") return None T_world_rigs = {} header = lines[0].strip().split(",") if len(header) not in {20, 26, 28, 29}: print( f"Invalid header, expected 20, 26, 28 or 29 columns, but got {len(header)}" ) print(header) return None start_index = 0 if len(header) in {20, 28}: # no recording_source field in this version start_index = -1 N = min(len(lines), load_first_n) idxs = sample_from_range(1, N, subsample) for ii in idxs: line = lines[ii] # Handle data error line = line.strip() if len(line) == 0: continue cols = line.split(",") timestamp_ns = int(cols[start_index + 2]) * 1000 tx, ty, tz, qx, qy, qz, qw = [ float(num) for num in cols[start_index + 4 : start_index + 11] ] rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix translation = torch.tensor([tx, ty, tz]).view(3, 1) T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_world_rig = PoseTW.from_matrix3x4(T_world_rig) # T_world_rig = T_world_rig.fit_to_SO3() if not load_torch: T_world_rig = T_world_rig.numpy() T_world_rigs[timestamp_ns] = T_world_rig return T_world_rigs def load_trajectory( traj_path, time_in_secs=False, load_torch=False, subsample: Union[float, int] = 1, load_quaternion=False, load_first_n=99999999999, ): print("checking " + traj_path) fs = fsspec.get_mapper(traj_path).fs if not fs.exists(traj_path): return None if not fs.isfile(traj_path): traj_path = exists_nonzero_path( [ os.path.join(traj_path, "trajectory.csv"), # default os.path.join(traj_path, "traj000.csv"), # ASE ] ) if traj_path is None: return None print("loading " + traj_path) T_world_rigs = {} # check for number of columns first with fsspec.open(traj_path, "r").open() as f: header = f.readline() num_cols = len(header.split(",")) if num_cols not in [8, 14, 17]: return None # load data without header with fsspec.open(traj_path, "rb").open() as f: lines = np.loadtxt(f, delimiter=",", skiprows=1) N = min(len(lines), load_first_n) idxs = sample_from_range(0, N, subsample) for ii in idxs: line = lines[ii] timestamp_ns = int(line[0]) if time_in_secs: timestamp = timestamp_ns / 1e9 else: timestamp = timestamp_ns sub_line = line[1:8] tx, ty, tz, qw, qx, qy, qz = [float(e) for e in sub_line] rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix if load_quaternion: # allow "raw" loading T_world_rig = np.array([tx, ty, tz, qw, qx, qy, qz]) else: translation = torch.tensor([tx, ty, tz]).view(3, 1) T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_world_rig = PoseTW.from_matrix3x4(T_world_rig) T_world_rig = T_world_rig.fit_to_SO3() if not load_torch: T_world_rig = T_world_rig.numpy() T_world_rigs[timestamp] = T_world_rig return T_world_rigs def parse_global_name_to_id_csv(csv_path: str, verbose: bool = True) -> Dict[str, int]: """ Loads a csv with 2 columns: old_sem_name, sem_id and returns it as a dictionary of {old_sem_name: sem_id} """ global_name_to_id = None if len(csv_path) > 0: if verbose: print(f"trying to load taxonomy from csv at {csv_path}") with fsspec.open(csv_path) as f: global_name_to_id = dict( np.loadtxt( f, delimiter=",", skiprows=1, dtype={ "names": ("Object Name", "Object cls ID"), "formats": ("U30", int), }, ndmin=1, ) ) if verbose: print(f"loaded {len(global_name_to_id)} name-to-id mappings from csv.") return global_name_to_id def exists_nonzero_path(path: Union[str, list]) -> Optional[str]: """Helper function, iterate through paths to make sure exists; Input: paths - can be str or list of str Returns: found - if not found, return False, if found, return the path """ if isinstance(path, str): paths = [path] else: paths = path # Iterate through each path, breaking if good file is found. found = None for path in paths: try: fs = fsspec.core.url_to_fs(path)[0] except Exception as e: print(f"skipping {path}: {e}") continue if fs.exists(path): found = path break return found def get_timestamp_list_ns(reader, stream_id=None): if stream_id is None: filtered_reader = reader else: filtered_reader = reader.filtered_by_fields( stream_ids=[stream_id], record_types=["data"] ) time_list = filtered_reader.get_timestamp_list() # go from vrs output of float times in seconds to long times in nanoseconds timestamp_list = [int(t * 1e9) for t in time_list] return timestamp_list def sample_times(time_list: List, start_time: int, end_time: int) -> Tuple[int, int]: """ Sample timestamps within the interval [start_time,end_time] using binary search, making sure that at least one sample is taken. Inputs: time_list: list of sorted times start_time: float of start of range to sample end_time: float of end of range to sample Returns: (idx_i,idx_j): tuple of indices into time_list of sampled range. Suppose the anchor modality is IMU, Using `bisect_left` would give a time order like this: image1, image2, image3, image4, image5 | | | | | img_i ... img_(j-1) img_j | | imu1 imu2 | | audio1 audio2 """ idx_i = bisect_left(time_list, start_time) idx_j = bisect_left(time_list, end_time) # Make sure sampled image data is in between the start and end time if idx_j > idx_i: assert ( start_time <= time_list[idx_i] and time_list[idx_i] <= time_list[idx_j - 1] and time_list[idx_j - 1] < end_time ), ( f"start {start_time} end {end_time}, time_list[idx_i], {time_list[idx_i]}, time_list[idx_j], {time_list[idx_j]}" ) else: # make sure idx_j is greater than idx_i idx_j = max(idx_j, idx_i + 1) return idx_i, idx_j def sample_from_range( start: int, end: int, sample_rate: Union[float, int], add_random: bool = True, ) -> List[int]: """ sample from a range using defined sample_rate. Args: start (int): start of the range (inclusive). end (int): end of the range (exclusive). sample_rate (Union[float, int]): target sampling rate. add_random (bool): whether to add randomness to the final samples. Returns: list: a list of integers sampled from the range, in increasing order. Example: 1. sample_rate is integer. We just return range(start, end, sample_rate). 2. sample_rate is float. We first round up the sample_rate and then add the missing numbers from the reminder of the entire list. For example, we'd like to sample from [0, 1, 2, ..., 9] with sample rate 1.25, which will result in 8 samples. We first round the sample rate to 2, then get the list of numbers by range(0, 10, 2) which is [0, 2, 4, 6, 8]. And then we randomly get 3 numbers from the reminder of the list [1, 3, 5, 7, 9] to add to the final sample list. """ assert end >= start, "the end of the range must be greater than the start." assert sample_rate > 0, "sample rate must be positive." if end == start: print(f"[Warn] end equals start ({start}, {end}), return emply list") return [] # if sample rate is an integer, we just return the sampling by using sample_rate as the step size. if type(sample_rate) is int or sample_rate.is_integer(): return list(range(start, end, int(sample_rate))) # Otherwise, we do sampling with non-integer sampling rate. if (end - start) % sample_rate != 0: print( f"[WARN] sample_rate not divisible by for the range: got sample_rate {sample_rate}, start {start}, end {end}. Can not achieve the desired sampling rate in the end." ) step = int(np.ceil(sample_rate)) # round-up the sampling rate num = int((end - start) / sample_rate) # number of final samples # Generate the evenly spaced integers integers = list(range(start, end, step)) # If we don't have enough integers, sample the missing ones randomly if len(integers) < num: missing_num = num - len(integers) # Create a list of potential candidates that excludes already selected integers candidates = [i for i in range(start, end) if i not in integers] # Add the missing integers if add_random: integers.extend(random.sample(candidates, missing_num)) else: integers = list( np.linspace(start, end, num, endpoint=False).round().astype(int) ) return sorted(integers) def read_image_from_vrs( reader: pyvrs.filter.FilteredVRSReader, cam_id: str, image_ts_ns: int, intr_type: str, intr_params: Union[List, np.array], T_rig_camera: PoseTW, scale_down_images: int = 0, valid_radius: Optional[torch.Tensor] = None, wh_multiple_of: int = 16, ): """ Expect all the input time is in vrs capture time domain. """ cam_name = ARIA_CAM_INFO["id_to_name"][cam_id] # Read image from time-associated VRS block. ret_error = (None, None, None) try: # convert from nanoseconds to seconds for vrs reader image_ts = image_ts_ns / 1e9 record = reader.read_record_by_time( cam_id, image_ts, record_type=RecordType.DATA ) except ValueError as e: return ret_error if record is None: return ret_error if len(record.image_blocks) < 1: # Bad image block. return ret_error else: image = record.image_blocks[0] cam_hw_before = image.shape exposure_s = record.metadata_blocks[0]["exposure_duration_s"] gain = record.metadata_blocks[0]["gain"] # note that currently capture_time_ns is equal to image_ts_ns but this might # change (?) so we rely on this meta data instead and pass it back out. capture_time_ns = record.metadata_blocks[0]["capture_timestamp_ns"] cam = CameraTW.from_surreal( height=cam_hw_before[0], width=cam_hw_before[1], type_str=intr_type, params=intr_params, T_camera_rig=T_rig_camera.inverse(), exposure_s=exposure_s, gain=gain, valid_radius=valid_radius, ) image = rescale_image(image, cam_name, scale_down_images, wh_multiple_of) cam = rescale_camera_tw( cam, cam_hw_before, cam_name, scale_down_images, wh_multiple_of ) if image.ndim == 2: image = np.expand_dims(image, axis=2) image = image.transpose(2, 0, 1) # HxWxC -> CxHxW image = torch.tensor(image.astype(np.float32) / 255.0) return image, cam, capture_time_ns def read_image_snippet_from_vrs( image_reader: SyncVRSReader, cam_id: str, start_time_ns: int, end_time_ns: int, cam_calib, subsample: Union[float, int] = 1, scale_down_images: int = 0, valid_radius: Optional[torch.Tensor] = None, wh_multiple_of: int = 16, ): """ If time code mapping provided, assume the input time is the timecode time domain. Need to convert it to capture time domain to read data. Otherwise, the start_time_ns and end_time_ns need to be in the capture time domain. Output time domain is always aligned with the input time domain. """ image_reader.set_image_conversion(conversion=ImageConversion.NORMALIZE) filtered_reader = image_reader.filtered_by_fields( stream_ids=[cam_id], record_types=["data"] ) capture_time_list_ns = get_timestamp_list_ns(filtered_reader) img_i, img_j = sample_times(capture_time_list_ns, start_time_ns, end_time_ns) images = [] times_ns = [] cam_tws = [] frame_ids = [] sample_range = sample_from_range( img_i, img_j, sample_rate=subsample, add_random=False ) for i in sample_range: image, cam_tw, capture_image_time_ns = read_image_from_vrs( reader=filtered_reader, cam_id=cam_id, image_ts_ns=capture_time_list_ns[i], intr_type=cam_calib["intr_type"], intr_params=cam_calib["intr_params"], T_rig_camera=cam_calib["T_rig_views"], scale_down_images=scale_down_images, valid_radius=valid_radius, wh_multiple_of=wh_multiple_of, ) if ( image is not None and capture_image_time_ns is not None and cam_tw is not None ): images.append(image) times_ns.append(capture_image_time_ns) cam_tws.append(cam_tw) frame_ids.append(i) images = torch.stack(images) # Long to hold timestamp in ns to not lose accuracy times_ns = torch.LongTensor(times_ns) cam_tws = torch.stack(cam_tws) frame_ids = torch.LongTensor(frame_ids) return images, times_ns, cam_tws, frame_ids def load_global_points_csv( path: str, max_inv_depth_std: float = 0.001, min_observations: int = 5, ): print(f"loading global points from {path}") uid_to_p3 = {} uid_to_inv_dist_std = {} uid_to_dist_std = {} if path.split(".")[-1] == "gz" or "maps/maps_v1" in path: compression = "gzip" else: compression = None cache_path = path + ".pickle.gz" if not os.path.exists(cache_path): with fsspec.open(path, "rb") as f: csv = pd.read_csv(f, compression=compression) # filter by inverse distance std csv = csv[csv.inv_dist_std < max_inv_depth_std] if "num_observations" in csv.columns: csv = csv[csv.num_observations > min_observations] print(csv.columns) # select points and uids and return mapping uid_pts = csv[ ["uid", "inv_dist_std", "dist_std", "px_world", "py_world", "pz_world"] ] for row in tqdm.tqdm(uid_pts.values): uid = int(row[0]) inv_dist_std = float(row[1]) dist_std = float(row[2]) p3 = row[3:] uid_to_p3[uid] = p3 uid_to_inv_dist_std[uid] = inv_dist_std uid_to_dist_std[uid] = dist_std try: # cache points with gzip.open(cache_path, "wb") as f: pickle.dump(uid_to_p3, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(uid_to_inv_dist_std, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(uid_to_dist_std, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"Cached global points to {cache_path}") except: print("Failed to cache the semidense points, like a write permission issue") else: # load from the cached file with gzip.open(cache_path, "rb") as f: uid_to_p3 = pickle.load(f) uid_to_inv_dist_std = pickle.load(f) uid_to_dist_std = pickle.load(f) print(f"Loaded global points from cached file {cache_path}") uid_to_p3 = {uid: torch.from_numpy(p3) for uid, p3 in uid_to_p3.items()} return uid_to_p3, uid_to_inv_dist_std, uid_to_dist_std def load_semidense_observations(path: str): print(f"loading semidense observations from {path}") time_to_uids = defaultdict(list) uid_to_times = defaultdict(list) if path.split(".")[-1] == "gz" or "maps/maps_v1" in path: compression = "gzip" else: compression = None cache_path = path + ".pickle.gz" if not os.path.exists(cache_path): with fsspec.open(path, "rb") as f: csv = pd.read_csv(f, compression=compression) csv = csv[["uid", "frame_tracking_timestamp_us"]] for row in tqdm.tqdm(csv.values): uid = int(row[0]) time_ns = int(row[1]) * 1000 time_to_uids[time_ns].append(uid) uid_to_times[uid].append(time_ns) try: with gzip.open(cache_path, "wb") as f: pickle.dump(time_to_uids, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(uid_to_times, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"Cached semidense observations to {cache_path}") except: print( "Failed to cache the semidense observations, like a write permission issue" ) else: with gzip.open(cache_path, "rb") as f: time_to_uids = pickle.load(f) uid_to_times = pickle.load(f) print(f"Loaded semidense observations from cached file {cache_path}") return time_to_uids, uid_to_times ================================================ FILE: efm3d/utils/gravity.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import torch.nn.functional as F from efm3d.aria.pose import PoseTW, rotation_from_euler GRAVITY_DIRECTION_DLR = np.array([0.0, -1.0, 0.0], np.float32) GRAVITY_DIRECTION_VIO = np.array([0.0, 0.0, -1.0], np.float32) def get_transform_to_vio_gravity_convention(gravity_direction: np.array): """ Get transformation to map gravity_direction to (0,0,-1) as per our (and VIO/Temple) convention. """ # gravity_direction = (d1, d2, d3) (0,0,-1)^T; d1, d2, d3 column vectors of rotation matrix R_gravity_vio # -d3 = gravity_direction d3 = -gravity_direction.copy() # now construct an orthonormal basis for the rotation matrix # d1 is a vector thats orthogonal to gravity_direction by construction d1 = np.array( [ gravity_direction[2] - gravity_direction[1], gravity_direction[0], -gravity_direction[0], ] ) # get d2 via orthogonal direction vector to d3 and d1 d2 = np.cross(d3, d1) # get rotation matrix R_gravity_vio = np.concatenate( [d1[:, np.newaxis], d2[:, np.newaxis], d3[:, np.newaxis]], 1 ) assert (np.linalg.det(R_gravity_vio) - 1.0) < 1e-5 assert (((R_gravity_vio @ R_gravity_vio.transpose()) - np.eye(3)) < 1e-5).all() R_gravity_vio = torch.from_numpy(R_gravity_vio) # normalize to unit length R_gravity_vio = F.normalize(R_gravity_vio, p=2, dim=-2) R_vio_gravity = R_gravity_vio.transpose(1, 0) T_vio_gravity = PoseTW.from_Rt(R_vio_gravity, torch.zeros(3)) return T_vio_gravity def correct_adt_mesh_gravity(mesh): """ Change gravity direction of ADT mesh """ gravity_direction = np.array([0.0, -1.0, 0.0], np.float32) T_vio_gravity = get_transform_to_vio_gravity_convention(gravity_direction).double() print("Changing ADT gravity convention to VIO convention.") mesh.apply_transform(T_vio_gravity.matrix.numpy()) return mesh def reject_vector_a_from_b(a, b): # https://en.wikipedia.org/wiki/Vector_projection b_norm = torch.sqrt((b**2).sum(-1, keepdim=True)) b_unit = b / b_norm # batched dot product for variable dimensions a_proj = b_unit * (a * b_unit).sum(-1, keepdim=True) a_rej = a - a_proj return a_rej def gravity_align_T_world_cam( T_world_cam, gravity_w=GRAVITY_DIRECTION_VIO, z_grav=False ): """ get T_world_gravity from T_world_cam such that the x axis of T_world_gravity is gravity. """ assert T_world_cam.dim() > 1, f"{T_world_cam} has wrong dimension; expected >1" dim = T_world_cam.dim() device = T_world_cam.device R_wc = T_world_cam.R dir_shape = [1] * (dim - 1) + [3] g_w = torch.from_numpy(gravity_w.copy()).view(dir_shape).to(R_wc) g_w = g_w.expand_as(R_wc[..., 1]) # forward vector (z) that is orthogonal to gravity direction d3 = reject_vector_a_from_b(a=R_wc[..., 2], b=g_w) # optionally add a tiny offset to avoid cross product two identical vectors. d3_is_zeros = (d3 == 0.0).all(dim=-1).unsqueeze(-1).expand_as(d3) d3_offset = torch.zeros(*d3.shape).to(T_world_cam._data.device) d3_offset[..., 1] += 0.001 d3 = torch.where(d3_is_zeros, d3 + d3_offset, d3) d2 = torch.linalg.cross(d3, g_w, dim=-1) # camera down vector is x direction since Aria cameras are rotated by 90 degree CW # hence the new x direction is gravity R_wcg = torch.cat([g_w.unsqueeze(-1), d2.unsqueeze(-1), d3.unsqueeze(-1)], -1) # normalize to unit length R_world_cg = torch.nn.functional.normalize(R_wcg, p=2, dim=-2) if z_grav: # add extra rotation to make z gravity direction, not x. R_cg_cgz = rotation_from_euler( torch.tensor([[-np.pi / 2.0, 0.0, np.pi / 2.0]]) ).to(device) R_world_cgz = R_world_cg @ R_cg_cgz.inverse() T_world_cgz = PoseTW.from_Rt(R_world_cgz, T_world_cam.t) return T_world_cgz else: R_world_cg = R_world_cg T_world_cg = PoseTW.from_Rt(R_world_cg, T_world_cam.t) return T_world_cg ================================================ FILE: efm3d/utils/image.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import cv2 import matplotlib.pyplot as plt import numpy as np import torch # Some globals for opencv drawing functions. BLU = (255, 0, 0) GRN = (0, 255, 0) RED = (0, 0, 255) WHT = (255, 255, 255) BLK = (0, 0, 0) FONT = cv2.FONT_HERSHEY_DUPLEX FONT_PT = (5, 15) FONT_SZ = 0.5 FONT_TH = 1.0 def string2color(string): string = string.lower() if string == "white": return WHT elif string == "green": return GRN elif string == "red": return RED elif string == "black": return BLK elif string == "blue": return BLU else: raise ValueError("input color string %s not supported" % string) def normalize(img, robust=0.0, eps=1e-6): if isinstance(img, torch.Tensor): vals = img.view(-1).cpu().numpy() elif isinstance(img, np.ndarray): vals = img.flatten() if robust > 0.0: v_min = np.quantile(vals, robust) v_max = np.quantile(vals, 1.0 - robust) else: v_min = vals.min() v_max = vals.max() # make sure we are not dividing by 0 dv = max(eps, v_max - v_min) # normalize to 0-1 img = (img - v_min) / dv if isinstance(img, torch.Tensor): img = img.clamp(0, 1) elif isinstance(img, np.ndarray): img = img.clip(0, 1) return img def put_text( img: np.ndarray, text: str, scale: float = 1.0, line: int = 0, color: Tuple[Tuple, str] = WHT, font_pt: Optional[Tuple[int, int]] = None, truncate: int = None, ): """Writes text with a shadow in the back at various lines and autoscales it. Args: image: image HxWx3 or BxHxWx3, should be uint8 for anti-aliasing to work text: text to write scale: 0.5 for small, 1.0 for normal, 1.5 for big font line: vertical line to write on (0: first, 1: second, -1: last, etc) color: text color, tuple of BGR integers between 0-255, e.g. (0,0,255) is red, can also be a few strings like "white", "black", "green", etc truncate: if not None, only show the first N characters Returns: image with text drawn on it """ if isinstance(img, list) or len(img.shape) == 4: # B x H x W x 3 for i in range(len(img)): img[i] = put_text(img[i], text, scale, line, color, font_pt, truncate) else: # H x W x 3 if truncate and len(text) > truncate: text = text[:truncate] + "..." # Add "..." to denote truncation. height = img.shape[0] scale = scale * (height / 320.0) wht_th = max(int(FONT_TH * scale), 1) blk_th = 2 * wht_th text_ht = 15 * scale if not font_pt: font_pt = int(FONT_PT[0] * scale), int(FONT_PT[1] * scale) font_pt = font_pt[0], int(font_pt[1] + line * text_ht) if line < 0: font_pt = font_pt[0], int(font_pt[1] + (height - text_ht * 0.5)) cv2.putText(img, text, font_pt, FONT, FONT_SZ * scale, BLK, blk_th, lineType=16) if isinstance(color, str): color = string2color(color) cv2.putText( img, text, font_pt, FONT, FONT_SZ * scale, color, wht_th, lineType=16 ) return img def rotate_image90(image: np.ndarray, k: int = 3): """Rotates an image and then re-allocates memory to avoid problems with opencv Input: image: numpy image, HxW or HxWxC k: number of times to rotate by 90 degrees counter clockwise Returns rotated image: numpy image, HxW or HxWxC """ return np.ascontiguousarray(np.rot90(image, k=k)) def smart_resize( image: np.ndarray, height: int = -1, width: int = -1, pad_image: bool = False ): """Resize with opencv, auto-inferring height or width to maintain aspect ratio.""" if image.ndim == 4: return np.stack([smart_resize(im, height, width, pad_image) for im in image]) assert image.ndim == 3, "only three channel image currently supported" if width == -1 and height == -1: return image hh, ww = image.shape[0], image.shape[1] if width == -1: width = int(round((float(ww) / float(hh)) * height)) width = int(width / 2) * 2 # enforce divisible by 2 if height == -1: height = int(round((float(hh) / float(ww)) * width)) height = int(height / 2) * 2 # enforce divisible by 2 if pad_image: ar_orig = ww / hh ar_new = width / height if ar_new > ar_orig: # pad the sides. h_scale = height / hh new_w = h_scale * ww pad = (width - new_w) / 2 pad_before = int(pad / h_scale) dtype = image.dtype pad_img = np.zeros((hh, pad_before, 3), dtype=dtype) image = np.hstack([pad_img, image, pad_img]) elif ar_new < ar_orig: # pad the top and bottom w_scale = width / ww new_h = w_scale * hh pad = (height - new_h) / 2 pad_before = int(pad / w_scale) dtype = image.dtype pad_img = np.zeros((pad_before, ww, 3), dtype=dtype) image = np.vstack([pad_img, image, pad_img]) return cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) def torch2cv2( img: Union[np.ndarray, torch.Tensor], rotate: bool = False, rgb2bgr: bool = True, ensure_rgb: bool = False, apply_colormap: Optional[str] = None, robust_quant: float = 0.0, ): """ Converts numpy/torch float32 image [0,1] CxHxW to numpy uint8 [0,255] HxWxC Args: img: image CxHxW float32 image rotate: if True, rotate image 90 degrees rgb2bgr: convert image to BGR ensure_rgb: ensure RGB if True (i.e. replicate the single color channel 3 times) apply_colormap: apply colormap if specified (matplotlib color map names i.e. "jet") to a single channel image. Overwrites ensure_rgb. This lets you display single channel images outside the 0-1 range. (image is normalized to [0,1] before applying the colormap.) robust_quant: quantile to robustly compute min and max for normalization of the image. """ if isinstance(img, torch.Tensor): if img.dim() == 4: if img.shape[0] == 1: # pre-serve old way of just squeezing 0th dim img = img[0] else: # run torch2cv2 on all frames of the video return np.stack( [ torch2cv2( im, rotate, rgb2bgr, ensure_rgb, apply_colormap, robust_quant, ) for im in img ] ) img = img.data.cpu().float().numpy() if img.ndim == 2: img = img[np.newaxis, :, :] # CxHxW -> HxWxC img = img.transpose(1, 2, 0) if img.shape[2] == 1 and apply_colormap is not None: # make sure to normalize so min is 0 and max is 1. img = normalize(img, robust=robust_quant) cm = plt.cm.get_cmap(apply_colormap) img = cm(img[:, :, 0])[:, :, :3] img_cv2 = (img * 255.0).astype(np.uint8) if rgb2bgr: img_cv2 = img_cv2[:, :, ::-1] if rotate: img_cv2 = rotate_image90(img_cv2) else: img_cv2 = np.ascontiguousarray(img_cv2) if ensure_rgb and img_cv2.shape[2] == 1: img_cv2 = img_cv2[:, :, 0] if ensure_rgb and img_cv2.ndim == 2: img_cv2 = np.stack([img_cv2, img_cv2, img_cv2], -1) return img_cv2 def numpy2mp4(imgs, output_path, fps=10): """ Convert a numpy array to mp4. imgs: T, H, W, 3 """ T, H, W, C = imgs.shape assert C == 3, "input image should be 3-channel" fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (W, H)) for i in range(T): out.write(imgs[i]) out.release() ================================================ FILE: efm3d/utils/image_sampling.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Literal import einops import torch def compute_factor(size): return 1.0 * size / 2 def convert_pixel_to_coordinates(coordinates, factor): return (coordinates / factor) - 1.0 def normalize_keypoints(kpts, height, width): # compute conversion factor x_factor = compute_factor(width) y_factor = compute_factor(height) pts_dst = kpts pts_dst[..., 0] = convert_pixel_to_coordinates(pts_dst[..., 0], x_factor) pts_dst[..., 1] = convert_pixel_to_coordinates(pts_dst[..., 1], y_factor) return pts_dst def sample_images( feat2d, query_pts_cam, cams, n_by_c=True, warn=True, padding_mode: Literal["border", "zeros", "reflection"] = "border", interp_mode: Literal["bilinear", "nearest", "bicubic"] = "bilinear", single_channel_mask: bool = False, ): """ Uses 3D points and calibrated cameras to sample features from 2D feature maps. Inputs: feat2d: torch.tensor - feature maps to sample from shaped B(xT)xCxHxW query_pts_cam: torch.tensor - 3D points in camera coordinates shaped B(xT)xNx3 cams: CameraTW - calibrated camera objects shaped B(xT)x15 n_by_c: return shapes ending in NxC or CxN Returns: samp_feats: torch.tensor - sampled features from 2D feature maps shaped B(xT)xCxN valid: torch.tensor - boolean of whether there was a valid sampling B(xT)xCxN """ assert query_pts_cam.dim() == feat2d.dim() - 1 T = None if feat2d.dim() == 5: B, T, C, H, W = feat2d.shape feat2d = feat2d.view(-1, C, H, W) query_pts_cam = query_pts_cam.view(B * T, -1, 3) cams = cams.view(B * T, -1) elif feat2d.dim() == 4: B, C, H, W = feat2d.shape else: raise ValueError(f"feat2d.dim must be 5 or 4 {feat2d.shape}") camH = cams[0].size[1] featH = feat2d.shape[-2] camW = cams[0].size[0] featW = feat2d.shape[-1] # Cams may need to be rescaled to match the feature map spatial dimensions. if camH != featH or camW != featW: cams_resize = cams.scale_to(feat2d) else: cams_resize = cams assert round(cams_resize[0].size[0].item()) == featW, ( f"height of cam and feature image do not match. {cams_resize[0].size[0]}!= {feat2d.shape}" ) assert round(cams_resize[0].size[1].item()) == featH, ( f"width of cam and feature image do not match. {cams_resize[0].size[1]}!= {feat2d.shape}" ) samp_pts, valid = cams_resize.project(query_pts_cam) if warn: frac_valid = valid.count_nonzero() / valid.numel() if frac_valid < 0.05: print( f"[Warning] not many valids! {frac_valid} {valid.count_nonzero()} {valid.shape}" ) samp_pts[~valid] = 0.0 samp_pts = samp_pts.float() # Sample into the 2D feature maps. norm_samp_pts = normalize_keypoints( samp_pts.clone(), height=cams_resize[0].size[1], width=cams_resize[0].size[0] ) device = feat2d.device padding_mode = "zeros" if "mps" in str(device) else padding_mode samp_feats = torch.nn.functional.grid_sample( feat2d, norm_samp_pts.unsqueeze(-2), align_corners=False, padding_mode=padding_mode, mode=interp_mode, # bilinear allows differentiating. ) # squeeze back down the dimension of 1 we unsqueezed for norm_samp_pts to comply with interface samp_feats = samp_feats.squeeze(-1) # Overwrite invalid projections with zeros. BT = samp_feats.shape[0] valid = valid.reshape(BT, 1, -1) if single_channel_mask: samp_feats[(~valid).expand_as(samp_feats)] = 0.0 else: valid = valid.repeat(1, C, 1) samp_feats[~valid] = 0.0 if T is None: if n_by_c: samp_feats = einops.rearrange(samp_feats, "b c n -> b n c", b=B) valid = einops.rearrange(valid, "b c n -> b n c", b=B)[..., 0] else: if n_by_c: samp_feats = einops.rearrange(samp_feats, "(b t) c n -> b t n c", t=T, b=B) valid = einops.rearrange(valid, "(b t) c n -> b t n c", t=T, b=B)[..., 0] else: samp_feats = einops.rearrange(samp_feats, "(b t) c n -> b t c n", t=T, b=B) valid = einops.rearrange(valid, "(b t) c n -> b t c n", t=T, b=B)[..., 0] return samp_feats, valid ================================================ FILE: efm3d/utils/marching_cubes.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import torch logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) def marching_cubes_scaled(values, isolevel, voxel_extent, voxel_mask): """ Runs marching cubes on a values tensor (D H W) at the specified isolevel. Voxel_mask is used to tell marching cubes where to run in the voxel grid. Uses scikit implementation which runs only on CPU. Returns vertices, face ids, and normals in the voxel coordinate system scaled to the given voxel_extent. """ from skimage.measure import marching_cubes as mc_scikit device = values.device values = values.cpu() # CPU only assert values.ndim == 3, f"skicit can only do non-batched inputs, {values.shape}" isolevel = max(values.min(), min(isolevel, values.max())) logging.info(f"mc min {values.min()}, max {values.max()}, isolevel {isolevel}") voxel_mask = voxel_mask.cpu().numpy() if voxel_mask is not None else None try: if voxel_mask is not None: verts, faces, normals, _ = mc_scikit( values.contiguous().numpy(), isolevel, mask=voxel_mask ) else: verts, faces, normals, _ = mc_scikit(values.contiguous().numpy(), isolevel) logging.info(f"{verts.shape}, {faces.shape}") except RuntimeError as e: logging.error(f"{e} {values.shape}, {voxel_mask.shape}") return torch.tensor([]), torch.tensor([]), torch.tensor([]) except Exception as e: logging.error(f"{e} {values.shape}, {voxel_mask.shape}") return torch.tensor([]), torch.tensor([]), torch.tensor([]) # copy to get around negative stride # go back to x, y, z ordering verts, faces, normals = ( torch.from_numpy(verts.copy()), torch.from_numpy(faces.copy()), torch.from_numpy(normals.copy()), ) verts = verts[:, [2, 1, 0]] normals = normals[:, [2, 1, 0]] verts, faces, normals = verts.to(device), faces.to(device), normals.to(device) logging.info(f"{verts.shape}, {faces.shape}, {normals.shape}") vD, vH, vW = values.shape logging.info(f"{vD}, {vH}, {vW}, {voxel_extent}") x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD dVox = torch.tensor([dW, dH, dD]).view(1, 3).to(device) vox_min = torch.tensor([x_min, y_min, z_min]).view(1, 3).to(device) logging.info(f"{verts.shape}") verts = verts * dVox + vox_min + dVox * 0.5 return verts, faces, normals ================================================ FILE: efm3d/utils/mesh_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import random from typing import Union import numpy as np import torch import trimesh from matplotlib import pyplot as plt def point_to_closest_vertex_dist(pts, verts, tris): # pts N 3 float # verts M 3 float # norms M 3 float # tris O 3 int assert verts.ndim == 2, f"{verts.shape}" assert tris.ndim == 2, f"{tris.shape}" assert pts.ndim == 2, f"{pts.shape}" v0s = verts[None, tris[:, 0], :] v1s = verts[None, tris[:, 1], :] v2s = verts[None, tris[:, 2], :] pts = pts.unsqueeze(1) # compute distance to closest vertex vs = torch.cat([v0s, v1s, v2s], 0) dist_vs = torch.linalg.norm(vs.unsqueeze(1) - pts.unsqueeze(0), 2.0, -1) # 3, N, M dist_vs = torch.min(dist_vs, 0)[0] dist_vs = torch.min(dist_vs, 1)[0] # N return dist_vs def point_to_closest_tri_dist(pts, verts, tris): """ Compute the min distance of points to triangles. If a point doesn't intersect with any triangles return a big number (1e6) for that point. """ assert verts.ndim == 2, f"{verts.shape}" assert tris.ndim == 2, f"{tris.shape}" assert pts.ndim == 2, f"{pts.shape}" def dot(a, b): return (a * b).sum(-1, keepdim=True) # pts N 3 float # verts M 3 float # norms M 3 float # tris O 3 int v0s = verts[None, tris[:, 0], :] v1s = verts[None, tris[:, 1], :] v2s = verts[None, tris[:, 2], :] pts = pts.unsqueeze(1) # compute if point projects inside triangle u = v1s - v0s v = v2s - v0s n = torch.cross(u, v) w = pts - v0s nSq = dot(n, n) gamma = dot(torch.cross(u, w, -1), n) / nSq beta = dot(torch.cross(w, v, -1), n) / nSq alpha = 1.0 - gamma - beta valid_alpha = torch.logical_and(0.0 <= alpha, alpha <= 1.0) valid_beta = torch.logical_and(0.0 <= beta, beta <= 1.0) valid_gamma = torch.logical_and(0.0 <= gamma, gamma <= 1.0) projs_to_tri = torch.logical_and(valid_alpha, valid_beta) projs_to_tri = torch.logical_and(projs_to_tri, valid_gamma) num_proj = projs_to_tri.count_nonzero(1) projs_to_tri = projs_to_tri.squeeze(-1) # compute distance to triangle plane n = n / torch.sqrt(nSq) dist_tri = dot(n, w).squeeze(-1).abs() # set distance to large for point-triangle combinations that do not project dist_tri[~projs_to_tri] = 1e6 dist_tri = torch.min(dist_tri, 1)[0] # N num_proj = num_proj.squeeze(-1) return dist_tri, num_proj def compute_pts_to_mesh_dist(pts, faces, verts, step): dev = pts.device N = pts.shape[0] err = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev) dist_tri = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev) dist_ver = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev) num_proj = torch.zeros(N).to(dev) for i in range(0, faces.shape[0], step): dist_tri_i, num_proj_i = point_to_closest_tri_dist( pts, verts, faces[i : i + step] ) dist_ver_i = point_to_closest_vertex_dist(pts, verts, faces[i : i + step]) dist_tri = torch.min(dist_tri_i, dist_tri) dist_ver = torch.min(dist_ver_i, dist_ver) num_proj = num_proj + num_proj_i prog_perc = min((i + step) / faces.shape[0] * 100, 100) print(f"Compute pts to mesh progress: {prog_perc:.01f}%", end="\r") err = torch.where(num_proj == 0, dist_ver, dist_tri) err = err.detach().cpu().numpy() return err def eval_mesh_to_mesh( pred: Union[str, trimesh.Trimesh], gt: Union[str, trimesh.Trimesh], threshold=0.05, sample_num=10000, step=50000, cut_height=None, ): """ Eval point to faces distance using `point_to_closest_tri_dist`. """ rnd_seed = 0 random.seed(0) np.random.seed(0) if isinstance(gt, str): print(f"load gt mesh {gt}") gt_mesh = trimesh.load_mesh(gt) else: gt_mesh = gt if isinstance(pred, str): print(f"load pred mesh {pred}") pred_mesh = trimesh.load_mesh(pred) else: pred_mesh = pred if cut_height is not None: cutting_plane = [[0, 0, -1], [0, 0, cut_height]] gt_mesh = gt_mesh.slice_plane( plane_origin=cutting_plane[1], plane_normal=cutting_plane[0] ) pred_mesh = pred_mesh.slice_plane( plane_origin=cutting_plane[1], plane_normal=cutting_plane[0] ) if torch.cuda.is_available(): dev = "cuda:0" elif torch.backends.mps.is_available(): dev = "mps" else: dev = "cpu" print(f"==> [eval_mesh_to_mesh] use device {dev}") pred_vertices = torch.from_numpy(pred_mesh.vertices.view(np.ndarray)).to(dev) gt_vertices = torch.from_numpy(gt_mesh.vertices.view(np.ndarray)).to(dev) pred_faces = torch.from_numpy(pred_mesh.faces.view(np.ndarray)).to(dev) gt_faces = torch.from_numpy(gt_mesh.faces.view(np.ndarray)).to(dev) print(f"gt vertices and faces {gt_vertices.shape}, {gt_faces.shape}") print(f"pred vertices and faces {pred_vertices.shape}, {pred_faces.shape}") # accuracy (from sampled point in pred to GT) acc = torch.from_numpy(np.array(sample_num, np.finfo(np.float32).max)).to(dev) pred_pts, _ = trimesh.sample.sample_surface(pred_mesh, sample_num, seed=rnd_seed) pred_pts = torch.from_numpy(pred_pts.view(np.ndarray)).to(dev) acc = compute_pts_to_mesh_dist(pred_pts, gt_faces, gt_vertices, step) # completeness gt_pts, _ = trimesh.sample.sample_surface(gt_mesh, sample_num, seed=rnd_seed) gt_pts = torch.from_numpy(gt_pts.view(np.ndarray)).to(dev) comp = compute_pts_to_mesh_dist(gt_pts, pred_faces, pred_vertices, step) precision5 = np.mean((acc < 0.05).astype("float")) recal5 = np.mean((comp < 0.05).astype("float")) precision1 = np.mean((acc < 0.01).astype("float")) recal1 = np.mean((comp < 0.01).astype("float")) fscore5 = 2 * precision5 * recal5 / (precision5 + recal5) fscore1 = 2 * precision1 * recal1 / (precision1 + recal1) # sort to get percentile numbers. acc_sorted = np.sort(acc) comp_sorted = np.sort(comp) metrics = { "acc_mean": np.mean(acc), "comp_mean": np.mean(comp), "prec@0.05": precision5, "recal@0.05": recal5, "fscore@0.05": fscore5, } # Create some visualizations for debugging. cmap = plt.cm.jet # accuracy heatmap (as a pointcloud) on predicted mesh norm = plt.Normalize(acc.min(), acc.max()) colors = cmap(norm(acc)) acc_pc = trimesh.points.PointCloud(pred_pts.detach().cpu().numpy()) acc_pc.colors = colors # completeness heatmap (as a pointcloud) on gt mesh norm = plt.Normalize(comp.min(), comp.max()) colors = cmap(norm(comp)) com_pc = trimesh.points.PointCloud(gt_pts.detach().cpu().numpy()) com_pc.colors = colors viz = { "acc_pc": acc_pc, "comp_pc": com_pc, "gt_mesh": gt_mesh, } for threshold in [0.01, 0.05]: prec_inliers = acc < threshold recall_inliers = comp < threshold # create visualizations for precision and recall prec_pc = copy.deepcopy(acc_pc) recal_pc = copy.deepcopy(com_pc) prec_pc.colors[prec_inliers] = [0, 255, 0, 255] # green prec_pc.colors[~prec_inliers] = [255, 0, 0, 255] # red recal_pc.colors[recall_inliers] = [0, 255, 0, 255] # green recal_pc.colors[~recall_inliers] = [255, 0, 0, 255] # red viz[f"prec@{threshold:.2}_pc"] = prec_pc viz[f"recal@{threshold:.2}_pc"] = recal_pc raw_data = {"acc": acc, "comp": comp} return metrics, viz, raw_data ================================================ FILE: efm3d/utils/obb_csv_writer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import csv from typing import Dict, Optional import fsspec import torch from efm3d.aria.obb import ObbTW from efm3d.aria.pose import PoseTW from pyquaternion import Quaternion class ObbCsvReader: def __init__(self, file_name): self.file_name = file_name self.file_reader = fsspec.open(self.file_name, "r").open() self.csv_reader = csv.DictReader(self.file_reader) try: self.next_row = next(self.csv_reader) except Exception: # StopIteration self.next_row = None self.all_obbs = None self.sem_ids_to_names = {} def parse_row(self, row): t_ns = int(row["time_ns"]) tx_wo = float(row["tx_world_object"]) ty_wo = float(row["ty_world_object"]) tz_wo = float(row["tz_world_object"]) qw_wo = float(row["qw_world_object"]) qx_wo = float(row["qx_world_object"]) qy_wo = float(row["qy_world_object"]) qz_wo = float(row["qz_world_object"]) sx = float(row["scale_x"]) sy = float(row["scale_y"]) sz = float(row["scale_z"]) if "instance" in row: inst_id = int(row["instance"]) else: inst_id = -1 sem_id = int(row["sem_id"]) name = row["name"] if sem_id not in self.sem_ids_to_names: self.sem_ids_to_names[sem_id] = name else: assert name == self.sem_ids_to_names[sem_id] if "prob" in row: prob = float(row["prob"]) else: # methods like ObjectMapper may not have probabilities prob = -1.0 # create obbs xmin = -sx / 2.0 xmax = sx / 2.0 ymin = -sy / 2.0 ymax = sy / 2.0 zmin = -sz / 2.0 zmax = sz / 2.0 bb3s = torch.tensor([xmin, xmax, ymin, ymax, zmin, zmax]) # create poses rot_mat = Quaternion(w=qw_wo, x=qx_wo, y=qy_wo, z=qz_wo).rotation_matrix translation = torch.tensor([tx_wo, ty_wo, tz_wo]).view(3, 1) T_wo = torch.concat([torch.tensor(rot_mat), translation], dim=1) T_wo = PoseTW.from_matrix3x4(T_wo) T_wo = T_wo.fit_to_SO3() T_world_object = T_wo._data # sem ids sem_ids = sem_id inst_ids = inst_id probs = prob # moveable: assuming static for now. # bb2s also decide the visibility of the 3D obbs in the corresponding camera. # we now just assume the obbs are visible in all the cameras bb2_rgbs = torch.ones(4) bb2_slamls = torch.ones(4) bb2_slamrs = torch.ones(4) # assume everything is static for now. moveables = torch.zeros(1) obb_tw = ObbTW.from_lmc( bb3_object=bb3s, bb2_rgb=bb2_rgbs, bb2_slaml=bb2_slamls, bb2_slamr=bb2_slamrs, T_world_object=T_world_object, sem_id=sem_ids, inst_id=inst_ids, prob=probs, moveable=moveables, ).float() return t_ns, obb_tw def __iter__(self): return self def __next__(self): """ Get the next obbs set with the same timestamp. """ if self.next_row is None: raise StopIteration t0_ns, obb = self.parse_row(self.next_row) obbs = [obb] for row in self.csv_reader: t_ns = int(row["time_ns"]) if t_ns != t0_ns: self.next_row = row return t0_ns, torch.stack(obbs) t_ns, obb = self.parse_row(row) obbs.append(obb) self.next_row = None return t0_ns, torch.stack(obbs) @property def obbs(self): if self.all_obbs is not None: return self.all_obbs all_obbs = {} for t_ns, obbs in self: all_obbs[t_ns] = obbs self.all_obbs = all_obbs return all_obbs class ObbCsvWriter: def __init__(self, file_name=""): if not file_name: file_name = "/tmp/obbs.csv" print(f"starting obb writer to {file_name}") self.file_name = file_name self.file_writer = fsspec.open(self.file_name, "w").open() headers_str = "time_ns,tx_world_object,ty_world_object,tz_world_object,qw_world_object,qx_world_object,qy_world_object,qz_world_object,scale_x,scale_y,scale_z,name,instance,sem_id,prob" headers = headers_str.split(",") self.num_cols = len(headers) header_row = ",".join(headers) self.file_writer.write(header_row + "\n") self.rows = [] def write_rows(self): for row in self.rows: self.file_writer.write(row + "\n") self.file_writer.flush() self.rows = [] def write( self, obb_padded: ObbTW, timestamps_ns: int = -1, sem_id_to_name: Optional[Dict[int, str]] = None, flush_at_end: bool = True, ): obb = obb_padded.remove_padding().clone().cpu() time_ns = str(int(timestamps_ns)) N = obb.shape[0] if N == 0: # write all -1 to indicate the obbs for this timestamp is missing # null_row = [time_ns] + ["-1" for _ in range(self.num_cols - 1)] # self.file_writer.write(",".join(null_row) + "\n") return obbs_poses = obb.T_world_object obbs_dims = obb.bb3_diagonal.numpy() obb_sems = obb.sem_id.squeeze(-1).numpy() obb_inst = obb.inst_id.squeeze(-1).numpy() obb_prob = obb.prob.squeeze(-1).numpy() for i in range(N): sem_id = obb_sems[i] if sem_id_to_name and sem_id in sem_id_to_name: name = sem_id_to_name[sem_id] else: name = str(int(sem_id)) qwxyz = obbs_poses[i].q # torch.Tensor [4] qwxyz = ",".join(qwxyz.numpy().astype(str)) txyz = obbs_poses[i].t # torch.Tensor [3] txyz = ",".join(txyz.numpy().astype(str)) sxyz = ",".join(obbs_dims[i].astype(str)) self.file_writer.write( f"{time_ns},{txyz},{qwxyz},{sxyz},{name},{obb_inst[i]},{obb_sems[i]},{obb_prob[i]}\n" ) if flush_at_end: self.file_writer.flush() def flush(self): self.file_writer.flush() def __del__(self): if hasattr(self, "file_writer"): self.file_writer.close() ================================================ FILE: efm3d/utils/obb_io.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch from efm3d.aria.aria_constants import ARIA_OBB_BB2, ARIA_OBB_BB3 from efm3d.aria.pose import closest_timed_poses, interpolate_timed_poses, PoseTW from efm3d.utils.common import find_nearest def bb2extent(bb): if bb.ndim == 1: bb = bb.reshape(1, -1) x_min = bb[:, 0].min() x_max = bb[:, 0].max() y_min = bb[:, 1].min() y_max = bb[:, 1].max() z_min = bb[:, 2].min() z_max = bb[:, 2].max() out = np.stack([x_min, x_max, y_min, y_max, z_min, z_max], axis=0) return out def extent2bb(extent): if extent.ndim == 1: extent = extent.reshape(1, -1) x_min, x_max = extent[:, 0], extent[:, 1] y_min, y_max = extent[:, 2], extent[:, 3] z_min, z_max = extent[:, 4], extent[:, 5] arr = ( [ x_min, y_min, z_min, x_max, y_min, z_min, x_max, y_max, z_min, x_min, y_max, z_min, x_min, y_min, z_max, x_max, y_min, z_max, x_max, y_max, z_max, x_min, y_max, z_max, ], ) if torch.is_tensor(extent): bb3d = torch.stack(arr, dim=-1).reshape(-1, 8, 3) elif isinstance(extent, np.ndarray): bb3d = np.stack(arr, axis=-1).reshape(-1, 8, 3) else: raise TypeError("Unknown type") return bb3d.squeeze() def get_all_Ts_world_object_for_time( obs, time, load_dynamic_objects=True, interpolate_poses=True, dt_threshold_ns: int = 10_000_000, ): # concat static obb poses and dynamic ones at the current time static_Ts_world_object = obs["static_Ts_world_object"] have_dynamic_objects = len(obs["timedTs_world_object"]) > 1 if load_dynamic_objects and have_dynamic_objects: if time in obs["timedTs_world_object"].keys(): dynamic_Ts_world_object = obs["timedTs_world_object"][time] else: if interpolate_poses: dynamic_Ts_world_object = interpolate_timed_poses( obs["timedTs_world_object"], time ) print( f"Warning: did not find time {time} in dynamic objects pose map - so interpolated poses" ) else: dynamic_Ts_world_object, dt = closest_timed_poses( obs["timedTs_world_object"], time ) if abs(dt) > dt_threshold_ns: dynamic_Ts_world_object = {} else: print( f"Warning: no time {time} in dynamic objects pose map - picked closest pose before in time {dt}" ) else: dynamic_Ts_world_object = {} all_Ts_world_object = {} all_Ts_world_object.update(static_Ts_world_object) all_Ts_world_object.update(dynamic_Ts_world_object) static_inst = set(static_Ts_world_object.keys()) dynamic_inst = set(dynamic_Ts_world_object.keys()) if len(static_inst.intersection(dynamic_inst)): print( "Warning: static and dynamic instances overlap overwriting static poses with dynamic ones! " ) return all_Ts_world_object def get_inst_id_in_camera( bb2s_camera, time: int, camera_name: str, ): if bb2s_camera and time in bb2s_camera.keys(): inst_ids = [line[0] for line in bb2s_camera[time]] else: bb2_times = list(bb2s_camera.keys()) nearest_idx = find_nearest(bb2_times, float(time), return_index=True) nearest_time = bb2_times[nearest_idx] if abs(time - nearest_time) >= 1_000_000: print( f"Error: {camera_name}: target time {time}ns has too large gap from the found nearest time {nearest_time}ns with gap {abs(time - nearest_time)}ns, skip this frame." ) return [] print( f"{camera_name}:", time, nearest_time, time - nearest_time, ) inst_ids = [line[0] for line in bb2s_camera[nearest_time]] return inst_ids def get_instance_id_in_frameset( obs, time: int, load_dynamic_objects: bool, interpolate_poses: bool = True, dt_threshold_ns: int = 10_000_000, ): # Get 3D object transforms that are visible in this frameset. bb2s_rgb = obs[ARIA_OBB_BB2[0]] bb2s_slaml = obs[ARIA_OBB_BB2[1]] bb2s_slamr = obs[ARIA_OBB_BB2[2]] bb2_time_rgb = time all_Ts_world_object = get_all_Ts_world_object_for_time( obs, bb2_time_rgb, load_dynamic_objects, interpolate_poses=interpolate_poses, dt_threshold_ns=dt_threshold_ns, ) instance2proto = obs["inst2proto"] local_extents = obs[ARIA_OBB_BB3] inst_ids_rgb = get_inst_id_in_camera(bb2s_rgb, bb2_time_rgb, "rgb") # Support having visibility for only RGB. if len(bb2s_slaml) == 0: inst_ids_slaml = [] else: inst_ids_slaml = get_inst_id_in_camera(bb2s_slaml, bb2_time_rgb, "slaml") if len(bb2s_slamr) == 0: inst_ids_slamr = [] else: inst_ids_slamr = get_inst_id_in_camera(bb2s_slamr, bb2_time_rgb, "slamr") # Get union of all instance ids. inst_ids = list( set(inst_ids_rgb).union(set(inst_ids_slaml)).union(set(inst_ids_slamr)) ) # Make sure that all 2D BB instance ids have a 3D pose, prototype and local extent. warning_ids = [ id for id in inst_ids if id not in all_Ts_world_object or id not in instance2proto or id not in local_extents ] if len(warning_ids) > 0: [inst_ids.remove(warning_id) for warning_id in warning_ids] inst_ids = np.unique(inst_ids) return inst_ids def get_bb2s_for_instances(obs, time, inst_ids, cam_names, cam_scales=None): """ Args: obs (dict): observation dict from Hive table time (int): nanoseconds timestamp of observation inst_ids (list): list of instance ids to get 2D BBs for cam_names (list): list of camera names cam_scales (dict): dict of camera scale for each camera (via cam_name) {cam_name:[x_scal, y_scale]} """ # visible bounding boixes are >=0; invisible ones are < 0 no_bb2 = [-1, -1, -1, -1] bb2_time_rgb = time bb2s = {cam_name: [] for cam_name in cam_names} for bb2_name, cam_name in zip(ARIA_OBB_BB2, cam_names): if bb2_time_rgb not in obs[bb2_name].keys(): bb2_insts = [no_bb2] * len(inst_ids) else: bb2_obs_at_time = obs[bb2_name][bb2_time_rgb] bb2_insts = bb2s[cam_name] for iid in inst_ids: bb2 = None for line in bb2_obs_at_time: if line[0] == iid: bb2 = line[1:] break if bb2: bb2_insts.append(bb2) else: bb2_insts.append(no_bb2) bb2_insts = torch.from_numpy(np.array(bb2_insts)).float() if cam_scales: bb2_insts[:2] = bb2_insts[:2] * cam_scales[cam_name][0] bb2_insts[2:] = bb2_insts[2:] * cam_scales[cam_name][1] bb2s[cam_name] = bb2_insts return bb2s def next_obb_observations( obs, time, inst_ids, cam_names, cam_scales=None, load_dynamic_objects: bool = True, interpolate_poses: bool = True, dt_threshold_ns: int = 10_000_000, ): """ Args: obs (dict): observation dict from Hive table time (float): timestamp of observation inst_ids (list): list of instance ids to get 2D BBs for cam_names (list): list of camera names cam_scales (dict): dict of camera scale for each camera (via cam_name) {cam_name:[x_scal, y_scale]} """ all_Ts_world_object = get_all_Ts_world_object_for_time( obs, time, load_dynamic_objects=load_dynamic_objects, interpolate_poses=interpolate_poses, dt_threshold_ns=dt_threshold_ns, ) # make sure we have a pose for all instances at this time. inst_ids = list(set(inst_ids).intersection(set(all_Ts_world_object.keys()))) # make sure we have instances for all obb extends inst_ids = list(set(inst_ids).intersection(set(obs[ARIA_OBB_BB3].keys()))) # get data Ts_wo = [all_Ts_world_object[iid] for iid in inst_ids] proto_names = [obs["inst2proto"][iid] for iid in inst_ids] proto_ids = [obs["proto2id"][name] for name in proto_names] exs = [obs[ARIA_OBB_BB3][iid] for iid in inst_ids] bbs_object = np.array([extent2bb(ex) for ex in exs]) bbs_object = torch.tensor(bbs_object).float() # handle no obbs case. if Ts_wo: Ts_world_object = torch.stack(Ts_wo).float() else: Ts_world_object = PoseTW(torch.zeros(0, 12)) inst_ids = torch.tensor(inst_ids) sem_ids = torch.tensor(proto_ids) bb3 = torch.from_numpy(np.array([bb2extent(bb) for bb in bbs_object])) # get 2D BBs for this frame bb2s = get_bb2s_for_instances(obs, time, inst_ids, cam_names, cam_scales) return ( bb2s["rgb"], bb2s["slaml"], bb2s["slamr"], bb3, Ts_world_object, sem_ids, inst_ids, ) ================================================ FILE: efm3d/utils/obb_matchers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import torch from efm3d.aria.obb import bb2_xxyy_to_xyxy, ObbTW from efm3d.utils.obb_utils import box3d_overlap_wrapper from scipy.optimize import linear_sum_assignment from torchvision.ops import generalized_box_iou from torchvision.ops.boxes import box_area logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class HungarianMatcher2d3d(torch.nn.Module): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__( self, cost_class: float = 1, cost_bbox2: float = 1, cost_giou2: float = 1, cost_bbox3: float = 1, cost_iou3: float = 1, ): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_bbox2 = cost_bbox2 self.cost_bbox3 = cost_bbox3 self.cost_giou2 = cost_giou2 self.cost_iou3 = cost_iou3 assert ( cost_class != 0 or cost_bbox2 != 0 or cost_bbox3 != 0 or cost_giou2 != 0 or cost_iou3 != 0 ), "all costs cant be 0" @torch.no_grad() def forward_obbs( self, prd: ObbTW, tgt: ObbTW, prd_logits=None, logits_is_prob: bool = False, ): if prd.ndim == 2: return self.forward( pred_logits=prd_logits.unsqueeze(0), pred_bb2s=prd.bb2_rgb.unsqueeze(0), pred_bb3s=prd.bb3corners_world.unsqueeze(0), pred_center_world=prd.bb3_center_world.unsqueeze(0), tgt_labels=[tgt.sem_id.squeeze(-1)], tgt_bb2s=[tgt.bb2_rgb], tgt_bb3s=[tgt.bb3corners_world], tgt_center_world=[tgt.bb3_center_world], logits_is_prob=logits_is_prob, )[0] elif prd.ndim == 3: if isinstance(tgt, ObbTW): tgt = tgt.remove_padding() return self.forward( pred_logits=prd_logits, pred_bb2s=prd.bb2_rgb, pred_bb3s=prd.bb3corners_world, pred_center_world=prd.bb3_center_world, tgt_labels=[tt.sem_id.squeeze(-1) for tt in tgt], tgt_bb2s=[tt.bb2_rgb for tt in tgt], tgt_bb3s=[tt.bb3corners_world for tt in tgt], tgt_center_world=[tt.bb3_center_world for tt in tgt], logits_is_prob=logits_is_prob, ) else: raise ValueError(f"Unsupported shape {prd.shape}") @torch.no_grad() def forward( self, pred_logits=None, pred_bb2s=None, pred_center_world=None, pred_bb3s=None, tgt_labels=None, tgt_bb2s=None, tgt_center_world=None, tgt_bb3s=None, logits_is_prob: bool = False, ): """Performs the matching Params: outputs: This is a dict that contains at least these entries: "pred_logits": Tensor of dim [batch_size, snippet_frames, num_queries, num_semcls] with the classification logits "pred_bb2s": Tensor of dim [batch_size, snippet_frames, num_queries, 4] with the predicted 2d box coordinates targets: This is a list of batch_size targets: "tgt_labels": Tensor of dim [snippet_frames, num_target_boxes] (where num_target_boxes is the number of ground-truth objects in the target) containing the class labels "tgt_bb2s": Tensor of dim [snippet_frames, num_target_boxes, 4] containing the target box coordinates Returns: A list of size batch_size, containing tuples of (index_i, index_j) where: - index_i is the indices of the selected predictions (in order) - index_j is the indices of the corresponding selected targets (in order) For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) """ assert pred_bb2s.dim() == 3, f"{pred_bb2s.shape}" assert pred_center_world.dim() == 3, f"{pred_center_world.shape}" B, N = pred_bb2s.shape[:2] assert len(tgt_bb2s) == B, "number of targets should be equal to batch size" assert len(tgt_center_world) == B, ( "number of targets should be equal to batch size" ) cost_class = None if pred_logits is not None: assert pred_logits.dim() == 3, f"{pred_logits.shape}" assert len(tgt_labels) == B, ( "number of targets should be equal to batch size" ) # We flatten to compute the cost matrices in a batch # [batch_size * num_queries, num_semcls] out_prob = pred_logits.flatten(0, 1) if not logits_is_prob: out_prob = out_prob.softmax(-1) tgt_ids = torch.cat(tgt_labels) assert tgt_ids.ndim == 1, f"{tgt_ids.shape} is not right" # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be omitted. cost_class = -out_prob[:, tgt_ids] if cost_class.isnan().any(): logger.warning( f"have {cost_class.isnan().sum()} nan values in cost_class" ) cost_class = torch.nan_to_num(cost_class, nan=1e6) # [batch_size * num_queries, 4] pred_bb2s = pred_bb2s.flatten(0, 1) pred_center_world = pred_center_world.flatten(0, 1) # remember sizes for later sizes = [len(v) for v in tgt_bb2s] # Also concat the target boxes tgt_bb2s = torch.cat(tgt_bb2s) tgt_center_world = torch.cat(tgt_center_world) # Compute the L1 cost between boxes cost_bbox2 = torch.cdist(pred_bb2s, tgt_bb2s, p=1) if cost_bbox2.isnan().any(): logger.warning(f"have {cost_bbox2.isnan().sum()} nan values in cost_bbox") cost_bbox2 = torch.nan_to_num(cost_bbox2, nan=1e6) # 3d bbs cost_bbox3 = torch.cdist(pred_center_world, tgt_center_world, p=1) if cost_bbox3.isnan().any(): logger.warning(f"have {cost_bbox3.isnan().sum()} nan values in cost_bbox") cost_bbox3 = torch.nan_to_num(cost_bbox3, nan=1e6) # 3d bbs iou cost_iou3 = None if pred_bb3s is not None and tgt_bb3s is not None and self.cost_iou3 > 0.0: pred_bb3s = pred_bb3s.flatten(0, 1) tgt_bb3s = torch.cat(tgt_bb3s) cost_iou3 = -box3d_overlap_wrapper(pred_bb3s, tgt_bb3s).iou if cost_iou3.isnan().any(): logger.warning( f"have {cost_iou3.isnan().sum()} nan values in cost_iou3" ) cost_iou3 = torch.nan_to_num(cost_iou3, nan=1e6) # Compute the giou cost between boxes cost_giou = -generalized_box_iou( bb2_xxyy_to_xyxy(pred_bb2s), bb2_xxyy_to_xyxy(tgt_bb2s) ) # set invalid costs to high value so they are not chosen in linear assignment # invalid predictions have size 0.0 pred_areas = box_area(bb2_xxyy_to_xyxy(pred_bb2s)) pred_invalid = pred_areas <= 0.0 cost_giou[pred_invalid, :] = 1e6 if cost_giou.isnan().any(): logger.warning(f"have {cost_giou.isnan().sum()} nan values in cost_giou") cost_giou = torch.nan_to_num(cost_giou, nan=1e6) # Final cost matrix C = ( self.cost_bbox2 * cost_bbox2 + self.cost_bbox3 * cost_bbox3 + self.cost_giou2 * cost_giou ) if cost_class is not None: C = C + self.cost_class * cost_class if cost_iou3 is not None: C = C + self.cost_iou3 * cost_iou3 C = C.view(B, N, -1).cpu() indices = [ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) ] return [ ( torch.as_tensor(row_id, dtype=torch.int64), torch.as_tensor(col_id, dtype=torch.int64), ) for row_id, col_id in indices ] ================================================ FILE: efm3d/utils/obb_metrics.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from time import time from typing import Dict, List, Optional import torch from efm3d.aria.camera import CameraTW from efm3d.aria.obb import ObbTW from efm3d.utils.file_utils import parse_global_name_to_id_csv from efm3d.utils.obb_utils import MeanAveragePrecision3D ARIA_CAM_IDS = list(range(3)) ARIA_CAM_NAMES = ["rgb", "slaml", "slamr"] logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) class ObbMetrics(torch.nn.Module): """ Metrics that directly work with our ObbTW class It is a torch.nn.Module to be able to behave like a torchmetrics object """ def __init__( self, cam_ids=ARIA_CAM_IDS, cam_names=ARIA_CAM_NAMES, class_metrics: bool = False, volume_range_metrics: bool = False, eval_2d: bool = True, eval_3d: bool = False, ignore_bb2d_visibility: bool = False, global_name_to_id_file: Optional[str] = None, global_name_to_id: Optional[Dict] = None, ret_all_prec_rec: Optional[bool] = False, max_detection_thresholds: Optional[List[float]] = None, ) -> None: """ Args: cam_ids (list): list of camera ids to evaluate cam_names (list): list of camera names to evaluate class_metrics (bool): if True, computes per-class metrics volume_range_metrics (bool): if True, computes volume range metrics eval_2d (bool): if True, evaluate 2d detections eval_3d (bool): if True, evaluate 3d detections """ from torchmetrics.detection.mean_ap import MeanAveragePrecision super().__init__() assert eval_2d or eval_3d, ( "At least eval_2d or eval_3d needs to be set to True." ) self.eval_2d = eval_2d self.eval_3d = eval_3d self.ignore_bb2d_visibility = ignore_bb2d_visibility self.metric_2d = torch.nn.ModuleDict( { cam_name: MeanAveragePrecision(class_metrics=class_metrics) for cam_name in cam_names } if eval_2d else {} ) bbox_area_ranges = None if volume_range_metrics: # Using category statistics from SUN_3D dataset: D42985037. bbox_area_ranges = { "all": (0, 1e5), "small": (0, 1e-2), # pen, remote, toilet paper, etc. "medium": (1e-2, 1), # chair, bin, monitor, etc. "large": (1, 1e5), # bed, sofa, etc. } if max_detection_thresholds is None: # max number of detections to evaluate - 220 is sufficient for ASE scenes max_detection_thresholds = [220] self.metric_3d = torch.nn.ModuleDict( { cam_name: MeanAveragePrecision3D( class_metrics=class_metrics, bbox_area_ranges=bbox_area_ranges, max_detection_thresholds=max_detection_thresholds, ret_all_prec_rec=ret_all_prec_rec, ) for cam_name in cam_names } if eval_3d else {} ) self.cam_ids = cam_ids self.cam_names = cam_names self.cam_id_to_name = {id: name for id, name in zip(cam_ids, cam_names)} self.sem_id_to_name = None if global_name_to_id_file is not None: global_name_to_id = parse_global_name_to_id_csv(global_name_to_id_file) if global_name_to_id is not None: self.sem_id_to_name = { int(sem_id): name for name, sem_id in global_name_to_id.items() } def update(self, prediction: ObbTW, target: ObbTW, cam: Optional[CameraTW] = None): """ """ for cam_id in self.cam_ids: if self.eval_2d: self.update_2d( prediction.bb2(cam_id), prediction.prob.squeeze(), prediction.sem_id.squeeze(), target.bb2(cam_id), target.sem_id.squeeze(), cam_id, ) if self.eval_3d: visible_predictions_ind = prediction.visible_bb3_ind(cam_id) visible_targets_ind = target.visible_bb3_ind(cam_id) if self.ignore_bb2d_visibility: visible_predictions_ind[:] = True visible_targets_ind[:] = True if not visible_predictions_ind.any(): print("WARNING: no predictions are visible") if not visible_targets_ind.any(): print("WARNING: no targets are visible") # Use visible boxes in the camera for evaluation self.update_3d( prediction.bb3corners_world[visible_predictions_ind], prediction.prob[visible_predictions_ind].view(-1), prediction.sem_id[visible_predictions_ind].view(-1), target.bb3corners_world[visible_targets_ind], target.sem_id[visible_targets_ind].view(-1), cam_id, ) def forward(self, prediction: ObbTW, target: ObbTW): self.update(prediction, target) return self.compute() def update_3d( self, pred_bb3corners: torch.Tensor, pred_scores: torch.Tensor, pred_labels: torch.Tensor, tgt_bb3corners: torch.Tensor, tgt_labels: torch.Tensor, cam_id: int = 0, ): assert pred_bb3corners.dim() == 3 assert tgt_bb3corners.dim() == 3 assert pred_scores.dim() == 1 assert pred_labels.dim() == 1 assert tgt_labels.dim() == 1 p = [ { "boxes": pred_bb3corners, "scores": pred_scores, "labels": pred_labels, } ] t = [ { "boxes": tgt_bb3corners, "labels": tgt_labels, } ] self.metric_3d[self.cam_id_to_name[cam_id]].update(p, t) def update_2d( self, pred_bb2: torch.Tensor, pred_scores: torch.Tensor, pred_labels: torch.Tensor, tgt_bb2: torch.Tensor, tgt_labels: torch.Tensor, cam_id: int = 0, ): assert pred_scores.dim() == 1 assert pred_labels.dim() == 1 assert tgt_labels.dim() == 1 p = [ { "boxes": pred_bb2, "scores": pred_scores, "labels": pred_labels, } ] t = [ { "boxes": tgt_bb2, "labels": tgt_labels, } ] self.metric_2d[self.cam_id_to_name[cam_id]].update(p, t) def update_2d_instances( self, preds, #: List[Instances], tgts, #: List[Instances], cam_id: int = 0, ): for pred, tgt in zip(preds, tgts): self.update_2d( pred_bb2=pred.pred_boxes.tensor, pred_scores=pred.scores, pred_labels=pred.pred_classes, tgt_bb2=tgt.gt_boxes.tensor, tgt_labels=tgt.gt_classes, cam_id=cam_id, ) def compute(self): metrics = {} for cam_name in self.cam_names: if self.eval_2d: m2d = self.metric_2d[cam_name].compute() for metric_name, val in m2d.items(): if ( "small" not in metric_name and "medium" not in metric_name and "large" not in metric_name ): metrics[f"{cam_name}/{metric_name}_2D"] = val if self.eval_3d: logger.info(f"Computing metric {self.metric_3d[cam_name]}") t0 = time() m3d = self.metric_3d[cam_name].compute(self.sem_id_to_name) t1 = time() logger.info( f"DONE Computing metric {self.metric_3d[cam_name]} in {t1 - t0} seconds" ) for metric_name, val in m3d.items(): metrics[f"{cam_name}/{metric_name}_3D"] = val return metrics def reset(self): for metric in self.metric_2d.values(): metric.reset() for metric in self.metric_3d.values(): metric.reset() ================================================ FILE: efm3d/utils/obb_trackers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import Optional import torch from efm3d.aria.camera import CameraTW from efm3d.aria.obb import bb2_xxyy_to_xyxy, bb3_xyzxyz_to_xxyyzz, ObbTW from efm3d.aria.pose import all_rot90, find_r90, PoseTW from efm3d.utils.obb_matchers import HungarianMatcher2d3d from efm3d.utils.obb_utils import box3d_overlap_wrapper, remove_invalid_box3d from torch.nn import functional as F from torchvision.ops.boxes import box_iou logging.basicConfig() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def nms_3d( obbs, nms_iou3_thr: float = 0.1, nms_iou3_max_thr: float = 0.15, verbose: bool = False, mark_in_place: bool = False, ): """ NMS based on 3D bbs. When a duplicate is found the obb with higher probability is retained. """ ids_keep, ids_bad = list(range(obbs.shape[0])), [] if mark_in_place: if obbs is None or obbs.shape[0] == 0: return (ids_keep, ids_bad) remove_invalid_box3d(obbs, mark_in_place) if obbs.shape[0] == 0: return (ids_keep, ids_bad) else: if obbs is None or obbs.shape[0] == 0: return obbs, None, (ids_keep, ids_bad) obbs, _ = remove_invalid_box3d(obbs, mark_in_place) if obbs.shape[0] == 0: return obbs, None, (ids_keep, ids_bad) bb3s = obbs.bb3corners_world N = bb3s.shape[0] # iou and make the diagonal negative (since we dont want to check overlap with self.) iou3 = box3d_overlap_wrapper(bb3s, bb3s).iou - 2.0 * torch.eye( N, device=bb3s.device ) # we want bb3s to overlap and be in the same class same_ids = obbs.sem_id == obbs.sem_id.view(1, -1) overlap = iou3 > nms_iou3_thr overlap_overwrite = iou3 > nms_iou3_max_thr nms = torch.logical_or(torch.logical_and(overlap, same_ids), overlap_overwrite) if verbose and overlap.count_nonzero() > 0: logger.debug("overlap", iou3[overlap]) if verbose and overlap_overwrite.count_nonzero() > 0: logger.debug("overlap_overwrite", iou3[overlap_overwrite]) # ids where we want to NMS ids = torch.nonzero(nms, as_tuple=True) # ids of the obbs we want to remove (lower probability) ids_bad = torch.where( (obbs[ids[0]].prob < obbs[ids[1]].prob).squeeze(-1), ids[0], ids[1] ) ids_bad = set(ids_bad.tolist()) if len(ids_bad) > 0: # have some obbs to remove; compute which to keep and return those. if verbose: logger.debug( f"NMS3d: found {len(ids_bad)} non-maxima to suppress. {ids_bad}/ {bb3s.shape[0]}" ) ids_keep = list(set(range(bb3s.shape[0])) - ids_bad) ids_bad = list(ids_bad) if mark_in_place: obbs._mark_invalid_ids(torch.tensor(ids_bad, dtype=torch.long)) return (ids_keep, ids_bad) return obbs[ids_keep], obbs[ids_bad], (ids_keep, ids_bad) if mark_in_place: return (ids_keep, ids_bad) return obbs, None, (ids_keep, ids_bad) def nms_2d(obbs, nms_iou2_thr: float, verbose: bool = False): """ NMS based on 2D bbs. When a duplicate is found the obb with higher probability is retained. """ if obbs is None or obbs.shape[0] == 0: return obbs, None bb2s = bb2_xxyy_to_xyxy(obbs.bb2_rgb) N = bb2s.shape[0] # iou and make the diagonal negative (since we dont want to check overlap with self.) iou2 = box_iou(bb2s, bb2s) - 2.0 * torch.eye(N, device=bb2s.device) # we want bb2s to overlap and be in the same class same_ids = obbs.sem_id == obbs.sem_id.view(1, -1) overlap = torch.logical_and(iou2 > nms_iou2_thr, same_ids) # ids where we have overlap ids = torch.nonzero(overlap, as_tuple=True) # ids of the obbs we want to remove (lower probability) ids_bad = torch.where( (obbs[ids[0]].prob < obbs[ids[1]].prob).squeeze(-1), ids[0], ids[1] ) ids_bad = set(ids_bad.tolist()) if len(ids_bad) > 0: # have some obbs to remove; compute which to keep and return those. if verbose: logger.debug( f"NMS2d: found {len(ids_bad)} non-maxima to suppress. {ids_bad}/ {bb2s.shape[0]}" ) ids_keep = list(set(range(bb2s.shape[0])) - ids_bad) return obbs[ids_keep], obbs[list(ids_bad)] return obbs, None class ObbTracker: """ A simple obb tracker that uses Hungarian matching to associate new detected obbs with a set of "world"-state obbs it maintains incrementally. """ def __init__( self, track_best: bool = False, track_running_average: bool = True, max_assoc_dist: float = 0.1, max_assoc_iou2: float = 0.2, max_assoc_iou3: float = 0.2, prob_inst_thr: float = 0.3, prob_assoc_thr: float = 0.25, nms_iou3_thr: float = 0.1, nms_iou2_thr: float = 0.5, w_max: int = 30, w_min: int = 5, dt_max_inst: float = 1.0, dt_max_occ: float = 5.0, ): """ Args: track_best: choose the highest probability obb for obbs that have been associated. This is the most basic fusion strategy. track_running_average: maintain a running average of obbs that have been associated. This allows denoising obb parameters but struggles with detections that are not consistently in the same canonical orientation. max_assoc_dist: maximum distance to associate an obb with another obb. Obbs that are further are assumed to be distinct and lead to a new instantiation. max_assoc_iou2: maximum 2D IoU to associate an obb with another obb; beyond we instantiate a new obb. max_assoc_iou3: maximum 3D IoU to associate an obb with another obb; beyond we instantiate a new obb. prob_inst_thr: minimum probability threshold for instantiating a new world obb. prob_assoc_thr: minimum probability threshold for associating a new obb with existing world obbs. nms_iou3_thr: 3D IoU threshold to consider an world obb to be a duplicate and suppress it. nms_iou2_thr: 2D IoU threshold to consider an world obb to be a duplicate and suppress it. w_max: maximum weight accumulated in the running average. w_min: minimum weight needed to return the scene_obb. dt_max_inst: how long it can take for an object to be instantiated; in seconds dt_max_occ: how long it is okay for an instantiated object to be occluded; in seconds """ self.matcher = HungarianMatcher2d3d( cost_class=8.0, cost_bbox2=0.0, cost_bbox3=1.0, cost_giou2=4.0, cost_iou3=0.0, # cost_class=8.0, # cost_bbox2=0.0, # cost_bbox3=0.0, # cost_giou2=4.0, # cost_iou3=4.0, ) # the set of scene obbs self.scene_obbs_w = None # w is the weight (count) of each of the scene obbs self.w = None self.scene_probs_full = None self.num_semcls = 128 # when last got an observation associated self.last_obs_time = None # when last possible to observe (based on 2d bb in frame) self.last_possible_obs_time = None # time of tracker (pseudo time incremented by 1 each track() call) self.time = 0 self.hz = 10.0 # how long it can take for an object to be instantiated; in seconds self.dt_max_inst = dt_max_inst # how long it is okay for an instantiated object to be occluded; in seconds self.dt_max_occ = dt_max_occ self.w_max = w_max self.w_min = w_min self.track_best = track_best self.track_running_average = track_running_average self.max_assoc_dist = max_assoc_dist self.max_assoc_iou2 = max_assoc_iou2 self.max_assoc_iou3 = max_assoc_iou3 self.prob_inst_thr = prob_inst_thr self.prob_assoc_thr = prob_assoc_thr self.nms_iou3_thr = nms_iou3_thr self.nms_iou3_max_thr = 0.15 self.nms_iou2_thr = nms_iou2_thr self.R90s = all_rot90() self.counts_as_prob = False self.device = torch.device("cpu") self.num_instances_so_far = 0 def reset(self): self.scene_obbs_w = None self.w = None self.scene_probs_full = None self.last_obs_time = None self.last_possible_obs_time = None self.time = 0 def set_hz(self, hz: float): # adjust obb framerate self.hz = float(hz) @property def obbs_world(self): """ The main function to access the tracked obbs that pass a set of gates. The returned objects are a subset of the full set of world obbs. """ if self.scene_obbs_w is None: return ObbTW().to(self.device), ObbTW().to(self.device) sem_ids = self.scene_probs_full.argmax(dim=1) if (sem_ids != self.scene_obbs_w.sem_id.squeeze(-1)).any(): change = sem_ids != self.scene_obbs_w.sem_id.squeeze(-1) logger.debug( "semantic id has changed because of probs_full averaging ", sem_ids[change].tolist(), self.scene_obbs_w.sem_id.squeeze(-1)[change].tolist(), ) self.scene_obbs_w.set_sem_id(sem_ids) # which obbs have we seen recently? dt = self.last_possible_obs_time - self.last_obs_time seen_uninst = dt < self.dt_max_inst seen_occlusion = dt < self.dt_max_occ # remove obbs that do not have enough observations enough_observations = self.w > self.w_min # categories of obbs good_visible = torch.logical_and(enough_observations, seen_occlusion) good_invisible = torch.logical_and(enough_observations, ~seen_occlusion) uninst_visible = torch.logical_and(~enough_observations, seen_uninst) uninst_delete = torch.logical_and(~enough_observations, ~seen_uninst) # return all good ones obbs_w = self.scene_obbs_w[good_visible] # return the stale visible ones for debugging obbs_invis_w = self.scene_obbs_w[ torch.logical_or(uninst_visible, good_invisible) ] # delete uninst obbs if uninst_delete.count_nonzero() > 0: logger.debug( f"removing un-instantiated obbs {uninst_delete.count_nonzero()}" ) self.scene_obbs_w = self.scene_obbs_w[~uninst_delete] self.scene_probs_full = self.scene_probs_full[~uninst_delete] self.last_obs_time = self.last_obs_time[~uninst_delete] self.last_possible_obs_time = self.last_possible_obs_time[~uninst_delete] self.w = self.w[~uninst_delete] # NMS based on 3D IoU if self.nms_iou3_thr > 0.0: obbs_w, obbs_non_max_w = self.nms_3d(obbs_w) # NMS based on 2D IoU if self.nms_iou2_thr > 0.0: obbs_w, obbs_non_max_w = self.nms_2d(obbs_w) return obbs_w, obbs_invis_w def track( self, obbs_w: ObbTW, probs_full: Optional[torch.Tensor] = None, cam: Optional[CameraTW] = None, T_world_rig: Optional[PoseTW] = None, ): """ Args: obbs_w: new obb detections to track. shape: Nx34 probs_full: full probability distribution over the classes of each of the obb detections. """ self.device = obbs_w.device assert obbs_w.ndim == 2, f"{obbs_w.shape}" # if we dont have any good new obbs return if obbs_w.shape[0] == 0: return self.obbs_world # set 2d bbs obbs_w = self.set_2d_bbs(obbs_w, cam, T_world_rig) # filter out obbs that are too low probability to be associated assoc = obbs_w.prob.squeeze(-1) > self.prob_assoc_thr # remove probs_full padding if probs_full is not None: probs_full = probs_full[: obbs_w.shape[0], :] else: # create one-hot probability encoding based on semantic id probs_full = F.one_hot( obbs_w.sem_id.squeeze(-1).long(), num_classes=self.num_semcls ).float() obbs_w = obbs_w[assoc] probs_full = probs_full[assoc] if probs_full is not None else None # if we dont have any good new obbs return if obbs_w.shape[0] == 0: return self.obbs_world # if we dont have any scene obbs yet (at the beginning) initialize the # tracker state and return it. if self.scene_obbs_w is None: self.add_new_obbs(obbs_w, probs_full) return self.obbs_world # find matches indices = self.matcher.forward_obbs( prd=obbs_w, tgt=self.scene_obbs_w, prd_logits=probs_full, logits_is_prob=True, ) # if we have not matches we return if len(indices[0]) == 0: return self.obbs_world # get matched obbs pids, tids = indices[0], indices[1] pobbs, tobbs = obbs_w[pids], self.scene_obbs_w[tids] pprobs_full = probs_full[pids] if probs_full is not None else None # find good associations based on the 2d and 3d iou dist = torch.linalg.norm( pobbs.bb3_center_world - tobbs.bb3_center_world, 2, dim=-1 ).cpu() if self.max_assoc_iou2 > 0: iou2 = ( box_iou( bb2_xxyy_to_xyxy(pobbs.bb2_rgb), bb2_xxyy_to_xyxy(tobbs.bb2_rgb) ) .cpu() .diagonal() ) else: iou2 = None # filter out invalid bboxes, if we can't compute iou3 we return pobbs, valid_ind = remove_invalid_box3d(pobbs) pprobs_full = probs_full[valid_ind] if pprobs_full is not None else None if pobbs.shape[0] == 0: return self.obbs_world if iou2 is not None: iou2 = iou2[valid_ind] tobbs = tobbs[valid_ind] dist = dist[valid_ind] pids = pids[valid_ind] tids = tids[valid_ind] # this function could fail due to thin object (ValueError: Planes have zero areas). # if we can't compute iou3 we return try: iou3 = ( box3d_overlap_wrapper(pobbs.bb3corners_world, tobbs.bb3corners_world) .iou.cpu() .diagonal() ) except Exception as e: print(e) return self.obbs_world # assoc = torch.logical_or(dist < self.max_assoc_dist, iou2 > self.max_assoc_iou) assoc = iou3 > self.max_assoc_iou3 if self.max_assoc_iou2 > 0.0: assoc = torch.logical_or(assoc, iou2 > self.max_assoc_iou2) # new obbs new_ids = list(set(range(obbs_w.shape[0])) - set(pids.tolist())) if assoc.count_nonzero() > 0: logger.debug( f"{assoc.count_nonzero()} associated", "dist", dist[assoc], "iou2", iou2[assoc] if iou2 is not None else None, "iou3", iou3[assoc], ) if (~assoc).count_nonzero() > 0: logger.debug( f"{(~assoc).count_nonzero()} not associated", "dist", dist[~assoc], "iou2", iou2[assoc] if iou2 is not None else None, "iou3", iou3[~assoc], ) new_obbs = torch.cat([pobbs[~assoc].clone(), obbs_w[new_ids]]) new_insts = new_obbs.prob.squeeze(-1) > self.prob_inst_thr new_obbs = new_obbs[new_insts] if pprobs_full is not None: new_probs_full = torch.cat( [pprobs_full[~assoc].clone(), probs_full[new_ids]] ) new_probs_full = new_probs_full[new_insts] # associated obbs pids, tids = pids[assoc], tids[assoc] pobbs, tobbs = pobbs[assoc], tobbs[assoc] pprobs_full = pprobs_full[assoc] if pprobs_full is not None else None # deal with associations if self.track_best and tids.shape[0] > 0: better_pred = (pobbs.prob > tobbs.prob).squeeze(-1).cpu() # update better obbs better_tids = tids[better_pred] self.scene_obbs_w._data[better_tids] = pobbs._data[better_pred] # increment weights self.w[tids] = self.w[tids] + 1.0 # update times self.last_obs_time[tids] = self.time # update counts as probabilities if self.counts_as_prob: scene_obbs = self.scene_obbs_w[tids].clone() scene_obbs.set_prob(self.w[tids]) self.scene_obbs_w._data[tids] = scene_obbs._data elif self.track_running_average and tids.shape[0] > 0: wpp = (self.w[tids] + 1.0).unsqueeze(-1) pdiag = pobbs.bb3_diagonal # running average T_world_object dT_tobj_pobj = tobbs.T_world_object.inverse() @ pobbs.T_world_object xi_tobj_pobj = dT_tobj_pobj.log() # check if any relative pose is further than 45 degree which # indicates that there is a 90 deg rotation that is closer. dr = xi_tobj_pobj[..., 3:] dr_norm = torch.linalg.norm(dr, 2, dim=-1) too_big = dr_norm > 3.14 * 0.25 if too_big.any(): # find closest 90 degree rotation pT_wo, R90min = find_r90( tobbs[too_big].T_world_object, pobbs[too_big].T_world_object, self.R90s.to(tobbs.device), ) # update xi with the 90 deg closest rotation dT_tobj_pobj = tobbs[too_big].T_world_object.inverse() @ pT_wo xi_tobj_pobj[too_big] = dT_tobj_pobj.log() # also permute the diagonal according to the 90 deg rotation pdiag[too_big] = ( (R90min @ pdiag[too_big].unsqueeze(-1)).squeeze(-1).abs() ) # apply updates ppT_world_object = tobbs.T_world_object @ PoseTW.exp(xi_tobj_pobj / wpp) # running average over scale / diagonal of obb ppdiag = (tobbs.bb3_diagonal * self.w[tids].unsqueeze(-1) + pdiag) / wpp ppbb3 = bb3_xyzxyz_to_xxyyzz( torch.cat([-ppdiag * 0.5, ppdiag * 0.5], dim=-1) ) # running average over prob ppprob = (tobbs.prob * self.w[tids].unsqueeze(-1) + pobbs.prob) / wpp if pprobs_full is not None: # running average over the full probability distribution pprobs_full = ( self.scene_probs_full[tids] * self.w[tids].unsqueeze(-1) + pprobs_full ) / wpp self.scene_probs_full[tids] = pprobs_full # update target parameters tobbs.set_T_world_object(ppT_world_object) tobbs.set_bb3_object(ppbb3) tobbs.set_prob(ppprob.squeeze(-1)) self.scene_obbs_w._data[tids] = tobbs._data # update weights self.w[tids] = wpp.clamp(max=self.w_max).squeeze(-1) # update times self.last_obs_time[tids] = self.time # update counts as probabilities if self.counts_as_prob: scene_obbs = self.scene_obbs_w[tids].clone() scene_obbs.set_prob(self.w[tids]) self.scene_obbs_w._data[tids] = scene_obbs._data # add new obbs if new_obbs.shape[0] > 0: self.add_new_obbs(new_obbs, new_probs_full) # update last possible obs time based on 2d visibility self.update_last_obs_time(cam, T_world_rig) # update time self.time += 1.0 / self.hz return self.obbs_world def update_last_obs_time(self, cam, T_world_rig): if cam is None or T_world_rig is None: # mark all visible self.last_possible_obs_time[:] = self.time return # compute visibility of scene obbs: # - at least 50% of object has to be in 2d bb # - 2d bb has to be at least 100 pixel area and each side has to be at least 10 pixels bb2s, _, frac = self.scene_obbs_w.get_pseudo_bb2( cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True ) bb2s, frac = bb2s.squeeze(0), frac.squeeze(0) visible = frac > 0.5 # area = box_area(bb2_xxyy_to_xyxy(bb2s)) # visible = torch.logical_and(visible, area > 100) visible = torch.logical_and(visible, bb2s[..., 1] - bb2s[..., 0] > 10) visible = torch.logical_and(visible, bb2s[..., 3] - bb2s[..., 2] > 10) # update last possible times self.last_possible_obs_time[visible] = self.time def add_new_obbs(self, new_obbs, new_probs_full): new_w = torch.ones(new_obbs.shape[0], device=new_obbs.device) new_obbs_time = self.time * torch.ones( new_obbs.shape[0], device=new_obbs.device ) # Set instance ids for new obbs new_obbs.set_inst_id( torch.arange( self.num_instances_so_far, self.num_instances_so_far + new_obbs.shape[0], device=new_obbs.device, ) ) # Increment number of instances we have seen so far self.num_instances_so_far += new_obbs.shape[0] if self.scene_obbs_w is None: self.scene_obbs_w = new_obbs self.scene_probs_full = new_probs_full self.w = new_w self.last_obs_time = new_obbs_time self.last_possible_obs_time = new_obbs_time.clone() else: self.scene_obbs_w = torch.cat([self.scene_obbs_w, new_obbs], dim=0) self.scene_probs_full = torch.cat( [self.scene_probs_full, new_probs_full], dim=0 ) self.w = torch.cat([self.w, new_w]) self.last_obs_time = torch.cat([self.last_obs_time, new_obbs_time]) self.last_possible_obs_time = torch.cat( [self.last_possible_obs_time, new_obbs_time] ) def nms_3d(self, obbs): obbs_keep, obbs_rm, _ = nms_3d(obbs, self.nms_iou3_thr, self.nms_iou3_max_thr) return obbs_keep, obbs_rm def nms_2d(self, obbs): return nms_2d(obbs, self.nms_iou2_thr) def set_2d_bbs(self, obbs_w: ObbTW, cam: CameraTW, T_world_rig: PoseTW): if cam is None or T_world_rig is None: return obbs_w if obbs_w.shape[0] > 0: bb2s, valids, frac = obbs_w.get_pseudo_bb2( cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True ) invisible = ~valids # frac < 0.1 bb2s[invisible] = -1.0 obbs_w.set_bb2(0, bb2s.squeeze(0)) if self.scene_obbs_w is not None and self.scene_obbs_w.shape[0] > 0: bb2s, valids, frac = self.scene_obbs_w.get_pseudo_bb2( cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True ) invisible = ~valids # frac < 0.1 bb2s[invisible] = -1.0 self.scene_obbs_w.set_bb2(0, bb2s.squeeze(0)) return obbs_w ================================================ FILE: efm3d/utils/obb_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch from efm3d.aria.obb import ObbTW from pytorch3d.ops.iou_box3d import ( _box3d_overlap, _box_planes, _box_triangles, _check_nonzero, ) from torch import IntTensor, Tensor from torch.nn import functional as F from torchmetrics.detection.mean_ap import ( _fix_empty_tensors, BaseMetricResults, MARMetricResults, MeanAveragePrecision, ) from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # logger.setLevel(logging.DEBUG) @dataclass class IouOutputs: vol: torch.Tensor iou: torch.Tensor def input_validator_box3d( # noqa preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]] ) -> None: """Ensure the correct input format of `preds` and `targets`""" if not isinstance(preds, Sequence): raise ValueError("Expected argument `preds` to be of type Sequence") if not isinstance(targets, Sequence): raise ValueError("Expected argument `target` to be of type Sequence") if len(preds) != len(targets): raise ValueError( "Expected argument `preds` and `target` to have the same length" ) for k in ["boxes", "scores", "labels"]: if any(k not in p for p in preds): raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") for k in ["boxes", "labels"]: if any(k not in p for p in targets): raise ValueError(f"Expected all dicts in `target` to contain the `{k}` key") if any(type(pred["boxes"]) is not Tensor for pred in preds): raise ValueError("Expected all boxes in `preds` to be of type Tensor") if any(type(pred["scores"]) is not Tensor for pred in preds): raise ValueError("Expected all scores in `preds` to be of type Tensor") if any(type(pred["labels"]) is not Tensor for pred in preds): raise ValueError("Expected all labels in `preds` to be of type Tensor") if any(type(target["boxes"]) is not Tensor for target in targets): raise ValueError("Expected all boxes in `target` to be of type Tensor") if any(type(target["labels"]) is not Tensor for target in targets): raise ValueError("Expected all labels in `target` to be of type Tensor") for i, item in enumerate(targets): if item["boxes"].size(0) != item["labels"].size(0): raise ValueError( f"Input boxes and labels of sample {i} in targets have a" f" different length (expected {item['boxes'].size(0)} labels, got {item['labels'].size(0)})" ) if item["boxes"].shape[-2:] != (8, 3): raise ValueError( f"Input boxes of sample {i} in targets have a" f" wrong shape (expected (...,8, 3), got {item['boxes'].shape})" ) for i, item in enumerate(preds): if not ( item["boxes"].size(0) == item["labels"].size(0) == item["scores"].size(0) ): raise ValueError( f"Input boxes, labels and scores of sample {i} in predictions have a" f" different length (expected {item['boxes'].size(0)} labels and scores," f" got {item['labels'].size(0)} labels and {item['scores'].size(0)})" ) class MAPMetricResults3D(BaseMetricResults): """Class to wrap the final mAP results.""" __slots__ = ( "map", "map_25", "map_50", "map_small", "map_medium", "map_large", ) def box3d_volume(boxes: Tensor) -> Tensor: """ Computes the volume of a set of 3d bounding boxes. Args: boxes (Tensor[N, 8, 3]): 3d boxes for which the volume will be computed. Returns: Tensor[N]: the volume for each box """ if boxes.numel() == 0: return torch.zeros(0).to(boxes) # Triple product to calculate volume a = boxes[:, 1, :] - boxes[:, 0, :] b = boxes[:, 3, :] - boxes[:, 0, :] c = boxes[:, 4, :] - boxes[:, 0, :] vol = torch.abs(torch.cross(a, b, dim=-1) @ c.T)[0] return vol def box3d_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: """ Convert 3d box coordinate conventions. """ assert in_fmt == "xyz8" assert out_fmt == "xyz8" return boxes class MeanAveragePrecision3D(MeanAveragePrecision): def __init__( self, box_format: str = "xyz8", bbox_area_ranges: Optional[Dict[str, Tuple[float, float]]] = None, iou_thresholds: Optional[List[float]] = None, rec_thresholds: Optional[List[float]] = None, max_detection_thresholds: Optional[List[int]] = None, class_metrics: bool = False, # compute per class metrics compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ret_all_prec_rec: bool = False, ) -> None: # type: ignore # Use Omni3D iOU thresholds by default iou_thresholds = ( iou_thresholds or torch.linspace( 0.05, 0.5, round((0.5 - 0.05) / 0.05) + 1, dtype=torch.float64 ).tolist() ) rec_thresholds = ( rec_thresholds or torch.linspace( 0.0, 1.00, round(1.00 / 0.01) + 1, dtype=torch.float64 ).tolist() ) super().__init__( iou_thresholds=iou_thresholds, rec_thresholds=rec_thresholds, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) if not _TORCHVISION_GREATER_EQUAL_0_8: raise ModuleNotFoundError( "`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed." " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." ) allowed_box_formats = ["xyz8"] if box_format not in allowed_box_formats: raise ValueError( f"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}" ) self.box_format = box_format max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100])) self.max_detection_thresholds = max_det_thr.tolist() if not isinstance(class_metrics, bool): raise ValueError("Expected argument `class_metrics` to be a boolean") self.class_metrics = class_metrics # important to overwrite after the __init__() call since they are otherwise overwritten by super().__init__() self.bbox_area_ranges = bbox_area_ranges if bbox_area_ranges is None: self.bbox_area_ranges = {"all": (0, 1e5)} self.ret_all_prec_rec = ret_all_prec_rec self.eval_imgs = [] if self.ret_all_prec_rec else None def update( self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] ) -> None: # type: ignore """Add detections and ground truth to the metric. Args: preds: A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image): - ``boxes``: ``torch.FloatTensor`` of shape [num_boxes, 8, 3] containing `num_boxes` detection boxes of the format specified in the constructor. By default, this method expects (4) +---------+. (5) | ` . | ` . | (0) +---+-----+ (1) | | | | (7) +-----+---+. (6)| ` . | ` . | (3) ` +---------+ (2) box_corner_vertices = [ [xmin, ymin, zmin], [xmax, ymin, zmin], [xmax, ymax, zmin], [xmin, ymax, zmin], [xmin, ymin, zmax], [xmax, ymin, zmax], [xmax, ymax, zmax], [xmin, ymax, zmax], ] - ``scores``: ``torch.FloatTensor`` of shape [num_boxes] containing detection scores for the boxes. - ``labels``: ``torch.IntTensor`` of shape [num_boxes] containing 0-indexed detection classes for the boxes. target: A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image): - ``boxes``: ``torch.FloatTensor`` of shape [num_boxes, 8, 3] containing `num_boxes` ground truth boxes of the format specified in the constructor. - ``labels``: ``torch.IntTensor`` of shape [num_boxes] containing 1-indexed ground truth classes for the boxes. Raises: ValueError: If ``preds`` is not of type List[Dict[str, Tensor]] ValueError: If ``target`` is not of type List[Dict[str, Tensor]] ValueError: If ``preds`` and ``target`` are not of the same length ValueError: If any of ``preds.boxes``, ``preds.scores`` and ``preds.labels`` are not of the same length ValueError: If any of ``target.boxes`` and ``target.labels`` are not of the same length ValueError: If any box is not type float and of length 4 ValueError: If any class is not type int and of length 1 ValueError: If any score is not type float and of length 1 """ input_validator_box3d(preds, target) for item in preds: boxes = _fix_empty_tensors(item["boxes"]) boxes = box3d_convert(boxes, in_fmt=self.box_format, out_fmt="xyz8") if hasattr(self, "detection_boxes"): self.detection_boxes.append(boxes) else: self.detections.append(boxes) self.detection_labels.append(item["labels"]) self.detection_scores.append(item["scores"]) for item in target: boxes = _fix_empty_tensors(item["boxes"]) boxes = box3d_convert(boxes, in_fmt=self.box_format, out_fmt="xyz8") if hasattr(self, "groundtruth_boxes"): self.groundtruth_boxes.append(boxes) else: self.groundtruths.append(boxes) self.groundtruth_labels.append(item["labels"]) def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor: """Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given image and class. Args: id: Image Id, equivalent to the index of supplied samples class_id: Class Id of the supplied ground truth and detection labels max_det: Maximum number of evaluated detection bounding boxes """ if hasattr(self, "detection_boxes"): gt = self.groundtruth_boxes[id] det = self.detection_boxes[id] else: gt = self.groundtruths[id] det = self.detections[id] gt_label_mask = self.groundtruth_labels[id] == class_id det_label_mask = self.detection_labels[id] == class_id if len(gt_label_mask) == 0 or len(det_label_mask) == 0: return Tensor([]) gt = gt[gt_label_mask] det = det[det_label_mask] if len(gt) == 0 or len(det) == 0: return Tensor([]) # Sort by scores and use only max detections scores = self.detection_scores[id] scores_filtered = scores[self.detection_labels[id] == class_id] inds = torch.argsort(scores_filtered, descending=True) det = det[inds] if len(det) > max_det: det = det[:max_det] # generalized_box_iou # both det and gt are List of "boxes" ious = box3d_overlap_wrapper(det, gt).iou return ious def _evaluate_image( self, id: int, class_id: int, area_range: Tuple[int, int], max_det: int, ious: dict, ) -> Optional[dict]: """Perform evaluation for single class and image. Args: id: Image Id, equivalent to the index of supplied samples. class_id: Class Id of the supplied ground truth and detection labels. area_range: List of lower and upper bounding box area threshold. max_det: Maximum number of evaluated detection bounding boxes. ious: IoU results for image and class. """ if hasattr(self, "detection_boxes"): gt = self.groundtruth_boxes[id] det = self.detection_boxes[id] else: gt = self.groundtruths[id] det = self.detections[id] gt_label_mask = self.groundtruth_labels[id] == class_id det_label_mask = self.detection_labels[id] == class_id if len(gt_label_mask) == 0 or len(det_label_mask) == 0: return None gt = gt[gt_label_mask] det = det[det_label_mask] if len(gt) == 0 and len(det) == 0: return None areas = box3d_volume(gt) ignore_area = (areas < area_range[0]) | (areas > area_range[1]) # sort dt highest score first, sort gt ignore last ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8)) # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA" ignore_area_sorted = ignore_area_sorted.to(torch.bool) gt = gt[gtind] scores = self.detection_scores[id] scores_filtered = scores[det_label_mask] scores_sorted, dtind = torch.sort(scores_filtered, descending=True) det = det[dtind] if len(det) > max_det: det = det[:max_det] # load computed ious ious = ( ious[id, class_id][:, gtind] if len(ious[id, class_id]) > 0 else ious[id, class_id] ) nb_iou_thrs = len(self.iou_thresholds) nb_gt = len(gt) nb_det = len(det) gt_matches = torch.zeros( (nb_iou_thrs, nb_gt), dtype=torch.bool, device=det.device ) det_matches = torch.zeros( (nb_iou_thrs, nb_det), dtype=torch.bool, device=det.device ) gt_ignore = ignore_area_sorted det_ignore = torch.zeros( (nb_iou_thrs, nb_det), dtype=torch.bool, device=det.device ) if torch.numel(ious) > 0: for idx_iou, t in enumerate(self.iou_thresholds): for idx_det, _ in enumerate(det): m = MeanAveragePrecision._find_best_gt_match( t, gt_matches, idx_iou, gt_ignore, ious, idx_det ) if m != -1: det_ignore[idx_iou, idx_det] = gt_ignore[m] det_matches[idx_iou, idx_det] = 1 gt_matches[idx_iou, m] = 1 # set unmatched detections outside of area range to ignore det_areas = box3d_volume(det) det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1]) ar = det_ignore_area.reshape((1, nb_det)) det_ignore = torch.logical_or( det_ignore, torch.logical_and( det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0) ), ) det_matches = det_matches.cpu() gt_matches = gt_matches.cpu() scores_sorted = scores_sorted.cpu() gt_ignore = gt_ignore.cpu() det_ignore = det_ignore.cpu() ret = { "dtMatches": det_matches, "gtMatches": gt_matches, "dtScores": scores_sorted, "gtIgnore": gt_ignore, "dtIgnore": det_ignore, } if self.ret_all_prec_rec: self.eval_imgs.append(ret) return ret def _summarize_results( self, precisions: Tensor, recalls: Tensor ) -> Tuple[MAPMetricResults3D, MARMetricResults]: """Summarizes the precision and recall values to calculate mAP/mAR. Args: precisions: Precision values for different thresholds recalls: Recall values for different thresholds """ results = dict(precision=precisions, recall=recalls) map_metrics = MAPMetricResults3D() last_max_det_thr = self.max_detection_thresholds[-1] map_metrics.map = self._summarize(results, True, max_dets=last_max_det_thr) if 0.25 in self.iou_thresholds: map_metrics.map_25 = self._summarize( results, True, iou_threshold=0.25, max_dets=last_max_det_thr ) if 0.5 in self.iou_thresholds: map_metrics.map_50 = self._summarize( results, True, iou_threshold=0.5, max_dets=last_max_det_thr ) mar_metrics = MARMetricResults() for max_det in self.max_detection_thresholds: mar_metrics[f"mar_{max_det}"] = self._summarize( results, False, max_dets=max_det ) if "small" in self.bbox_area_ranges: map_metrics.map_small = self._summarize( results, True, area_range="small", max_dets=last_max_det_thr ) mar_metrics.mar_small = self._summarize( results, False, area_range="small", max_dets=last_max_det_thr ) if "medium" in self.bbox_area_ranges: map_metrics.map_medium = self._summarize( results, True, area_range="medium", max_dets=last_max_det_thr ) mar_metrics.mar_medium = self._summarize( results, False, area_range="medium", max_dets=last_max_det_thr ) if "large" in self.bbox_area_ranges: map_metrics.map_large = self._summarize( results, True, area_range="large", max_dets=last_max_det_thr ) mar_metrics.mar_large = self._summarize( results, False, area_range="large", max_dets=last_max_det_thr ) return map_metrics, mar_metrics def compute(self, sem_id_to_name_mapping: Optional[Dict[int, str]] = None) -> dict: metrics = MeanAveragePrecision.compute(self) final_results = {} # resemble class-based results. if self.class_metrics: seen_classes = self._get_classes() if sem_id_to_name_mapping is None: logger.warning("No sem_id to name mapping. Falling back on id=name") sem_id_to_name_mapping = { sem_id: str(sem_id) for sem_id in seen_classes } for k, v in metrics.items(): # Deal with per-class metrics if "per_class" in k: # populate per class numbers mapped, unmapped = set(), set() for idx, pcr in enumerate(v): if seen_classes[idx] not in sem_id_to_name_mapping: unmapped.add(seen_classes[idx]) else: mapped.add(seen_classes[idx]) final_results[ f"{k}@{sem_id_to_name_mapping[seen_classes[idx]]}" ] = pcr if len(unmapped) > 0: logger.warning( f"Mapped sem_ids {mapped} but DID NOT MAP sem_ids {unmapped}" ) else: final_results[k] = v else: final_results = metrics return final_results def coplanar_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None: faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device) verts = boxes.index_select(index=faces.view(-1), dim=1) B = boxes.shape[0] P, V = faces.shape # (B, P, 4, 3) -> (B, P, 3) v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2) # Compute the normal e0 = F.normalize(v1 - v0, dim=-1) e1 = F.normalize(v2 - v0, dim=-1) normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1) # Check the fourth vertex is also on the same plane mat1 = (v3 - v0).view(B, 1, -1) # (B, 1, P*3) mat2 = normal.view(B, -1, 1) # (B, P*3, 1) good = (mat1.bmm(mat2).abs() < eps).view(-1) return good def nonzero_area_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None: """ Checks that the sides of the box have a non zero area """ faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device) verts = boxes.index_select(index=faces.view(-1), dim=1) B = boxes.shape[0] T, V = faces.shape # (B, T, 3, 3) -> (B, T, 3) v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2) normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # (B, T, 3) face_areas = normals.norm(dim=-1) / 2 return (face_areas > eps).all(-1) def bb3_valid(boxes: torch.Tensor, eps: float = 1e-4) -> None: """ Checks that the box is valid """ # Check that the box is not degenerate return nonzero_area_mask(boxes, eps) & coplanar_mask(boxes, eps) def box3d_overlap_wrapper( boxes1: torch.Tensor, boxes2: torch.Tensor, eps: float = 1e-3 ) -> IouOutputs: """ only compute ious and volumes for good boxes and recompose with 0s for all bad boxes. its better because it can handle if a subset of boxes is bad. But it costs more compute. """ if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]): raise ValueError("Each box in the batch must be of shape (8, 3)") m1 = bb3_valid(boxes1, eps) m2 = bb3_valid(boxes2, eps) b1_good = boxes1[m1] b2_good = boxes2[m2] vol = torch.zeros(boxes1.shape[0], boxes2.shape[0], device=boxes1.device) iou = torch.zeros_like(vol) if b1_good.shape[0] == 0 or b2_good.shape[0] == 0: logger.info("no valid bbs returning 0 volumes and ious") else: try: vol_good, iou_good = _box3d_overlap.apply(b1_good, b2_good) m_good = m1.unsqueeze(-1) & m2.unsqueeze(0) vol[m_good] = vol_good.view(-1) iou[m_good] = iou_good.view(-1) except Exception: logger.exception("returning 0 volumes and ious because of an exception") return IouOutputs(vol=vol, iou=iou) def remove_invalid_box3d(obbs: ObbTW, mark_in_place: bool = False) -> torch.Tensor: boxes = obbs.bb3corners_world assert boxes.dim() == 3 assert (8, 3) == boxes.shape[1:] valid_ind, invalid_ind = [], [] for b in range(boxes.shape[0]): try: # no need for co planarity check since our obbs are good by construction. # _check_coplanar(boxes[b : b + 1, :, :]) _check_nonzero(boxes[b : b + 1, :, :]) valid_ind.append(b) except Exception: invalid_ind.append(b) if mark_in_place: obbs._mark_invalid_ids(torch.tensor(invalid_ind, dtype=torch.long)) return valid_ind return obbs[valid_ind], valid_ind def prec_recall_bb3( padded_pred: ObbTW, padded_target: ObbTW, iou_thres=0.2, return_ious=False, per_class=False, ): """Compute precision and recall based on 3D IoU.""" assert padded_pred.ndim == 2 and padded_target.ndim == 2, ( f"input ObbTWs must be Nx34, but got {padded_pred.shape} and {padded_target.shape}" ) pred = padded_pred.remove_padding() target = padded_target.remove_padding() pred_shape = pred.shape target_shape = target.shape pred, _ = remove_invalid_box3d(pred) target, _ = remove_invalid_box3d(target) if pred.shape != pred_shape: logging.warning( f"Warning: predicted obbs filtered from {pred_shape[0]} to {pred.shape[0]}" ) if target.shape != target_shape: logging.warning( f"Warning: target obbs filtered from {target_shape[0]} to {target.shape[0]}" ) prec_recall = (-1.0, -1.0, None) # deal with edge cases first if pred.shape[0] == 0: # invalid precision and 0 recall prec_recall = (-1.0, 0.0, None) return prec_recall elif target.shape[0] == 0: # invalid recall and 0 precision prec_recall = (0.0, -1.0, None) return prec_recall pred_sems = pred.sem_id target_sems = target.sem_id.squeeze(-1).unsqueeze(0) # 1. Match classes sem_id_match = pred_sems == target_sems # 2. Match IoUs ious = box3d_overlap_wrapper(pred.bb3corners_world, target.bb3corners_world).iou iou_match = ious > iou_thres # 3. Match both sem_iou_match = torch.logical_and(sem_id_match, iou_match) # make final matching matrix final_sem_iou_match = torch.zeros_like(sem_iou_match).bool() num_pred = sem_iou_match.shape[0] # TP + FP num_target = sem_iou_match.shape[1] # TP + FN # 4. Deal with the case where one prediction correspond to multiple GTs. # In this case, only the GT with highest IoU is considered the match. for pred_idx in range(int(num_pred)): if sem_iou_match[pred_idx, :].sum() <= 1: final_sem_iou_match[pred_idx, :] = sem_iou_match[pred_idx, :].clone() else: tgt_ious = ious[pred_idx, :].clone() tgt_ious[~sem_iou_match[pred_idx, :]] = -1.0 sorted_ids = torch.argsort(tgt_ious, descending=True) tp_id = sorted_ids[0] # Set the pred with highest iou final_sem_iou_match[pred_idx, :] = False final_sem_iou_match[pred_idx, tp_id] = True # 5. Deal with the case where one GT correspond to multiple predictions. # In this case, if the predictions contain probabilities, we take the one with the highest score, otherwise, we take the one with the highest iou. for gt_idx in range(int(num_target)): if final_sem_iou_match[:, gt_idx].sum() <= 1: continue else: pred_scores = pred.prob.squeeze(-1).clone() if torch.all(pred_scores.eq(-1.0)): # go with highest iou pred_ious = ious[:, gt_idx].clone() pred_ious[~final_sem_iou_match[:, gt_idx]] = -1.0 sorted_ids = torch.argsort(pred_ious, descending=True) tp_id = sorted_ids[0] # Set the pred with highest iou final_sem_iou_match[:, gt_idx] = False final_sem_iou_match[tp_id, gt_idx] = True else: # go with the highest score pred_scores[~final_sem_iou_match[:, gt_idx]] = -1.0 sorted_ids = torch.argsort(pred_scores, descending=True) tp_id = sorted_ids[0] final_sem_iou_match[:, gt_idx] = False final_sem_iou_match[tp_id, gt_idx] = True TPs = final_sem_iou_match.any(-1) # precision = TP / (TP + FP) = TP / #Preds num_tp = TPs.sum().item() prec = num_tp / num_pred # recall = TP / (TP + FN) = TP / #GTs rec = num_tp / num_target ret = (prec, rec, final_sem_iou_match) if return_ious: ret = ret + (ious,) if per_class: # per class prec and recalls per_class_results = {} all_sems = torch.cat([pred_sems.squeeze(-1), target_sems.squeeze(0)], dim=0) unique_classes = torch.unique(all_sems.squeeze(-1)) for sem_id in unique_classes: pred_obbs_sem = pred_sems.squeeze(-1) == sem_id TPs_sem = (TPs & pred_obbs_sem).sum().item() num_pred_sem = pred_obbs_sem.sum().item() gt_obbs_sem = target_sems.squeeze(0) == sem_id num_gt_sem = gt_obbs_sem.sum().item() prec_sem = TPs_sem / num_pred_sem if num_pred_sem > 0 else -1.0 rec_sem = TPs_sem / num_gt_sem if num_gt_sem > 0 else -1.0 per_class_results[sem_id] = {} per_class_results[sem_id]["num_true_positives"] = TPs_sem per_class_results[sem_id]["num_dets"] = num_pred_sem per_class_results[sem_id]["num_gts"] = num_gt_sem per_class_results[sem_id]["precision"] = prec_sem per_class_results[sem_id]["recall"] = rec_sem ret = ret + (per_class_results,) return ret def prec_recall_curve( pred_gt_pairs: List[Tuple[ObbTW, ObbTW]], iou_thres=0.2, interp=True ): # get all probs probs = torch.empty(0) for pred, _ in pred_gt_pairs: pred_no_padding = pred.cpu().remove_padding() ps = pred_no_padding.prob.squeeze(-1) probs = torch.concatenate([probs, ps]) # truncate probs = (probs * 100).int() / 100.0 # combine too close probs probs = torch.unique(probs) probs = probs.tolist() probs.sort(reverse=True) precs = [] recalls = [] eps = 1e-6 for prob in probs: tps = 0 dets = 0 gts = 0 for pred, gt in pred_gt_pairs: pred_no_padding = pred.remove_padding() gt_no_padding = gt.remove_padding() # thresholding pred_no_padding = pred_no_padding[pred_no_padding.prob.squeeze(-1) >= prob] dets += pred_no_padding.shape[0] gts += gt_no_padding.shape[0] pred_no_padding = ( pred_no_padding.cuda() if torch.cuda.is_available() else pred_no_padding ) gt_no_padding = ( gt_no_padding.cuda() if torch.cuda.is_available() else gt_no_padding ) _, _, match_mat = prec_recall_bb3( pred_no_padding, gt_no_padding, iou_thres=iou_thres ) if match_mat is None: continue tps += match_mat.any(-1).sum().item() prec = tps / (dets + eps) rec = tps / (gts + eps) precs.append(prec) recalls.append(rec) if interp: precs = torch.Tensor(precs) precs_interp = [] for idx, _ in enumerate(precs): precs_interp.append(precs[idx:].max().item()) precs = precs_interp return precs, recalls, probs def draw_prec_recall_curve( prec: List, recall: List, save_folder: str, name: str = "pr_curve.png", iou_thres: Optional[float] = None, ): import matplotlib.pyplot as plt fig_title = "Prec-Recall Curve" if iou_thres is not None: fig_title += f" @IoU={iou_thres:.2f}" figure_path = os.path.join(save_folder, name) plt.figure(figsize=(4, 4)) plt.title(fig_title) plt.xlim([0, 1.1]) plt.ylim([0, 1.1]) plt.xlabel("recall") plt.ylabel("precision") # append prec recall if the last recall is not 1 if recall[-1] != 1: prec.append(0) recall.append(recall[-1]) plt.plot(recall, prec) plt.savefig(figure_path) print(f"Save precision recall curve to {figure_path}") return figure_path ================================================ FILE: efm3d/utils/pointcloud.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_DISTANCE_M, ARIA_IMG, ARIA_IMG_T_SNIPPET_RIG, ARIA_POINTS_DIST_STD, ARIA_POINTS_WORLD, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.utils.depth import dist_im_to_point_cloud_im from efm3d.utils.ray import sample_depths_in_grid, transform_rays from efm3d.utils.voxel import tensor_wrap_voxel_extent from torch.nn import functional as F def get_points_world(batch, batch_idx=None, dist_std0=0.04, prefer_points=False): if ARIA_DISTANCE_M[0] in batch and not prefer_points: dists = batch[ARIA_DISTANCE_M[0]].squeeze(2) cams = batch[ARIA_CALIB[0]] B, T = cams.shape[:2] Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]] T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] Ts_wr = T_ws @ Ts_sr Ts_cw = cams.T_camera_rig @ Ts_wr.inverse() Ts_wc = Ts_cw.inverse() pc_c, valids = dist_im_to_point_cloud_im(dists, cams) B, T, H, W = pc_c.shape[:4] pc_w = Ts_wc * pc_c.view(B, T, -1, 3) pc_w = pc_w.view(B, T, H, W, 3) pc_w[~valids] = float("nan") # nan # remove all points that are invalid across all time and batches. all_valid = ~(~valids).all(0).all(0) all_valid = all_valid.view(1, 1, H, W).repeat(B, T, 1, 1) pc_w = pc_w[all_valid].view(B, T, -1, 3) dist_stds = torch.ones(pc_w.shape[:-1], device=pc_w.device) * dist_std0 elif ARIA_POINTS_WORLD in batch: pc_w = batch[ARIA_POINTS_WORLD] if ARIA_POINTS_DIST_STD in batch: dist_stds = batch[ARIA_POINTS_DIST_STD] else: dist_stds = torch.ones(pc_w.shape[:-1], device=pc_w.device) * 0.01 else: raise NotImplementedError( f"do need either points or depth image! {batch.keys()}" ) if batch_idx is not None: return pc_w[batch_idx], dist_stds[batch_idx] return pc_w, dist_stds def get_freespace_world( batch, batch_idx, T_wv, vW, vH, vD, voxel_extent, S=1, prefer_points=False, dropout_points=False, drop_points_rate_max=0.5, ): """ Get points (semi-dense or GT points) of a snippet in the batch. """ cams = batch[ARIA_CALIB[0]][batch_idx] T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][ batch_idx ] # T_world_rig (one per snippet) Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][ batch_idx ] # Ts_snippet_rig (T per snippet) Ts_wr = T_ws @ Ts_sr Ts_wc = Ts_wr @ cams.T_camera_rig.inverse() # Ts_world_cam # compute rays and max depths if ARIA_DISTANCE_M[0] in batch and not prefer_points: # get gt distances into world points gt_dist = batch[ARIA_DISTANCE_M[0]][batch_idx] cams = batch[ARIA_CALIB[0]][batch_idx] # invalid depth has values 0 or NaN (padding used by semidense stream). valid_depths = gt_dist.squeeze(1) > 1e-4 p3cs, valids = dist_im_to_point_cloud_im( gt_dist.squeeze(1), cams, ) valids = torch.logical_and(valids, valid_depths) p3cs = p3cs.reshape(p3cs.shape[0], -1, 3) T, N = p3cs.shape[:2] ds = torch.norm(p3cs, 2.0, dim=-1) dirs_c = F.normalize(p3cs, 2.0, dim=-1) rays_c = torch.cat([torch.zeros_like(dirs_c), dirs_c], dim=-1) T_vc = T_wv.inverse() @ Ts_wc rays_v = transform_rays(rays_c, T_vc) rays_v = rays_v.view(-1, 6) ds = ds.view(-1) valids = valids.reshape(-1) rays_v = rays_v[valids] ds = ds[valids] else: p_w = batch[ARIA_POINTS_WORLD][batch_idx] # TxNx3 T, N = p_w.shape[:2] p0_w = Ts_wc.t.unsqueeze(1) # Tx1x3 diff_w = p_w - p0_w ds = torch.norm(diff_w, 2.0, dim=-1) dir_w = F.normalize(diff_w, 2.0, dim=-1) # filter out nans good = ~p_w.isnan().any(dim=-1) p0_w = p0_w.repeat(1, N, 1)[good] ds = ds[good] dir_w = dir_w[good] rays_w = torch.cat([p0_w, dir_w], dim=-1) rays_v = transform_rays(rays_w, T_wv.inverse()) # dropout rays if desired if dropout_points: N = rays_v.shape[0] p = drop_points_rate_max Ndrop = int(N * (torch.rand(1).item() * p + (1.0 - p))) print(f"dropout {Ndrop}/{N} points") rnd = torch.randperm(N, device=p_w.device)[:Ndrop] rays_v = rays_v[rnd, :] ds = ds[rnd] x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD diag = math.sqrt(dW**2 + dH**2 + dD**2) # subtract diagonal of voxel size to not label the occupied voxel as free ds = ds - diag # sample depths that lie within the feature volume grid (same function as used for nerf3d!) depths, _, _ = sample_depths_in_grid( rays_v.view(1, 1, -1, 6), ds.view(1, 1, -1), voxel_extent, vW, vH, vD, S, ) depths = depths.view(-1, S) rays_v = rays_v.view(-1, 1, 6) pts_v = rays_v[..., :3] + depths.unsqueeze(-1) * rays_v[..., 3:] pts_v = pts_v.view(-1, 3) return T_wv * pts_v def collapse_pointcloud_time(pc_w): pc_w = pc_w.reshape(-1, 3) # filter out nans bad = pc_w.isnan().any(dim=-1) pc_w = pc_w[~bad] # filter out duplicates from the collapsing of the time dimension pc_w = torch.unique(pc_w, dim=0) pc_w = pc_w.reshape(-1, 3) return pc_w def pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent): """ converts a point cloud in voxel grid coordinates into voxel ids. """ assert pc_v.ndim == 3, f"{pc_v.shape}" # T N 3 assert isinstance(voxel_extent, torch.Tensor) assert voxel_extent.ndim == 1, f"{voxel_extent.shape}" # 6 device = pc_v.device x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent.tolist() valid = pc_v[..., 0] > x_min valid = torch.logical_and(pc_v[..., 0] < x_max, valid) valid = torch.logical_and(pc_v[..., 1] > y_min, valid) valid = torch.logical_and(pc_v[..., 1] < y_max, valid) valid = torch.logical_and(pc_v[..., 2] > z_min, valid) valid = torch.logical_and(pc_v[..., 2] < z_max, valid) dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD s = [1] * (pc_v.ndim - 1) + [3] dVox = torch.tensor([dW, dH, dD]).view(s).to(device) vox_min = torch.tensor([x_min, y_min, z_min]).view(s).to(device) pc_id = ((pc_v - vox_min) / dVox).floor().long() valid = torch.logical_and(pc_id[..., 0] >= 0, valid) valid = torch.logical_and(pc_id[..., 0] < vW, valid) valid = torch.logical_and(pc_id[..., 1] >= 0, valid) valid = torch.logical_and(pc_id[..., 1] < vH, valid) valid = torch.logical_and(pc_id[..., 2] >= 0, valid) valid = torch.logical_and(pc_id[..., 2] < vD, valid) # to match the D H W ordering of the voxel tensors pc_id = pc_id[..., [2, 1, 0]] return pc_id, valid def pointcloud_to_occupancy_snippet( pcs_w, Ts_wc, cams, T_wv, vW, vH, vD, voxel_extent, S=1 ): """ converts a pointcloud to an occupancy grid (and mask where there are points). All voxels which have a point in them are marked occupied Along rays to the points of the cloud we sample S points and mark them as not occupied. """ assert pcs_w.ndim == 3, f"{pcs_w.shape}" # T N 3 assert Ts_wc.ndim == 2, f"{Ts_wc.shape}" # T C assert cams.ndim == 2, f"{cams.shape}" # T C assert T_wv.ndim in [1, 2], f"{T_wv.shape}" # 1 C voxel_extent = tensor_wrap_voxel_extent(voxel_extent) device = pcs_w.device occ = torch.zeros((vD, vH, vW), device=device) mask = torch.zeros_like(occ) # get invalid mask as the points that are nan and do not project into the # camera. Ts_vc = T_wv.inverse() @ Ts_wc pc_c = Ts_wc.inverse() * pcs_w invalid = pc_c.isnan().any(-1) # T N pc_im, valid = cams.project(pc_c) invalid = torch.logical_or(invalid, ~valid) depth = torch.sqrt((pc_c**2).sum(-1)) ray_c = pc_c / depth.unsqueeze(-1) # camera origins are not occupied rayP_c = torch.zeros_like(Ts_wc.t) rayP_v = Ts_vc * rayP_c pc_ids, valid = pointcloud_to_voxel_ids(rayP_v, vW, vH, vD, voxel_extent) pc_ids = pc_ids[valid] pc_ids = pc_ids.view(-1, 3) if pc_ids.numel() > 0: occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 # sample along the ray x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD diag = math.sqrt(dW**2 + dH**2 + dD**2) T, N = ray_c.shape[:2] rayP_c = rayP_c.view(T, 1, 3).repeat(1, N, 1) # sample depths conservatively up to the depth - diagonal of a voxel ds = depth.unsqueeze(-1) - diag ds = torch.rand((T, N, S), device=device) * ds samples_c = rayP_c.unsqueeze(2) + ds.unsqueeze(3) * ray_c.unsqueeze(2) samples_c = samples_c.view(T, -1, 3) samples_v = Ts_vc * samples_c pc_ids, valid = pointcloud_to_voxel_ids(samples_v, vW, vH, vD, voxel_extent) invalid_ = invalid.unsqueeze(-1).repeat(1, 1, S).view(T, -1) valid = torch.logical_and(valid, ~invalid_) pc_ids = pc_ids[valid] pc_ids = pc_ids.view(-1, 3) if pc_ids.numel() > 0: occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 # add points as occupied pc_v = T_wv.inverse() * pcs_w pc_ids, valid = pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent) valid = torch.logical_and(valid, ~invalid) pc_ids = pc_ids[valid] pc_ids = pc_ids.view(-1, 3) if pc_ids.numel() > 0: occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 return occ, mask def pointcloud_occupancy_samples( p3s_w, Ts_wc, cams, vW, vH, vD, voxel_extent, S=16, sample_beyond=False, vox_diag_scale=1.0, T_wv=None, sample_mode="random", ): """ compute occupied points and sample S freespace points along rays. """ assert p3s_w.ndim == 4, f"{p3s_w.shape}" # B T N 3 assert Ts_wc.ndim == 3, f"{Ts_wc.shape}" # B T C assert not sample_beyond, "not supported" B = p3s_w.shape[0] # precompute things pc_c = Ts_wc.inverse() * p3s_w invalid = pc_c.isnan().any(-1) # B T N pc_im, valid = cams.project(pc_c) invalid = torch.logical_or(invalid, ~valid) depth = torch.sqrt((pc_c**2).sum(-1)).unsqueeze(-1) rayD_c = pc_c / depth B, T, N = rayD_c.shape[:3] rayP_c = torch.zeros_like(Ts_wc.t) rayP_c = rayP_c.view(B, T, 1, 3).repeat(1, 1, N, 1) T_vc = T_wv.inverse().unsqueeze(-2) @ Ts_wc voxel_extent = tensor_wrap_voxel_extent(voxel_extent, B, device=depth.device) diag = voxel_extent[..., 1::2] - voxel_extent[..., 0::2] diag = diag / torch.tensor([vW, vH, vD], device=voxel_extent.device) diag = torch.sqrt((diag**2).sum(-1)) * vox_diag_scale delta = diag.view(B, 1, 1, 1) ds_free_max = depth - delta # BxTxNx1 # sample depths conservatively up to the depth - diagonal of a voxel rays_c = torch.cat([rayP_c, rayD_c], dim=-1) rays_v = transform_rays(rays_c, T_vc) ds_free, _, _ = sample_depths_in_grid( rays_v, ds_free_max.squeeze(-1), voxel_extent, vW, vH, vD, S, d_near=0.01, d_far=10.0, sample_mode=sample_mode, ) free_c = rayP_c.unsqueeze(3) + ds_free.unsqueeze(4) * rayD_c.unsqueeze(3) free_c = free_c.view(B, T, -1, 3) free_w = Ts_wc * free_c ds_occ = depth + delta occ_c = rayP_c + ds_occ * rayD_c occ_c = occ_c.view(B, T, -1, 3) occ_w = Ts_wc * occ_c # occupied, on surface, free space return occ_w, p3s_w, free_w, ~invalid def pointcloud_to_occupancy( pc_w, T_wc, cam, T_wv, vW, vH, vD, voxel_extent, S=1, occ=None, mask=None ): device = pc_w.device if occ is None: occ = torch.zeros((vD, vH, vW), device=device) if mask is None: mask = torch.zeros_like(occ) T_vc = T_wv.inverse() @ T_wc pc_c = T_wc.inverse() * pc_w invalid = pc_c.isnan().any(-1) pc_c = pc_c[~invalid] pc_im, valid = cam.unsqueeze(0).project(pc_c.unsqueeze(0)) pc_im, valid = pc_im.squeeze(0), valid.squeeze(0) depth = torch.sqrt((pc_c**2).sum(-1)) ray_c = pc_c / depth.unsqueeze(-1) ray_c = ray_c[valid] depth = depth[valid] # camera origins are not occupied rayP_c = torch.zeros_like(T_wc.t) rayP_v = T_vc * rayP_c pc_ids, valid = pointcloud_to_voxel_ids(rayP_v, vW, vH, vD, voxel_extent) pc_ids = pc_ids[valid] occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 # sample along the ray N = ray_c.shape[0] rayP_c = rayP_c.view(1, 3).repeat(N, 1) ds = torch.rand((N, S), device=device) * depth.unsqueeze(1) samples_c = rayP_c.unsqueeze(1) + ds.unsqueeze(2) * ray_c.unsqueeze(1) samples_c = samples_c.view(-1, 3) samples_v = T_vc * samples_c pc_ids, valid = pointcloud_to_voxel_ids(samples_v, vW, vH, vD, voxel_extent) pc_ids = pc_ids[valid] occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 # add points as occupied pc_v = T_wv.inverse() * pc_w invalid = pc_v.isnan().any(-1) pc_v = pc_v[~invalid] pc_ids, valid = pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent) pc_ids = pc_ids[valid] occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0 return occ, mask def pointcloud_to_voxel_counts(points_v, voxel_extent, vW, vH, vD): """ Convert a pointcloud in the voxel coordinate to a voxel grid where each voxel value indicates the number of points falling into this voxel. """ assert points_v.ndim == 2, f"{points_v.shape}" voxel_extent = tensor_wrap_voxel_extent(voxel_extent).to(points_v.device) assert voxel_extent.ndim == 1, f"{voxel_extent.shape}" if points_v.shape[0] == 0: print("WARNING: No 3D points provided. ") return torch.zeros((1, vD, vH, vW), device=points_v.device, dtype=torch.int64) num_voxels_x, num_voxels_y, num_voxels_z = vW, vH, vD bb_min, bb_max = voxel_extent[..., 0::2], voxel_extent[..., 1::2] dim = torch.tensor([vW, vH, vD], device=points_v.device) voxel_sizes = (bb_max - bb_min) / dim voxel_min = bb_min point_count = torch.zeros( (num_voxels_x, num_voxels_y, num_voxels_z), device=points_v.device ) voxel_indices = torch.floor((points_v - voxel_min) / voxel_sizes).to(torch.int64) # Filter out points that fall outside the voxel grid valid_indices = (voxel_indices >= 0) & ( voxel_indices < torch.tensor([num_voxels_x, num_voxels_y, num_voxels_z]).to(voxel_indices) ) valid_indices = valid_indices.all(dim=-1) voxel_indices = voxel_indices[valid_indices] # get flat index so we can use bincount to get counts voxel_indices_flat = ( voxel_indices[..., 0] + voxel_indices[..., 1] * vW + voxel_indices[..., 2] * vW * vH ) # get counts of how many points per voxel point_count = torch.bincount(voxel_indices_flat, minlength=vW * vH * vD) # reshape back to vD x vH x vW convention. point_count = point_count.view(1, vD, vH, vW) return point_count def get_points_counts( batch, T_wv, vW, vH, vD, voxel_extent, prefer_points=True, MAX_NUM_POINTS_VOXEL=50, return_mask=False, dropout_points=False, dropout_points_rate_max=0.0, ): """ Get points as voxel grid where each voxel is assigned a count of how many points are inside it. If return_mask is trued the function returns the binary occupancy instead of point counts. """ B, T, _, H, W = batch[ARIA_IMG[0]].shape point_counts = [] for b in range(B): p_w = get_points_world(batch, b, prefer_points=prefer_points)[0] p_w = collapse_pointcloud_time(p_w) if dropout_points: print("drop points ", p_w.shape) N = p_w.shape[0] p = dropout_points_rate_max Ndrop = int(N * (torch.rand(1).item() * p + (1.0 - p))) print(f"dropout {N - Ndrop}/{N} points") rnd = torch.randperm(N, device=p_w.device)[:Ndrop] p_w = p_w[rnd, :] # transform points into voxel coordinate. p_v = T_wv[b].inverse() * p_w if isinstance(voxel_extent, list): ve_b = voxel_extent else: ve_b = voxel_extent[b].tolist() point_count = pointcloud_to_voxel_counts(p_v, ve_b, vW, vH, vD) point_counts.append(point_count) point_counts = torch.stack(point_counts, dim=0) # B x 1 x vD, vH, vW # Normalize point_counts = point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL if return_mask: # Only use as a mask. Comment out if want to use real point counts. point_counts[point_counts > 1e-4] = 1.0 return point_counts def get_freespace_counts( batch, T_wv, vW, vH, vD, voxel_extent, num_free_samples=1, prefer_points=True, MAX_NUM_POINTS_VOXEL=50, return_mask=False, dropout_points=False, dropout_points_rate_max=0.0, ): """ Get points as voxel grid where each voxel is assigned a count of how many points are inside it. If return_mask is trued the function returns the binary occupancy instead of point counts. """ B, T, _, H, W = batch[ARIA_IMG[0]].shape point_counts = [] for b in range(B): if isinstance(voxel_extent, list): ve_b = voxel_extent else: ve_b = voxel_extent[b].tolist() p_w = get_freespace_world( batch, b, T_wv[b], vW, vH, vD, ve_b, num_free_samples, prefer_points, dropout_points, dropout_points_rate_max, ) # transform points into voxel coordinate. p_v = T_wv[b].inverse() * p_w point_count = pointcloud_to_voxel_counts(p_v, ve_b, vW, vH, vD) point_counts.append(point_count) point_counts = torch.stack(point_counts, dim=0) # B x 1 x vD, vH, vW # Normalize point_counts = point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL if return_mask: # Only use as a mask. Comment out if want to use real point counts. point_counts[point_counts > 1e-4] = 1.0 return point_counts ================================================ FILE: efm3d/utils/ray.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Literal import einops import torch from efm3d.aria.camera import CameraTW, pixel_grid from efm3d.utils.voxel import tensor_wrap_voxel_extent from torch.nn import functional as F def grid_ray(pixel_grid, camera): """ grid_ray: Given a 2D grid size, this function creates a 2D grid and then unprojects the grid into rays in their respective rig coordinate systems. Args: grid_width: self-explanatory grid_height: self-explanatory camera: Batch of Camera objects [B x object_params] Returns: Rays: [B x grid_height x grid_width x 6] rays in their respective rig coordinates Each ray grid in a batch may have different rig coordinate systems. Valid: Valid rays in the batch """ eps = 1e-6 grid_height, grid_width = pixel_grid.shape[0], pixel_grid.shape[1] batch_size = camera.shape[0] pixel_grid = pixel_grid.reshape(-1, 2) pixel_grid = einops.repeat(pixel_grid, "n c -> b n c", b=batch_size) rays, valid = camera.double().unproject(pixel_grid.double()) rays = rays.float() assert not torch.isnan(rays).any(), ( f"have {torch.isnan(rays).count_nonzero().item()} nans in rays. Camera params: {camera.params}" ) rays = F.normalize(rays, p=2, dim=-1, eps=eps) rays = torch.where(valid.unsqueeze(-1), rays, torch.zeros_like(rays)) T_rig_camera = camera.T_camera_rig.inverse() T_rig_camera = T_rig_camera.to(dtype=rays.dtype) rays = T_rig_camera.rotate(rays) ray_origins = einops.repeat( T_rig_camera.t, "b c -> b n c", n=grid_width * grid_height ) # set invalid rays to zeros rays = F.normalize(rays, p=2, dim=-1, eps=eps) rays = torch.where(valid.unsqueeze(-1), rays, torch.zeros_like(rays)) ray_origins = torch.where( valid.unsqueeze(-1), ray_origins, torch.zeros_like(ray_origins) ) rays = torch.cat([ray_origins, rays], dim=-1) return rays.view([batch_size, grid_height, grid_width, -1]), valid.view( [batch_size, grid_height, grid_width] ) def ray_grid(cam: CameraTW): """ rays returned are in rig coordinate system """ if cam.ndim == 1: px = pixel_grid(cam) rays, valid = grid_ray(px, cam.unsqueeze(0)) return rays.squeeze(0), valid.squeeze(0) elif cam.ndim == 2: px = pixel_grid(cam[0]) # assuming camera sizes are all the same in a batch! return grid_ray(px, cam) else: raise ValueError(f"Camera must be 1 or 2 dimensional: {cam.shape}") def transform_rays(rays_old: torch.Tensor, T_new_old): """ Expects rays to be in old coordinate frame """ assert rays_old.shape[-1], ( "Rays must be 6 dimensional in the following order: [ray_origins, ray_directions]" ) ray_origins = T_new_old.transform(rays_old[..., :3]) ray_directions = T_new_old.rotate(rays_old[..., 3:]) return torch.cat([ray_origins, ray_directions], dim=-1) def ray_obb_intersection( rays_v, voxel_extent, t_min=-1e9, t_max=1e9, return_points=False ): assert rays_v.ndim == 3, f"{rays_v.shape}" assert rays_v.shape[-1] == 6, f"{rays_v.shape}" device = rays_v.device B, N = rays_v.shape[:2] x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent raysP_v = rays_v[..., :3] raysD_v = rays_v[..., 3:] # assume normalized! ns_bb = [ [1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, -1.0], ] ps_bb = [ [x_max, 0.0, 0.0], [x_min, 0.0, 0.0], [0.0, y_max, 0.0], [0.0, y_min, 0.0], [0.0, 0.0, z_max], [0.0, 0.0, z_min], ] eps = 1e-3 minmaxs_bb = [ [x_max - eps, y_min, z_min, x_max + eps, y_max, z_max], [x_min - eps, y_min, z_min, x_min + eps, y_max, z_max], [x_min, y_max - eps, z_min, x_max, y_max + eps, z_max], [x_min, y_min - eps, z_min, x_max, y_min + eps, z_max], [x_min, y_min, z_max - eps, x_max, y_max, z_max + eps], [x_min, y_min, z_min - eps, x_max, y_max, z_min + eps], ] t_upper = torch.ones((B, N), device=device) * t_max t_lower = torch.ones((B, N), device=device) * t_min ts = torch.stack([t_upper, t_lower], dim=-1) for n_bb, p_bb, minmax_bb in zip(ns_bb, ps_bb, minmaxs_bb): n_bb = torch.tensor(n_bb).view(1, 1, 3).to(device) p_bb = torch.tensor(p_bb).view(1, 1, 3).to(device) min_bb = torch.tensor(minmax_bb[:3]).view(1, 1, 3).to(device) max_bb = torch.tensor(minmax_bb[3:]).view(1, 1, 3).to(device) # dot product denom = (raysD_v * n_bb).sum(-1) valid = denom.abs() > 1e-6 dp = p_bb - raysP_v t = (dp * n_bb).sum(-1) / denom valid = torch.logical_and(valid, t > t_min) valid = torch.logical_and(valid, t < t_max) # points on surface ps_v = raysP_v + raysD_v * t.unsqueeze(-1) valid = torch.logical_and(valid, (ps_v > min_bb).all(-1)) valid = torch.logical_and(valid, (ps_v < max_bb).all(-1)) ts_min = torch.where(valid, t, t_upper) ts_max = torch.where(valid, t, t_lower) ts[..., 0] = torch.minimum(ts_min, ts[..., 0]) ts[..., 1] = torch.maximum(ts_max, ts[..., 1]) if return_points: one_int = ts[..., 0] == ts[..., 1] ts[..., 0] = torch.where( one_int, t_min * torch.ones_like(ts[..., 0]) * t_min, ts[..., 0] ) no_int = ts[..., 0] > ts[..., 1] ts[no_int] = t_min ps_min_v = raysP_v + raysD_v * ts[..., 0].unsqueeze(-1) ps_max_v = raysP_v + raysD_v * ts[..., 1].unsqueeze(-1) return ts, ps_min_v, ps_max_v return ts def sample_depths_in_grid( rays_v, ds_max, voxel_extent, W, H, D, num_samples, d_near=0.01, d_far=10.0, sample_mode: Literal["random", "uniform"] = "random", ds_min=None, ): assert rays_v.ndim == 4, f"{rays_v.shape}" # BxTxNx6 assert ds_max.ndim == 3, f"{ds_max.shape}" # BxTxN B = rays_v.shape[0] voxel_extent = tensor_wrap_voxel_extent(voxel_extent, B).to(rays_v.device) def safe_extent(voxel_extent, W, H, D): # compute a "safe" voxel extent that is shrunk by half a voxel in all # directions bb_min, bb_max = voxel_extent[::2], voxel_extent[1::2] dim = torch.tensor([W, H, D], device=voxel_extent.device) dd = 0.5 * (bb_max - bb_min) / dim bb_min = bb_min + dd bb_max = bb_max - dd voxel_extent_safe = torch.zeros_like(voxel_extent) voxel_extent_safe[::2] = bb_min voxel_extent_safe[1::2] = bb_max return voxel_extent_safe B, T, N = rays_v.shape[:3] ts = [] for b in range(B): voxel_extent_safe = safe_extent(voxel_extent[b], W, H, D) ts.append( ray_obb_intersection( rays_v[b].view(T, N, 6), voxel_extent_safe, t_min=d_near, t_max=d_far, ) ) ts = torch.stack(ts, 0) # BxTxNx2 no_int = ts[..., 0] > ts[..., 1] one_int = ts[..., 0] == ts[..., 1] depths_min = torch.where( one_int, torch.ones_like(ts[..., 0]) * d_near, ts[..., 0] ) # BxTxN depths_max = ts[..., 1] depths_min[no_int] = torch.nan depths_max[no_int] = torch.nan if ds_max is not None: depths_max = torch.minimum(ds_max, depths_max) if ds_min is not None: depths_min = torch.maximum(ds_min, depths_min) ddepths = depths_max - depths_min ddepths[ddepths < 1e-3] = torch.nan # go to d_min to d_max per ray depths = torch.linspace(0.0, 1.0, num_samples).to(rays_v.device) depths = depths.view(1, 1, 1, num_samples).repeat(B, T, N, 1) depths = depths_min.unsqueeze(-1) + ddepths.unsqueeze(-1) * depths if sample_mode == "uniform": return depths, depths_max, ~no_int.view(B, T, N) elif sample_mode == "random": # add noise noise = torch.rand((B, T, N, num_samples), device=rays_v.device) noise = noise * (ddepths.unsqueeze(-1) / num_samples) if num_samples > 1: noise[..., -1] = 0.0 depths = depths + noise return depths, depths_max, ~no_int.view(B, T, N) else: raise ValueError(f"Unknown sample mode {sample_mode}") ================================================ FILE: efm3d/utils/reconstruction.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Literal import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_DISTANCE_M, ARIA_IMG_T_SNIPPET_RIG, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.utils.depth import dist_im_to_point_cloud_im from efm3d.utils.detection_utils import compute_focal_loss from efm3d.utils.pointcloud import ( pointcloud_occupancy_samples, pointcloud_to_occupancy_snippet, pointcloud_to_voxel_ids, ) from efm3d.utils.voxel_sampling import pc_to_vox, sample_voxels from einops import rearrange from torch.nn import functional as F def build_gt_occupancy(occ, visible, p3s_w, Ts_wc, cams, T_wv, voxel_extent): """ build GT occupancy from GT point cloud, return batched occupancy with masks. """ B, vD, vH, vW = occ.shape occ_gts, masks = [], [] for b in range(B): occ_gt, mask = pointcloud_to_occupancy_snippet( p3s_w[b], Ts_wc[b], cams[b], T_wv[b], vW, vH, vD, voxel_extent, S=1, ) mask = torch.logical_and(mask.bool(), visible[b]) occ_gts.append(occ_gt) masks.append(mask) occ_gts = torch.stack(occ_gts) masks = torch.stack(masks) return occ_gts, masks def get_fused_gt_feat( visible, p3s_w, Ts_wc, cams, T_wv, voxel_extent, img_feat_gt, feat_pred, gt_dists, vD, vH, vW, ): feat_dim = img_feat_gt.shape[2] gt_feat_volume = torch.zeros_like(feat_pred).detach() # BxCxDxHxW gt_feat_volume = gt_feat_volume.permute( 0, 2, 3, 4, 1 ) # BxDxHxWxC for easier indexing gt_feat_volume_counts = ( torch.zeros(*gt_feat_volume.shape[:4]).to(feat_pred).detach() ) # BxDxHxW dists = gt_dists.squeeze(2) B, T = cams.shape[:2] pc_c, valids = dist_im_to_point_cloud_im(dists, cams) pc_c = pc_c.reshape(B, T, -1, 3) # BxTxNx3 T_vc = T_wv.inverse() @ Ts_wc pc_v = T_vc * pc_c for b in range(B): pc_ids, valid_v = pointcloud_to_voxel_ids(pc_v[b], vW, vH, vD, voxel_extent) for t in range(T): valid = torch.logical_and(valid_v[t], valids[b, t].reshape(-1)) pc_ids_t = pc_ids[t][valid] feat_gt_t = img_feat_gt[b, t].permute(1, 2, 0).reshape(-1, feat_dim) feat_gt_t = feat_gt_t[valid] gt_feat_volume_counts[b][ pc_ids_t[:, 0], pc_ids_t[:, 1], pc_ids_t[:, 2] ] += 1.0 gt_feat_volume[b][pc_ids_t[:, 0], pc_ids_t[:, 1], pc_ids_t[:, 2]] += ( feat_gt_t ) gt_feat_volume[b][gt_feat_volume_counts[b] > 1e-4] /= gt_feat_volume_counts[b][ gt_feat_volume_counts[b] > 1e-4 ].unsqueeze(-1) surface_mask = gt_feat_volume_counts > 1e-4 # BxDxHxW return gt_feat_volume, surface_mask def get_feats_world(batch, tgt_feats): B = tgt_feats.shape[0] tgt_H, tgt_W = tgt_feats.shape[-2], tgt_feats.shape[-1] dists_ori = batch[ARIA_DISTANCE_M[0]] cams_ori = batch[ARIA_CALIB[0]] # rescale dist and camera to tgt feat size dists = rearrange(dists_ori, "b t c h w -> (b t) c h w") dists = F.interpolate(dists, [tgt_H, tgt_W], mode="nearest") dists = rearrange(dists, "(b t) c h w -> b t c h w", b=B).squeeze(2) cams = cams_ori.scale_to_size((tgt_W, tgt_H)) B, T = cams.shape[:2] Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]] T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET] Ts_wr = T_ws @ Ts_sr Ts_cw = cams.T_camera_rig @ Ts_wr.inverse() Ts_wc = Ts_cw.inverse() pc_c, valids = dist_im_to_point_cloud_im(dists, cams) B, T, H, W = pc_c.shape[:4] pc_w = Ts_wc * pc_c.view(B, T, -1, 3) pc_w = pc_w.view(B, T, H, W, 3) pc_w[~valids] = float("nan") # nan # remove all points that are invalid across all time and batches. all_valid = ~(~valids).all(0).all(0) all_valid = all_valid.view(1, 1, H, W).repeat(B, T, 1, 1) pc_w = pc_w[all_valid].view(B, T, -1, 3) feat_dim = tgt_feats.shape[2] tgt_feats = tgt_feats.permute(0, 1, 3, 4, 2) tgt_feats = tgt_feats[all_valid].view(B, T, -1, feat_dim) return pc_w, tgt_feats def compute_tv_loss(occ): # B 1 D H W tv_d = (occ[:, 1:, :, :] - occ[:, :-1, :, :]).abs().mean() tv_h = (occ[:, :, 1:, :] - occ[:, :, :-1, :]).abs().mean() tv_w = (occ[:, :, :, 1:] - occ[:, :, :, :-1]).abs().mean() tv_loss = tv_d + tv_h + tv_w return tv_loss def compute_occupancy_loss_subvoxel( occ, visible, p3s_w_all, Ts_wc, cams, T_wv, voxel_extent, S=1, sample_beyond=False, surf_val=0.5, subsample=1, free_surf_occ_weights=None, loss_type: Literal["l2", "l1", "logl1", "ce", "focal"] = "focal", ): """ sample occupied, surface and freespace GT points obtain predictions at those sample points by sampling into the occ voxel grid via tri-linear interpolation. """ assert p3s_w_all.ndim == 4, f"{p3s_w_all.shape}" # B T N 3 assert occ.ndim == 4, f"{occ.shape}" # B D H W assert visible.ndim == 4, f"{visible.shape}" # B D H W assert not sample_beyond, "not supported" device = occ.device B, vD, vH, vW = occ.shape if subsample > 1: # subsample B, T, N = p3s_w_all.shape[:3] ids = torch.randperm(N)[: N // subsample].to(device) p3s_w = p3s_w_all[:, :, ids] # print("subsample", subsample, p3s_w.shape, p3s_w_all.shape) else: p3s_w = p3s_w_all B, T, N = p3s_w.shape[:3] p3s_occ_w, p3s_surf_w, p3s_free_w, valid = pointcloud_occupancy_samples( p3s_w, Ts_wc, cams, vD, vH, vW, voxel_extent, S=S, sample_beyond=sample_beyond, vox_diag_scale=1.0, T_wv=T_wv, ) Ts_vw = T_wv.inverse().view(B, 1, -1).repeat(1, T, 1) p3s_occ_v = Ts_vw * p3s_occ_w p3s_surf_v = Ts_vw * p3s_surf_w p3s_free_v = Ts_vw * p3s_free_w B, vD, vH, vW = occ.shape # free points p3s_free_vox, valid_free = pc_to_vox(p3s_free_v, vW, vH, vD, voxel_extent) valid_free = torch.logical_and(valid_free, valid) free_samples, valid_samples = sample_voxels( occ.unsqueeze(1), p3s_free_vox.view(B, -1, 3) ) free_samples, valid_samples = ( free_samples.view(B, T, -1), valid_samples.view(B, T, -1), ) valid_free = torch.logical_and(valid_samples, valid_free) free_samples = free_samples[valid_free].clamp(0.0, 1.0) free_gt = torch.zeros_like(free_samples) # surface points p3s_surf_vox, valid_surf = pc_to_vox(p3s_surf_v, vW, vH, vD, voxel_extent) valid_surf = torch.logical_and(valid_surf, valid) surf_samples, valid_samples = sample_voxels( occ.unsqueeze(1), p3s_surf_vox.view(B, -1, 3) ) surf_samples, valid_samples = ( surf_samples.view(B, T, -1), valid_samples.view(B, T, -1), ) valid_surf = torch.logical_and(valid_samples, valid_surf) surf_samples = surf_samples[valid_surf].clamp(0.0, 1.0) surf_gt = surf_val * torch.ones_like(surf_samples) # occupied points p3s_occ_vox, valid_occ = pc_to_vox(p3s_occ_v, vW, vH, vD, voxel_extent) valid_occ = torch.logical_and(valid_occ, valid) occ_samples, valid_samples = sample_voxels( occ.unsqueeze(1), p3s_occ_vox.view(B, -1, 3) ) occ_samples, valid_samples = ( occ_samples.view(B, T, -1), valid_samples.view(B, T, -1), ) valid_occ = torch.logical_and(valid_samples, valid_occ) occ_samples = occ_samples[valid_occ].clamp(0.0, 1.0) occ_gt = torch.ones_like(occ_samples) if free_surf_occ_weights is None: num = free_samples.numel() + surf_samples.numel() + occ_samples.numel() if loss_type == "l2": # L2 loss pred = torch.cat([free_samples, surf_samples, occ_samples], -1) gt = torch.cat([free_gt, surf_gt, occ_gt], -1) loss = ((pred - gt) ** 2).sum() elif loss_type == "l1": # L1 loss pred = torch.cat([free_samples, surf_samples, occ_samples], -1) gt = torch.cat([free_gt, surf_gt, occ_gt], -1) loss = (pred - gt).abs().sum() elif loss_type == "logl1": # logl1 loss pred = torch.cat([free_samples, surf_samples, occ_samples], -1) gt = torch.cat([free_gt, surf_gt, occ_gt], -1) loss = (torch.log(pred + 1e-5) - torch.log(gt + 1e-5)).abs().sum() elif loss_type == "ce": # CE on free and occ and L1 on surf pred = torch.cat([free_samples, occ_samples], -1) gt = torch.cat([free_gt, occ_gt], -1) loss = F.binary_cross_entropy(pred, gt, reduction="sum") loss = loss + (surf_samples - surf_gt).abs().sum() elif loss_type == "focal": # like used to using focal loss pred = torch.cat([free_samples, surf_samples, occ_samples], -1) gt = torch.cat([free_gt, surf_gt, occ_gt], -1) loss = compute_focal_loss(pred, gt) assert not loss.isnan().any(), ( f"have nans in loss {loss.isnan().count_nonzero()}" ) # handle no samples case in mean num = max(1.0, num) loss = loss.sum() / num return loss assert loss_type == "focal", f"{loss_type} not supported" loss_free = compute_focal_loss(free_samples, free_gt).sum() loss_surf = compute_focal_loss(surf_samples, surf_gt).sum() loss_occ = compute_focal_loss(occ_samples, occ_gt).sum() loss_free = loss_free / max(1.0, loss_free.numel()) loss_surf = loss_surf / max(1.0, loss_surf.numel()) loss_occ = loss_occ / max(1.0, loss_occ.numel()) loss = loss_free * free_surf_occ_weights[0] + loss_occ * free_surf_occ_weights[2] return loss ================================================ FILE: efm3d/utils/render.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import colorsys from typing import Dict, Literal import cv2 import numpy as np import torch from efm3d.aria import CameraTW, ObbTW, PoseTW from efm3d.utils.image import put_text, torch2cv2 AXIS_COLORS_RGB = { 0: (255, 0, 0), # red 3: (0, 255, 0), # green 8: (0, 0, 255), # blue } # use RGB for xyz axes respectively def get_colors(num_colors: int, scale_to_255: bool = False): assert num_colors > 0, f"Number of colors {num_colors} has to be positive." colors = [] for i in range(num_colors): hue = i / num_colors # Spread out the colors in the hue space saturation = 1.0 # Use maximum saturation for bright colors value = 1.0 # Use maximum value for bright colors rgb = colorsys.hsv_to_rgb(hue, saturation, value) if scale_to_255: colors.append((int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))) else: colors.append((rgb[0], rgb[1], rgb[2])) return colors # RGB values in [0, 1] used in Static Structure Index SSI_SEM_COLORS = { "floor": (1, 0.75, 0.75), "mirror": (0.5, 0.5, 0.5), "ceiling": (1, 1, 0.75), "chair": (0.2, 0.6, 1), "bench": (0.2, 0.6, 1), "ottoman": (0.2, 0.6, 1), "table": (1, 1, 0), "desk": (1, 1, 0), "storage": (0.7, 0.4, 0.05), "plant": (0, 1, 0), "plant_or_flower_pot": (0, 1, 0), "vase": (0, 1, 0), "screen": (1, 0, 0), "wallart": (0.6, 0.3, 0.95), "picture_frame_or_painting": (0.6, 0.3, 0.95), "bed": (0.55, 0.9, 0), # "couch": (0, 1, 1), # SSI color "couch": (0.1, 0.5, 0.1), # dark green # "sofa": (0, 1, 1), # SSI color "sofa": (0.1, 0.5, 0.1), # dark green "wall": (1, 1, 1), "lamp": (1, 0.8, 0.25), "door": (0.95, 0.25, 0.85), "window": (0.5, 1, 1), "unknown": (0.4, 0.4, 0.8), "other": (0.6, 0.6, 0.6), # hard code 'floor_mat' to dark red "floor_mat": (0.8, 0.15, 0.15), # dark red } def get_colors_from_sem_map( sem_ids_to_names: Dict[int, str], scale_to_255: bool = True, match_with_ssi: bool = True, ): """ sem_ids_to_names: taxonomy map from semantic id to semantic name. scale_to_255: whether to scale the colors to [0, 255]. match_with_ssi: whether to match the colors with the Static Structure Index taxonomy for the overlapped classes. """ if len(sem_ids_to_names) == 0: num_sem_ids = 1 else: num_sem_ids = max(sem_ids_to_names.keys()) + 1 colors = get_colors(num_sem_ids, scale_to_255=scale_to_255) if match_with_ssi: for sem_id, sem_name in sem_ids_to_names.items(): sn = sem_name.lower() if sn in SSI_SEM_COLORS: clr = SSI_SEM_COLORS[sn] if scale_to_255: clr2 = ( int(round(clr[0] * 255)), int(round(clr[1] * 255)), int(round(clr[2] * 255)), ) else: clr2 = clr colors[sem_id] = clr2 return colors def draw_bb2s( viz, bb2s, line_type=cv2.LINE_AA, bb2s_center=None, labels=None, rotate_text=True, color=None, text_size=0.6, ): """ Args: viz: numpy array image bb2s: a list of bounding boxes as numpy array Nx 4 where (x_min, x_max, y_min, y_max) per row color: either a 3-tuple/list or a list 3-tuples, or an np.array shaped Nx3 """ height = viz.shape[0] if height > 320: thickness = 2 else: thickness = 1 if color is None: color = (255, 100, 100) # brighter red if bb2s.shape[0] == 0: return viz def _draw_bb2_line(img, p1, p2, clr): cv2.line(img, p1, p2, clr, thickness, lineType=line_type) if isinstance(color[0], (list, tuple, np.ndarray)): assert len(color) == len(bb2s), ( "need either single color or same # of colors as bb2s" ) if isinstance(color[0], np.ndarray): colors = [clr.tolist() for clr in color] else: colors = color elif isinstance(color[0], (int, float)): colors = [color for _ in range(len(bb2s))] else: raise TypeError("Unknown type for 'color' argument of draw_bb2s()") for i, (bb2, clr) in enumerate(zip(bb2s, colors)): x_min, y_min = int(round(bb2[0].item())), int(round(bb2[2].item())) # min pt x_max, y_max = int(round(bb2[1].item())), int(round(bb2[3].item())) # max pt # if x_min < 0 or y_min < 0: # print("WARNING line point outside image") _draw_bb2_line(viz, (x_min, y_min), (x_min, y_max), clr) _draw_bb2_line(viz, (x_min, y_max), (x_max, y_max), clr) _draw_bb2_line(viz, (x_max, y_max), (x_max, y_min), clr) _draw_bb2_line(viz, (x_max, y_min), (x_min, y_min), clr) if bb2s_center is not None: cx = int(round(float(bb2s_center[i, 0]))) cy = int(round(float(bb2s_center[i, 1]))) cv2.circle(viz, (cx, cy), 1, clr, 1, lineType=line_type) if labels is not None: text = labels[i] x = int(round((x_min + x_max) / 2.0)) y = int(round((y_min + y_max) / 2.0)) if rotate_text: viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE) center_rot90 = (height - y, x) x, y = center_rot90 ((txt_w, txt_h), _) = cv2.getTextSize( text, cv2.FONT_HERSHEY_DUPLEX, text_size, 1 ) x = x - int(round(txt_w / 4)) y = y + int(round(txt_h / 4)) put_text(viz, text, scale=text_size, font_pt=(x, y)) if rotate_text: viz = cv2.rotate(viz, cv2.ROTATE_90_COUNTERCLOCKWISE) return viz def draw_bb3_lines( viz, T_world_cam: PoseTW, cam: CameraTW, obbs: ObbTW, draw_cosy: bool, T: int, line_type=cv2.LINE_AA, colors=None, thickness=1, ): bb3corners_world = obbs.T_world_object * obbs.bb3edge_pts_object(T) bb3corners_cam = T_world_cam.inverse() * bb3corners_world B = bb3corners_cam.shape[0] pt3s_cam = bb3corners_cam.view(B, -1, 3) pt2s, valids = cam.project(pt3s_cam) sem_ids = obbs.sem_id.int() # reshape to lines each composed of T segments pt2s = pt2s.round().int().view(B * 12, T, 2) valids = valids.view(B * 12, T) for line in range(pt2s.shape[0]): line_id = line % 12 obb_id = line // 12 sem_id = sem_ids[obb_id] # if colors is not None and sem_id >= len(colors): # print("warning sem_id too big", sem_id, len(colors)) if colors is None or sem_id >= len(colors): color = (255, 255, 255) else: color = colors[sem_id] for i in range(T - 1): j = i + 1 if valids[line, i] and valids[line, j]: # check if we should color this line in a special way if draw_cosy and line_id in AXIS_COLORS_RGB: color = AXIS_COLORS_RGB[line_id] pt1 = ( int(round(float(pt2s[line, i, 0]))), int(round(float(pt2s[line, i, 1]))), ) pt2 = ( int(round(float(pt2s[line, j, 0]))), int(round(float(pt2s[line, j, 1]))), ) cv2.line( viz, pt1, pt2, color, thickness, lineType=line_type, ) def draw_bb3s( viz, T_world_rig: PoseTW, cam: CameraTW, obbs: ObbTW, draw_bb3_center=False, draw_bb3=True, draw_label=False, draw_cosy=True, draw_score=True, render_obb_corner_steps=10, line_type=cv2.LINE_AA, sem_id_to_name_mapping: Dict[int, str] = None, rotate_label=True, colors=None, white_backing_line=True, draw_inst_id=False, ): # Get pose of camera. T_world_cam = T_world_rig.float() @ cam.T_camera_rig.inverse() # Project the 3D BB center into the image. if draw_bb3: # auto set the thickness of the bb3 lines thickness = 1 # draw white background lines if white_backing_line: draw_bb3_lines( viz, T_world_cam, cam, obbs, draw_cosy=draw_cosy, T=render_obb_corner_steps, line_type=cv2.LINE_AA, colors=None, thickness=thickness + 1, ) # draw semantic colors draw_bb3_lines( viz, T_world_cam, cam, obbs, draw_cosy=draw_cosy, T=render_obb_corner_steps, line_type=cv2.LINE_AA, colors=colors, thickness=thickness, ) if draw_label or draw_bb3_center: bb3center_cam = T_world_cam.inverse() * obbs.bb3_center_world bb2center_im, valids = cam.unsqueeze(0).project(bb3center_cam.unsqueeze(0)) bb2center_im, valids = bb2center_im.squeeze(0), valids.squeeze(0) for idx, (pt2, valid) in enumerate(zip(bb2center_im, valids)): if valid: center = (int(pt2[0]), int(pt2[1])) if draw_bb3_center: cv2.circle(viz, center, 3, (255, 0, 0), 1, lineType=line_type) if draw_label: height = viz.shape[0] sem_id = int(obbs.sem_id.squeeze(-1)[idx]) if sem_id_to_name_mapping and sem_id in sem_id_to_name_mapping: text = sem_id_to_name_mapping[sem_id] else: # display sem_id if no mapping is provided. text = str(sem_id) if draw_inst_id: inst_id = int(obbs.inst_id.squeeze(-1)[idx]) text = f"{inst_id}: {text}" # rot 90 degree before drawing the text if rotate_label: viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE) center_rot90 = (height - center[1], center[0]) x, y = center_rot90 else: x, y = center ((txt_w, txt_h), _) = cv2.getTextSize( text, cv2.FONT_HERSHEY_DUPLEX, 0.4, 1 ) x = x - txt_w // 4 y = y + txt_h // 4 # Show text on top of the 3d boxes bb2_ymin = obbs.bb2_rgb[idx][2] bb2_ymax = obbs.bb2_rgb[idx][3] up = int((bb2_ymax - bb2_ymin) / 2.0) if y - up > 0: put_text(viz, text, scale=0.8, font_pt=(x, y - up)) if draw_score and obbs.prob is not None: score = float(obbs.prob.squeeze(-1)[idx]) score_text = f"{score:.2f}" score_pos = (x, y + int(txt_h + 0.5) - up) put_text( viz, score_text, scale=0.5, font_pt=score_pos, color=(200, 200, 200), ) if rotate_label: viz = cv2.rotate(viz, cv2.ROTATE_90_COUNTERCLOCKWISE) return viz def draw_obbs_image( img: torch.Tensor, obbs_padded: ObbTW, T_world_rig: PoseTW = None, cam: CameraTW = None, aria_cam_id: Literal[0, 1, 2] = 0, draw_bb2=False, draw_bb3=True, draw_bb3_center=False, draw_label=False, draw_cosy=True, draw_score=False, render_obb_corner_steps=10, post_rotate_viz=True, # whether to rotate the image 90 degrees before (pre) or after (post) rendering 3d bbs; only for debugging. resulting image should be the same! rgb2bgr=True, rotate_viz=True, background_sem_id: int = None, prob_threshold: float = 0.5, sem_id_to_name_mapping: Dict[int, str] = None, draw_label_2d: bool = False, # Draw label on 2D viz also. white_backing_line: bool = True, draw_inst_id: bool = False, draw_conic: bool = False, ): assert img.dim() == 3, f"image input must be 3D tensor {img.shape}" assert obbs_padded.dim() == 2, ( f"assuming one set of obbs per frame {obbs_padded.shape}" ) viz = torch2cv2(img, rotate=False, ensure_rgb=True, rgb2bgr=rgb2bgr) if not post_rotate_viz and rotate_viz: viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE) # get valid obbs obbs = obbs_padded.remove_padding() if obbs.shape[0] == 0: # Handle no valid OBBs. if post_rotate_viz and rotate_viz: viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE) return viz # filter out low probability obbs good = obbs.prob >= prob_threshold colors = None if sem_id_to_name_mapping is not None: colors = get_colors_from_sem_map(sem_id_to_name_mapping) if obbs.shape[0] > 0 and good.any(): obbs = obbs[good.squeeze(-1), :] # if we have background id given, then filter out background obbs if background_sem_id is not None: background = obbs.sem_id == background_sem_id obbs = obbs[~background.squeeze(-1), :] if obbs.shape[0] > 0: # Draw 2D bounding box. if not draw_label_2d or sem_id_to_name_mapping is None: labels = None else: sem_id = obbs.sem_id.squeeze(-1) labels = [sem_id_to_name_mapping[int(si)] for si in sem_id] if draw_inst_id: inst_ids = obbs.inst_id.squeeze(-1) labels = [f"{inst}:{n}" for inst, n in zip(inst_ids, labels)] if draw_bb2: viz = draw_bb2s( viz, obbs.bb2(aria_cam_id), bb2s_center=obbs.get_bb2_centers(aria_cam_id), labels=labels, ) if draw_conic and cam and T_world_rig: pass # Draw 3D bounding box (requires poses from VIO). if cam and T_world_rig and (draw_bb3 or draw_bb3_center): if not post_rotate_viz and rotate_viz: cam = cam.rotate_90_cw() viz = draw_bb3s( viz, T_world_rig, cam, obbs, draw_bb3_center=draw_bb3_center, draw_bb3=draw_bb3, draw_label=draw_label, draw_cosy=draw_cosy, draw_score=draw_score, render_obb_corner_steps=render_obb_corner_steps, sem_id_to_name_mapping=sem_id_to_name_mapping, rotate_label=rotate_viz, colors=colors, white_backing_line=white_backing_line, draw_inst_id=draw_inst_id, ) # Rotate everything before displaying. if post_rotate_viz and rotate_viz: viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE) return viz def draw_obbs_snippet( imgs: torch.Tensor, obbs_padded: ObbTW, Ts_world_rig: PoseTW = None, cams: CameraTW = None, aria_cam_id: Literal[0, 1, 2] = 0, draw_bb2=True, draw_bb3=True, draw_bb3_center=False, render_obb_corner_steps=10, post_rotate_viz=True, # whether to rotate the image 90 degrees before (pre) or after (post) rendering 3d bbs; only for debugging. resulting image should be the same! rgb2bgr=True, rotate_viz=True, background_sem_id: int = None, prob_threshold: float = 0.5, sem_id_to_name_mapping: Dict[int, str] = None, draw_label: bool = False, draw_label_2d: bool = False, # Draw label on 2D viz also. white_backing_line: bool = True, draw_cosy: bool = True, draw_score: bool = False, draw_inst_id: bool = False, draw_conic: bool = False, ): assert imgs.dim() == 4, f"snippet input must be 4D tensor {imgs.shape}" T = imgs.shape[0] viz = [] for t in range(T): if obbs_padded.dim() == 2: cur_obbs_padded = obbs_padded elif obbs_padded.dim() == 3: cur_obbs_padded = obbs_padded[t] else: raise ValueError( f"obbs_padded must have 2 or 3 dimensions {obbs_padded.shape}" ) viz.append( draw_obbs_image( img=imgs[t], obbs_padded=cur_obbs_padded, T_world_rig=Ts_world_rig[t], cam=cams[t], aria_cam_id=aria_cam_id, draw_bb2=draw_bb2, draw_bb3=draw_bb3, draw_bb3_center=draw_bb3_center, render_obb_corner_steps=render_obb_corner_steps, post_rotate_viz=post_rotate_viz, rgb2bgr=rgb2bgr, rotate_viz=rotate_viz, background_sem_id=background_sem_id, prob_threshold=prob_threshold, sem_id_to_name_mapping=sem_id_to_name_mapping, draw_label=draw_label, draw_label_2d=draw_label_2d, white_backing_line=white_backing_line, draw_cosy=draw_cosy, draw_score=draw_score, draw_inst_id=draw_inst_id, draw_conic=draw_conic, ) ) return viz def discretize_values(values: torch.Tensor, precision: int): """ Discretize the values of an input tensor with a certain precision. The lower the precision, the coarser the output. The function is added to better rendering a dense pointcloud. """ d_values = (values * precision).int() d_values = (torch.unique(d_values, dim=0) / precision).float() return d_values ================================================ FILE: efm3d/utils/rescale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Literal import cv2 import numpy as np import torch from efm3d.aria import CameraTW, ObbTW from efm3d.aria.aria_constants import RESOLUTION_MAP def get_crops_scale( W: int, H: int, cam_name: Literal["rgb", "slaml", "slamr"], down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, ): # Pre-cropping is universal to all down_scaling. # Handle RGB properly with binning pre_crop = None if cam_name == "rgb" and W == 2880 and H == 2880: # crop image to 2816x2816 pre_crop = [32, 32] H, W = H - 64, W - 64 if down_scale in [0, 1, 2, 4]: factor = 1 if cam_name == "rgb": if W == 2816 and H == 2816: # downsample to 1408x1408 factor = 2 if down_scale > 0: factor = 2 * down_scale * factor else: factor = down_scale if factor <= 1: factor = None if factor: # W, H after scaling down W, H = W // factor, H // factor # post-crop to reach size divisible by wh_multiple_of w_crop = (W % wh_multiple_of) // 2 h_crop = (H % wh_multiple_of) // 2 post_crop = [w_crop, h_crop] # set outputs none if they are not needed if w_crop == 0 and h_crop == 0: post_crop = None elif down_scale in RESOLUTION_MAP: if cam_name == "rgb": target_h = RESOLUTION_MAP[down_scale][0] target_w = RESOLUTION_MAP[down_scale][0] elif cam_name in ["slaml", "slamr"]: target_w = RESOLUTION_MAP[down_scale][1] target_h = RESOLUTION_MAP[down_scale][2] else: raise ValueError("Specified cam_name of %s is not supported" % down_scale) if target_h % wh_multiple_of != 0 or target_w % wh_multiple_of != 0: raise ValueError( f"only wh_multiple_of 16 is guaranteed when using scale_down == [5,6,7,8,9] {target_h} % {wh_multiple_of}" ) # This rescale factor can be non-integer. factor_w = W / target_w factor_h = H / target_h assert factor_w == factor_h, ( "rescale factor must maintain original aspect ratio" ) factor = factor_w post_crop = None else: raise ValueError("Specified down_scale of %d is not supported" % down_scale) return pre_crop, factor, post_crop def rescale_camera_tw( cam: CameraTW, cam_size_before, # tuple of (height, width, ...) cam_name: Literal["rgb", "slaml", "slamr"], down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, ): """ Rescale CameraTW tensors by passing the camera size, camera name, and a down scale factor. cam shape should be [..., N] where N is the valid camera calibration dimension (25 or 33) """ H, W = cam_size_before[:2] if (cam.c > 1000.0).any(): # it can happen that the calibration was stored with respect to the full # 2880 x 2880 resolution although the rgb video stream is binned to 1408 x # 1408. We catch this by looking at the principal point which should be # about [704, 704] and fix the calibration. H, W = 2880, 2880 if cam.valid_radius[0].item() < 1000.0: # it is likely that the valid_radius was set on the wrong cam # size (2x too small) so we fix it here. cam.set_valid_radius(cam.valid_radius * 2.0) pre_crop, factor, post_crop = get_crops_scale( W, H, cam_name, down_scale, wh_multiple_of ) if pre_crop: # new width and height after center crop W, H = W - 2 * pre_crop[0], H - 2 * pre_crop[1] cam = cam.crop(pre_crop, (W, H)) if factor: cam = cam.scale(1.0 / factor) # after scaling W, H = W // factor, H // factor if post_crop: # new width and height after center crop W, H = W - 2 * post_crop[0], H - 2 * post_crop[1] cam = cam.crop(post_crop, (W, H)) return cam def rescale_calib( calib, cam_size_before, # tuple of (height, width, ...) cam_name: Literal["rgb", "slaml", "slamr"], down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, ): """ rescale raw camera parameters """ # fisheye264 assert calib.shape[-1] == 15 H, W = cam_size_before[:2] # it can happen that the calibration was stored with respect to the full # 2880 x 2880 resolution although the rgb video stream is binned to 1408 x # 1408. We catch this by looking at the principal point which should be # about [704, 704] and fix the calibration. if (calib[1:3] > 1000.0).any(): H, W = 2880, 2880 pre_crop, factor, post_crop = get_crops_scale( W, H, cam_name, down_scale, wh_multiple_of ) if pre_crop: calib[1:3] = calib[1:3] - np.array(pre_crop) if factor: calib[0] = calib[0] / factor calib[1:3] = (calib[1:3] + 0.5) / factor - 0.5 if post_crop: calib[1:3] = calib[1:3] - np.array(post_crop) return calib def rescale_image( img, cam_name: Literal["rgb", "slaml", "slamr"], down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, ): H, W = img.shape[:2] pre_crop, factor, post_crop = get_crops_scale( W, H, cam_name, down_scale, wh_multiple_of ) if pre_crop: img = img[pre_crop[1] : H - pre_crop[1], pre_crop[0] : W - pre_crop[0], ...] if factor: # When factor is integer, then cv2.INTER_AREA behaves identically # to skimage.downscale_local_mean, as described in the blog post: # https://medium.com/@wenrudong/what-is-opencvs-inter-area-actually-doing-282a626a09b3 H, W = img.shape[:2] target_wh = int(round(W / factor)), int(round(H / factor)) orig_ndim = img.ndim img = cv2.resize(img, target_wh, interpolation=cv2.INTER_AREA) if orig_ndim == 3 and img.ndim == 2: img = np.expand_dims(img, axis=2) # Preserve HxWx1 vs HxW to match input. if post_crop: H, W = img.shape[:2] img = img[post_crop[1] : H - post_crop[1], post_crop[0] : W - post_crop[0], ...] return img def rescale_image_tensor( img, cam_name: Literal["rgb", "slaml", "slamr"], down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, interpolate_mode: str = "bilinear", ): """ Rescale the Aria image tensor. `img` is a torch Tensor, which is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. `down_scale` specifies the degree of down-sampling. """ from torchvision.transforms.functional import InterpolationMode, resize str2torchvision_mapping = { "nearest": InterpolationMode.NEAREST, "nearest-exact": InterpolationMode.NEAREST_EXACT, "bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC, "box": InterpolationMode.BOX, "hamming": InterpolationMode.HAMMING, "lanczos": InterpolationMode.LANCZOS, } H, W = img.shape[-2:] pre_crop, factor, post_crop = get_crops_scale( W, H, cam_name, down_scale, wh_multiple_of ) if pre_crop: img = img[..., pre_crop[1] : H - pre_crop[1], pre_crop[0] : W - pre_crop[0]] if factor: H, W = img.shape[-2:] target_hw = int(round(H / factor)), int(round(W / factor)) img = resize( img, target_hw, interpolation=str2torchvision_mapping[interpolate_mode], antialias=True, ) if post_crop: H, W = img.shape[-2:] img = img[..., post_crop[1] : H - post_crop[1], post_crop[0] : W - post_crop[0]] return img def rescale_depth_img( depth_img, scale_down, filter_boundary=True, valid=None, wh_multiple_of: int = 16 ): # Use torch to re-scale since opencv doesn't re-scale. # And make sure it's 1xHxW depth_img = torch.tensor(depth_img).squeeze().unsqueeze(0) depth_img_rescale = rescale_image_tensor( depth_img, "rgb", scale_down, wh_multiple_of=wh_multiple_of, interpolate_mode="nearest", ) if not filter_boundary: return depth_img_rescale # Change the mask to float to capture the boundaries of invalid area. if valid is None: d_mask = (depth_img > 0).float() else: d_mask = torch.tensor(valid).float().unsqueeze(0) d_mask_rescale = rescale_image_tensor( d_mask, "rgb", scale_down, wh_multiple_of=wh_multiple_of, interpolate_mode="nearest", ) # only the mask pixels which are close to 1.0 is the valid ones. depth_img_rescale[abs(d_mask_rescale - 1.0) > 1e-5] = 0.0 return depth_img_rescale def rescale_obb_tw( obbs: ObbTW, cam_size_before_rgb, cam_size_before_slam, down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0, wh_multiple_of: int = 16, ): """ Rescale ObbTW 2d bb tensors by passing the camera size, camera name, and a down scale factor. """ H_rgb, W_rgb = cam_size_before_rgb[:2] H_slam, W_slam = cam_size_before_slam[:2] pre_crop_rgb, factor_rgb, post_crop_rgb = get_crops_scale( W_rgb, H_rgb, "rgb", down_scale, wh_multiple_of ) pre_crop_slam, factor_slam, post_crop_slam = get_crops_scale( W_slam, H_slam, "slaml", down_scale, wh_multiple_of ) if pre_crop_rgb or pre_crop_slam: if not pre_crop_rgb: pre_crop_rgb = [0, 0] if not pre_crop_slam: pre_crop_slam = [0, 0] obbs = obbs.crop_bb2(left_top_rgb=pre_crop_rgb, left_top_slam=pre_crop_slam) if factor_rgb or factor_slam: if not factor_slam: factor_slam = 1.0 if not factor_rgb: factor_slam = 1.0 obbs = obbs.scale_bb2(scale_rgb=1.0 / factor_rgb, scale_slam=1.0 / factor_slam) if post_crop_rgb or post_crop_slam: if not post_crop_rgb: post_crop_rgb = [0, 0] if not post_crop_slam: post_crop_slam = [0, 0] obbs = obbs.crop_bb2(left_top_rgb=post_crop_rgb, left_top_slam=post_crop_slam) return obbs ================================================ FILE: efm3d/utils/viz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import platform from typing import Optional, Tuple, Union import moderngl import numpy as np import torch from efm3d.aria.aria_constants import ( ARIA_CALIB, ARIA_CALIB_TIME_NS, ARIA_DISTANCE_M, ARIA_DISTANCE_M_PRED, ARIA_IMG_T_SNIPPET_RIG, ARIA_MESH_FACES, ARIA_MESH_VERT_NORMS_W, ARIA_MESH_VERTS_W, ARIA_OBB_PADDED, ARIA_OBB_PRED_VIZ, ARIA_OBB_TRACKED, ARIA_OBB_UNINST, ARIA_POINTS_WORLD, ARIA_POSE_T_SNIPPET_RIG, ARIA_POSE_TIME_NS, ARIA_SNIPPET_T_WORLD_SNIPPET, ) from efm3d.aria.camera import CameraTW from efm3d.aria.obb import BB3D_LINE_ORDERS, OBB_LINE_INDS, OBB_MESH_TRI_INDS, ObbTW from efm3d.aria.pose import PoseTW from efm3d.utils.common import sample_nearest from efm3d.utils.depth import dist_im_to_point_cloud_im from efm3d.utils.gravity import gravity_align_T_world_cam, GRAVITY_DIRECTION_VIO from efm3d.utils.render import discretize_values, get_colors_from_sem_map from PIL import Image from torch.nn import functional as F # mapping from edge ids to colors for visualizing the xyz axes AXIS_COLORS_GL = { 0: (1.0, 0.0, 0.0, 1.0), # red 3: (0.0, 1.0, 0.0, 1.0), # green 8: (0.0, 0.0, 1.0, 1.0), # blue } # use RGB for xyz axes respectively def render_points(pts, rgba, prog=None, ctx=None, point_size=1.0, scene=None): if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if pts.shape[0] == 0: return if scene is not None: prog, ctx = scene.prog, scene.ctx prog["global_color"].value = rgba prog["point_size"].value = point_size vbo = ctx.buffer(pts.astype("float32").tobytes()) vao = ctx.vertex_array(prog, [(vbo, "3f", "in_vert")]) vao.render(moderngl.POINTS) vao.release() vbo.release() def render_cubes(centers, bb3_halfdiag, prog, ctx, rgb=None): cs = centers.reshape(-1, 3) offs = [ torch.tensor([-1.0, -1.0, -1.0], device=cs.device), torch.tensor([1.0, -1.0, -1.0], device=cs.device), torch.tensor([1.0, 1.0, -1.0], device=cs.device), torch.tensor([-1.0, 1.0, -1.0], device=cs.device), torch.tensor([-1.0, -1.0, 1.0], device=cs.device), torch.tensor([1.0, -1.0, 1.0], device=cs.device), torch.tensor([1.0, 1.0, 1.0], device=cs.device), torch.tensor([-1.0, 1.0, 1.0], device=cs.device), ] offs = torch.stack(offs, dim=0) corners = ( cs.unsqueeze(1) + (offs * bb3_halfdiag.unsqueeze(0)).unsqueeze(0) ).clone() tris = ( torch.tensor(OBB_MESH_TRI_INDS, dtype=torch.int32, device=cs.device) .transpose(1, 0) .unsqueeze(0) ) tris_offset = 8 * torch.arange( 0, corners.shape[0], dtype=torch.int32, device=cs.device ).view(-1, 1, 1) tris = (tris + tris_offset).clone() normals = F.normalize((offs * bb3_halfdiag.unsqueeze(0)), 2.0, -1) normals = normals.unsqueeze(0).repeat(corners.shape[0], 1, 1).clone() # render_rgb_points(corners, normals, prog, ctx) if rgb is not None: render_rgb_tri_mesh(corners, normals, tris, rgb, prog, ctx) else: render_tri_mesh(corners, normals, tris, prog, ctx) def render_tri_mesh(pts, normals, tris, prog, ctx): if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if isinstance(tris, torch.Tensor): tris = tris.detach().cpu().numpy() if isinstance(normals, torch.Tensor): normals = normals.detach().cpu().float().numpy() if pts.shape[0] == 0: return prog["point_size"].value = 1.0 vbo = ctx.buffer(pts.astype("float32").tobytes()) nbo = ctx.buffer(normals.astype("float32").tobytes()) ibo = ctx.buffer(tris.astype("int32").tobytes()) vao = ctx.vertex_array( prog, [(vbo, "3f", "in_vert"), (nbo, "3f", "in_normal")], ibo ) vao.render(moderngl.TRIANGLES) vao.release() ibo.release() nbo.release() vbo.release() def render_rgb_tri_mesh(pts, normals, tris, rgb, prog, ctx): if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if isinstance(tris, torch.Tensor): tris = tris.detach().cpu().numpy() if isinstance(normals, torch.Tensor): normals = normals.detach().cpu().float().numpy() if isinstance(rgb, torch.Tensor): rgb = rgb.detach().cpu().float().numpy() if pts.shape[0] == 0: return prog["point_size"].value = 1.0 vbo = ctx.buffer(pts.astype("float32").tobytes()) nbo = ctx.buffer(normals.astype("float32").tobytes()) cbo = ctx.buffer(rgb.astype("float32").tobytes()) ibo = ctx.buffer(tris.astype("int32").tobytes()) vao = ctx.vertex_array( prog, [(vbo, "3f", "in_vert"), (nbo, "3f", "in_normal"), (cbo, "3f", "in_rgb")], ibo, ) vao.render(moderngl.TRIANGLES) vao.release() ibo.release() cbo.release() nbo.release() vbo.release() def render_scalar_field_points( pts, values, prog, ctx, val_min=0.0, val_max=1.0, point_size=1.0, alphas=None, ): assert pts.shape[-1] == 3, f"only support 3d points {pts.shape}" assert pts.numel() == 3 * values.numel(), ( f"pts and values must have same numel {pts.numel()} {values.numel()}, {pts.shape} and {values.shape}" ) if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if isinstance(values, torch.Tensor): values = values.detach().cpu().float().numpy() if pts.shape[0] == 0: return if alphas is None: alphas = np.ones_like(values) else: if isinstance(alphas, torch.Tensor): alphas = alphas.detach().cpu().float().numpy() if isinstance(alphas, torch.Tensor): alphas = alphas.detach().cpu().float().numpy() prog["max_value"].value = val_max prog["min_value"].value = val_min prog["point_size"].value = point_size vbo = ctx.buffer(pts.astype("float32").tobytes()) vbv = ctx.buffer(values.astype("float32").tobytes()) vba = ctx.buffer(alphas.astype("float32").tobytes()) vao = ctx.vertex_array( prog, [(vbo, "3f", "in_vert"), (vbv, "1f", "in_value"), (vba, "1f", "in_alpha")] ) vao.render(moderngl.POINTS) vao.release() vba.release() vbv.release() vbo.release() def render_rgb_points( pts, rgb, prog, ctx, point_size=1.0, ): if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if isinstance(rgb, torch.Tensor): rgb = rgb.detach().cpu().float().numpy() if pts.shape[0] == 0: return prog["point_size"].value = point_size vbo = ctx.buffer(pts.astype("float32").tobytes()) cbo = ctx.buffer(rgb.astype("float32").tobytes()) vao = ctx.vertex_array(prog, [(vbo, "3f", "in_vert"), (cbo, "3f", "in_rgb")]) vao.render(moderngl.POINTS) vao.release() cbo.release() vbo.release() def render_linestrip(pts, rgba, prog=None, ctx=None, scene=None): if isinstance(pts, torch.Tensor): pts = pts.detach().cpu().float().numpy() if rgba is None: rgba = (0.0, 0.0, 0.0, 1.0) if pts.shape[0] == 0: return if scene is not None: prog, ctx = scene.prog, scene.ctx prog["global_color"].value = rgba vbo = ctx.buffer(pts.astype("float32").tobytes()) vao = ctx.vertex_array(prog, vbo, "in_vert") vao.render(moderngl.LINE_STRIP) vao.release() vbo.release() def render_line(p0, p1, rgba, prog=None, ctx=None, scene=None): if isinstance(p0, list): p0 = np.array(p0) if isinstance(p1, list): p1 = np.array(p1) if isinstance(p0, torch.Tensor): p0 = p0.detach().cpu().numpy() if isinstance(p1, torch.Tensor): p1 = p1.detach().cpu().numpy() if scene is not None: prog, ctx = scene.prog, scene.ctx pts = np.stack([p0, p1]) render_linestrip(pts, rgba=rgba, prog=prog, ctx=ctx) def render_cosy( T: Optional[PoseTW] = None, prog=None, ctx=None, scale: float = 0.1, scene=None ): if scene is not None: prog, ctx = scene.prog, scene.ctx if T is None: T = PoseTW.from_Rt(torch.eye(3), torch.zeros(3)) T = T.cpu().detach() ex = (T * torch.tensor([scale, 0.0, 0.0])).squeeze(0) ey = (T * torch.tensor([0.0, scale, 0.0])).squeeze(0) ez = (T * torch.tensor([0.0, 0.0, scale])).squeeze(0) render_line(T.t, ex, rgba=(1.0, 0.0, 0.0, 1.0), prog=prog, ctx=ctx) render_line(T.t, ey, rgba=(0.0, 1.0, 0.0, 1.0), prog=prog, ctx=ctx) render_line(T.t, ez, rgba=(0.0, 0.0, 1.0, 1.0), prog=prog, ctx=ctx) def render_frustum( T_wr: PoseTW, cam: CameraTW, prog=None, ctx=None, rgba=(0, 0, 0, 1.0), scale=0.2, scene=None, ): """ Draw the camera frustum of the given camera cam at the rig pose T_wr. """ assert T_wr.dim() == 1 assert cam.dim() == 1 cam = cam.cpu().detach() T_wr = T_wr.cpu().detach() if scene is not None: prog, ctx = scene.prog, scene.ctx def scaled_unproject(cam, pt2, scale): pt3 = cam.unproject(pt2)[0] pt3 = pt3 / torch.linalg.norm(pt3, dim=-1, keepdim=True) return pt3 * scale T_wc = T_wr @ cam.T_camera_rig.inverse() T_wc = T_wc.detach().cpu() c = cam.c rs = cam.valid_radius * 0.7071 # multiply by sqrt(0.5) to get the diagonal # valid get image corners tl = (c + torch.FloatTensor([-rs[0], -rs[1]])).view(1, 1, -1) tr = (c + torch.FloatTensor([-rs[0], rs[1]])).view(1, 1, -1) br = (c + torch.FloatTensor([rs[0], rs[1]])).view(1, 1, -1) bl = (c + torch.FloatTensor([rs[0], -rs[1]])).view(1, 1, -1) # unproject to 3d tl_w = (T_wc * scaled_unproject(cam, tl, scale)).squeeze() tr_w = (T_wc * scaled_unproject(cam, tr, scale)).squeeze() br_w = (T_wc * scaled_unproject(cam, br, scale)).squeeze() bl_w = (T_wc * scaled_unproject(cam, bl, scale)).squeeze() c_w = T_wc.t # get line_strip p3_w = torch.stack( [tl_w, tr_w, br_w, bl_w, tl_w, c_w, tr_w, c_w, br_w, c_w, bl_w, c_w, tl_w], 0 ) return render_linestrip(p3_w.numpy(), rgba=rgba, prog=prog, ctx=ctx) def render_obbs_line( obbs: ObbTW, prog=None, ctx=None, rgba=(0.0, 0.0, 0.0, 1.0), colors=None, color_alpha=1.0, line_width=3.0, draw_cosy=False, scene=None, ): """ Draw multiple oriented bounding boxes (obbs) each as a set of lines. obbs should be of shape N x C. """ assert obbs.dim() == 2, f"{obbs.shape}" if scene is not None: prog, ctx = scene.prog, scene.ctx old_line_width = ctx.line_width ctx.line_width = line_width for obb in obbs: sem_id = int(obb.sem_id) if colors is not None and sem_id < len(colors): rgb = colors[sem_id] rgba = (rgb[0], rgb[1], rgb[2], color_alpha) if obb.sem_id.item() >= 0: render_obb_line( obb, prog, ctx, rgba=rgba, draw_cosy=draw_cosy, ) ctx.line_width = old_line_width def get_color_from_id(sem_id, max_sem_id, rgba=None): if sem_id: rgba = (0.0, 0.0, 0.0, 1.0) return rgba def render_obb_line(obb: ObbTW, prog, ctx, rgba=None, draw_cosy=False): """ Draw line-based oriented bounding box (obb) for a single obb. """ assert obb.dim() == 1 p3_w = obb.bb3corners_world if not draw_cosy: # Draw with linestrip. p3_w_strip = p3_w[OBB_LINE_INDS, :] render_linestrip(p3_w_strip, rgba=rgba, prog=prog, ctx=ctx) else: # Draw lines one by one. p3_w_all = p3_w[BB3D_LINE_ORDERS, :] for i, p3 in enumerate(p3_w_all): if i in AXIS_COLORS_GL: cur_rgba = AXIS_COLORS_GL[i] else: cur_rgba = rgba render_linestrip(p3, rgba=cur_rgba, prog=prog, ctx=ctx) class SceneView: """ SceneView is a simple 3D scene renderer using OpenGL. Simply follow the pattern: # init the scene sceneView = SceneView(...) while something: # clear render buffer sceneView.clear() # set view to camera pose sceneView.set_follow_view(T_world_camera) # OR set view to model view matrix (any matrix you want to) sceneView.set_view(MV) # do call any rendering functions using scene.ctx and scene.prog ... # finish the rendering and obtain the rendered image img = sceneView.finish() # display or save image Here is a simple example to render a coordinate system at the origin: ``` scene = SceneView(width=320, height=320) scene.clear() T_wc = PoseTW() scene.set_default_view(PoseTW(), zoom_factor=6) render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) img = np.array(scene.finish()) ``` """ def __init__( self, width: int, height: int, z_near: float = 0.1, z_far: float = 1000.0, bg_color: Tuple[float, float, float] = (1.0, 1.0, 1.0), ): """ Args: width (int): width of rendered image. height (int): height of rendered image. z_near (float): near clipping plane. z_far (float): far clipping plane. bg_color (Tuple[float, float, float]): background color (0-1 range) """ self.width = width self.height = height self.z_near = z_near self.z_far = z_far self.bg_color = bg_color self.ctx = init_egl_context() if self.ctx is not None: self.prog = simple_shader_program(self.ctx) self.prog_scalar_field = scalar_field_shader_program(self.ctx) self.prog_rgb_point_cloud = rgb_point_cloud_shader_program(self.ctx) self.prog_mesh = mesh_normal_shader_program(self.ctx) self.prog_mesh_rgb = mesh_rgb_shader_program(self.ctx) # attach frame and depth buffer. Depth buffer is important to be able to # do z-buffering! self.fbo1 = self.ctx.framebuffer( self.ctx.renderbuffer((width, height), samples=4), self.ctx.depth_renderbuffer((width, height), samples=4), ) self.fbo2 = self.ctx.framebuffer( self.ctx.renderbuffer((width, height)), self.ctx.depth_renderbuffer((width, height)), ) # setup camera projection for rendering fu, fv = self.width / 0.5, self.height / 0.5 self.f = min(fu, fv) self.P = projection_matrix_rdf_top_left( self.width, self.height, self.f, self.f, (self.width - 1.0) / 2, (self.height - 1.0) / 2, self.z_near, self.z_far, ) def valid(self): return self.ctx is not None def clear(self, bg_color: Optional[Tuple[float, float, float]] = None): """ clear the scene rendering buffer. (call before any rendering!) if bg_color is specified then this is used over the one specified during construction. """ self.fbo1.use() if bg_color is not None: self.ctx.clear( red=bg_color[0], green=bg_color[1], blue=bg_color[2], depth=1e4 ) else: self.ctx.clear( red=self.bg_color[0], green=self.bg_color[1], blue=self.bg_color[2], depth=1e4, ) # enable depth test, point size, blending, and cull backfacing mesh triangles self.ctx.enable( moderngl.DEPTH_TEST | moderngl.PROGRAM_POINT_SIZE | moderngl.BLEND | moderngl.CULL_FACE ) def set_default_view(self, T_world_camera: PoseTW, zoom_factor: float = 4.0): """ set view to follow given T_world_camera behind and to the right of the T_wc. """ mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-1, -1, -2]) self.set_view(mv) def set_follow_view(self, T_world_camera: PoseTW, zoom_factor: float = 4.0): """ set view to follow given T_world_camera. """ mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-1, 0, -2]) self.set_view(mv) def set_birds_eye_view(self, T_world_camera: PoseTW, zoom_factor: float = 6.0): """ set view to a birds eye view given T_world_camera. """ mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-2, 0, -0.0001]) T_ahead = PoseTW.from_Rt(torch.eye(3), torch.tensor([0, -2, 0])) mv = T_ahead.matrix.numpy() @ mv self.set_view(mv) def set_side_view(self, T_world_camera: PoseTW, zoom_factor: float = 6.0): """ set view to the left side of T_world_camera """ mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-0, 2, -0.0001]) T_ahead = PoseTW.from_Rt(torch.eye(3), torch.tensor([-2.5, 0, 0])) mv = T_ahead.matrix.numpy() @ mv self.set_view(mv) def set_birds_eye_view_from_bb( self, bb_scene_xyzxyz: torch.Tensor, zoom_factor: float = 6.0 ): """ set view to a birds eye view given bounding volume of scene assumes gravity aligned coordinate system with z=up """ bb_min = bb_scene_xyzxyz[:3] bb_max = bb_scene_xyzxyz[3:] bb_diag = bb_max - bb_min bb_center = (bb_max + bb_min) * 0.5 up = torch.tensor([0, 0, 1]) dz = bb_diag[0] * self.f / self.width dz = max(dz, bb_diag[1] * self.f / self.height) dz += bb_diag[2] * 0.5 dir_max = bb_diag / F.normalize(bb_diag, p=2, dim=0) eye = bb_center + up * zoom_factor * dz eye = bb_center + dir_max * zoom_factor * dz eye = bb_max + up * zoom_factor * dz mv = model_view_look_at_rdf(eye.numpy(), bb_center.numpy(), -up.numpy()) self.set_view(mv) def set_view(self, mv: Union[PoseTW, np.array]): """ set view to model view matrix. """ if isinstance(mv, PoseTW): mv = mv.matrix.numpy() MVP = self.P @ mv # important to transpose MVP since opengl is column-major! self.prog["mvp"].write(MVP.transpose().astype("float32").tobytes()) self.prog_scalar_field["mvp"].write(MVP.transpose().astype("float32").tobytes()) self.prog_rgb_point_cloud["mvp"].write( MVP.transpose().astype("float32").tobytes() ) self.prog_mesh["mvp"].write(MVP.transpose().astype("float32").tobytes()) self.prog_mesh["mv"].write(mv.transpose().astype("float32").tobytes()) self.prog_mesh_rgb["mvp"].write(MVP.transpose().astype("float32").tobytes()) self.prog_mesh_rgb["mv"].write(mv.transpose().astype("float32").tobytes()) def finish(self): """ finish the scene rendering and return the rendered image as a PIL image. (call after all rendering!) """ self.ctx.copy_framebuffer(self.fbo2, self.fbo1) data = self.fbo2.read(components=3, alignment=1) img = Image.frombytes("RGB", self.fbo2.size, data) img = img.transpose(Image.FLIP_LEFT_RIGHT) return img def draw_obb_scene_3d( tgt, T_ws, Ts_sr, cams, frame_id=0, tgt_removed=None, sem_ids_to_names=None, prd=None, width=512, height=512, draw_origin=False, draw_trajectory=True, draw_frustum=True, matcher=None, prd_logits=None, p3s_world=None, p3s_pred_world=None, depth_pred=None, # optional scene object - if you want constant GPU memory allocate them once # outside and pass them in to reuse. scene: Optional[SceneView] = None, z_height_clip=None, render_raw_pred=True, render_removed_pred=True, cams_slaml=None, cams_slamr=None, zoom_factor=4.0, bird_eye_view=False, scene_mesh_v=None, scene_mesh_f=None, scene_mesh_n=None, scene_mesh_T_wv=None, ): """ Draw a 3D scene of obbs, camera trajectory and camera frustum. The scene is selected via the frame_id which indexes into the snippet variables Ts_wr, cams which are TxC. Args: tgt: target obbs whose bounding boxes are to be drawn Ts_wr: camera trajectory cams: camera calibrations frame_id: frame index to select from Ts_wr and cams sem_ids_to_names: a dict mapping sem ids to names prd: predicted obbs (if any) who are to be drawn. These are optional and meant to allow comparing two sets of bounding boxes. width: width of figure (only needed if scene is not provided) height: height of figure (only needed if scene is not provided) draw_origin: if True, draw the origin of the scene draw_trajectory: if True, draw the camera trajectory draw_frustum: if True, draw the camera frustum matcher: a function that takes tgt ObbTW, prd ObbTW and prd_logits and returns a list of matching ids to draw (HungarianMatcher) prd_logits: a list of logits matching the prd ObbTWs for the matcher bg_color: background color for the rendered scene scene: optional scene to draw into (if not provided instantiated internally) z_height_clip: z clip to limit the points of the scene to below this height (remove ceilings for better viz) cams_slaml: camera calibrations of slam left camera cams_slamr: camera calibrations of slam right camera Returns: fig: plotly figure with all the drawings """ if scene is None: scene = SceneView(width=width, height=height) cam = cams[frame_id].cpu() cam_slaml = cams_slaml[frame_id].cpu() if cams_slaml is not None else None cam_slamr = cams_slamr[frame_id].cpu() if cams_slamr is not None else None if p3s_world is not None: p3s_world = p3s_world[frame_id].cpu() if p3s_world.ndim == 3 else p3s_world if p3s_pred_world is not None: p3s_pred_world = ( p3s_pred_world[frame_id].cpu() if isinstance(p3s_pred_world, list) else p3s_pred_world ) if depth_pred is not None: depth_pred = ( depth_pred[frame_id].cpu() if isinstance(depth_pred, list) else depth_pred ) pose_id = frame_id Ts_wr = T_ws @ Ts_sr Ts_wr = Ts_wr.cpu() if cams.shape[0] != Ts_wr.shape[0]: pose_id = round(Ts_wr.shape[0] * (float(frame_id) / float(cams.shape[0]))) Ts_wc = Ts_wr @ cam.T_camera_rig.inverse() T_wr = Ts_wr[pose_id] T_wc = Ts_wc[pose_id] if tgt is not None and tgt.ndim == 3: tgt = tgt[frame_id] tgt = tgt.cpu() if tgt is not None else None colors = None if sem_ids_to_names is not None: # needs to color to be in scale [0,1] colors = get_colors_from_sem_map(sem_ids_to_names, scale_to_255=False) # setup framebuffer for rendering scene.clear() if not bird_eye_view: scene.set_follow_view(T_wc, zoom_factor=zoom_factor) else: scene.set_birds_eye_view(T_wc, zoom_factor=zoom_factor) # draw target obbs if tgt is not None and tgt.shape[0] > 0: render_obbs_line( tgt, scene.prog, scene.ctx, rgba=(1.0, 0.0, 0.0, 1.0), colors=colors, ) if render_removed_pred and tgt_removed is not None: if tgt_removed.ndim == 3: tgt_removed = tgt_removed[frame_id] render_obbs_line( tgt_removed.cpu(), scene.prog, scene.ctx, rgba=(0.75, 0.75, 0.75, 0.3), ) if render_raw_pred and prd is not None: if prd.ndim == 3: prd = prd[frame_id] # change the alpha value of the predictions when we have target obbs. if tgt is not None and tgt.shape[0] > 0: color_alpha = 0.3 else: color_alpha = 1.0 # draw predicted obbs render_obbs_line( prd.cpu(), scene.prog, scene.ctx, colors=colors, color_alpha=color_alpha, ) if draw_trajectory: # draw rig trajectory render_linestrip( Ts_wr.t, rgba=(0.0, 0.0, 0.0, 1.0), prog=scene.prog, ctx=scene.ctx ) # draw the current rig pose render_cosy(T_wr, ctx=scene.ctx, prog=scene.prog, scale=0.3) # draw the snippet origin # n the case of frames coming from different snippets, e.g. T_ws is [10, 12] # take the first T_ws as the snippet origin. if T_ws.shape[0] > 0: T_ws = T_ws[0:1] render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3) if p3s_world is not None: if z_height_clip is not None: keep = p3s_world[:, 2] < z_height_clip p3s_world = p3s_world[keep] # draw 3d points render_points( p3s_world, (0.1, 0.1, 0.1, 1.0), prog=scene.prog, ctx=scene.ctx, point_size=1.2, ) if scene_mesh_v is not None: verts_w = scene_mesh_T_wv * scene_mesh_v.to(scene_mesh_T_wv.device) normals_w = scene_mesh_T_wv.rotate(scene_mesh_n.to(scene_mesh_T_wv.device)) render_tri_mesh( verts_w, normals_w, scene_mesh_f, prog=scene.prog_mesh, ctx=scene.ctx, ) if p3s_pred_world is not None: if depth_pred is None: # draw 3d points render_points( p3s_pred_world, (0.0, 1.0, 0.0, 1.0), prog=scene.prog, ctx=scene.ctx, point_size=2.0, ) else: # draw 3d points colored by depth render_scalar_field_points( p3s_pred_world, depth_pred, prog=scene.prog_scalar_field, ctx=scene.ctx, val_min=0.0, val_max=3.0, point_size=2.0, ) if draw_frustum: # draw the current frustum render_frustum( T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=(0.0, 0.0, 0.0, 1.0) ) render_line( T_wr.t, T_wc.t, rgba=(0.0, 0.0, 1.0, 1.0), prog=scene.prog, ctx=scene.ctx ) if draw_trajectory: # Show smaller frustums along trajectory. for twr in Ts_wr: render_frustum( twr, cam, prog=scene.prog, ctx=scene.ctx, rgba=(0.0, 0.0, 0.0, 1.0), scale=0.08, ) if cam_slaml is not None: # draw the current frustum render_frustum( T_wr, cam_slaml, prog=scene.prog, ctx=scene.ctx, rgba=(0.0, 0.0, 0.0, 1.0), ) T_wcsl = T_wr @ cam_slaml.T_camera_rig.inverse() render_line( T_wr.t, T_wcsl.t, rgba=(0.0, 0.0, 1.0, 1.0), prog=scene.prog, ctx=scene.ctx, ) if cam_slamr is not None: # draw the current frustum render_frustum( T_wr, cam_slamr, prog=scene.prog, ctx=scene.ctx, rgba=(0.0, 0.0, 0.0, 1.0), ) T_wcsr = T_wr @ cam_slamr.T_camera_rig.inverse() render_line( T_wr.t, T_wcsr.t, rgba=(0.0, 0.0, 1.0, 1.0), prog=scene.prog, ctx=scene.ctx, ) if draw_origin: # draw the origin cosy render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0) if matcher is not None and prd is not None and prd_logits is not None: # draw matches under matcher tgt_sem_id = [tgt.sem_id.squeeze(-1)] indices = matcher( prd_logits.unsqueeze(0), prd.bb3_center_world.unsqueeze(0), tgt_sem_id, [tgt.bb3_center_world], ) for p, t in zip(indices[0][0], indices[0][1]): pt0 = prd.bb3_center_world[p] pt1 = tgt.bb3_center_world[t] render_line(pt0, pt1, (0.0, 0.0, 0.0, 1.0), scene.ctx, scene.prog) # finish and obtain image img = scene.finish() return img def draw_snippet_scene_3d( snippet, sem_ids_to_names=None, width=512, height=512, draw_origin=False, frame_id: Optional[int] = None, batch_id: int = 0, # optional scene object - if you want constant GPU memory allocate them once # outside and pass them in to reuse. scene: Optional[SceneView] = None, clean_viz: bool = False, viz_gt_points: bool = True, ): """ Draw a 3D scene of obbs and camera trajectory. Args: snippet: a snippet dict containing all relevant information for drawing sem_ids_to_names: a dict mapping sem ids to names width: width of figure (only needed if scene is not provided) height: height of figure (only needed if scene is not provided) draw_origin: if True, draw the origin of the scene draw_center: if True, draw the center of the scene return_plotly: if True, return the plotly figures, otherwise return the rendered images. frame_id: if set, only return the image/plotly plot for this frame. batch_id: if we are passing batched inputs, select the batch with this id for rendering. scene: optional scene to draw into (if not provided instantiated internally) viz_gt_points: if there is ground truth depth in the batch, visualize the GT depth instead of the semi-dense points. Returns: fig: plotly figure with all the drawings """ if scene is None: scene = SceneView(width=width, height=height) has_slaml = ARIA_CALIB[1] in snippet has_slamr = ARIA_CALIB[2] in snippet cams = snippet[ARIA_CALIB[0]].cpu() if has_slaml: cams_slaml = snippet[ARIA_CALIB[1]].cpu() else: cams_slaml = None if has_slamr: cams_slamr = snippet[ARIA_CALIB[2]].cpu() else: cams_slamr = None T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET].cpu() if ARIA_IMG_T_SNIPPET_RIG[0] in snippet: Ts_sr = snippet[ARIA_IMG_T_SNIPPET_RIG[0]].cpu() elif ARIA_POSE_T_SNIPPET_RIG in snippet: Ts_sr = snippet[ARIA_POSE_T_SNIPPET_RIG].cpu() if Ts_sr.shape[0] != cams.shape[0]: cam_times_ns = snippet[ARIA_CALIB_TIME_NS[0]].tolist() pose_times_ns = snippet[ARIA_POSE_TIME_NS].tolist() Ts_sr = sample_nearest(cam_times_ns, pose_times_ns, Ts_sr) if T_ws.ndim == 3: T_ws = T_ws[batch_id] Ts_sr = Ts_sr[batch_id] cams = cams[batch_id] if has_slaml: cams_slaml = cams_slaml[batch_id] if has_slamr: cams_slamr = cams_slamr[batch_id] obbs, prd, uninst = None, None, None if ARIA_OBB_PADDED in snippet: obbs = snippet[ARIA_OBB_PADDED].cpu() obbs = obbs[batch_id] if obbs.ndim == 4 else obbs have_tracked = ARIA_OBB_TRACKED in snippet if have_tracked: obbs = snippet[ARIA_OBB_TRACKED].cpu() obbs = obbs[batch_id] if obbs.ndim == 4 else obbs if ARIA_OBB_PRED_VIZ in snippet: prd = snippet[ARIA_OBB_PRED_VIZ].cpu() prd = prd[batch_id] if prd.ndim == 4 else prd if ARIA_OBB_UNINST in snippet: uninst = snippet[ARIA_OBB_UNINST].cpu() uninst = uninst[batch_id] if uninst.ndim == 4 else uninst p3s_world = None # If GT depth exists, visualize GT depth pointcloud instead of semi-dense points. if viz_gt_points and ARIA_DISTANCE_M[0] in snippet: # Note: we only visualize GT depth map of RGB images now. valid_depths = snippet[ARIA_DISTANCE_M[0]].squeeze(1) > 1e-4 p3cs, valids = dist_im_to_point_cloud_im( snippet[ARIA_DISTANCE_M[0]].squeeze(1), snippet[ARIA_CALIB[0]], ) valids = torch.logical_and(valids, valid_depths) p3cs = p3cs.reshape(p3cs.shape[0], -1, 3) T_s_c = ( snippet[ARIA_IMG_T_SNIPPET_RIG[0]] @ snippet[ARIA_CALIB[0]].T_camera_rig.inverse() ) T_w_c = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET] @ T_s_c p3ws = T_w_c * p3cs p3ws = p3ws.reshape(-1, 3) valids = valids.reshape(-1) p3ws = p3ws[valids] p3s_world = discretize_values(p3ws, precision=70) if ARIA_POINTS_WORLD in snippet: p3s_world = snippet[ARIA_POINTS_WORLD] p3s_world = p3s_world[batch_id] if p3s_world.ndim == 4 else p3s_world p3s_pred_world, depth_pred = None, None if ARIA_DISTANCE_M_PRED[0] in snippet: dist_m = snippet[ARIA_DISTANCE_M_PRED[0]].cpu() dist_m = dist_m[batch_id] if dist_m.ndim == 4 else dist_m # scale camera to fit the depth image (in case depth image is at a lower res) cams_depth = cams.scale_to(dist_m) Ts_wc = T_ws @ Ts_sr @ cams.T_camera_rig.inverse() p3s_pred_world, depth_pred = [], [] for t in range(dist_m.shape[0]): p3s_c, valids = dist_im_to_point_cloud_im(dist_m[t], cams_depth[t]) p3s_pred_world.append(Ts_wc[t] * p3s_c[valids]) depth_pred.append(dist_m[t][valids]) Ts_wr = T_ws @ Ts_sr obbs = obbs.remove_padding() if obbs is not None else None prd = prd.remove_padding() if prd is not None else None uninst = uninst.remove_padding() if uninst is not None else None # clip the point cloud 1m above the rig coordinates z_height_clip = Ts_wr.t[..., 2].max() + 1.0 assert Ts_wr.shape[0] == cams.shape[0], ( f"poses and cameras must have the same length but got {Ts_wr.shape[0]} and {cams.shape[0]}" ) if obbs is not None: assert Ts_wr.shape[0] == len(obbs), ( f"poses and obbs must have the same length {len(obbs)} but got {Ts_wr.shape}" ) if frame_id: assert frame_id >= 0 and frame_id < Ts_wr.shape[0] frame_ids = [frame_id] if frame_id else range(Ts_wr.shape[0]) scene_mesh_v = None scene_mesh_f = None scene_mesh_n = None scene_mesh_T_wv = None if ARIA_MESH_VERTS_W in snippet: scene_mesh_v = snippet[ARIA_MESH_VERTS_W].squeeze().cpu().detach().float() scene_mesh_f = snippet[ARIA_MESH_FACES].squeeze().cpu().detach().float() # flip normals to visualize better. scene_mesh_n = -snippet[ARIA_MESH_VERT_NORMS_W].squeeze().cpu().detach().float() scene_mesh_T_wv = PoseTW() imgs = [] for t in frame_ids: # transform obbs into world coordinates too tgt_w = obbs[t].transform(T_ws) if obbs is not None else None prd_w = prd[t].transform(T_ws) if prd is not None else None uninst_w = uninst[t].transform(T_ws) if uninst is not None else None img = draw_obb_scene_3d( tgt=tgt_w, prd=prd_w, tgt_removed=uninst_w, T_ws=T_ws, Ts_sr=Ts_sr, cams=cams, cams_slaml=cams_slaml, cams_slamr=cams_slamr, frame_id=t, p3s_world=p3s_world, p3s_pred_world=p3s_pred_world, depth_pred=depth_pred, sem_ids_to_names=sem_ids_to_names, width=width, height=height, draw_origin=draw_origin, scene=scene, z_height_clip=z_height_clip, render_raw_pred=(not clean_viz) or (not have_tracked and clean_viz), render_removed_pred=not clean_viz, scene_mesh_v=scene_mesh_v, scene_mesh_f=scene_mesh_f, scene_mesh_n=scene_mesh_n, scene_mesh_T_wv=scene_mesh_T_wv, ) imgs.append(np.array(img)) return imgs def normalize(x): return x / (np.linalg.norm(x, 2) + 1e-6) # https://github.com/stevenlovegrove/Pangolin/blob/7776a567f5c7b074668b8abb2316aba3f4b8b568/components/pango_opengl/src/opengl_render_state.cpp#L621 # e=eye is the eye location in world coordinates (camera position) # l=look_at is the look at direction (projects to image center) # u=up is the up direction def model_view_look_at_rdf(e, look_at, u): z = normalize(look_at - e) if np.allclose(u - z, np.zeros(3), atol=1e-5): # Add some tiny offset so that cross product is non-zero. z[1] = z[1] + 0.001 x = normalize(np.cross(z, u)) y = normalize(np.cross(z, x)) M = np.zeros((4, 4)) M[0, 0] = x[0] M[0, 1] = x[1] M[0, 2] = x[2] M[1, 0] = y[0] M[1, 1] = y[1] M[1, 2] = y[2] M[2, 0] = z[0] M[2, 1] = z[1] M[2, 2] = z[2] M[3, 0] = 0.0 M[3, 1] = 0.0 M[3, 2] = 0.0 M[0, 3] = -(M[0, 0] * e[0] + M[0, 1] * e[1] + M[0, 2] * e[2]) M[1, 3] = -(M[1, 0] * e[0] + M[1, 1] * e[1] + M[1, 2] * e[2]) M[2, 3] = -(M[2, 0] * e[0] + M[2, 1] * e[1] + M[2, 2] * e[2]) M[3, 3] = 1.0 return M def get_mv(T_world_cam: PoseTW, zoom_factor: float = 3.0, position=[-1, 0, -2]): """ T_world_cam is the camera pose in world coordinates that the ModelView Matrix will "follow". zoom_factor is the zoom factor for the ModelView Matrix. I.e. from how far above and behind the camera pose we will render the scene. 1.0 is very close, 3.0 is medium (good default) and 6.0 is farther away. """ # gravity align the camera pose to make rendering videos smoother. T_world_cam = gravity_align_T_world_cam( T_world_cam.clone().unsqueeze(0), gravity_w=GRAVITY_DIRECTION_VIO ).squeeze(0) T_world_cam = T_world_cam.detach().cpu() # center is where "look at" position; will project to center of rendering center = T_world_cam.t # eye is the position of the camera center (translation) eye = T_world_cam * (torch.FloatTensor(position) * zoom_factor) # eye = T_world_cam * (torch.FloatTensor([-1,0,-1]) * zoom_factor) eye = eye.squeeze(0) # up is the up direction for the rendering camera. We choose it to be the # negative x-axis of the camera pose. Which works for our 90-deg rotated # cameras on Aria. up = T_world_cam.R[:, 0] # model view matrix mv = model_view_look_at_rdf(eye.numpy(), center.numpy(), up.numpy()) return mv # https://github.com/stevenlovegrove/Pangolin/blob/7776a567f5c7b074668b8abb2316aba3f4b8b568/components/pango_opengl/src/opengl_render_state.cpp#L462 # Camera Axis: # X - Right, Y - Down, Z - Forward # Image Origin: # Top Left # Pricipal point specified with image origin (0,0) at top left of top-left pixel (not center) def projection_matrix_rdf_top_left(w, h, fu, fv, u0, v0, zNear, zFar): # http://www.songho.ca/opengl/gl_projectionmatrix.html L = -(u0) * zNear / fu R = +(w - u0) * zNear / fu T = -(v0) * zNear / fv B = +(h - v0) * zNear / fv P = np.zeros((4, 4)) P[0, 0] = 2 * zNear / (R - L) P[1, 1] = 2 * zNear / (T - B) P[0, 2] = (R + L) / (L - R) P[1, 2] = (T + B) / (B - T) P[2, 2] = (zFar + zNear) / (zFar - zNear) P[3, 2] = 1.0 P[2, 3] = (2 * zFar * zNear) / (zNear - zFar) return P def init_egl_context(): try: if platform.system() == "Darwin": ctx = moderngl.create_context(standalone=True) else: ctx = moderngl.create_context(standalone=True, backend="egl") except Exception as e: print(f"{e}") return None return ctx def simple_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; uniform float point_size; in vec3 in_vert; void main() { gl_Position = mvp * vec4(in_vert, 1.0); gl_PointSize = point_size; }""" fragment_shader_source = """#version 330 uniform vec4 global_color; out vec4 f_color; void main() { f_color = global_color; } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog def mesh_normal_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; uniform mat4 mv; uniform float point_size; in vec3 in_vert; in vec3 in_normal; out vec3 n_c; void main() { gl_Position = mvp * vec4(in_vert, 1.0); gl_PointSize = point_size; n_c = transpose(inverse(mat3(mv))) * in_normal; }""" fragment_shader_source = """#version 330 in vec3 n_c; out vec4 f_color; void main() { f_color = vec4((normalize(n_c) + vec3(1.0, 1.0, 1.0)) / 2.0, 1.0f); } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog def mesh_rgb_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; uniform mat4 mv; uniform float point_size; in vec3 in_vert; in vec3 in_normal; in vec3 in_rgb; out vec3 n_c; out vec3 rgb; void main() { gl_Position = mvp * vec4(in_vert, 1.0); gl_PointSize = point_size; n_c = transpose(inverse(mat3(mv))) * in_normal; rgb = in_rgb; }""" fragment_shader_source = """#version 330 in vec3 n_c; in vec3 rgb; out vec4 f_color; void main() { f_color = vec4(rgb * max(n_c.z, 0.0), 1.0); } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog def rgb_point_cloud_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; uniform mat4 mv; uniform float point_size; in vec3 in_vert; in vec3 in_rgb; out vec3 rgb; void main() { gl_Position = mvp * vec4(in_vert, 1.0); gl_PointSize = point_size; rgb = in_rgb; }""" fragment_shader_source = """#version 330 in vec3 rgb; out vec4 f_color; void main() { f_color = vec4(rgb, 1.0f); } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog def scalar_field_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; uniform float point_size; uniform float max_value; uniform float min_value; in vec3 in_vert; in float in_value; in float in_alpha; out vec3 frag_rgb; out float frag_a; // https://thebookofshaders.com/06/ vec3 hsb2rgb( in vec3 c ){ vec3 rgb = clamp(abs(mod(c.x*6.0+vec3(0.0,4.0,2.0),6.0)-3.0)-1.0, 0.0, 1.0); rgb = rgb*rgb*(3.0-2.0*rgb); return c.z * mix( vec3(1.0), rgb, c.y); } vec3 hsv(float v) { return hsb2rgb(vec3(v, 1.0, 1.0)); } // https://github.com/kbinani/colormap-shaders/tree/master // The MIT License (MIT) // Copyright (c) 2015 kbinani // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. float jet_red(float x) { if (x < 0.7) { return 4.0 * x - 1.5; } else { return -4.0 * x + 4.5; } } float jet_green(float x) { if (x < 0.5) { return 4.0 * x - 0.5; } else { return -4.0 * x + 3.5; } } float jet_blue(float x) { if (x < 0.3) { return 4.0 * x + 0.5; } else { return -4.0 * x + 2.5; } } vec3 jet(float x) { float r = clamp(jet_red(x), 0.0, 1.0); float g = clamp(jet_green(x), 0.0, 1.0); float b = clamp(jet_blue(x), 0.0, 1.0); return vec3(r, g, b); } void main() { float f_value = (in_value - min_value) / (max_value - min_value); f_value = clamp(f_value, 0.0, 1.0); frag_rgb = jet(f_value); frag_a = in_alpha; gl_Position = mvp * vec4(in_vert, 1.0); gl_PointSize = point_size; } """ fragment_shader_source = """#version 330 in vec3 frag_rgb; in float frag_a; out vec4 f_color; void main() { f_color = vec4(frag_rgb, frag_a); } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog def semantic_color_shader_program(ctx): vertex_shader_source = """#version 330 uniform mat4 mvp; in int in_sem_id; in vec3 in_vert; out int sem_id; out vec3 v_vert; void main() { v_vert = in_vert; gl_Position = mvp * vec4(v_vert, 1.0); sem_id = in_sem_id; }""" fragment_shader_source = """#version 330 uniform int sem_max; in int sem_id; in vec3 v_vert; out vec4 f_color; // https://thebookofshaders.com/06/ vec3 hsb2rgb( in vec3 c ){ vec3 rgb = clamp(abs(mod(c.x*6.0+vec3(0.0,4.0,2.0),6.0)-3.0)-1.0, 0.0, 1.0); rgb = rgb*rgb*(3.0-2.0*rgb); return c.z * mix( vec3(1.0), rgb, c.y); } void main() { sem_hue = sem_id / sem_max; f_color = vec4(hsb2rgb(vec3(sem_hue, 1.0, 1.0)), 1.0); } """ prog = ctx.program( vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source ) return prog ================================================ FILE: efm3d/utils/voxel.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def tensor_wrap_voxel_extent(voxel_extent, B=None, device="cpu"): if isinstance(voxel_extent, torch.Tensor): if B is not None: assert voxel_extent.shape[0] == B return voxel_extent elif isinstance(voxel_extent, list): if B is None: return torch.tensor(voxel_extent, device=device) else: return torch.tensor(voxel_extent, device=device).view(1, 6).repeat(B, 1) else: raise NotImplementedError(f"type {voxel_extent} not supported") def create_voxel_grid(vW, vH, vD, voxel_extent, device="cpu"): """ Given a bounding box range [x_min, x_max, y_min, y_max, z_min, z_max], and the number of voxels in each dimension [vW, vH, vD], return a voxel center positions. Note that the min and max coordinates are not [x_min, y_min, z_min] and [x_max, y_max, z_max], since they are the bounding range but not the center positions. vW: the number of voxels for x-dim vH: the number of voxels for y-dim vD: the number of voxels for z-dim voxel_extent: the bounding box range in [x_min, x_max, y_min, y_max, z_min, z_max] """ x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD # take the center position of each voxel rng_x = torch.linspace(x_min + dW / 2, x_max - dW / 2, steps=vW, device=device) rng_y = torch.linspace(y_min + dH / 2, y_max - dH / 2, steps=vH, device=device) rng_z = torch.linspace(z_min + dD / 2, z_max - dD / 2, steps=vD, device=device) xx, yy, zz = torch.meshgrid(rng_x, rng_y, rng_z, indexing="ij") vox_v = torch.stack([xx, yy, zz], axis=-1) return vox_v def erode_voxel_mask(mask): """ Erode a given mask by one voxel i.e. 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1 1 1 0 -> 0 0 1 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 """ # B T D H W assert mask.ndim in [4, 5], f"mask dim needs to be 3 or 4 got {mask.shape}" kernel = torch.ones((1, 1, 3, 3, 3), device=mask.device) mask = ( 1.0 - torch.clamp( torch.nn.functional.conv3d(1.0 - mask.float(), kernel, padding="same"), 0, 1 ) ).bool() return mask ================================================ FILE: efm3d/utils/voxel_sampling.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch def pc_to_vox(pc_v, vW, vH, vD, voxel_extent): device = pc_v.device if isinstance(voxel_extent, list): x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent valid = pc_v[..., 0] > x_min valid = torch.logical_and(pc_v[..., 0] < x_max, valid) valid = torch.logical_and(pc_v[..., 1] > y_min, valid) valid = torch.logical_and(pc_v[..., 1] < y_max, valid) valid = torch.logical_and(pc_v[..., 2] > z_min, valid) valid = torch.logical_and(pc_v[..., 2] < z_max, valid) dW = (x_max - x_min) / vW dH = (y_max - y_min) / vH dD = (z_max - z_min) / vD dVox = torch.tensor([dW, dH, dD]).view(1, 3).to(device) vox_min = torch.tensor([x_min, y_min, z_min]).view(1, 3).to(device) pc_id = (pc_v - vox_min) / dVox else: s = pc_v.shape[:-1] B = s[0] vox_min = voxel_extent[..., 0::2].view(B, 1, 3) vox_max = voxel_extent[..., 1::2].view(B, 1, 3) dim = ( torch.tensor([vW, vH, vD], device=voxel_extent.device) .view(1, 1, 3) .repeat(B, 1, 1) ) dVox = (vox_max - vox_min) / dim pc_v = pc_v.view(B, -1, 3) valid = torch.logical_not(pc_v.isnan().any(-1)) valid = torch.logical_and(valid, (pc_v > vox_min).all(-1)) valid = torch.logical_and(valid, (pc_v < vox_max).all(-1)) pc_id = (pc_v - vox_min) / dVox valid = valid.view(s) pc_id = pc_id.view(list(s) + [3]) return pc_id, valid def compute_factor(size): return 1.0 * size / 2 def convert_coordinates_to_voxel(coordinates, factor): return factor * (coordinates + 1.0) def convert_voxel_to_coordinates(coordinates, factor): return (coordinates / factor) - 1.0 def normalize_keypoints(kpts, depth, height, width): # compute conversion factor x_factor = compute_factor(width) y_factor = compute_factor(height) z_factor = compute_factor(depth) factors = torch.tensor([x_factor, y_factor, z_factor], device=kpts.device).view( [1] * (kpts.ndim - 1) + [3] ) pts_dst = convert_voxel_to_coordinates(kpts, factors) return pts_dst def denormalize_keypoints(kpts, depth, height, width): # compute conversion factor x_factor = compute_factor(width) y_factor = compute_factor(height) z_factor = compute_factor(depth) if isinstance(kpts, torch.Tensor): pts_dst = kpts.clone() elif isinstance(kpts, np.ndarray): pts_dst = kpts.copy() else: raise TypeError("must be torch or numpy") factors = torch.tensor([x_factor, y_factor, z_factor], device=kpts.device).view( [1] * (kpts.ndim - 1) + [3] ) pts_dst = convert_coordinates_to_voxel(kpts, factors) return pts_dst def in_grid(pt_vox, depth, height, width): valid = pt_vox[..., 0] >= 0.5 valid = torch.logical_and(pt_vox[..., 0] <= width - 0.5, valid) valid = torch.logical_and(pt_vox[..., 1] >= 0.5, valid) valid = torch.logical_and(pt_vox[..., 1] <= height - 0.5, valid) valid = torch.logical_and(pt_vox[..., 2] >= 0.5, valid) valid = torch.logical_and(pt_vox[..., 2] <= depth - 0.5, valid) return valid def sample_voxels(feat3d, pts_v, differentiable=False, interp_mode="bilinear"): """ Sample voxel grid of features at pts_v locations. Args: feat3d: feature volume batches B C D H W pts_v: 3d points in -1 to 1 range in shape compatible with B N 3 differentiable: we need this to be differentiable wrt to the pts_v Returns: voxel grid samples in shape B C N """ assert feat3d.ndim == 5, f"{feat3d.shape}" assert pts_v.ndim == 3, f"{pts_v.shape}" B, C, D, H, W = feat3d.shape valid = in_grid(pts_v, height=H, width=W, depth=D) # Sample into the 3D feature maps. norm_samp_pts = normalize_keypoints(pts_v.clone(), height=H, width=W, depth=D) if differentiable: # use differentiable implementation of 3d trilinear interpolation. samp_feats = diff_grid_sample( feat3d, norm_samp_pts.view(B, 1, 1, -1, 3), align_corners=False, # B 1 1 N 3 ) else: # if we dont need differentiability wrt to sample points then we can use # the default implementation. samp_feats = torch.nn.functional.grid_sample( feat3d, norm_samp_pts.view(B, 1, 1, -1, 3), # B 1 1 N 3 align_corners=False, padding_mode="border", mode=interp_mode, # important to be differentiable ) # squeeze back down the dimension of 1 we unsqueezed for norm_samp_pts to comply with interface samp_feats = samp_feats.view(B, C, -1) return samp_feats, valid def diff_grid_sample(feature_3d, pts_norm, align_corners=False): N, C, iD, iH, iW = feature_3d.shape _, D, H, W, _ = pts_norm.shape assert not pts_norm.isnan().any(), "have nan values in pts_norm! not supported" ix = pts_norm[..., 0] iy = pts_norm[..., 1] iz = pts_norm[..., 2] if align_corners: ix = ((ix + 1.0) * 0.5) * (iW - 1) iy = ((iy + 1.0) * 0.5) * (iH - 1) iz = ((iz + 1.0) * 0.5) * (iD - 1) else: ix = ((ix + 1.0) * 0.5) * iW - 0.5 iy = ((iy + 1.0) * 0.5) * iH - 0.5 iz = ((iz + 1.0) * 0.5) * iD - 0.5 with torch.no_grad(): ix_tnw = torch.floor(ix) iy_tnw = torch.floor(iy) iz_tnw = torch.floor(iz) ix_tne = ix_tnw + 1 iy_tne = iy_tnw iz_tne = iz_tnw ix_tsw = ix_tnw iy_tsw = iy_tnw + 1 iz_tsw = iz_tnw ix_tse = ix_tnw + 1 iy_tse = iy_tnw + 1 iz_tse = iz_tnw ix_bnw = ix_tnw iy_bnw = iy_tnw iz_bnw = iz_tnw + 1 ix_bne = ix_tnw + 1 iy_bne = iy_tnw iz_bne = iz_tnw + 1 ix_bsw = ix_tnw iy_bsw = iy_tnw + 1 iz_bsw = iz_tnw + 1 ix_bse = ix_tnw + 1 iy_bse = iy_tnw + 1 iz_bse = iz_tnw + 1 bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) with torch.no_grad(): torch.clamp(ix_bnw, 0, iW - 1, out=ix_bnw) torch.clamp(iy_bnw, 0, iH - 1, out=iy_bnw) torch.clamp(iz_bnw, 0, iD - 1, out=iz_bnw) torch.clamp(ix_bne, 0, iW - 1, out=ix_bne) torch.clamp(iy_bne, 0, iH - 1, out=iy_bne) torch.clamp(iz_bne, 0, iD - 1, out=iz_bne) torch.clamp(ix_bsw, 0, iW - 1, out=ix_bsw) torch.clamp(iy_bsw, 0, iH - 1, out=iy_bsw) torch.clamp(iz_bsw, 0, iD - 1, out=iz_bsw) torch.clamp(ix_bse, 0, iW - 1, out=ix_bse) torch.clamp(iy_bse, 0, iH - 1, out=iy_bse) torch.clamp(iz_bse, 0, iD - 1, out=iz_bse) torch.clamp(ix_tnw, 0, iW - 1, out=ix_tnw) torch.clamp(iy_tnw, 0, iH - 1, out=iy_tnw) torch.clamp(iz_tnw, 0, iD - 1, out=iz_tnw) torch.clamp(ix_tne, 0, iW - 1, out=ix_tne) torch.clamp(iy_tne, 0, iH - 1, out=iy_tne) torch.clamp(iz_tne, 0, iD - 1, out=iz_tne) torch.clamp(ix_tsw, 0, iW - 1, out=ix_tsw) torch.clamp(iy_tsw, 0, iH - 1, out=iy_tsw) torch.clamp(iz_tsw, 0, iD - 1, out=iz_tsw) torch.clamp(ix_tse, 0, iW - 1, out=ix_tse) torch.clamp(iy_tse, 0, iH - 1, out=iy_tse) torch.clamp(iz_tse, 0, iD - 1, out=iz_tse) feature_3d = feature_3d.reshape(N, C, -1) # D H W, z y x bnw_val = torch.gather( feature_3d, 2, (iz_bnw * iH * iW + iy_bnw * iW + ix_bnw) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) bne_val = torch.gather( feature_3d, 2, (iz_bne * iH * iW + iy_bne * iW + ix_bne) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) bsw_val = torch.gather( feature_3d, 2, (iz_bsw * iH * iW + iy_bsw * iW + ix_bsw) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) bse_val = torch.gather( feature_3d, 2, (iz_bse * iH * iW + iy_bse * iW + ix_bse) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) tnw_val = torch.gather( feature_3d, 2, (iz_tnw * iH * iW + iy_tnw * iW + ix_tnw) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) tne_val = torch.gather( feature_3d, 2, (iz_tne * iH * iW + iy_tne * iW + ix_tne) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) tsw_val = torch.gather( feature_3d, 2, (iz_tsw * iH * iW + iy_tsw * iW + ix_tsw) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) tse_val = torch.gather( feature_3d, 2, (iz_tse * iH * iW + iy_tse * iW + ix_tse) .long() .view(N, 1, D * H * W) .repeat(1, C, 1), ) out_val = ( bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) ) return out_val ================================================ FILE: environment-mac.yml ================================================ name: efm3d channels: - defaults - conda-forge - pytorch dependencies: - python=3.9 - pytorch=2.3.0 - torchvision=0.18.0 - pip - pip: - omegaconf==2.3.0 - hydra-core==1.3.2 - webdataset==0.2.86 - vrs==1.2.1 - fsspec==2024.6.0 - einops==0.8.0 - pandas==2.2.2 - pyquaternion==0.9.9 - opencv-python==4.10.0.84 - tqdm==4.66.4 - matplotlib==3.9.0 - numpy==1.26.4 - moderngl==5.8.2 - trimesh==4.4.9 - scikit-image==0.24.0 - projectaria_tools ================================================ FILE: environment.yml ================================================ name: efm3d channels: - nvidia/label/cuda-12.1.1 - pytorch - nvidia - conda-forge - defaults dependencies: - ninja - python=3.9 - pip - cuda - anaconda::cudnn - gcc=12.1 - gxx=12.1 - numpy=1.26.4 - pytorch-cuda=12.1 - pytorch=2.3.0 - torchvision=0.18.0 - torchaudio=2.3.0 - pip: - omegaconf==2.3.0 - hydra-core==1.3.2 - webdataset==0.2.86 - vrs==1.2.1 - fsspec==2024.6.0 - einops==0.8.0 - pandas==2.2.2 - pyquaternion==0.9.9 - opencv-python==4.10.0.84 - tqdm==4.66.4 - matplotlib==3.9.0 - moderngl==5.8.2 - trimesh==4.4.9 - scikit-image==0.24.0 - projectaria_tools==1.5.5 - projectaria-atek==1.0.0 - tensorboard==2.14.0 - torchmetrics==0.10.1 - git+https://github.com/facebookresearch/pytorch3d.git@V0.7.8 ================================================ FILE: eval.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import os from efm3d.inference.eval import obb_eval_dataset from efm3d.inference.pipeline import compute_avg_metrics, run_one ASE_DATA_PATH = "./data/ase_eval" ADT_DATA_PATH = "./data/adt" AEO_DATA_PATH = "./data/aeo" if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run EFM3D evaluation benchmark") parser.add_argument( "--num_seqs", type=int, default=9999, help="number of sequences to evaluate, by default evaluate all sequences", ) parser.add_argument( "--num_snips", type=int, default=9999, help="number of snippets per sequence, by default evaluate the full sequence", ) parser.add_argument( "--snip_stride", type=float, default=0.1, help="overlap between snippets in second, default to 0.1 (recommend to set it between 0.1-0.5), set it larger will make performance worse but run faster", ) parser.add_argument( "--voxel_res", type=float, default=0.04, help="voxel resolution in meter for volumetric fusion", ) parser.add_argument( "--model_ckpt", type=str, default="./ckpt/model_release.pth", help="model checkpoint path", ) parser.add_argument( "--model_cfg", type=str, default="./efm3d/config/evl_inf.yaml", help="model config file", ) parser.add_argument("--output_dir", type=str, default="./output", help="output dir") parser.add_argument( "--ase", action="store_true", help="Evaluate the model on ASE dataset" ) parser.add_argument( "--adt", action="store_true", help="Evaluate the model on ADT dataset" ) parser.add_argument( "--aeo", action="store_true", help="Evaluate the model on AEO dataset" ) args = parser.parse_args() input_paths = [] if args.ase: with open("./data/ase_splits.json", "r") as f: seq_ids = json.load(f)["test_sequences"] seq_ids = [seq.strip() for seq in seq_ids] input_paths = [ os.path.join(ASE_DATA_PATH, seq.strip()) for seq in seq_ids[: args.num_seqs] ] elif args.adt: with open("./data/adt_sequences.txt", "r") as f: seq_ids = f.readlines() seq_ids = [seq.strip() for seq in seq_ids] input_paths = [ os.path.join(ADT_DATA_PATH, seq.strip(), "video.vrs") for seq in seq_ids[: args.num_seqs] ] elif args.aeo: with open("./data/aeo_sequences.txt", "r") as f: seq_ids = f.readlines() seq_ids = [seq.strip() for seq in seq_ids] input_paths = [ os.path.join(AEO_DATA_PATH, seq.strip(), "main.vrs") for seq in seq_ids[: args.num_seqs] ] else: assert args.ase or args.adt or args.aeo, ( "Specify eval dataset, for example, --ase" ) for input_path in input_paths: run_one( input_path, args.model_ckpt, model_cfg=args.model_cfg, max_snip=args.num_snips, snip_stride=args.snip_stride, voxel_res=args.voxel_res, output_dir=args.output_dir, ) # aggregate results if len(seq_ids) > 1: dirs = [] model_name = os.path.splitext(os.path.basename(args.model_ckpt))[0] output_dir = os.path.join(args.output_dir, model_name) for seq_id in seq_ids: seq_output_dir = os.path.join(output_dir, seq_id) dirs.append(seq_output_dir) metrics_paths = [os.path.join(folder, "metrics.json") for folder in dirs] metrics_paths = [p for p in metrics_paths if os.path.exists(p)] if len(metrics_paths) > 0: avg_ret = compute_avg_metrics(metrics_paths) print("==> mean results") print(json.dumps(avg_ret, indent=2, sort_keys=True)) with open(os.path.join(output_dir, "mean_metrics.json"), "w") as f: json.dump(avg_ret, f, indent=2, sort_keys=True) # aggregate mAP for 3D object detection if args.ase or args.aeo: joint_map = obb_eval_dataset(output_dir) print("==> joint mAP") print(json.dumps(joint_map, indent=2, sort_keys=True)) with open( os.path.join(args.output_dir, "joint_metrics.json"), "w" ) as f: json.dump(joint_map, f, indent=2, sort_keys=True) ================================================ FILE: infer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from efm3d.inference.pipeline import run_one if __name__ == "__main__": parser = argparse.ArgumentParser( description="Run EVL model inference on Aria sequences" ) parser.add_argument("--input", type=str, required=True, help="input data") parser.add_argument( "--model_ckpt", type=str, default="./ckpt/model_release.pth", help="model checkpoint path", ) parser.add_argument( "--model_cfg", type=str, default="./efm3d/config/evl_inf.yaml", help="model config file", ) parser.add_argument("--output_dir", type=str, default="./output", help="output dir") parser.add_argument( "--num_seqs", type=int, default=9999, help="number of sequences to evaluate, by default evaluate all sequences", ) parser.add_argument( "--num_snips", type=int, default=9999, help="number of snippets per sequence, by default evaluate the full sequence", ) parser.add_argument( "--snip_stride", type=float, default=0.1, help="overlap between snippets in second, default to 0.1 (recommend to set it between 0.1-0.5), set it larger will make performance worse but run faster", ) parser.add_argument( "--voxel_res", type=float, default=0.04, help="voxel resolution in meter for volumetric fusion", ) parser.add_argument( "--obb_only", action="store_true", help="only run OBB prediction, skip occupancy prediction and volume fusion for faster inference on long sequences", ) parser.add_argument( "--skip_video", action="store_true", help="skip video generation", ) parser.add_argument( "--skip_snips", type=int, default=0, help="skip the first N snippets", ) args = parser.parse_args() run_one( args.input, args.model_ckpt, model_cfg=args.model_cfg, max_snip=args.num_snips, snip_stride=args.snip_stride, voxel_res=args.voxel_res, output_dir=args.output_dir, obb_only=args.obb_only, skip_video=args.skip_video, skip_snips=args.skip_snips, ) ================================================ FILE: prepare_inference.sh ================================================ #!/usr/bin/env bash # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -e if ! ls infer.py | grep -q "infer.py"; then echo "Error: Can't find infer.py under the current directory. Make sure to run this script under " exit 1 fi # download DinoV2 weights wget -O ckpt/dinov2_vitb14_reg4_pretrain.pth https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth if [ ! -f "ckpt/evl_model_ckpt.zip" ]; then echo "Error: File evl_model_ckpt.zip does not exist. Make sure it's put under EFM3D_DIR/ckpt" exit 1 fi # model cd ckpt unzip evl_model_ckpt.zip mv evl_model_ckpt/*.pth . mv evl_model_ckpt/seq136_sample.zip ../data rmdir evl_model_ckpt # data cd ../data unzip seq136_sample.zip rm seq136_sample.zip echo "Done preparing for inference" ================================================ FILE: requirements-extra.txt ================================================ projectaria-atek git+https://github.com/facebookresearch/pytorch3d.git@V0.7.8 tensorboard==2.14.0 torchmetrics==0.10.1 ================================================ FILE: requirements.txt ================================================ torch==2.3.0 torchvision==0.18.0 omegaconf==2.3.0 hydra-core==1.3.2 webdataset==0.2.86 vrs==1.2.1 fsspec==2024.6.0 einops==0.8.0 pandas==2.2.2 pyquaternion==0.9.9 opencv-python==4.10.0.84 tqdm==4.66.4 matplotlib==3.9.0 numpy==1.26.4 moderngl==5.8.2 trimesh==4.4.9 scikit-image==0.24.0 projectaria_tools ================================================ FILE: sbatch_run.sh ================================================ #!/bin/bash # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #SBATCH --job-name=efm3d_multinode #SBATCH --nodes=2 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:8 #SBATCH --cpus-per-task=96 export NCCL_DEBUG=INFO export PYTHONFAULTHANDLER=1 nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) nodes_array=($nodes) head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) echo Node IP: $head_node_ip export LOGLEVEL=INFO srun torchrun \ --nnodes 2 \ --nproc_per_node 8 \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node_ip:29600 \ train.py ================================================ FILE: train.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ # train with a single gpu python train.py # train with 8 gpus torchrun --standalone --nproc_per_node=8 train.py # train with multi-node multi-gpu, run sbatch sbatch_run.sh """ import math import os import random import shutil import time from datetime import datetime import hydra import omegaconf import torch import torch.distributed as dist import tqdm import webdataset as wds import yaml from efm3d.aria.tensor_wrapper import custom_collate_fn from efm3d.dataset.augmentation import ColorJitter, PointDropSimple, PointJitter from efm3d.dataset.efm_model_adaptor import load_atek_wds_dataset_as_efm_train from efm3d.dataset.vrs_dataset import preprocess from efm3d.dataset.wds_dataset import get_tar_sample_num from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter DATA_PATH = "./data/ase_train" MAX_LR = 2e-4 MIN_LR = MAX_LR * 0.1 BATCH_SIZE = 2 MAX_EPOCHS = 40 MAX_SAMPLES_PER_EPOCH = 100000 SAVE_EVERY_EPOCHS = 5 # save the model every LOG_STEP = 5 # print error every def get_lr(it, warmup_its, max_its, max_lr, min_lr): """ cosine learning rate scheduler, `it` can be either step or epoch. """ # learning rate scheduler # linear warmup for warmup_epochs if it < warmup_its: return max_lr * (it + 1) / warmup_its # return min_lr if epoch > max_epochs if it > max_its: return min_lr # cosine annealing decay_ratio = (it - warmup_its) / (max_its - warmup_its) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # 1.0 -> 0.0 return min_lr + coeff * (max_lr - min_lr) def get_dataloader( data_path, batch_size, world_size, max_samples_per_epoch, epoch_sample_ratio=1.0, tar_yaml="train_tars.yaml", ): assert epoch_sample_ratio > 0 and epoch_sample_ratio <= 1.0, ( f"{epoch_sample_ratio} is the ratio ([0, 1]) of samples used in each epoch" ) tar_yaml = os.path.join(data_path, tar_yaml) with open(tar_yaml, "r") as f: tar_list = yaml.safe_load(f)["tars"] tar_list = [os.path.join(data_path, tar_name) for tar_name in tar_list] # check existence for tar in tar_list: assert os.path.exists(tar), f"{tar} not exists" random.shuffle(tar_list) dataset = load_atek_wds_dataset_as_efm_train( urls=tar_list, atek_to_efm_taxonomy_mapping_file=f"{os.path.dirname(__file__)}/efm3d/config/taxonomy/atek_to_efm.csv", batch_size=batch_size, collation_fn=custom_collate_fn, ) samples_per_tar = get_tar_sample_num(tar_list[0]) dataset_size = len(tar_list) * samples_per_tar dataset_size = min(dataset_size, max_samples_per_epoch) dataset_size = int(dataset_size * epoch_sample_ratio) batches_per_epoch = int(dataset_size // (batch_size * world_size)) dataloader = wds.WebLoader( dataset, num_workers=batch_size, pin_memory=True, prefetch_factor=2, batch_size=None, shuffle=False, ) dataloader = dataloader.with_epoch(batches_per_epoch) dataloader = dataloader.with_length(batches_per_epoch) return dataloader ddp = int(os.environ.get("RANK", -1)) != -1 if ddp: assert torch.cuda.is_available() init_process_group("nccl") DDP_RANK = int(os.environ["RANK"]) DDP_LOCAL_RANK = int(os.environ["LOCAL_RANK"]) DDP_WORLD_SIZE = int(os.environ["WORLD_SIZE"]) device = f"cuda:{DDP_LOCAL_RANK}" print(f"==> setting device to {device}") torch.cuda.set_device(device) master_process = DDP_RANK == 0 else: DDP_RANK = 0 DDP_LOCAL_RANK = 0 DDP_WORLD_SIZE = 1 master_process = True device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) device = "cuda" if torch.cuda.is_available() else "cpu" model_config = omegaconf.OmegaConf.load("efm3d/config/evl_train.yaml") model = hydra.utils.instantiate(model_config) model = model model.to(device) if ddp: model = DDP(model, device_ids=[DDP_LOCAL_RANK]) raw_model = model.module if ddp else model train_dataloader = get_dataloader( DATA_PATH, BATCH_SIZE, DDP_WORLD_SIZE, max_samples_per_epoch=MAX_SAMPLES_PER_EPOCH, tar_yaml="train_tars.yaml", ) val_dataloader = get_dataloader( DATA_PATH, BATCH_SIZE, DDP_WORLD_SIZE, max_samples_per_epoch=MAX_SAMPLES_PER_EPOCH, tar_yaml="val_tars.yaml", ) optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR) if master_process: exp_name = f"efm3d_train_b{BATCH_SIZE}g{DDP_WORLD_SIZE}e{MAX_EPOCHS}lr{str(MAX_LR)}_{datetime.fromtimestamp(time.time()).strftime('%y-%m-%d-%H-%M-%S')}" log_dir = os.path.join("tb_logs", exp_name) writer = SummaryWriter(log_dir=log_dir) color_jitter = ColorJitter( brightness=0.5, contrast=0.3, saturation=0.3, hue=0.05, sharpness=2.0, snippet_jitter=True, ) point_drop = PointDropSimple(max_dropout_rate=0.8) point_jitter = PointJitter(depth_std_scale_min=1.0, depth_std_scale_max=6.0) augmentations = [color_jitter, point_drop, point_jitter] step = 0 val_step = 0 # main loop for epoch in range(MAX_EPOCHS): # train model.train() for batch in tqdm.tqdm(train_dataloader): start = time.time() optimizer.zero_grad() batch = preprocess(batch, device, aug_funcs=augmentations) output = model(batch) losses, total_loss = raw_model.compute_losses(output, batch) total_loss.backward() # epoch-based lr scheduler lr = get_lr( epoch, warmup_its=5, max_its=MAX_EPOCHS, max_lr=MAX_LR, min_lr=MIN_LR ) for param_group in optimizer.param_groups: param_group["lr"] = lr if ddp: dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) max_norm = 1.0 norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() time_per_it = time.time() - start if master_process and step % LOG_STEP == 0: print( f"E:s-{epoch}:{step} | loss {total_loss.item():.03f} | lr {lr:.06f} | norm {norm} | time {time_per_it:.02f}s/it" ) # log training writer.add_scalar("train/loss", total_loss.item(), step) for stream in losses: for loss_name in losses[stream]: writer.add_scalar( f"train/loss/{stream}/{loss_name}", losses[stream][loss_name].item(), step, ) writer.add_scalar("train/lr", lr, step) writer.add_scalar("train/iter_sec", time_per_it, step) # log images (log every `10xlog_step` since writing video is slow) if step % (10 * LOG_STEP) == 0: imgs = raw_model.log_single(batch, output, batch_idx=0) for k, v in imgs.items(): vid = torch.tensor(v.transpose((0, 3, 1, 2))).unsqueeze(0) writer.add_video(f"train/{k}", vid, global_step=step, fps=10) step += 1 # val model.eval() for batch in tqdm.tqdm(val_dataloader): with torch.no_grad(): start = time.time() batch = preprocess(batch, device, aug_funcs=augmentations) output = model(batch) losses, total_loss = raw_model.compute_losses(output, batch) if ddp: dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) time_per_it = time.time() - start if master_process and val_step % LOG_STEP == 0: print( f"E:s-{epoch}:{val_step} | loss {total_loss.item():.03f} | time {time_per_it:.02f}s/it" ) # log val if val_step % LOG_STEP == 0: writer.add_scalar("val/loss", total_loss.item(), val_step) for stream in losses: for loss_name in losses[stream]: writer.add_scalar( f"val/loss/{stream}/{loss_name}", losses[stream][loss_name].item(), val_step, ) writer.add_scalar("val/iter_sec", time_per_it, val_step) # log images if val_step % (10 * LOG_STEP) == 0: imgs = raw_model.log_single(batch, output, batch_idx=0) for k, v in imgs.items(): vid = torch.tensor(v.transpose((0, 3, 1, 2))).unsqueeze(0) writer.add_video(f"val/{k}", vid, global_step=val_step, fps=10) val_step += 1 # save model if master_process and (epoch + 1) % SAVE_EVERY_EPOCHS == 0: ckpt_path = os.path.join( log_dir, f"model_e{epoch}s{step}_l{total_loss.item():.02f}.pth" ) last_ckpt_path = os.path.join(log_dir, "last.pth") torch.save( {"state_dict": raw_model.state_dict(), "optimizer": optimizer.state_dict()}, ckpt_path, ) shutil.copy(ckpt_path, last_ckpt_path) if master_process: writer.close() if ddp: destroy_process_group()