[
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "content": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to make participation in our project and\nour community a harassment-free experience for everyone, regardless of age, body\nsize, disability, ethnicity, sex characteristics, gender identity and expression,\nlevel of experience, education, socio-economic status, nationality, personal\nappearance, race, religion, or sexual identity and orientation.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy towards other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\n  advances\n* Trolling, insulting/derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or electronic\n  address, without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned to this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies within all project spaces, and it also applies when\nan individual is representing the project or its community in public spaces.\nExamples of representing a project or community include using an official\nproject e-mail address, posting via an official social media account, or acting\nas an appointed representative at an online or offline event. Representation of\na project may be further defined and clarified by project maintainers.\n\nThis Code of Conduct also applies outside the project spaces when there is a\nreasonable belief that an individual's behavior may have a negative impact on\nthe project or its community.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the project team at <opensource-conduct@meta.com>. All\ncomplaints will be reviewed and investigated and will result in a response that\nis deemed necessary and appropriate to the circumstances. The project team is\nobligated to maintain confidentiality with regard to the reporter of an incident.\nFurther details of specific enforcement policies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,\navailable at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see\nhttps://www.contributor-covenant.org/faq\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# Contributing to \"efm3d\"\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Requests\n\nWe welcome pull requests.\n\n1. Fork the repo and create your branch from `main`.\n2. If you've added code that should be tested, add tests.\n3. If you've changed APIs, update the documentation in the code.\n4. Ensure the test suite passes.\n5. If you haven't already, complete the Contributor License Agreement (\"CLA\").\n\n## Contributor License Agreement (\"CLA\")\n\nIn order to accept your pull request, we need you to submit a CLA. You only need\nto do this once to work on any of Facebook's open source projects.\n\nComplete your CLA here: <https://code.facebook.com/cla>\n\n## Issues\n\nWe use GitHub issues to track public bugs. Please ensure your description is\nclear and has sufficient instructions to be able to reproduce the issue.\n\nFacebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe\ndisclosure of security bugs. In those cases, please go through the process\noutlined on that page and do not file a public issue.\n\n## License\n\nBy contributing to \"efm3d\", you agree that your contributions will be licensed under\nthe [LICENSE](../LICENSE) file in the root directory of this source tree.\n"
  },
  {
    "path": ".github/workflows/conda-env.yaml",
    "content": "name: Conda Environment CI\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\njobs:\n  test:\n    name: Test conda env\n    runs-on: \"ubuntu-latest\"\n    defaults:\n      run:\n        shell: bash -el {0}\n    steps:\n      - uses: actions/checkout@v4\n      - uses: conda-incubator/setup-miniconda@v3\n        with:\n          activate-environment: efm3d\n          environment-file: environment-mac.yml\n          python-version: 3.9\n          auto-activate-base: false\n      - run: |\n          conda info\n          conda list\n          conda activate efm3d\n          pip install -r requirements.txt\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\ndist/\nbuild/\neggs/\n.eggs/\n*.egg-info/\nlib/\nlib64/\n\n# PyTorch specific\n*.pt\n*.pth\n*.ckpt\n*.tfevents\n.ipynb_checkpoints/\n\n# Environment\n.env\nvenv/\nENV/\n\n# IDEs\n.vscode/\n.idea/\n\n# Miscellaneous\n.DS_Store\nThumbs.db\n\n# artifacts\n*.mp4\n\n# data\n*.ply\ndata/\ntb_logs/\n\n# model weights\nckpt/\n\n# output dir\n*.out\noutput/\n# tensoboard output\nruns/\n"
  },
  {
    "path": "INSTALL.md",
    "content": "# Installation\n\nWe provide two ways to install the dependencies of EFM3D. We recommend using miniconda to manage the dependencies, which\nalso provide a easy setup to for all the additional dependencies listed in `requirements.txt` and `requirements-extra.txt`.\n\n## Install using conda (recommended)\n\nFirst install [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install),\nthen run the following commands under the `<EFM3D_DIR>` root directory\n\n```\nconda env create --file=environment.yml\nconda activate efm3d\n\ncd efm3d/thirdparty/mmdetection3d/cuda/\npython setup.py install\n```\n\nThe commands will first create a conda environment named `efm3d`, and then build the\nthird-party CUDA kernel required for training.\n\n## Install via pip\n\nMake sure you have\nPython>=3.9, then install the dependencies using `pip`.\nThe packages in `requirements.txt` are needed for the basic functionalities of\nEFM3D, such as running the example model inference to see 3D object detection\nand surface reconstruction on a [vrs](https://facebookresearch.github.io/vrs/)\nsequence.\n\n```\npip install -r requirements.txt\n```\n\nAdditional dependencies in `requirements-extra.txt` are needed for training and eval.\n\n```\npip install -r requirements-extra.txt\n```\n\n**Important**: For training, we also need to built a CUDA kernel from\n[mmdetection3d](https://github.com/open-mmlab/mmdetection3d). Compile the CUDA\nkernel of the IoU3d loss by running the following commands, which requires the\ninstallation of\n[CUDA dev toolkit](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/).\n\n```\ncd efm3d/thirdparty/mmdetection3d/cuda/\npython setup.py install\n```\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models\n\n[[paper](https://arxiv.org/abs/2406.10224)]\n[[website](https://www.projectaria.com/research/efm3D/)]\n\n## Intro\n\nThis is the official release for the paper EFM3D: A Benchmark for Measuring\nProgress Towards 3D Egocentric Foundation Models\n(https://arxiv.org/abs/2406.10224). To measure progress on what we term\nEgocentric Foundation Models (EFMs) we establish EFM3D, a benchmark with two\ncore 3D egocentric perception tasks. EFM3D is the first benchmark for 3D object\ndetection and surface regression on high quality annotated egocentric data of\n[Project Aria](https://www.projectaria.com/). We also propose Egocentric Voxel\nLifting (EVL), a baseline for 3D EFMs.\n\n<img src=\"assets/efm3d.png\">\n\nWe provide the following code and assets\n\n- The pretrained EVL model weights for surface reconstruction and 3D object\n  detection on Aria sequences\n- The datasets included in the EFM3D benchmark, including the training and\n  evaluation data for Aria Synthetic Datasets (ASE), Aria Everyday Objects (AEO)\n  for 3D object detection, and the eval mesh models for surface reconstruction\n  evaluation.\n- Distributed training code to train EVL.\n- Native integration with\n  [Aria Training and Evaluation Kit (ATEK)](https://github.com/facebookresearch/atek).\n\nThe following serves as a minimal example to run the model inference, including\ninstallation guide, data downloading instructions and how to run the inference\ncode.\n\n## Installation\n\n**Option 1**: First navigate to the root folder. The core library is written in\nPyTorch, with additional dependencies listed in `requirements.txt`. This needs\nPython>=3.9\n\n```\npip install -r requirements.txt\n```\n\n**Option 2**: You can choose to use conda to manage the dependencies.\nWe recommend using [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install) for its fast dependency solver.\nThe runtime dependencies can be installed by running (replace `environment.yaml` with `environment-mac.yml` if run on macOS)\n\n```\nconda env create --file=environment.yml\nconda activate efm3d\n```\n\nThis should be sufficient to initiate the use of the EVL model inference with\nthe pretrained model weights. please refer to [INSTALL.md](INSTALL.md) for a\nfull installation, which is required for training and eval.\n\n## Inference\n\n### Pretrained models\n\nDownload the pretrained model weights and a sample data on the\n[EFM3D](https://www.projectaria.com/research/efm3D/#download-dataset) page\n(email required). We provide two model checkpoints, one for server-side GPU\n(>20GB GPU memory) and one for desktop GPU. There is a sample sequence attached\nto the model weights to facilitate using the model. Check out the\n[README.md](ckpt/README.md) for detailed instructions on how to download the\nmodel weights.\n\n### Run on the sample data\n\nAfter downloading the model weights `evl_model_ckpt.zip`, put it under\n`${EFM3D_DIR}/ckpt/`, then run the command under `${EFM3D_DIR}`\n\n```\nsh prepare_inference.sh\n```\n\nThis will unzip the file, make sure the model weights and sample data are put\nunder the right paths. To run inference on the sample sequence\n\n```\npython infer.py --input ./data/seq136_sample/video.vrs\n```\n\n**Note**: the pretrained model requires ~20GB GPU memory. Use the following\ncommand to run the model on a desktop GPU with ~10GB memory (tested on\nRTX-3080). The performance is downgraded a bit.\n\n```\npython infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08\n```\n\n### Run on macOS\n\nThe inference demo works on macOS too. Use the following command (tested on\nApple M1 MAX 64GB memory)\n\n```\nPYTORCH_ENABLE_MPS_FALLBACK=1 python infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08\n```\n\nThis wraps up the basic usage of EVL model. To train the model from scratch and\nuse the EFM3D benchmark, have a full installation following\n[INSTALL.md](INSTALL.md) then read below\n\n### Inference with ATEK\n\nThe inference also supports taking\n[ATEK-format](https://github.com/facebookresearch/atek) WDS sequences. First\ndownload a test ASE sequence following the `ASE eval data` section in\n[README.md](data/README.md), then run\n\n```\npython infer.py --input ./data/ase_eval/81022\n```\n\n## Datasets\n\nSee [README.md](data/README.md) for instructions to work with all datasets\nincluded in the EFM3D benchmark. There are three datasets in the EFM3D benchmark\n\n- [Aria Synthetic Environments (ASE)](https://www.projectaria.com/datasets/ase/):\n  for training and eval on 3D object detection and surface reconstruction\n- [Aria Digital Twin (ADT)](https://www.projectaria.com/datasets/adt/): for eval\n  on surface reconstruction\n- [Aria Everyday Objects (AEO)](https://www.projectaria.com/datasets/aeo/): for\n  eval on 3D object detection.\n\n## Train EVL\n\nFirst make sure you have a full installation (see [INSTALL.md](INSTALL.md)).\nTrain the EVL model from scratch requires downloading the full ASE training data\nYou can download a small subset of ASE sequences (>10 sequences) to test the\ntraining script. Check out the `ASE training data` section in\n[data/README.md](data/README.md). After following the instructions to prepare\nthe data, run the following command.\n\n- train the EVL model from scratch on a single GPU\n\n```\npython train.py\n```\n\n- train with 8 GPUs\n\n```\ntorchrun --standalone --nproc_per_node=8 train.py\n```\n\nWe also provide a script to train on multi-node multi-gpu environment via\n[slurm](https://slurm.schedmd.com/documentation.html). The pretrained model is\ntrained on 2 nodes with 8xH100.\n\n- train with multi-node multi-gpu using slurm\n\n```\nsbatch sbatch_run.sh\n```\n\nBy default the tensorboard log is saved to `${EFM3D_DIR}/tb_logs`.\n\n## EFM3D benchmark\n\nPlease see [benchmark.md](benchmark.md) for details.\n\n## Citing EFM3D\n\nIf you find EFM3D useful, please consider citing\n\n```\n@article{straub2024efm3d,\n  title={EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models},\n  author={Straub, Julian and DeTone, Daniel and Shen, Tianwei and Yang, Nan and Sweeney, Chris and Newcombe, Richard},\n  journal={arXiv preprint arXiv:2406.10224},\n  year={2024}\n}\n```\n\nIf you use Aria Digital Twin (ADT) dataset in the EFM3D benchmark, please\nconsider citing\n\n```\n@inproceedings{pan2023aria,\n  title={Aria digital twin: A new benchmark dataset for egocentric 3d machine perception},\n  author={Pan, Xiaqing and Charron, Nicholas and Yang, Yongqian and Peters, Scott and Whelan, Thomas and Kong, Chen and Parkhi, Omkar and Newcombe, Richard and Ren, Yuheng Carl},\n  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n  pages={20133--20143},\n  year={2023}\n}\n```\n\nIf you use the Aria Synthetic Environments (ASE) dataset in the EFM3D benchmark,\nplease consider citing\n\n```\n@article{avetisyan2024scenescript,\n  title={SceneScript: Reconstructing Scenes With An Autoregressive Structured Language Model},\n  author={Avetisyan, Armen and Xie, Christopher and Howard-Jenkins, Henry and Yang, Tsun-Yi and Aroudj, Samir and Patra, Suvam and Zhang, Fuyang and Frost, Duncan and Holland, Luke and Orme, Campbell and others},\n  journal={arXiv preprint arXiv:2403.13064},\n  year={2024}\n}\n```\n\n## How to Contribute\n\nWe welcome contributions! Go to [CONTRIBUTING](./.github/CONTRIBUTING.md) and\nour [CODE OF CONDUCT](./.github/CODE_OF_CONDUCT.md) for how to get started.\n\n## License\n\nEFM3D is released by Meta under the [Apache 2.0 license](LICENSE).\n"
  },
  {
    "path": "benchmark.md",
    "content": "## EFM3D Benchmark\n\nWe provide three evaluation datasets for the EFM3D benchmarks. For more details on the benchmark see the [EFM3D](https://arxiv.org/abs/2406.10224) paper.\n\n### ASE - 3D object detection and mesh reconstruction\nAria Synthetic Environments (ASE) is a synthetic dataset, created from procedurally-generated interior layouts filled with 3D objects, simulated with the sensor characteristics of Aria glasses. We use ASE for both surface reconstruction and 3D object detection evaluation.\n\nFirst follow instructions in the dataset [README.md](data/README.md) to download the ASE eval set and eval meshes, then run the following\n\n```\npython eval.py --ase\n```\n\nRunning the full evaluation on 100 eval sequences takes a long time (>10 hrs on a single GPU). To see the eval results on a reduced set, use `--num_seqs` and `--num_snips` to specify number of sequences and number of snippets per sequence to speed up evaluation, for example\n\n```\n# run eval on the first 10 sequences of ASE, each running for 100 snippets (10s)\npython eval.py --ase --num_seqs 10 --num_snips 100\n```\n\n### ADT - mesh reconstruction\nADT is the benchmark data for surface reconstruction, containing 6 sequences.\nDownload the ADT data and mesh files following the data instruction. Then run\n\n```\npython eval.py --adt\n```\n\nThe provided script provides an end-to-end solution to run EVL model with the default checkpoint,\nfinding the right GT mesh path for ASE and ADT dataset, then run the evaluation metrics for mesh-to-mesh distance.\nIf you have your own model that generates a ply file, check [eval_mesh_to_mesh](efm3d/utils/mesh_utils.py) for how to evaluate against surface GT directly.\n\n### AEO - 3D object detection\nAEO is the benchmark data for 3D object detection, with 25 sequences.\nDownload the AEO dataset following the data instruction. Then run\n\n```\npython eval.py --aeo\n```\n\nThis will run the EVL model inference using the default model checkpoint path.\nIf you have your own model for inference, check [eval.py](efm3d/inference/eval.py) for how to evaluate against 3D object GT directly.\n"
  },
  {
    "path": "efm3d/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/aria/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .camera import CameraTW, DEFAULT_CAM_DATA_SIZE\nfrom .obb import ObbTW, transform_obbs\nfrom .pose import PoseTW\nfrom .tensor_wrapper import smart_cat, smart_stack, TensorWrapper\n"
  },
  {
    "path": "efm3d/aria/aria_constants.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# High level organization of the constants:\n# - */time_ns is timestamp with respect to the aria clock in nanoseconds stored as torch.long()\n# - */snippet_time_s is the timestamp with respect to the start of the snippet in seconds stored as torch.float32()\n# - */t_A_B is a pose transformation from coordinate system B to A\n# - path-like key strings designate hierarchical relationships of data. I.e.\n#   rgb/img/... is all data relating to the rgb image information. rgb/calib/...\n#   is all about the calibration data. And all rgb/... is data relating to the\n#   rgb video stream.\n\n# ---------------------------------------------------------------------\n# sequence level information\n# ---------------------------------------------------------------------\nARIA_SEQ_ID = \"sequence/id\"\n# start of the sequence in ns relative to global Aria timestamp\nARIA_SEQ_TIME_NS = \"sequence/time_ns\"\n\n# ---------------------------------------------------------------------\n# snippet level information\n# ---------------------------------------------------------------------\nARIA_SNIPPET_ID = \"snippet/id_in_sequence\"\nARIA_SNIPPET_LENGTH_S = \"snippet/length_s\"\n# start of sequence in ns relative to global Aria timestamp (sometimes unix 0)\nARIA_SNIPPET_TIME_NS = \"snippet/time_ns\"\n# offset of snippet coordinate system to sequence coordinate system\nARIA_SNIPPET_T_WORLD_SNIPPET = \"snippet/t_world_snippet\"\n# Ratio of where in the snippet is the origin of cosy relative to the\n# snippet length. E.g. 0.5 for a 10 sec snippet would mean that 5 sec is origin,\n# was previously known as \"frame_selection\" in LocalCosyPreprocessor.\nARIA_SNIPPET_ORIGIN_RATIO = \"snippet/origin_ratio\"\n\n# ---------------------------------------------------------------------\n# streamer playback time information\n# ---------------------------------------------------------------------\nARIA_PLAY_TIME_NS = \"play/time_ns\"\nARIA_PLAY_SEQUENCE_TIME_S = \"play/sequence_time_s\"\nARIA_PLAY_SNIPPET_TIME_S = \"play/snippet_time_s\"\nARIA_PLAY_FREQUENCY_HZ = \"play/hz\"\n\n# ---------------------------------------------------------------------\n# aria video stream information\n# ---------------------------------------------------------------------\n# frame id in the sequence\nARIA_FRAME_ID = [\n    \"rgb/frame_id_in_sequence\",\n    \"slaml/frame_id_in_sequence\",\n    \"slamr/frame_id_in_sequence\",\n]\n# timestamp within snippet\nARIA_IMG_SNIPPET_TIME_S = [\n    \"rgb/img/snippet_time_s\",\n    \"slaml/img/snippet_time_s\",\n    \"slamr/img/snippet_time_s\",\n]\n# timestamp within sequence\nARIA_IMG_TIME_NS = [\n    \"rgb/img/time_ns\",\n    \"slaml/img/time_ns\",\n    \"slamr/img/time_ns\",\n]\n# poses of the rig at the time of the respective frame capture\n# T x 12\nARIA_IMG_T_SNIPPET_RIG = [\n    \"rgb/t_snippet_rig\",\n    \"slaml/t_snippet_rig\",\n    \"slamr/t_snippet_rig\",\n]\n# image tensors\nARIA_IMG = [\"rgb/img\", \"slaml/img\", \"slamr/img\"]\nARIA_IMG_FREQUENCY_HZ = [\n    \"rgb/img/hz\",\n    \"slaml/img/hz\",\n    \"slamr/img/hz\",\n]\n\n# ---------------------------------------------------------------------\n# calibration information\n# ---------------------------------------------------------------------\nARIA_CALIB = [\n    \"rgb/calib\",\n    \"slaml/calib\",\n    \"slamr/calib\",\n]\n# timestamp within the snippet\nARIA_CALIB_SNIPPET_TIME_S = [\n    \"rgb/calib/snippet_time_s\",\n    \"slaml/calib/snippet_time_s\",\n    \"slamr/calib/snippet_time_s\",\n]\n# timestamp within the sequence\nARIA_CALIB_TIME_NS = [\n    \"rgb/calib/time_ns\",\n    \"slaml/calib/time_ns\",\n    \"slamr/calib/time_ns\",\n]\n\n# ---------------------------------------------------------------------\n# pose information\n# ---------------------------------------------------------------------\n# pose timestamp within snippet\nARIA_POSE_SNIPPET_TIME_S = \"pose/snippet_time_s\"\n# pose timestamp within sequence\nARIA_POSE_TIME_NS = \"pose/time_ns\"\n# transformation from rig to snippet coordinate system\nARIA_POSE_T_SNIPPET_RIG = \"pose/t_snippet_rig\"\n# transformation from rig to world coordinate system\nARIA_POSE_T_WORLD_RIG = \"pose/t_world_rig\"\n# frequency of poses\nARIA_POSE_FREQUENCY_HZ = \"pose/hz\"\n\n# ---------------------------------------------------------------------\n# semidense points information\n# ---------------------------------------------------------------------\nARIA_POINTS_WORLD = \"points/p3s_world\"\nARIA_POINTS_TIME_NS = \"points/time_ns\"\nARIA_POINTS_SNIPPET_TIME_S = \"points/snippet_time_s\"\nARIA_POINTS_FREQUENCY_HZ = \"points/hz\"\nARIA_POINTS_INV_DIST_STD = \"points/inv_dist_std\"\nARIA_POINTS_DIST_STD = \"points/dist_std\"\n\n# ---------------------------------------------------------------------\n# imu information\n# ---------------------------------------------------------------------\nARIA_IMU = [\"imur\", \"imul\"]\nARIA_IMU_CHANNELS = [\n    [\"imur/lin_acc_ms2\", \"imur/rot_vel_rads\"],\n    [\"imul/lin_acc_ms2\", \"imul/rot_vel_rads\"],\n]\nARIA_IMU_SNIPPET_TIME_S = [\"imur/snippet_time_s\", \"imul/snippet_time_s\"]\nARIA_IMU_TIME_NS = [\"imur/time_ns\", \"imul/time_ns\"]\nARIA_IMU_FACTORY_CALIB = [\"imur/factory_calib\", \"imul/factory_calib\"]\nARIA_IMU_FREQUENCY_HZ = [\"imur/hz\", \"imul/hz\"]\n\n# ---------------------------------------------------------------------\n# audio data\n# ---------------------------------------------------------------------\nARIA_AUDIO = \"audio\"\n# snippet time within snippet of audio sample\nARIA_AUDIO_SNIPPET_TIME_S = \"audio/snippet_time_s\"\n# timestamp of audio sample in sequence\nARIA_AUDIO_TIME_NS = \"audio/time_ns\"\n# frequency of audio sample\nARIA_AUDIO_FREQUENCY_HZ = \"audio/hz\"\n\n# ---------------------------------------------------------------------\n# OBB\n# ---------------------------------------------------------------------\n# padded ObbTW tensor for oriented object bounding boxes given in *snippet coordinate system*\nARIA_OBB_PADDED = \"obbs/padded_snippet\"\n# mapping of semantic id of the obb to a string name\nARIA_OBB_SEM_ID_TO_NAME = \"obbs/sem_id_to_name\"\n# snippet time within the sequence\nARIA_OBB_SNIPPET_TIME_S = \"obbs/snippet_time_s\"\n# timestamp within the sequence\nARIA_OBB_TIME_NS = \"obbs/time_ns\"\n# frequency of object detection information\nARIA_OBB_FREQUENCY_HZ = \"obbs/hz\"\n\n# predicted ObbTW tensor for oriented object bounding boxes\nARIA_OBB_PRED = \"obbs/pred\"  # raw predictions from the networks.\nARIA_OBB_PRED_VIZ = \"obbs/pred_viz\"  # predictions for visualization (e.g. raw predictions filtered by some criteria.)\nARIA_OBB_PRED_SEM_ID_TO_NAME = \"obbs/pred/sem_id_to_name\"\nARIA_OBB_PRED_PROBS_FULL = \"obbs/pred/probs_full\"\nARIA_OBB_PRED_PROBS_FULL_VIZ = \"obbs/pred/probs_ful_viz\"\n# tracked ObbTW tensor for oriented object bounding boxes\nARIA_OBB_TRACKED = \"obbs/tracked\"\nARIA_OBB_TRACKED_PROBS_FULL = \"obbs/tracked/probs_full\"\n# tracked but not instantiated ObbTW tensor for oriented object bounding boxes\nARIA_OBB_UNINST = \"obbs/uninst\"\n\nARIA_OBB_BB2 = [\"bb2s_rgb\", \"bb2s_slaml\", \"bb2s_slamr\"]\nARIA_OBB_BB3 = \"bb3s_object\"\n\n# ---------------------------------------------------------------------\n# depth information\n# ---------------------------------------------------------------------\n# for depth images (z-depth) in meters\nARIA_DEPTH_M = [\"rgb/depth_m\", \"slaml/depth_m\", \"slamr/depth_m\"]\n# for distance images (distance along ray) in meters\nARIA_DISTANCE_M = [\"rgb/distance_m\", \"slaml/distance_m\", \"slamr/distance_m\"]\nARIA_DEPTH_TIME_NS = [\n    \"rgb/depth/time_ns\",\n    \"slaml/depth/time_ns\",\n    \"slamr/depth/time_ns\",\n]\nARIA_DEPTH_SNIPPET_TIME_S = [\n    \"rgb/depth/snippet_time_s\",\n    \"slaml/depth/snippet_time_s\",\n    \"slamr/depth/snippet_time_s\",\n]\n\nARIA_DEPTH_M_PRED = [\"rgb/pred/depth_m\", \"slaml/pred/depth_m\", \"slamr/pred/depth_m\"]\n# for distance images (distance along ray) in meters\nARIA_DISTANCE_M_PRED = [\n    \"rgb/pred/distance_m\",\n    \"slaml/pred/distance_m\",\n    \"slamr/pred/distance_m\",\n]\n\n# ---------------------------------------------------------------------\n# SDF information\n# ---------------------------------------------------------------------\nARIA_SDF = \"snippet/sdf/sdf\"\nARIA_SDF_EXT = \"snippet/sdf/extent\"\nARIA_SDF_COSY_TIME_NS = \"snippet/sdf/cosy_time_ns\"\nARIA_SDF_MASK = \"snippet/sdf/mask\"\nARIA_SDF_T_WORLD_VOXEL = \"snippet/sdf/T_world_voxel\"\n\n# ---------------------------------------------------------------------\n# GT Mesh information\n# ---------------------------------------------------------------------\nARIA_MESH_VERTS_W = \"snippet/mesh/verts_w\"\nARIA_MESH_FACES = \"snippet/mesh/faces\"\nARIA_MESH_VERT_NORMS_W = \"snippet/mesh/v_norms_w\"\nARIA_SCENE_MESH_VERTS_W = \"scene/mesh/verts_w\"\nARIA_SCENE_MESH_FACES = \"scene/mesh/faces\"\nARIA_SCENE_MESH_VERT_NORMS_W = \"scene/mesh/v_norms_w\"\n\n# ---------------------------------------------------------------------\n# Scene volume information (can be acquired from mesh or semidense points)\n# --------------------------------------------------------------------\nARIA_MESH_VOL_MIN = \"scene/mesh/vol_min\"\nARIA_MESH_VOL_MAX = \"scene/mesh/vol_max\"\nARIA_POINTS_VOL_MIN = \"scene/points/vol_min\"\nARIA_POINTS_VOL_MAX = \"scene/points/vol_max\"\n\n# ---------------------------------------------------------------------\n# additional image constants\n# ---------------------------------------------------------------------\n\n# Fixed mapping of resolutions, tuple has three numbers: (RGB_HW, SLAM_W, SLAM_H)\nRESOLUTION_MAP = {\n    0: (1408, 640, 480),\n    1: (704, 640, 480),\n    2: (352, 320, 240),\n    # 3: there is none\n    4: (176, 160, 112),  # there is some cropping in SLAM image height\n    5: (480, 640, 480),\n    6: (336, 448, 336),  # match typical internet image FOV (assume 70 deg)\n    7: (240, 320, 240),  # match typical internet pixels e.g. ImageNet\n    8: (192, 256, 192),\n    9: (144, 192, 144),\n    # divisible by 14 for ViTs that use patch size 14\n    10: (\n        1400,\n        560,\n        420,\n    ),  # similar to 0  560x420 instead of 616x462 so that we can also get half the resolution for equivalent to 7\n    11: (700, 560, 420),  # similar to 1\n    12: (420, 560, 420),  # similar to 5\n    13: (210, 280, 210),  # similar to 7\n}\n# Fixed mapping of corresponding wh_multiple_of, for each resolution\nWH_MULTIPLE_OF_MAP = {\n    0: 16,\n    1: 16,\n    2: 16,\n    # 3: there is none\n    4: 16,\n    5: 16,\n    6: 16,\n    7: 16,\n    8: 16,\n    9: 16,\n    10: 14,\n    11: 14,\n    12: 14,\n    13: 14,\n}\n\n# Helper constants for managing valid radius of the fisheye images, valid radius\n# defines a circle from the center of projection where project/unproject is valid\nRGB_RADIUS_FACTOR = 760.0 / 1408.0\nSLAM_RADIUS_FACTOR = 320.0 / 640.0\n\nARIA_RGB_WIDTH_TO_RADIUS = {\n    RESOLUTION_MAP[key][0]: RESOLUTION_MAP[key][0] * RGB_RADIUS_FACTOR\n    for key in RESOLUTION_MAP\n}\nARIA_SLAM_WIDTH_TO_RADIUS = {\n    RESOLUTION_MAP[key][1]: RESOLUTION_MAP[key][1] * SLAM_RADIUS_FACTOR\n    for key in RESOLUTION_MAP\n}\n\nARIA_RGB_SCALE_TO_WH = {\n    key: [RESOLUTION_MAP[key][0], RESOLUTION_MAP[key][0]] for key in RESOLUTION_MAP\n}\nARIA_SLAM_SCALE_TO_WH = {\n    key: [RESOLUTION_MAP[key][1], RESOLUTION_MAP[key][2]] for key in RESOLUTION_MAP\n}\n\nARIA_IMG_MIN_LUX = 30.0\nARIA_IMG_MAX_LUX = 150000.0\nARIA_IMG_MAX_PERC_OVEREXPOSED = 0.02\nARIA_IMG_MAX_PERC_UNDEREXPOSED = 0.0001\n\n# ---------------------------------------------------------------------\n# EFM Constants\n# ---------------------------------------------------------------------\nARIA_EFM_OUTPUT = \"efm/output\"\n\nARIA_CAM_INFO = {\n    \"name\": [\"rgb\", \"slaml\", \"slamr\"],\n    \"stream_id\": [0, 1, 2],\n    \"name_to_stream_id\": {\n        \"rgb\": 0,\n        \"slaml\": 1,\n        \"slamr\": 2,\n    },\n    \"width_height\": {\n        \"rgb\": (1408, 1408),\n        \"slaml\": (640, 480),\n        \"slamr\": (640, 480),\n    },\n    # vrs id\n    \"id\": [\"214-1\", \"1201-1\", \"1201-2\"],\n    \"id_to_name\": {\n        \"214-1\": \"rgb\",\n        \"1201-1\": \"slaml\",\n        \"1201-2\": \"slamr\",\n    },\n    # display names\n    \"display\": [\n        \"RGB\",\n        \"SLAM Left\",\n        \"SLAM Right\",\n    ],\n    # Physical position on glasses from left to right.\n    \"spatial_order\": [1, 0, 2],\n}\n"
  },
  {
    "path": "efm3d/aria/camera.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\n\nfrom .pose import get_T_rot_z, IdentityPose, PoseTW\nfrom .projection_utils import (\n    fisheye624_project,\n    fisheye624_unproject,\n    pinhole_project,\n    pinhole_unproject,\n)\nfrom .tensor_wrapper import autocast, autoinit, smart_cat, TensorWrapper\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\nclass DefaultCameraTWData(TensorWrapper):\n    \"\"\"Allows multiple input sizes.\"\"\"\n\n    def __init__(self):\n        self._data = -1 * torch.ones(33)\n\n    @property\n    def shape(self):\n        return (torch.Size([34]), torch.Size([26]), torch.Size([22]))\n\n\nclass DefaultCameraTWParam(TensorWrapper):\n    \"\"\"Allows multiple input sizes.\"\"\"\n\n    def __init__(self):\n        self._data = -1 * torch.ones(15)\n\n    @property\n    def shape(self):\n        return (torch.Size([16]), torch.Size([15]), torch.Size([8]), torch.Size([4]))\n\n\nclass DefaultCameraTWDistParam(TensorWrapper):\n    \"\"\"Allows multiple input sizes.\"\"\"\n\n    def __init__(self):\n        self._data = -1 * torch.ones(12)\n\n    @property\n    def shape(self):\n        return (torch.Size([12]), torch.Size([4]), torch.Size([0]))\n\n\nDEFAULT_CAM_DATA = DefaultCameraTWData()\nDEFAULT_CAM_PARAM = DefaultCameraTWParam()\nDEFAULT_CAM_DIST_PARAM = DefaultCameraTWDistParam()\nDEFAULT_CAM_DATA_SIZE = 34\n\nRGB_PARAMS = np.float32(\n    [2 * 600.0, 2 * 352.0, 2 * 352.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n)\nSLAM_PARAMS = np.float32([500.0, 320.0, 240.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n\nFISHEYE624_TYPE_STR = (\n    \"FisheyeRadTanThinPrism:f,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4\"\n)\nFISHEYE624_DF_TYPE_STR = (\n    \"FisheyeRadTanThinPrism:fu,fv,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4\"\n)\nPINHOLE_TYPE_STR = \"Pinhole\"\n\n\ndef is_fisheye624(inp):\n    names = [\n        \"Fisheye624\",\n        \"f624\",\n        FISHEYE624_TYPE_STR,\n        FISHEYE624_DF_TYPE_STR,\n        \"FisheyeRadTanThinPrism\",\n        \"CameraModelType.FISHEYE624\",\n    ]\n    names += [name.lower() for name in names]\n    return inp in names\n\n\ndef is_kb3(inp):\n    names = [\"KB:fu,fv,u0,v0,k0,k1,k2,k3\", \"KannalaBrandtK3\", \"KB3\"]\n    names += [name.lower() for name in names]\n    return inp in names\n\n\ndef is_pinhole(inp):\n    names = [\"Pinhole\", \"Linear\", \"CameraModelType.LINEAR\"]\n    names += [name.lower() for name in names]\n    return inp in names\n\n\ndef get_aria_camera(params=SLAM_PARAMS, width=640, height=480, valid_radius=None, B=1):\n    type_str = FISHEYE624_TYPE_STR if params.shape[-1] == 15 else FISHEYE624_DF_TYPE_STR\n    if valid_radius is None:\n        cam = CameraTW.from_surreal(width, height, type_str, params)\n    else:\n        cam = CameraTW.from_surreal(\n            width,\n            height,\n            type_str,\n            params,\n            valid_radius=valid_radius,\n        )\n    if B > 1:\n        cam = cam.unsqueeze(0).repeat(B, 1)\n    return cam\n\n\ndef get_pinhole_camera(params, width=640, height=480, valid_radius=None, B=1):\n    type_str = PINHOLE_TYPE_STR\n    if valid_radius is None:\n        cam = CameraTW.from_surreal(width, height, type_str, params)\n    else:\n        cam = CameraTW.from_surreal(\n            width,\n            height,\n            type_str,\n            params,\n            valid_radius=valid_radius,\n        )\n    if B > 1:\n        cam = cam.unsqueeze(0).repeat(B, 1)\n    return cam\n\n\ndef get_base_aria_rgb_camera_full_res():\n    params = RGB_PARAMS * 2\n    params[1:3] += 32\n    return get_aria_camera(params, 2880, 2880)\n\n\ndef get_base_aria_rgb_camera():\n    return get_aria_camera(RGB_PARAMS, 1408, 1408)\n\n\ndef get_base_aria_slam_camera():\n    return get_aria_camera(SLAM_PARAMS, 640, 480)\n\n\nclass CameraTW(TensorWrapper):\n    \"\"\"\n    Class to represent a batch of camera calibrations of the same camera type.\n    \"\"\"\n\n    SIZE_IND = slice(0, 2)\n    F_IND = slice(2, 4)\n    C_IND = slice(4, 6)\n    GAIN_IND = 6\n    EXPOSURE_S_IND = 7\n    VALID_RADIUS_IND = slice(8, 10)\n    T_CAM_RIG_IND = slice(10, 22)\n    DIST_IND = slice(22, None)\n\n    @autocast\n    @autoinit\n    def __init__(\n        self, data: Union[torch.Tensor, DefaultCameraTWData] = DEFAULT_CAM_DATA\n    ):\n        assert isinstance(data, torch.Tensor)\n        assert data.shape[-1] in {22, 26, 34}\n        super().__init__(data)\n\n    @classmethod\n    @autoinit\n    def from_parameters(\n        cls,\n        width: torch.Tensor = -1 * torch.ones(1),\n        height: torch.Tensor = -1 * torch.ones(1),\n        fx: torch.Tensor = -1 * torch.ones(1),\n        fy: torch.Tensor = -1 * torch.ones(1),\n        cx: torch.Tensor = -1 * torch.ones(1),\n        cy: torch.Tensor = -1 * torch.ones(1),\n        gain: torch.Tensor = -1 * torch.ones(1),\n        exposure_s: torch.Tensor = 1e-3 * torch.ones(1),\n        valid_radiusx: torch.Tensor = 99999.0 * torch.ones(1),\n        valid_radiusy: torch.Tensor = 99999.0 * torch.ones(1),\n        T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.\n        dist_params: Union[\n            torch.Tensor, DefaultCameraTWDistParam\n        ] = DEFAULT_CAM_DIST_PARAM,\n    ):\n        # Concatenate into one big data tensor, handles TensorWrapper objects.\n        data = smart_cat(\n            [\n                width,\n                height,\n                fx,\n                fy,\n                cx,\n                cy,\n                gain,\n                exposure_s,\n                valid_radiusx,\n                valid_radiusy,\n                T_camera_rig,\n                dist_params,\n            ],\n            dim=-1,\n        )\n        return cls(data)\n\n    @classmethod\n    @autoinit\n    def from_surreal(\n        cls,\n        width: torch.Tensor = -1 * torch.ones(1),\n        height: torch.Tensor = -1 * torch.ones(1),\n        type_str: str = \"Fisheye624\",\n        params: Union[torch.Tensor, DefaultCameraTWParam] = DEFAULT_CAM_PARAM,\n        gain: torch.Tensor = 1 * torch.ones(1),\n        exposure_s: torch.Tensor = 1e-3 * torch.ones(1),\n        valid_radius: torch.Tensor = 99999.0 * torch.ones(1),\n        T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.\n    ):\n        # Try to auto-determine the camera model.\n        if (\n            is_fisheye624(type_str) and params.shape[-1] == 16\n        ):  # Fisheye624 double focals\n            fx = params[..., 0].unsqueeze(-1)\n            fy = params[..., 1].unsqueeze(-1)\n            cx = params[..., 2].unsqueeze(-1)\n            cy = params[..., 3].unsqueeze(-1)\n            dist_params = params[..., 4:]\n        elif (\n            is_fisheye624(type_str) and params.shape[-1] == 15\n        ):  # Fisheye624 single focal\n            f = params[..., 0].unsqueeze(-1)\n            cx = params[..., 1].unsqueeze(-1)\n            cy = params[..., 2].unsqueeze(-1)\n            dist_params = params[..., 3:]\n            fx = fy = f\n        elif is_kb3(type_str) and params.shape[-1] == 8:  # KB3.\n            fx = params[..., 0].unsqueeze(-1)\n            fy = params[..., 1].unsqueeze(-1)\n            cx = params[..., 2].unsqueeze(-1)\n            cy = params[..., 3].unsqueeze(-1)\n            dist_params = params[..., 4:]\n        elif is_pinhole(type_str) and params.shape[-1] == 4:  # Pinhole.\n            fx = params[..., 0].unsqueeze(-1)\n            fy = params[..., 1].unsqueeze(-1)\n            cx = params[..., 2].unsqueeze(-1)\n            cy = params[..., 3].unsqueeze(-1)\n            dist_params = params[..., 4:]\n        else:\n            raise NotImplementedError(\n                \"Unknown number of params entered for camera model\"\n            )\n\n        if torch.any(torch.logical_or(valid_radius > height, valid_radius > width)):\n            if not is_pinhole(type_str):\n                # Try to auto-determine the valid radius for fisheye cameras.\n                default_radius = 99999.0\n                hw_ratio = height / width\n                eyevideo_camera_hw_ratio = torch.tensor(240.0 / 640.0).to(hw_ratio)\n                slam_camera_hw_ratio = torch.tensor(480.0 / 640.0).to(hw_ratio)\n                rgb_camera_hw_ratio = torch.tensor(2880.0 / 2880.0).to(hw_ratio)\n                guess_rgb = hw_ratio == rgb_camera_hw_ratio\n                guess_slam = hw_ratio == slam_camera_hw_ratio\n                guess_eyevideo = hw_ratio == eyevideo_camera_hw_ratio\n                valid_radius = default_radius * torch.ones_like(hw_ratio)\n                valid_radius = torch.where(\n                    guess_rgb, 1415 * (height / 2880), valid_radius\n                )\n                valid_radius = torch.where(\n                    guess_slam, 330 * (height / 480), valid_radius\n                )\n                # This is for Eye Video Camera\n                valid_radius = torch.where(\n                    guess_eyevideo, 330 * (height / 480), valid_radius\n                )\n                if torch.any(valid_radius == default_radius):\n                    raise ValueError(\n                        f\"Failed to auto-determine valid radius based on aspect ratios (valid_radius {valid_radius}, width {width}, height {height})\"\n                    )\n            else:\n                # Note that the valid_radius for pinhole camera is not well-defined.\n                # We heuristically set the valid radius to be the half of the image diagonal.\n                # Add one pixel to be sure that all pixels in the image are valid.\n                valid_radius = (\n                    torch.sqrt((width / 2.0) ** 2 + (height / 2.0) ** 2) + 1.0\n                )\n\n        return cls.from_parameters(\n            width=width,\n            height=height,\n            fx=fx,\n            fy=fy,\n            cx=cx,\n            cy=cy,\n            gain=gain,\n            exposure_s=exposure_s,\n            valid_radiusx=valid_radius,\n            valid_radiusy=valid_radius,\n            T_camera_rig=T_camera_rig,\n            dist_params=dist_params,\n        )\n\n    @property\n    def size(self) -> torch.Tensor:\n        \"\"\"Size (width height) of the images, with shape (..., 2).\"\"\"\n        return self._data[..., self.SIZE_IND]\n\n    @property\n    def f(self) -> torch.Tensor:\n        \"\"\"Focal lengths (fx, fy) with shape (..., 2).\"\"\"\n        return self._data[..., self.F_IND]\n\n    @property\n    def c(self) -> torch.Tensor:\n        \"\"\"Principal points (cx, cy) with shape (..., 2).\"\"\"\n        return self._data[..., self.C_IND]\n\n    @property\n    def K(self) -> torch.Tensor:\n        \"\"\"Intrinsic matrix with shape (..., 3, 3)\"\"\"\n        K = torch.eye(3, device=self.device, dtype=self.dtype)\n        # Make proper size of K to take care of B and T dims.\n        K_view = [1] * (self.f.ndim - 1) + [3, 3]\n        K_repeat = list(self.f.shape[:-1]) + [1, 1]\n        K = K.view(K_view)\n        K = K.repeat(K_repeat)\n        K[..., 0, 0] = self.f[..., 0]\n        K[..., 1, 1] = self.f[..., 1]\n        K[..., 0, 2] = self.c[..., 0]\n        K[..., 1, 2] = self.c[..., 1]\n        return K\n\n    @property\n    def K44(self) -> torch.Tensor:\n        \"\"\"Intrinsic matrix with shape (..., 4, 4)\"\"\"\n        K = torch.eye(4, device=self.device, dtype=self.dtype)\n        # Make proper size of K to take care of B and T dims.\n        K_view = [1] * (self.f.ndim - 1) + [4, 4]\n        K_repeat = list(self.f.shape[:-1]) + [1, 1]\n        K = K.view(K_view)\n        K = K.repeat(K_repeat)\n        K[..., 0, 0] = self.f[..., 0]\n        K[..., 1, 1] = self.f[..., 1]\n        K[..., 0, 2] = self.c[..., 0]\n        K[..., 1, 2] = self.c[..., 1]\n        return K\n\n    @property\n    def gain(self) -> torch.Tensor:\n        \"\"\"Gain of the camera, with shape (..., 1).\"\"\"\n        return self._data[..., self.GAIN_IND].unsqueeze(-1)\n\n    @property\n    def exposure_s(self) -> torch.Tensor:\n        \"\"\"Exposure of the camera in seconds, with shape (..., 1).\"\"\"\n        return self._data[..., self.EXPOSURE_S_IND].unsqueeze(-1)\n\n    @property\n    def valid_radius(self) -> torch.Tensor:\n        \"\"\"Radius from camera center for valid projections, with shape (..., 1).\"\"\"\n        return self._data[..., self.VALID_RADIUS_IND]\n\n    @property\n    def T_camera_rig(self) -> torch.Tensor:\n        \"\"\"Pose of camera, shape (..., 12).\"\"\"\n        return PoseTW(self._data[..., self.T_CAM_RIG_IND])\n\n    @property\n    def dist(self) -> torch.Tensor:\n        \"\"\"Distortion parameters, with shape (..., {0, D}), where D is number of distortion params.\"\"\"\n        return self._data[..., self.DIST_IND]\n\n    @property\n    def params(self) -> torch.Tensor:\n        \"\"\"Get the camera \"params\", which are defined as fx,fy,cx,cy,dist\"\"\"\n        return torch.cat([self.f, self.c, self.dist], dim=-1)\n\n    @property\n    def is_fisheye624(self):\n        return self.dist.shape[-1] == 12\n\n    @property\n    def is_kb3(self):\n        return self.dist.shape[-1] == 4\n\n    @property\n    def is_linear(self):\n        return self.dist.shape[-1] == 0\n\n    def set_valid_radius(self, valid_radius: torch.Tensor):\n        self._data[..., self.VALID_RADIUS_IND] = valid_radius\n\n    def set_T_camera_rig(self, T_camera_rig: PoseTW):\n        self._data[..., self.T_CAM_RIG_IND] = T_camera_rig._data.clone()\n\n    def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):\n        \"\"\"Update the camera parameters after resizing an image.\"\"\"\n        if isinstance(scales, (int, float)):\n            scales = (scales, scales)\n        s = self._data.new_tensor(scales)\n        data = torch.cat(\n            [\n                self.size * s,\n                self.f * s,\n                (self.c + 0.5) * s - 0.5,\n                self.gain,\n                self.exposure_s,\n                self.valid_radius * s,\n                self.T_camera_rig._data,\n                self.dist,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    def scale_to_size(self, size_wh: Union[int, Tuple[int]]):\n        \"\"\"Scale the camera parameters to a given image size\"\"\"\n        if torch.unique(self.size).numel() > 2:\n            raise ValueError(f\"cannot handle multiple sizes {self.size}\")\n        if isinstance(size_wh, int):\n            size_wh = (size_wh, size_wh)\n        i0w = tuple([0] * self.ndim)\n        i0h = tuple([0] * (self.ndim - 1) + [1])\n        scale = (\n            float(size_wh[0]) / float(self.size[i0w]),\n            float(size_wh[1]) / float(self.size[i0h]),\n        )\n        return self.scale(scale)\n\n    def scale_to(self, im: torch.Tensor):\n        \"\"\"\n        Scale the camera parameters to match the size of the given image assumes\n        ...xHxW image tensor convention of pytorch\n        \"\"\"\n        H, W = im.shape[-2:]\n        return self.scale_to_size((W, H))\n\n    def crop(self, left_top: Tuple[float], size: Tuple[int]):\n        \"\"\"Update the camera parameters after cropping an image.\"\"\"\n        left_top = self._data.new_tensor(left_top)\n        size = self._data.new_tensor(size)\n\n        # Expand the dimension if self._data is a tensor of CameraTW\n        if len(self._data.shape) > 1:\n            expand_dim = list(self._data.shape[:-1]) + [1]\n            size = size.repeat(expand_dim)\n            left_top = left_top.repeat(expand_dim)\n\n        data = torch.cat(\n            [\n                size,\n                self.f,\n                self.c - left_top,\n                self.gain,\n                self.exposure_s,\n                self.valid_radius,\n                self.T_camera_rig._data,\n                self.dist,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    @autocast\n    def in_image(self, p2d: torch.Tensor):\n        \"\"\"Check if 2D points are within the image boundaries.\"\"\"\n        assert p2d.shape[-1] == 2, f\"p2d shape needs to be 2d {p2d.shape}\"\n        # assert p2d.shape[:-2] == self.shape  # allow broadcasting\n        size = self.size.unsqueeze(-2)\n        valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), dim=-1)\n        return valid\n\n    @autocast\n    def in_radius(self, p2d: torch.Tensor):\n        \"\"\"Check if 2D points are within the valid fisheye radius region.\"\"\"\n        assert p2d.shape[-1] == 2, f\"p2d shape needs to be 2d {p2d.shape}\"\n        dists = torch.linalg.norm(\n            (p2d - self.c.unsqueeze(-2)) / self.valid_radius.unsqueeze(-2),\n            dim=-1,\n            ord=2,\n        )\n        valid = dists < 1.0\n        return valid\n\n    @autocast\n    def in_radius_mask(self):\n        \"\"\"\n        Return a mask that is True where 2D points are within the valid fisheye\n        radius region.  Returned mask is of shape ... x 1 x H x W, where ... is\n        the shape of the camera (BxT or B for example).\n        \"\"\"\n        s = self.shape[:-1]\n        C = self.shape[-1]\n        px = pixel_grid(self.view(-1, C)[0])\n        H, W, _ = px.shape\n        valids = self.in_radius(px.view(-1, 2))\n        s = s + (1, H, W)\n        valids = valids.view(s)\n        return valids\n\n    @autocast\n    def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:\n        \"\"\"Transform 3D points into 2D pixel coordinates.\"\"\"\n\n        # Explicitly promote the data types.\n        promoted_type = torch.promote_types(self._data.dtype, p3d.dtype)\n        self._data = self._data.to(promoted_type)\n        p3d = p3d.to(promoted_type)\n\n        # Try to auto-determine the camera model.\n        if self.is_fisheye624:  # Fisheye624.\n            params = torch.cat([self.f, self.c, self.dist], dim=-1)\n            if params.ndim == 1:\n                B = p3d.shape[0]\n                params = params.unsqueeze(0).repeat(B, 1)\n            p2d = fisheye624_project(p3d, params)\n        elif self.is_linear:  # Pinhole.\n            params = self.params\n            if params.ndim == 1:\n                B = p3d.shape[0]\n                params = params.unsqueeze(0).repeat(B, 1)\n            p2d = pinhole_project(p3d, params)\n        else:\n            raise ValueError(\n                \"only fisheye624 and pinhole implemented, kb3 not yet implemented\"\n            )\n\n        in_image = self.in_image(p2d)\n        in_radius = self.in_radius(p2d)\n        in_front = p3d[..., -1] > 0\n        valid = in_image & in_radius & in_front\n        return p2d, valid\n\n    @autocast\n    def unproject(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:\n        \"\"\"Transform 2D points into 3D rays.\"\"\"\n\n        # Explicitly promote the data types.\n        promoted_type = torch.promote_types(self._data.dtype, p2d.dtype)\n        self._data = self._data.to(promoted_type)\n        p2d = p2d.to(promoted_type)\n\n        # Try to auto-determine the camera model.\n        if self.is_fisheye624:  # Fisheye624.\n            params = torch.cat([self.f, self.c, self.dist], dim=-1)\n            if params.ndim == 1:\n                B = p2d.shape[0]\n                params = params.unsqueeze(0).repeat(B, 1)\n            rays = fisheye624_unproject(p2d, params)\n        elif self.is_linear:  # Pinhole.\n            params = self.params\n            if params.ndim == 1:\n                B = p2d.shape[0]\n                params = params.unsqueeze(0).repeat(B, 1)\n            rays = pinhole_unproject(p2d, params)\n        else:\n            raise ValueError(\n                \"only fisheye624 and pinhole implemented, kb3 not yet implemented\"\n            )\n\n        in_image = self.in_image(p2d)\n        in_radius = self.in_radius(p2d)\n        valid = in_image & in_radius\n        return rays, valid\n\n    def rotate_90_cw(self):\n        return self.rotate_90(clock_wise=True)\n\n    def rotate_90_ccw(self):\n        return self.rotate_90(clock_wise=False)\n\n    def rotate_90(self, clock_wise: bool):\n        dist_params = self.dist.clone()\n        if self.is_fisheye624:\n            # swap thin prism and tangential distortion parameters\n            # {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3} to\n            # {k_0 ... k_5} {p_1 p_0} {s_2 s_3 s_0 s_1}\n            dist_p = self.dist[..., 6:8]\n            dist_s = self.dist[..., 8:12]\n            dist_params[..., 6] = dist_p[..., 1]\n            dist_params[..., 7] = dist_p[..., 0]\n            dist_params[..., 8:10] = dist_s[..., 2:]\n            dist_params[..., 10:12] = dist_s[..., :2]\n        elif self.is_linear:\n            # no need to rotate distortion parameters since there are none\n            pass\n        elif self.is_kb3:\n            raise NotImplementedError(\"kb3 model rotation not implemented yet\")\n        else:\n            raise NotImplementedError(f\"camera model not recognized {self}\")\n\n        # clock-wise or counter clock-wise\n        DIR = 1 if clock_wise else -1\n        # rotate camera extrinsics by 90 degree CW\n        T_rot_z = PoseTW.from_matrix3x4(get_T_rot_z(DIR * np.pi * 0.5)).to(self.device)\n        if clock_wise:\n            # rotate x, y of principal point\n            # x_rotated = height - 1 - y_before\n            # y_rotated = x_before\n            rot_cx = self.size[..., 1] - self.c[..., 1] - 1\n            rot_cy = self.c[..., 0].clone()\n        else:\n            rot_cx = self.c[..., 1].clone()\n            rot_cy = self.size[..., 0] - self.c[..., 0] - 1\n\n        return CameraTW.from_parameters(\n            # swap width and height\n            self.size[..., 1].clone().unsqueeze(-1),\n            self.size[..., 0].clone().unsqueeze(-1),\n            # swap x, y of focal lengths\n            self.f[..., 1].clone().unsqueeze(-1),\n            self.f[..., 0].clone().unsqueeze(-1),\n            rot_cx.unsqueeze(-1),\n            rot_cy.unsqueeze(-1),\n            self.gain.clone(),\n            self.exposure_s.clone(),\n            # swap valid radius x, y\n            self.valid_radius[..., 1].clone().unsqueeze(-1),\n            self.valid_radius[..., 0].clone().unsqueeze(-1),\n            # rotate camera extrinsics\n            T_rot_z @ self.T_camera_rig,\n            dist_params,\n        )\n\n    def __repr__(self):\n        return f\"CameraTW {self.shape} {self.dtype} {self.device}\"\n\n\ndef grid_2d(\n    width: int,\n    height: int,\n    output_range=(-1.0, 1.0, -1.0, 1.0),\n    device=\"cpu\",\n    dtype=torch.float32,\n):\n    x = torch.linspace(\n        output_range[0], output_range[1], width + 1, device=device, dtype=dtype\n    )[:-1]\n    y = torch.linspace(\n        output_range[2], output_range[3], height + 1, device=device, dtype=dtype\n    )[:-1]\n    xx, yy = torch.meshgrid(x, y, indexing=\"xy\")\n    grid = torch.stack([xx, yy], dim=-1)\n    return grid\n\n\ndef pixel_grid(cam: CameraTW):\n    assert cam.ndim == 1, f\"Camera must be 1 dimensional {cam.shape}\"\n    W, H = int(cam.size[0]), int(cam.size[1])\n    return grid_2d(W, H, output_range=[0, W, 0, H], device=cam.device, dtype=cam.dtype)\n\n\ndef scale_image_to_cam(cams: CameraTW, ims: torch.Tensor) -> torch.Tensor:\n    \"\"\"Scale an image to a camera.\"\"\"\n\n    from torchvision.transforms import InterpolationMode, Resize\n\n    T = None\n    if ims.ndim == 5:\n        B, T, C, H, W = ims.shape\n        ims = ims.view(-1, C, H, W)\n        Wo, Ho = cams[0, 0].size.int().tolist()\n    elif ims.ndim == 4:\n        B, C, H, W = ims.shape\n        Wo, Ho = cams[0].size.int().tolist()\n    else:\n        raise ValueError(f\"unusable image shape {ims.shape}, {cams.shape}\")\n    ims = Resize((Ho, Wo), interpolation=InterpolationMode.BILINEAR, antialias=True)(\n        ims\n    )\n    if T is not None:\n        return ims.view(B, T, C, Ho, Wo)\n    return ims.view(B, C, Ho, Wo)\n"
  },
  {
    "path": "efm3d/aria/obb.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom typing import List, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .camera import CameraTW\nfrom .pose import IdentityPose, PAD_VAL, PoseTW, rotation_from_euler\nfrom .tensor_wrapper import autocast, autoinit, smart_cat, smart_stack, TensorWrapper\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\n# OBB corner numbering diagram for this implementation (the same as pytorch3d\n# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/ops/iou_box3d.py#L111)\n#\n# (4) +---------+. (5)\n#     | ` .     |  ` .\n#     | (0) +---+-----+ (1)\n#     |     |   |     |\n# (7) +-----+---+. (6)|\n#     ` .   |     ` . |\n#     (3) ` +---------+ (2)\n#\n# NOTE: Throughout this implementation, we assume that boxes\n# are defined by their 8 corners exactly in the order specified in the\n# diagram above for the function to give correct results. In addition\n# the vertices on each plane must be coplanar.\n# As an alternative to the diagram, this is a unit bounding\n# box which has the correct vertex ordering:\n# box_corner_vertices = [\n#     [0, 0, 0],  #   (0)\n#     [1, 0, 0],  #   (1)\n#     [1, 1, 0],  #   (2)\n#     [0, 1, 0],  #   (3)\n#     [0, 0, 1],  #   (4)\n#     [1, 0, 1],  #   (5)\n#     [1, 1, 1],  #   (6)\n#     [0, 1, 1],  #   (7)\n# ]\n\n# triangle indices to draw an OBB mesh from bb3corners_*\nOBB_MESH_TRI_INDS = [\n    [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],\n    [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],\n    [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],\n]\n\n# line indices to draw an OBB line strip frame from bb3corners_*\nOBB_LINE_INDS = [0, 1, 2, 3, 0, 3, 7, 4, 0, 1, 5, 6, 5, 4, 7, 6, 2, 1, 5]\n\n# corner indices to construct all edge lines\nBB3D_LINE_ORDERS = [\n    [0, 1],\n    [1, 2],\n    [2, 3],\n    [3, 0],\n    [4, 5],\n    [5, 6],\n    [6, 7],\n    [7, 4],\n    [0, 4],\n    [1, 5],\n    [2, 6],\n    [3, 7],\n]\n\n_box_planes = [\n    [0, 1, 2, 3],\n    [3, 2, 6, 7],\n    [0, 1, 5, 4],\n    [0, 3, 7, 4],\n    [1, 2, 6, 5],\n    [4, 5, 6, 7],\n]\n\nDOT_EPS = 1e-3\nAREA_EPS = 1e-4\n\n\nclass ObbTW(TensorWrapper):\n    \"\"\"\n    Oriented 3D Bounding Box observation in world coordinates (via\n    T_world_object) for Aria headsets.\n    \"\"\"\n\n    @autocast\n    @autoinit\n    def __init__(self, data: torch.Tensor = PAD_VAL * torch.ones((1, 34))):\n        assert isinstance(data, torch.Tensor)\n        assert data.shape[-1] == 34\n        super().__init__(data)\n\n    @classmethod\n    @autoinit\n    def from_lmc(\n        cls,\n        bb3_object: torch.Tensor = PAD_VAL * torch.ones(6),\n        bb2_rgb: torch.Tensor = PAD_VAL * torch.ones(4),\n        bb2_slaml: torch.Tensor = PAD_VAL * torch.ones(4),\n        bb2_slamr: torch.Tensor = PAD_VAL * torch.ones(4),\n        T_world_object: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.\n        sem_id: torch.Tensor = PAD_VAL * torch.ones(1),\n        inst_id: torch.Tensor = PAD_VAL * torch.ones(1),\n        prob: torch.Tensor = 1 * torch.ones(1),\n        moveable: torch.Tensor = 0 * torch.ones(1),\n    ):\n        # Concatenate into one big data tensor, handles TensorWrapper objects.\n        # make sure that its on the same device (fails if IdentityPose is used)\n        device = bb3_object.device\n        data = smart_cat(\n            [\n                bb3_object,\n                bb2_rgb.to(device),\n                bb2_slaml.to(device),\n                bb2_slamr.to(device),\n                T_world_object.to(device),\n                sem_id.to(device),\n                inst_id.to(device),\n                prob.to(device),\n                moveable.to(device),\n            ],\n            dim=-1,\n        )\n        return cls(data)\n\n    @property\n    def bb3_object(self) -> torch.Tensor:\n        \"\"\"3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6).\"\"\"\n        return self._data[..., :6]\n\n    @property\n    def bb3_min_object(self) -> torch.Tensor:\n        \"\"\"3D bounding box minimum corner [xmin,ymin,zmin] in object coord frame, with shape (..., 3).\"\"\"\n        return self._data[..., 0:6:2]\n\n    @property\n    def bb3_max_object(self) -> torch.Tensor:\n        \"\"\"3D bounding box maximum corner [xmax,ymax,zmax] in object coord frame, with shape (..., 3).\"\"\"\n        return self._data[..., 1:6:2]\n\n    @property\n    def bb3_center_object(self) -> torch.Tensor:\n        \"\"\"3D bounding box center in object coord frame, with shape (..., 3).\"\"\"\n        return 0.5 * (self.bb3_min_object + self.bb3_max_object)\n\n    @property\n    def bb3_center_world(self) -> torch.Tensor:\n        \"\"\"3D bounding box center in world coord frame, with shape (..., 3).\"\"\"\n        s = self.bb3_center_object.shape\n        _bb3_center_world = self.T_world_object.view(-1, 12).batch_transform(\n            self.bb3_center_object.view(-1, 3)\n        )\n        return _bb3_center_world.view(s)\n\n    @property\n    def bb3_diagonal(self) -> torch.Tensor:\n        \"\"\"3D bounding box diagonal, with shape (..., 3).\"\"\"\n        return self.bb3_max_object - self.bb3_min_object\n\n    @property\n    def bb3_volumes(self) -> torch.Tensor:\n        \"\"\"3D bounding box volumes, with shape (..., 1).\"\"\"\n        diags = self.bb3_diagonal\n        return diags.prod(dim=-1, keepdim=True)\n\n    @property\n    def bb2_rgb(self) -> torch.Tensor:\n        \"\"\"2D bounding box [xmin,xmax,ymin,ymax] as visible in RGB image, -1's if not visible, with shape (..., 4).\"\"\"\n        return self._data[..., 6:10]\n\n    def visible_bb3_ind(self, cam_id) -> torch.Tensor:\n        \"\"\"Indices of visible 3D bounding boxes in camera cam_id\"\"\"\n        bb2_cam = self.bb2(cam_id)\n        vis_ind = torch.all(bb2_cam > 0, dim=-1)\n        return vis_ind\n\n    @property\n    def bb2_slaml(self) -> torch.Tensor:\n        \"\"\"2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Left image, -1's if not visible, with shape (..., 4).\"\"\"\n        return self._data[..., 10:14]\n\n    @property\n    def bb2_slamr(self) -> torch.Tensor:\n        \"\"\"2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Right image, -1's if not visible, with shape (..., 4).\"\"\"\n        return self._data[..., 14:18]\n\n    def bb2(self, cam_id) -> torch.Tensor:\n        \"\"\"\n        2D bounding box [xmin,xmax,ymin,ymax] as visible in camera with given\n        cam_id, -1's if not visible, with shape (..., 4).\n        cam_id == 0 for rgb\n        cam_id == 1 for slam left\n        cam_id == 2 for slam right\n        \"\"\"\n        return self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4]\n\n    def set_bb2(self, cam_id, bb2d, use_mask=True):\n        \"\"\"\n        Set 2D bounding box [xmin,xmax,ymin,ymax] in camera with given\n        cam_id == 0 for rgb\n        cam_id == 1 for slam left\n        cam_id == 2 for slam right\n        \"\"\"\n        padding_mask = self.get_padding_mask()\n        self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4] = bb2d\n        if use_mask:\n            self._data[padding_mask] = PAD_VAL\n\n    def set_bb3_object(self, bb3_object, use_mask=True) -> torch.Tensor:\n        \"\"\"set 3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6).\"\"\"\n        padding_mask = self.get_padding_mask()\n        self._data[..., :6] = bb3_object\n        if use_mask:\n            self._data[padding_mask] = PAD_VAL\n\n    def set_prob(self, prob, use_mask=True):\n        \"\"\"Set probability score\"\"\"\n        padding_mask = self.get_padding_mask()\n        self._data[..., 32] = prob\n        if use_mask:\n            self._data[padding_mask] = PAD_VAL\n\n    @property\n    def T_world_object(self) -> torch.Tensor:\n        \"\"\"3D SE3 transform from object to world coords, with shape (..., 12).\"\"\"\n        return PoseTW(self._data[..., 18:30])\n\n    def get_padding_mask(self) -> torch.Tensor:\n        \"\"\"get boolean mask indicating which Obbs are valid/non-padded.\"\"\"\n        return (self._data == PAD_VAL).all(dim=-1, keepdim=False)\n\n    def set_T_world_object(self, T_world_object: PoseTW):\n        \"\"\"set 3D SE3 transform from object to world coords.\"\"\"\n        invalid_mask = self.get_padding_mask()\n        self._data[..., 18:30] = T_world_object._data\n        self._data[invalid_mask] = PAD_VAL\n\n    @property\n    def sem_id(self) -> torch.Tensor:\n        \"\"\"semantic id, with shape (..., 1).\"\"\"\n        return self._data[..., 30].unsqueeze(-1).int()\n\n    def set_sem_id(self, sem_id: torch.Tensor):\n        \"\"\"set semantic id to sem_id\"\"\"\n        self._data[..., 30] = sem_id.squeeze()\n\n    @property\n    def inst_id(self) -> torch.Tensor:\n        \"\"\"instance id, with shape (..., 1).\"\"\"\n        return self._data[..., 31].unsqueeze(-1).int()\n\n    def set_inst_id(self, inst_id: torch.Tensor):\n        \"\"\"set instance id to inst_id\"\"\"\n        self._data[..., 31] = inst_id.squeeze()\n\n    @property\n    def prob(self) -> torch.Tensor:\n        \"\"\"probability of detection, with shape (..., 1).\"\"\"\n        return self._data[..., 32].unsqueeze(-1)\n\n    @property\n    def moveable(self) -> torch.Tensor:\n        \"\"\"boolean if moveable, with shape (..., 1).\"\"\"\n        return self._data[..., 33].unsqueeze(-1)\n\n    @property\n    def bb3corners_world(self) -> torch.Tensor:\n        return self.T_world_object * self.bb3corners_object\n\n    @property\n    def bb3corners_object(self) -> torch.Tensor:\n        \"\"\"return the 8 corners of the 3D BB in object coord frame (..., 8, 3).\"\"\"\n        ids = [0, 2, 4, 1, 2, 4, 1, 3, 4, 0, 3, 4, 0, 2, 5, 1, 2, 5, 1, 3, 5, 0, 3, 5]\n        b3o = self.bb3_object\n        c3o = b3o[..., ids]\n        c3o = c3o.reshape(*c3o.shape[:-1], 8, 3)\n        return c3o\n\n    def bb3edge_pts_object(self, num_samples_per_edge: int = 10) -> torch.Tensor:\n        \"\"\"\n        return the num_samples_per_edge points per 3D BB edge in object coord\n        frame (..., num_samples_per_edge * 12, 3).\n\n        num_samples_per_edge == 1 will result in a list of corners (with some duplicates)\n        num_samples_per_edge == 2 will result in a list of corners (with some more duplicates)\n        num_samples_per_edge == 3 will result in a list of corners and edge midpoints\n        ...\n        \"\"\"\n        bb3corners = self.bb3corners_object\n        shape = bb3corners.shape\n        alphas = torch.linspace(0, 1, num_samples_per_edge, device=bb3corners.device)\n        alphas = alphas.view([1] * len(shape[:-2]) + [num_samples_per_edge, 1])\n        alphas = alphas.repeat(list(shape[:-2]) + [1, 3])\n        betas = torch.ones_like(alphas) - alphas\n        bb3edge_pts = []\n        for edge_ids in BB3D_LINE_ORDERS:\n            bb3edge_pts.append(\n                bb3corners[..., edge_ids[0], :].unsqueeze(-2) * betas\n                + bb3corners[..., edge_ids[1], :].unsqueeze(-2) * alphas\n            )\n        return torch.cat(bb3edge_pts, dim=-2)\n\n    def center(self):\n        \"\"\"\n        Returns a ObbTW object where the 3D OBBs are centered in their local coordinate system.\n        I.e. bb3_min_object == - bb3_max_object.\n        \"\"\"\n\n        T_wo = self.T_world_object\n        center_o = self.bb3_center_object\n        # compute centered bb3_object and obb pose T_world_object\n        centered_T_wo = PoseTW.from_Rt(T_wo.R, T_wo.batch_transform(center_o))\n        centered_bb3_min_o = self.bb3_min_object - center_o\n        centered_bb3_max_o = self.bb3_max_object - center_o\n        centered_bb3_o = torch.stack(\n            [\n                centered_bb3_min_o[..., 0],\n                centered_bb3_max_o[..., 0],\n                centered_bb3_min_o[..., 1],\n                centered_bb3_max_o[..., 1],\n                centered_bb3_min_o[..., 2],\n                centered_bb3_max_o[..., 2],\n            ],\n            dim=-1,\n        )\n        return ObbTW.from_lmc(\n            bb3_object=centered_bb3_o,\n            bb2_rgb=self.bb2_rgb,\n            bb2_slaml=self.bb2_slaml,\n            bb2_slamr=self.bb2_slamr,\n            T_world_object=centered_T_wo,\n            sem_id=self.sem_id,\n            inst_id=self.inst_id,\n            prob=self.prob,\n            moveable=self.moveable,\n        )\n\n    def add_padding(self, max_elts: int = 1000) -> \"ObbTW\":\n        \"\"\"\n        Adds padding to Obbs, useful for returning batches with a varying number\n        of Obbs. E.g. if in one batch we have 4 Obbs and another one we have 2,\n        setting max_elts=4 will add 2 pads (consisting of all -1s) to the second\n        element in the batch.\n        \"\"\"\n        assert self._data.ndim <= 2, \"higher than order 2 add_padding not supported yet\"\n        elts = self._data\n        num_to_pad = max_elts - len(elts)\n        # All -1's denotes a pad element.\n        pad_elt = PAD_VAL * self._data.new_ones(self._data.shape[-1])\n        if num_to_pad > 0:\n            rep_elts = torch.stack([pad_elt for _ in range(num_to_pad)], dim=0)\n            elts = torch.cat([elts, rep_elts], dim=0)\n        elif num_to_pad < 0:\n            elts = elts[:max_elts]\n            logger.warning(\n                f\"Warning: some obbs have been clipped (actual/max {len(elts)}/{max_elts}) in ObbTW.add_padding()\"\n            )\n        return self.__class__(elts)\n\n    def remove_padding(self) -> List[\"ObbTW\"]:\n        \"\"\"\n        Removes any padding by finding Obbs with all -1s. Returns a list.\n        \"\"\"\n        assert self.ndim <= 4, \"higher than order 4 remove_padding not supported yet\"\n\n        if self.ndim == 1:\n            return self  # Nothing to be done in this case.\n\n        # All -1's denotes a pad element.\n        pad_elt = (PAD_VAL * self._data.new_ones(self._data.shape[-1])).unsqueeze(-2)\n        is_not_pad = ~torch.all(self._data == pad_elt, dim=-1)\n\n        if self.ndim == 2:\n            new_data = self.__class__(self._data[is_not_pad])\n        elif self.ndim == 3:\n            B = self._data.shape[0]\n            new_data = []\n            for b in range(B):\n                new_data.append(self.__class__(self._data[b][is_not_pad[b]]))\n        else:  # self.ndim == 4:\n            B, T = self._data.shape[:2]\n            new_data = []\n            for b in range(B):\n                new_data.append([])\n                for t in range(T):\n                    new_data[-1].append(\n                        self.__class__(self._data[b, t][is_not_pad[b, t]])\n                    )\n        return new_data\n\n    def _mark_invalid(self, invalid_mask: torch.Tensor) -> \"ObbTW\":\n        \"\"\"\n        in place mark obbs in this ObbTW as invalid via mask\n        \"\"\"\n        assert invalid_mask.ndim == self.ndim - 1, \"invalid_mask must match ObbTW\"\n        assert invalid_mask.shape[:-1] == self.shape[:-1], (\n            \"invalid_mask must match ObbTW\"\n        )\n        self._data[invalid_mask] = PAD_VAL\n\n    def _mark_invalid_ids(self, invalid_ids: torch.Tensor) -> \"ObbTW\":\n        \"\"\"\n        in place mark obbs in this ObbTW as invalid via mask\n        \"\"\"\n        assert self.ndim == 2, \"invalid_ids only supported for 2d ObbTW\"\n        assert invalid_ids.ndim == 1, \"invalid_ids must be 1d\"\n        assert invalid_ids.dtype == torch.int64, \"invalid_ids must be int64\"\n        self._data[invalid_ids] = PAD_VAL\n\n    def num_valid(self) -> int:\n        \"\"\"\n        Returns the number of valid Obbs in this collection.\n        \"\"\"\n        if self.ndim == 1:\n            is_pad = torch.all(self._data == PAD_VAL, dim=-1)\n            return 0 if is_pad.item() else 1\n        elif self.ndim == 2:\n            is_pad = torch.all(self._data == PAD_VAL, dim=-1)\n            return self.shape[0] - is_pad.sum()\n        elif self.ndim == 3:\n            is_pad = torch.all(self._data == PAD_VAL, dim=-1)\n            return self.shape[0] * self.shape[1] - is_pad.sum()\n        elif self.ndim == 4:\n            is_pad = torch.all(self._data == PAD_VAL, dim=-1)\n            return self.shape[0] * self.shape[1] * self.shape[2] - is_pad.sum()\n        else:\n            raise NotImplementedError(f\"{self.shape}\")\n\n    def scale_bb2(self, scale_rgb: float, scale_slam: float):\n        \"\"\"Update the 2d bb parameters after resizing the underlying images.\n        All 2d bbs are scaled by the same scale specified for the frame of the\n        2d bb (RGB vs SLAM).\"\"\"\n\n        # Check for padded values and leave those unchanged.\n        pad_rgb = (\n            torch.all(self.bb2_rgb == PAD_VAL, dim=-1)\n            .unsqueeze(-1)\n            .expand(*self.bb2_rgb.shape)\n        )\n        pad_slamr = (\n            torch.all(self.bb2_slamr == PAD_VAL, dim=-1)\n            .unsqueeze(-1)\n            .expand(*self.bb2_slamr.shape)\n        )\n        pad_slaml = (\n            torch.all(self.bb2_slaml == PAD_VAL, dim=-1)\n            .unsqueeze(-1)\n            .expand(*self.bb2_slaml.shape)\n        )\n        sc_rgb = scale_rgb * torch.ones_like(self.bb2_rgb)\n        sc_slamr = scale_slam * torch.ones_like(self.bb2_slamr)\n        sc_slaml = scale_slam * torch.ones_like(self.bb2_slaml)\n        # If False, multiply by scale, if True multiply by 1.\n        sc_rgb = torch.where(pad_rgb, torch.ones_like(sc_rgb), sc_rgb)\n        sc_slamr = torch.where(pad_slamr, torch.ones_like(sc_slamr), sc_slamr)\n        sc_slaml = torch.where(pad_slaml, torch.ones_like(sc_slaml), sc_slaml)\n\n        data = smart_cat(\n            [\n                self.bb3_object,\n                self.bb2_rgb * sc_rgb,\n                self.bb2_slaml * sc_slaml,\n                self.bb2_slamr * sc_slamr,\n                self.T_world_object,\n                self.sem_id,\n                self.inst_id,\n                self.prob,\n                self.moveable,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    def crop_bb2(self, left_top_rgb: Tuple[float], left_top_slam: Tuple[float]):\n        \"\"\"Update the 2d bb parameters after cropping the underlying images.\n        All 2d bbs are cropped by the same crop specified for the frame of the\n        2d bb (RGB vs SLAM).\n        left_top_* is assumed to be a 2D tuple of the left top corner of te crop.\n        \"\"\"\n        # accumulate 2d bb formatting of (xmin, xmax, ymin, ymax)\n        left_top_rgb = self._data.new_tensor(\n            (left_top_rgb[0], left_top_rgb[0], left_top_rgb[1], left_top_rgb[1])\n        )\n        left_top_slam = self._data.new_tensor(\n            (left_top_slam[0], left_top_slam[0], left_top_slam[1], left_top_slam[1])\n        )\n\n        # Expand the dimension if self._data is a tensor of CameraTW\n        if len(self._data.shape) > 1:\n            expand_dim = list(self._data.shape[:-1]) + [1]\n            left_top_rgb = left_top_rgb.repeat(expand_dim)\n            left_top_slam = left_top_slam.repeat(expand_dim)\n\n        data = smart_cat(\n            [\n                self.bb3_object,\n                self.bb2_rgb - left_top_rgb,\n                self.bb2_slaml - left_top_slam,\n                self.bb2_slamr - left_top_slam,\n                self.T_world_object,\n                self.sem_id,\n                self.inst_id,\n                self.prob,\n                self.moveable,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    def rotate_bb2_cw(self, image_sizes: List[Tuple[int]]):\n        \"\"\"Update the 2d bb parameters after rotating the underlying images.\n        Args:\n          image_sizes: List of original image sizes before the rotation.\n                       The order of the images sizes should be [(w_rgb, h_rgb), (w_slaml, h_slaml), (w_slamr, h_slamr)].\n        \"\"\"\n        ## Early check the input input sizes\n        assert len(image_sizes) == 3, (\n            f\"the image sizes of 3 video stream should be given, but only got {len(image_sizes)}\"\n        )\n        for s in image_sizes:\n            assert len(s) == 2\n\n        # rotate the obbs stream by stream\n        bb2_rgb_cw = rot_obb2_cw(self.bb2_rgb.clone(), image_sizes[0])\n        bb2_slaml_cw = rot_obb2_cw(self.bb2_slaml.clone(), image_sizes[1])\n        bb2_slamr_cw = rot_obb2_cw(self.bb2_slamr.clone(), image_sizes[2])\n\n        data = smart_cat(\n            [\n                self.bb3_object,\n                bb2_rgb_cw,\n                bb2_slaml_cw,\n                bb2_slamr_cw,\n                self.T_world_object,\n                self.sem_id,\n                self.inst_id,\n                self.prob,\n                self.moveable,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    def rectify_obb2(self, fisheye_cams: List[CameraTW], pinhole_cams: List[CameraTW]):\n        rect_bb2s = []\n        for idx, (fisheye_cam, pinhole_cam) in enumerate(\n            zip(fisheye_cams, pinhole_cams)\n        ):\n            if idx == 0:  # rgb\n                bb2 = self.bb2_rgb\n            elif idx == 1:  # slaml\n                bb2 = self.bb2_slaml\n            else:  # slamr\n                bb2 = self.bb2_slamr\n\n            tl_points = bb2[..., [0, 2]].clone()  # top-left\n            bl_points = bb2[..., [0, 3]].clone()  # bottom-left\n            br_points = bb2[..., [1, 3]].clone()  # bottom-right\n            tr_points = bb2[..., [1, 2]].clone()  # top-right\n            visible_points = self.visible_bb3_ind(idx)\n\n            tl_rays, _ = fisheye_cam.unproject(tl_points)\n            br_rays, _ = fisheye_cam.unproject(br_points)\n            bl_rays, _ = fisheye_cam.unproject(bl_points)\n            tr_rays, _ = fisheye_cam.unproject(tr_points)\n\n            rect_tl_pts, valid = pinhole_cam.project(tl_rays)\n            rect_br_pts, valid = pinhole_cam.project(br_rays)\n            rect_tl_pts, valid = pinhole_cam.project(bl_rays)\n            rect_tr_pts, valid = pinhole_cam.project(tr_rays)\n            rect_concat = torch.cat(\n                [rect_tl_pts, rect_br_pts, rect_tl_pts, rect_tr_pts], dim=-1\n            )\n            xmin, _ = torch.min(rect_concat[..., 0::2], dim=-1, keepdim=True)\n            xmax, _ = torch.max(rect_concat[..., 0::2], dim=-1, keepdim=True)\n            ymin, _ = torch.min(rect_concat[..., 1::2], dim=-1, keepdim=True)\n            ymax, _ = torch.max(rect_concat[..., 1::2], dim=-1, keepdim=True)\n\n            # trim\n            width = pinhole_cam.size.reshape(-1, 2)[0][0]\n            height = pinhole_cam.size.reshape(-1, 2)[0][1]\n            xmin = torch.clamp(xmin, min=0, max=width - 1)\n            xmax = torch.clamp(xmax, min=0, max=width - 1)\n            ymin = torch.clamp(ymin, min=0, max=height - 1)\n            ymax = torch.clamp(ymax, min=0, max=height - 1)\n\n            rect_bb2 = torch.cat([xmin, xmax, ymin, ymax], dim=-1)\n\n            # remove the ones without any area\n            areas = (rect_bb2[..., 1] - rect_bb2[..., 0]) * (\n                rect_bb2[..., 3] - rect_bb2[..., 2]\n            )\n            areas = areas.unsqueeze(-1)\n            areas = areas.repeat(*([1] * (areas.ndim - 1)), 4)\n            rect_bb2[areas <= 0] = PAD_VAL\n            rect_bb2[~visible_points] = PAD_VAL\n            rect_bb2s.append(rect_bb2)\n\n        data = smart_cat(\n            [\n                self.bb3_object,\n                rect_bb2s[0],\n                rect_bb2s[1],\n                rect_bb2s[2],\n                self.T_world_object,\n                self.sem_id,\n                self.inst_id,\n                self.prob,\n                self.moveable,\n            ],\n            dim=-1,\n        )\n        return self.__class__(data)\n\n    def get_pseudo_bb2(\n        self,\n        cam: CameraTW,\n        T_world_rig: PoseTW,\n        num_samples_per_edge: int = 1,\n        return_frac_valids: bool = False,\n    ):\n        \"\"\"\n        get the 2d bbs of the projection of the 3d bbs into all given camera view points.\n        This is done by sampling points on the 3d bb edges (see\n        bb3edge_pts_object), projecting them and then computing the 2d bbs from\n        the valid projected points. The caller has to make sure the ObbTW has valid\n        3d bbs data\n\n        num_samples_per_edge == 1 and num_samples_per_edge == 2 are equivalent\n        (in both cases we project the obb corners into the frames to compute 2d bbs)\n        \"\"\"\n        assert self._data.shape[-2] > 0, \"No valid 3d bbs data found!\"\n        return bb2d_from_project_bb3d(\n            self, cam, T_world_rig, num_samples_per_edge, return_frac_valids\n        )\n\n    def get_bb2_heights(self, cam_id):\n        bb2s = self.bb2(cam_id)\n        valid_bb2s = self.visible_bb3_ind(cam_id)\n        heights = bb2s[..., 3] - bb2s[..., 2]\n        heights[~valid_bb2s] = -1\n        return heights\n\n    def get_bb2_widths(self, cam_id):\n        bb2s = self.bb2(cam_id)\n        valid_bb2s = self.visible_bb3_ind(cam_id)\n        widths = bb2s[..., 1] - bb2s[..., 0]\n        widths[~valid_bb2s] = -1\n        return widths\n\n    def get_bb2_areas(self, cam_id):\n        bb2s = self.bb2(cam_id)\n        valid_bb2s = self.visible_bb3_ind(cam_id)\n        areas = (bb2s[..., 1] - bb2s[..., 0]) * (bb2s[..., 3] - bb2s[..., 2])\n        areas[~valid_bb2s] = -1\n        return areas\n\n    def get_bb2_centers(self, cam_id):\n        bb2s = self.bb2(cam_id)\n        valid_bb2s = self.visible_bb3_ind(cam_id)\n        center_x = (bb2s[..., 0:1] + bb2s[..., 1:2]) / 2.0\n        center_y = (bb2s[..., 2:3] + bb2s[..., 3:4]) / 2.0\n        center_2d = torch.cat([center_x, center_y], -1)\n        center_2d[~valid_bb2s] = -1\n        return center_2d\n\n    def batch_points_inside_bb3(self, pts_world: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        checks if a set of points is inside the 3d bounding box\n        expected input shape is N x 3 where N is the number of points and the\n        number of obbs in self.\n        \"\"\"\n        assert pts_world.shape == self.T_world_object.t.shape\n        pts_object = self.T_world_object.inverse().batch_transform(pts_world)\n        inside_min = (pts_object > self.bb3_min_object).all(-1)\n        inside_max = (pts_object < self.bb3_max_object).all(-1)\n        return torch.logical_and(inside_min, inside_max)\n\n    def points_inside_bb3(\n        self, pts_world: torch.Tensor, scale_obb: float = 1.0\n    ) -> torch.Tensor:\n        \"\"\"\n        checks if a set of points is inside the 3d bounding box\n        \"\"\"\n        assert self.ndim == 1 and pts_world.ndim == 2\n        pts_object = self.T_world_object.inverse().transform(pts_world)\n        inside_min = (pts_object > self.bb3_min_object * scale_obb).all(-1)\n        inside_max = (pts_object < self.bb3_max_object * scale_obb).all(-1)\n        return torch.logical_and(inside_min, inside_max)\n\n    def _transform(self, T_new_world):\n        \"\"\"\n        in place transform T_world_object as T_new_object = T_new_world @ T_world_object\n        \"\"\"\n        T_world_object = self.T_world_object\n        T_new_object = T_new_world @ T_world_object\n        self.set_T_world_object(T_new_object)\n\n    def transform(self, T_new_world):\n        \"\"\"\n        transform T_world_object as T_new_object = T_new_world @ T_world_object\n        \"\"\"\n        obb_new = self.clone()\n        obb_new._transform(T_new_world)\n        return obb_new\n\n    def _transform_object(self, T_object_new):\n        \"\"\"\n        in place transform T_world_object as T_world_new = T_world_object @ T_object_new\n        \"\"\"\n        T_world_object = self.T_world_object\n        T_world_new = T_world_object @ T_object_new\n        self.set_T_world_object(T_world_new)\n\n    def filter_by_sem_id(self, keep_sem_ids):\n        valid = self._data.new_zeros(self.shape[:-1]).bool()\n        for si in keep_sem_ids:\n            valid = valid | (self.sem_id == si)[..., 0]\n        self._data[~valid] = PAD_VAL\n        return self\n\n    def filter_by_prob(self, prob_thr: float):\n        # since PAD_VAL is -1 this will work fine with padded entries\n        invalid = self.prob.squeeze(-1) < prob_thr\n        self._data[invalid] = PAD_VAL\n        return self\n\n    def filter_bb2_center_by_radius(self, calib, cam_id):\n        \"\"\"\n        Inputs\n            calib: CameraTW : shaped ... x 34, matching leading dims with self\n            cam_id : int : integer corresponding to which bb2ds to use (0: rgb, 1: slaml, 2: slamr)\n        \"\"\"\n        # Remove detections centers outside of valid_radius.\n        centers = self.get_bb2_centers(cam_id)\n        inside = calib.in_radius(centers)\n        self._data[~inside, :] = PAD_VAL\n        return self\n\n    def voxel_grid(self, vD: int, vH: int, vW: int):\n        \"\"\"\n        Input: Works on obbs shaped (B) x 34\n        Output: world points sampled uniformly in a voxel grid (B) x vW*vH*vD x 3\n        \"\"\"\n        x_min, x_max, y_min, y_max, z_min, z_max = self.bb3_object.unbind(-1)\n        dW = (x_max - x_min) / vW\n        dH = (y_max - y_min) / vH\n        dD = (z_max - z_min) / vD\n        # take the center position of each voxel\n        rng_x = tensor_linspace(\n            x_min + dW / 2, x_max - dW / 2, steps=vW, device=self.device\n        )\n        rng_y = tensor_linspace(\n            y_min + dH / 2, y_max - dH / 2, steps=vH, device=self.device\n        )\n        rng_z = tensor_linspace(\n            z_min + dD / 2, z_max - dD / 2, steps=vD, device=self.device\n        )\n        if self.ndim > 1:\n            if self.ndim > 2:\n                raise NotImplementedError\n            B = self.shape[0]\n            xs, ys, zs = [], [], []\n            for b in range(B):\n                xx, yy, zz = torch.meshgrid(rng_x[b], rng_y[b], rng_z[b], indexing=\"ij\")\n                xs.append(xx)\n                ys.append(yy)\n                zs.append(zz)\n            xx = torch.stack(xs)\n            yy = torch.stack(ys)\n            zz = torch.stack(zs)\n        else:\n            xx, yy, zz = torch.meshgrid(rng_x, rng_y, rng_z, indexing=\"ij\")\n        vox_v = torch.stack([xx, yy, zz], axis=-1)\n        vox_v = vox_v.reshape(B, -1, 3)\n        # vox_v = vox_v.unsqueeze(0).repeat(B, 1, 1)\n        T_wv = self.T_world_object\n        vox_w = T_wv * vox_v\n        return vox_w\n\n    def __repr__(self):\n        return f\"ObbTW {self.shape} {self.dtype} {self.device}\"\n\n\ndef _single_transform_obbs(obbs_padded, Ts_other_world):\n    assert obbs_padded.ndim == 3  # T x N x C\n    assert Ts_other_world.ndim == 2 and Ts_other_world.shape[0] == 1  # 1 x C\n    T, N, C = obbs_padded.shape\n    if T == 0:\n        # Directly return the input since T=0 and there are no obbs to transform.\n        return obbs_padded\n    obbs_transformed = []\n    for t in range(T):\n        # clone so that we get a new transformed obbs object.\n        obbs = obbs_padded[t, ...].remove_padding().clone()\n        obbs._transform(Ts_other_world)\n        obbs_transformed.append(obbs.add_padding(N))\n    obbs_transformed = ObbTW(smart_stack(obbs_transformed))\n    return obbs_transformed\n\n\ndef _batched_transform_obbs(obbs_padded, Ts_other_world):\n    assert obbs_padded.ndim == 4  # B x T x N x C\n    assert Ts_other_world.ndim == 3  # T x 1 x C\n    B, T, N, C = obbs_padded.shape\n    obbs_transformed = []\n    for b in range(B):\n        obbs_transformed.append(\n            _single_transform_obbs(obbs_padded[b], Ts_other_world[b])\n        )\n    obbs_transformed = ObbTW(smart_stack(obbs_transformed))\n    return obbs_transformed\n\n\ndef transform_obbs(obbs_padded, Ts_other_world):\n    \"\"\"\n    transform padded obbs from the world coordinate system to a \"other\"\n    coordinate system.\n    \"\"\"\n    if obbs_padded.ndim == 4:\n        return _batched_transform_obbs(obbs_padded, Ts_other_world)\n    return _single_transform_obbs(obbs_padded, Ts_other_world)\n\n\ndef rot_obb2_cw(bb2: torch.Tensor, size: Tuple[int]):\n    bb2_ori = bb2.clone()\n    # exchange (xmin, xmax, ymin, ymax) -> (ymax, ymin, xmin, xmax)\n    bb2 = bb2[..., [3, 2, 0, 1]]\n    # x_new = height - x_new\n    bb2[..., 0:2] = size[1] - bb2[..., 0:2] - 1\n    # bring back the invalid entries.\n    bb2[bb2_ori < 0] = bb2_ori[bb2_ori < 0]\n    return bb2\n\n\ndef project_bb3d_onto_image(\n    obbs: ObbTW, cam: CameraTW, T_world_rig: PoseTW, num_samples_per_edge: int = 1\n):\n    \"\"\"\n    project 3d bb edge points into snippet images defined by T_world_rig and\n    camera cam. The assumption is that obbs are in the \"world\" coordinate system\n    of T_world_rig.\n    Supports batched operation.\n\n    Args:\n        obbs (ObbTW): obbs to project; shape is (Bx)(Tx)Nx34\n        cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension;\n        T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12\n        num_samples_per_edge (int): how many points to sample per edge to\n            compute 2d bb (1, and 2 means only corners)\n    Returns:\n        bb3_corners_im (Tensor): bb3 corners in the image coordinate system; shape is (Bx)TxNx8x2\n        bb3_valids (Tensor): valid indices of bb3_corners_im (indicates which\n            corners lie within the images); shape is (Bx)TxNx8\n    \"\"\"\n    obb_dim = obbs.dim()\n    # support 3 sets of input shapes\n    if obb_dim == 2:  # Nx34\n        # cam: TxC, T_world_rig: Tx12\n        assert (\n            cam.dim() == 2\n            and T_world_rig.dim() == 2\n            and cam.shape[0]\n            == T_world_rig.shape[0]  # T dim should be the same for cam and T_world_rig\n        ), (\n            f\"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}.\"\n        )\n\n        # To the consistent shapes\n        obbs = obbs.unsqueeze(0).unsqueeze(0)  # expand to B(1)xT(1)xNx34\n        cam = cam[None, ...]  # expand to B(1)xTxC\n        T_world_rig = T_world_rig[None, ...]  # expand to B(1)xTx12\n        B, T = cam.shape[0:2]\n        N = obbs.shape[-2]\n        obbs = obbs.expand(B, T, *obbs.shape[-2:])  # repeat to real T: B(1)xTxNx34\n\n    elif obb_dim == 3:  # BxNx34\n        # cam: BxTxC, T_world_rig: BxTx12\n        assert cam.dim() == 3 and T_world_rig.dim() == 3\n        # B dim should be the same\n        assert obbs.shape[0] == cam.shape[0] and obbs.shape[0] == T_world_rig.shape[0]\n        # T dim of cam and pose should be the same\n        assert cam.shape[1] == T_world_rig.shape[1]\n\n        # To the consistent shapes\n        obbs = obbs.unsqueeze(1)  # expand to BxT(1)xNx34\n        B, T = cam.shape[0:2]\n        obbs = obbs.expand(B, T, *obbs.shape[-2:])\n\n    elif obb_dim == 4:  # BxTxNx34\n        pass\n    else:\n        raise ValueError(\n            f\"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}.\"\n        )\n\n    # check if all tensors are of correct shapes.\n    assert obbs.dim() == 4 and cam.dim() == 3 and T_world_rig.dim() == 3, (\n        f\"The shapes of obbs, cam and T_world_rig should be BxTxNx34, BxTxC, and BxTx12, respectively. However, we got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}\"\n    )\n    assert (\n        obbs.shape[0:2] == cam.shape[0:2] and obbs.shape[0:2] == T_world_rig.shape[0:2]\n    ), (\n        f\"The BxT dims should be the same for all tensors, but got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}\"\n    )\n\n    B, T = cam.shape[0:2]\n    N = obbs.shape[-2]\n    assert N > 0, \"obbs have to exist for this frame\"\n    # Get pose of camera.\n    T_world_cam = T_world_rig @ cam.T_camera_rig.inverse()\n    # Project the 3D BB corners into the image.\n    # BxTxNx8x3 -> BxTxN*8x3\n    if num_samples_per_edge <= 2:\n        bb3pts_world = obbs.bb3corners_world.view(B, T, -1, 3)\n    else:\n        bb3pts_object = obbs.bb3edge_pts_object(num_samples_per_edge)\n        bb3pts_world = obbs.T_world_object * bb3pts_object\n        bb3pts_world = bb3pts_world.view(B, T, -1, 3)\n    Npt = bb3pts_world.shape[2]\n    T_world_cam = T_world_cam.unsqueeze(2).repeat(1, 1, Npt, 1)\n    bb3pts_cam = (\n        T_world_cam.inverse()\n        .view(-1, 12)\n        .batch_transform(bb3pts_world.view(-1, 3))\n        .view(B, T, -1, 3)\n    )\n    bb3pts_im, bb3pts_valids = cam.project(bb3pts_cam)\n    bb3pts_im = bb3pts_im.view(B, T, N, -1, 2)\n    bb3pts_valids = bb3pts_valids.detach().view(B, T, N, -1)\n\n    if obb_dim == 2:\n        # remove B dim if it didn't exist before.\n        bb3pts_im = bb3pts_im.squeeze(0)\n        bb3pts_valids = bb3pts_valids.squeeze(0)\n    return bb3pts_im, bb3pts_valids\n\n\ndef bb2d_from_project_bb3d(\n    obbs: ObbTW,\n    cam: CameraTW,\n    T_world_rig: PoseTW,\n    num_samples_per_edge: int = 1,\n    return_frac_valids: bool = False,\n):\n    \"\"\"\n    get 2d bbs around the 3d bb corners of obbs projected into the image coordinate system\n    defined by T_world_rig and camera cam. The assumption is that obbs are in the\n    \"world\" coordinate system of T_world_rig.\n\n    This is done by sampling points on the 3d bb edges (see bb3edge_pts_object),\n    projecting them and then computing the 2d bbs from the valid projected\n    points.\n\n    Supports batched operation.\n\n    Args:\n        obbs (ObbTW): obbs to project; shape is (Bx)Nx34\n        cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension;\n        T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12\n    Returns:\n        bb2s (Tensor): 2d bounding boxes in the image coordinate system; shape is (Bx)TxNx4\n        bb2s_valid (Tensor): valid indices of bb2s; shape is (Bx)TxN\n    \"\"\"\n    from torchvision.ops.boxes import box_iou\n\n    bb3corners_im, bb3corners_valids = project_bb3d_onto_image(\n        obbs, cam, T_world_rig, num_samples_per_edge\n    )\n    # get image points that will min and max reduce correctly given the valid masks\n    bb3corners_im_min = torch.where(\n        bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im),\n        bb3corners_im,\n        999999 * torch.ones_like(bb3corners_im),\n    )\n    bb3corners_im_max = torch.where(\n        bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im),\n        bb3corners_im,\n        -999999 * torch.ones_like(bb3corners_im),\n    )\n    # compute 2d bounding boxes\n    bb2s_min = torch.min(bb3corners_im_min, dim=-2)[0]\n    bb2s_max = torch.max(bb3corners_im_max, dim=-2)[0]\n    bb2s = torch.stack(\n        [bb2s_min[..., 0], bb2s_max[..., 0], bb2s_min[..., 1], bb2s_max[..., 1]], dim=-1\n    )\n    # min < max so that it's a valid box.\n    non_empty_boxes = (bb2s[..., 0] < bb2s[..., 1]) & (bb2s[..., 2] < bb2s[..., 3])\n    if cam.is_linear:\n        bb2s_full = bb2s.clone()\n        # Clamp based on the camera size for linear cameras.\n        # Note that this could generate very big/loose bounding boxes if the object is badly truncated due to out of view.\n        bb2s[..., 0:2] = torch.clamp(\n            bb2s[..., 0:2], min=0, max=cam.size.view(-1, 2)[0, 0] - 1\n        )\n        bb2s[..., 2:4] = torch.clamp(\n            bb2s[..., 2:4], min=0, max=cam.size.view(-1, 2)[0, 1] - 1\n        )\n        # filter out empty boxes.\n        bb2s_valid = torch.logical_and(non_empty_boxes, bb3corners_valids.any(-1))\n        if return_frac_valids:\n            frac_valid = torch.zeros_like(bb2s_valid).float()\n            frac_valid[non_empty_boxes] = box_iou(\n                bb2_xxyy_to_xyxy(bb2s[non_empty_boxes]),\n                bb2_xxyy_to_xyxy(bb2s_full[non_empty_boxes]),\n            ).diagonal()\n    else:\n        # count number of valid points\n        num_points = bb3corners_valids.count_nonzero(-1)\n        # valid 2d bbs are non-empty and have at least 1/6 of the edge sample\n        # points in the valid image region\n        bb2s_valid = torch.logical_and(\n            non_empty_boxes, num_points > num_samples_per_edge * 2\n        )\n        if return_frac_valids:\n            frac_valid = num_points / bb3corners_valids.shape[-1]\n            frac_valid[~non_empty_boxes] = 0.0\n    if return_frac_valids:\n        return bb2s, bb2s_valid, frac_valid\n    return bb2s, bb2s_valid\n\n\ndef bb2_xxyy_to_xyxy(bb2s):\n    # check if the input is xxyy\n    is_xxyy = torch.logical_and(\n        bb2s[..., 0] <= bb2s[..., 1], bb2s[..., 2] <= bb2s[..., 3]\n    )\n    is_xxyy = is_xxyy.all()\n    if not is_xxyy:\n        logger.warning(\"Input 2d bbx doesn't follow xxyy convention.\")\n    return bb2s[..., [0, 2, 1, 3]]\n\n\ndef bb2_xyxy_to_xxyy(bb2s):\n    # check if the input is xxyy\n    is_xyxy = torch.logical_and(\n        bb2s[..., 0] <= bb2s[..., 2], bb2s[..., 1] <= bb2s[..., 3]\n    )\n    is_xyxy = is_xyxy.all()\n    if not is_xyxy:\n        logger.warning(\"Input 2d bbx doesn't follow xyxy convention.\")\n    return bb2s[..., [0, 2, 1, 3]]\n\n\ndef bb3_xyzxyz_to_xxyyzz(bb3s):\n    \"\"\"\n    take bb3 in xyzxyz format and return xxyyzz format.\n    \"\"\"\n    return bb3s[..., [0, 3, 1, 4, 2, 5]]\n\n\ndef bb3_xyz_xyz_to_xxyyzz(bb3s_min, bb3s_max):\n    \"\"\"\n    take min and max points of the bb3 and return xxyyzz format\n    \"\"\"\n    return torch.cat([bb3s_min, bb3s_max], -1)[..., [0, 3, 1, 4, 2, 5]]\n\n\ndef rnd_obbs(N: int = 1, num_semcls: int = 10, bb3_min_diag=0.1, bb2_min_diag=10):\n    pts3_min = torch.randn(N, 3)\n    pts3_max = pts3_min + bb3_min_diag + torch.randn(N, 3).abs()\n    pts2_min = torch.randn(N, 2)\n    pts2_max = pts2_min + bb2_min_diag + torch.randn(N, 2).abs()\n\n    obb = ObbTW.from_lmc(\n        bb3_object=bb3_xyzxyz_to_xxyyzz(torch.cat([pts3_min, pts3_max], -1)),\n        prob=torch.ones(N),\n        bb2_rgb=bb2_xyxy_to_xxyy(torch.cat([pts2_min, pts2_max], -1)),\n        sem_id=torch.randint(low=0, high=num_semcls - 1, size=[N]),\n        T_world_object=PoseTW.from_aa(torch.randn(N, 3), 10.0 * torch.randn(N, 3)),\n    )\n    return obb\n\n\ndef obb_time_union(obbs, pad_size=128):\n    \"\"\"\n    Take frame level ground truth shaped BxTxNxC and take the union\n    over the time dimensions using the instance id to extend to snippet level\n    obbs shaped BxNxC.\n    \"\"\"\n    # T already merged somewhere else.\n    if obbs.ndim == 3:\n        return obbs\n\n    assert obbs.ndim == 4, \"Only B x T x N x C supported\"\n    new_obbs = []\n    for obb in obbs:\n        new_obb = []\n        flat_time_obb = obb.clone().reshape(-1, 34)\n        unique = flat_time_obb.inst_id.unique()\n        for uni in unique:\n            if uni == PAD_VAL:\n                continue\n            found = int(torch.argwhere(flat_time_obb.inst_id == uni)[0, 0])\n            found_obb = flat_time_obb[found].clone()\n            new_obb.append(found_obb)\n        if len(new_obb) == 0:\n            print(f\"Adding empty OBB in time_union {obbs.shape}\")\n            new_obb.append(ObbTW().reshape(-1).to(obbs._data))\n        new_obbs.append(torch.stack(new_obb).add_padding(pad_size))\n    new_obbs = torch.stack(new_obbs)\n    # Remove all bb2 observations since we no longer know which frame in time it came from.\n    # Note: we set the visibility for the merged obbs in order to do the evaluation on those theses obbs.\n    pad_mask = new_obbs.get_padding_mask()\n    new_obbs.set_bb2(cam_id=0, bb2d=1)\n    new_obbs.set_bb2(cam_id=1, bb2d=1)\n    new_obbs.set_bb2(cam_id=2, bb2d=1)\n    new_obbs._data[pad_mask] = PAD_VAL\n    return new_obbs\n\n\ndef obb_filter_outside_volume(obbs, T_ws, T_wv, voxel_extent, border=0.1):\n    \"\"\"\n    Remove obbs outside a volume of size voxel_extent, e.g. from a lifter volume.\n    Obbs are filtered based on their center point being inside the volume, and\n    are additionally filtered near the border.\n    \"\"\"\n    assert obbs.ndim == 3, \"Only B x N x C supported\"\n    T_vs = T_wv.inverse() @ T_ws\n    obbs_v = obbs.transform(T_vs.unsqueeze(1))\n    centers_v = obbs_v.bb3_center_world\n    cx = centers_v[:, :, 0]\n    cy = centers_v[:, :, 1]\n    cz = centers_v[:, :, 2]\n    x_min = voxel_extent[0]\n    x_max = voxel_extent[1]\n    y_min = voxel_extent[2]\n    y_max = voxel_extent[3]\n    z_min = voxel_extent[4]\n    z_max = voxel_extent[5]\n    valid = (obbs_v.inst_id != PAD_VAL).squeeze(-1)\n    inside = (\n        (cx > (x_min + border))\n        & (cy > (y_min + border))\n        & (cz > (z_min + border))\n        & (cx < (x_max - border))\n        & (cy < (y_max - border))\n        & (cz < (z_max - border))\n    )\n    remove = valid & ~inside\n    obbs._data[remove, :] = PAD_VAL\n    return obbs\n\n\ndef tensor_linspace(start, end, steps, device):\n    \"\"\"\n    Vectorized version of torch.linspace.\n    Inputs:\n    - start: Tensor of any shape\n    - end: Tensor of the same shape as start\n    - steps: Integer\n    Returns:\n    - out: Tensor of shape start.size() + (steps,), such that\n      out.select(-1, 0) == start, out.select(-1, -1) == end,\n      and the other elements of out linearly interpolate between\n      start and end.\n    \"\"\"\n    assert start.size() == end.size()\n    view_size = start.size() + (1,)\n    w_size = (1,) * start.dim() + (steps,)\n    out_size = start.size() + (steps,)\n\n    start_w = torch.linspace(1, 0, steps=steps, device=device).to(start)\n    start_w = start_w.view(w_size).expand(out_size)\n    end_w = torch.linspace(0, 1, steps=steps, device=device).to(start)\n    end_w = end_w.view(w_size).expand(out_size)\n\n    start = start.contiguous().view(view_size).expand(out_size)\n    end = end.contiguous().view(view_size).expand(out_size)\n\n    out = start_w * start + end_w * end\n    return out\n\n\ndef make_obb(sz, position, prob=1.0, roll=0.0, pitch=0.0, yaw=0.1):\n    e_angles = torch.tensor([roll, pitch, yaw]).reshape(-1, 3)\n    R = rotation_from_euler(e_angles).reshape(3, 3)\n    T_voxel_object = PoseTW.from_Rt(R, torch.tensor(position))\n    bb3 = [\n        -sz[0] / 2.0,\n        sz[0] / 2.0,\n        -sz[1] / 2.0,\n        sz[1] / 2.0,\n        -sz[2] / 2.0,\n        sz[2] / 2.0,\n    ]\n    return ObbTW.from_lmc(\n        bb3_object=torch.tensor(bb3),\n        prob=[prob],\n        T_world_object=T_voxel_object,\n    )\n\n\n# =====> Main function for 3D IoU computation. <=======\ndef obb_iou3d(obb1: ObbTW, obb2: ObbTW, samp_per_dim=32):\n    \"\"\"\n    Computes the intersection of two boxes by sampling points uniformly in\n    x,y,z dims.\n\n    samp_per_dim: int, number of samples per dimension, e.g. if 8, then 8x8x8\n                       increase for more accuracy but less speed\n                       8: fast but not so accurate\n                       32: medium\n                       128: most accurate but slow\n    \"\"\"\n    assert obb1.ndim == 2\n    assert obb2.ndim == 2\n\n    B1 = obb1.shape[0]\n    B2 = obb2.shape[0]\n    vol1 = obb1.bb3_volumes\n    vol2 = obb2.bb3_volumes\n\n    dim = samp_per_dim\n    points1_w = obb1.voxel_grid(vD=dim, vH=dim, vW=dim)\n    points2_w = obb2.voxel_grid(vD=dim, vH=dim, vW=dim)\n    num_samples = points1_w.shape[1]\n\n    isin21 = is_point_inside_box(points2_w, obb1.bb3corners_world, verbose=True)\n    num21 = isin21.sum(dim=-1)\n    isin12 = is_point_inside_box(points1_w, obb2.bb3corners_world, verbose=True)\n    num12 = isin12.sum(dim=-1)\n\n    inters12 = vol1.view(B1, 1) * num12.view(B1, B2)\n    inters21 = vol2.view(B2, 1) * num21.view(B2, B1)\n    inters = (inters12 + inters21.transpose(1, 0)) / 2.0\n    union = (vol1.view(B1, 1) * num_samples) + (vol2.view(1, B2) * num_samples) - inters\n    iou = inters / union\n    return iou\n\n\ndef is_point_inside_box(points: torch.Tensor, box: torch.Tensor, verbose=False):\n    \"\"\"\n    Determines whether points are inside the boxes\n    Args:\n        points: tensor of shape (B1, P, 3) of the points\n        box: tensor of shape (B2, 8, 3) of the corners of the boxes\n    Returns:\n        inside: bool tensor of whether point (row) is in box (col) shape (B1, B2, P)\n    \"\"\"\n    device = box.device\n    B1 = points.shape[0]\n    B2 = box.shape[0]\n    P = points.shape[1]\n\n    normals = box_planar_dir(box)  # (B2, 6, 3)\n    box_planes = get_plane_verts(box)  # (B2, 6, 4, 3)\n    NP = box_planes.shape[1]  # = 6\n\n    # a point p is inside the box if it \"inside\" all planes of the box\n    # so we run the checks\n    ins = torch.zeros((B1, B2, P, NP), device=device, dtype=torch.bool)\n    # ins = []\n    for i in range(NP):\n        is_in = is_inside(points, box_planes[:, i], normals[:, i])\n        ins[:, :, :, i] = is_in\n        # ins.append(is_in)\n    # ins = torch.stack(ins, dim=-1)\n\n    ins = ins.all(dim=-1)\n    return ins\n\n\ndef box_planar_dir(\n    box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS\n) -> torch.Tensor:\n    \"\"\"\n    Finds the unit vector n which is perpendicular to each plane in the box\n    and points towards the inside of the box.\n    The planes are defined by `_box_planes`.\n    Since the shape is convex, we define the interior to be the direction\n    pointing to the center of the shape.\n    Args:\n       box: tensor of shape (B, 8, 3) of the vertices of the 3D box\n    Returns:\n       n: tensor of shape (B, 6) of the unit vector orthogonal to the face pointing\n          towards the interior of the shape\n    \"\"\"\n    assert box.shape[1] == 8 and box.shape[2] == 3\n    # center point of each box\n    box_ctr = box.mean(dim=1).view(-1, 1, 3)\n    # box planes\n    plane_verts = get_plane_verts(box)  # (B, 6, 4, 3)\n    v0, v1, v2, v3 = plane_verts.unbind(2)\n    plane_ctr, n = get_plane_center_normal(plane_verts)\n    # Check all verts are coplanar\n    normv = F.normalize(v3 - v0, dim=-1).unsqueeze(2).reshape(-1, 1, 3)\n    nn = n.unsqueeze(3).reshape(-1, 3, 1)\n    dists = normv @ nn\n    if not (dists.abs() < dot_eps).all().item():\n        msg = \"Plane vertices are not coplanar\"\n        raise ValueError(msg)\n    # Check all faces have non zero area\n    area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2\n    area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2\n    if (area1 < area_eps).any().item() or (area2 < area_eps).any().item():\n        msg = \"Planes have zero areas\"\n        raise ValueError(msg)\n    # We can write:  `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1).\n    # With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,\n    # since that e0 is orthogonal to n. Same for e1.\n    \"\"\"\n    # Below is how one would solve for (a, b, c)\n    # Solving for (a, b)\n    numF = verts.shape[0]\n    A = torch.ones((numF, 2, 2), dtype=torch.float32, device=device)\n    B = torch.ones((numF, 2), dtype=torch.float32, device=device)\n    A[:, 0, 1] = (e0 * e1).sum(-1)\n    A[:, 1, 0] = (e0 * e1).sum(-1)\n    B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1)\n    B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1)\n    ab = torch.linalg.solve(A, B)  # (numF, 2)\n    a, b = ab.unbind(1)\n    # solving for c\n    c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)\n    \"\"\"\n    # Since we know that <e0, n> = 0 and <e1, n> = 0 (e0 and e1 are orthogonal to n),\n    # the above solution is equivalent to\n    direc = F.normalize(box_ctr - plane_ctr, dim=-1)  # (6, 3)\n    c = (direc * n).sum(-1)\n    # If c is negative, then we revert the direction of n such that n points \"inside\"\n    negc = c < 0.0\n    n[negc] *= -1.0\n    # c[negc] *= -1.0\n    # Now (a, b, c) is the solution to (1)\n    return n\n\n\ndef get_plane_verts(box: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Return the vertex coordinates forming the planes of the box.\n    The computation here resembles the Meshes data structure.\n    But since we only want this tiny functionality, we abstract it out.\n    Args:\n        box: tensor of shape (B, 8, 3)\n    Returns:\n        plane_verts: tensor of shape (B, 6, 4, 3)\n    \"\"\"\n    device = box.device\n    B = box.shape[0]\n    faces = torch.tensor(_box_planes, device=device, dtype=torch.int64)  # (6, 4)\n    plane_verts = torch.stack([box[b, faces] for b in range(B)])  # (B, 6, 4, 3)\n    return plane_verts\n\n\ndef is_inside(\n    points: torch.Tensor,\n    plane: torch.Tensor,\n    normal: torch.Tensor,\n    return_proj: bool = True,\n):\n    \"\"\"\n    Computes whether point is \"inside\" the plane.\n    The definition of \"inside\" means that the point\n    has a positive component in the direction of the plane normal defined by n.\n    For example,\n                  plane\n                    |\n                    |         . (A)\n                    |--> n\n                    |\n         .(B)       |\n\n    Point (A) is \"inside\" the plane, while point (B) is \"outside\" the plane.\n    Args:\n      points: tensor of shape (B1, P, 3) of coordinates of a point\n      plane: tensor of shape (B2, 4, 3) of vertices of a box plane\n      normal: tensor of shape (B2, 3) of the unit \"inside\" direction on the plane\n      return_proj: bool whether to return the projected point on the plane\n    Returns:\n      is_inside: bool of shape (B2, P) of whether point is inside\n    \"\"\"\n    device = plane.device\n    assert plane.ndim == 3\n    assert normal.ndim == 2\n    assert points.ndim == 3\n    assert points.shape[2] == 3\n    B1 = points.shape[0]\n    B2 = plane.shape[0]\n    P = points.shape[1]\n    v0, v1, v2, v3 = plane.unbind(dim=1)\n    plane_ctr = plane.mean(dim=1)\n    e0 = F.normalize(v0 - plane_ctr, dim=1)\n    e1 = F.normalize(v1 - plane_ctr, dim=1)\n\n    dot1 = (e0.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2)\n    if not torch.allclose(dot1, torch.zeros((B2,), device=device), atol=1e-2):\n        raise ValueError(\"Input n is not perpendicular to the plane\")\n    dot2 = (e1.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2)\n    if not torch.allclose(dot2, torch.zeros((B2,), device=device), atol=1e-2):\n        raise ValueError(\"Input n is not perpendicular to the plane\")\n\n    # Every point p can be written as p = ctr + a e0 + b e1 + c n\n    # solving for c\n    # c = (point - ctr - a * e0 - b * e1).dot(n)\n    pts = points.view(B1, 1, P, 3)\n    ctr = plane_ctr.view(1, B2, 1, 3)\n    e0 = e0.view(1, B2, 1, 3)\n    e1 = e1.view(1, B2, 1, 3)\n    normal = normal.view(1, B2, 1, 3)\n\n    direc = torch.sum((pts - ctr) * normal, dim=-1)\n    ins = direc >= 0.0\n    return ins\n\n\ndef get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns the center and normal of planes\n    Args:\n        planes: tensor of shape (B, P, 4, 3)\n    Returns:\n        center: tensor of shape (B, P, 3)\n        normal: tensor of shape (B, P, 3)\n    \"\"\"\n    B = planes.shape[0]\n\n    add_dim1 = False\n    if planes.ndim == 3:\n        planes = planes.unsqueeze(1)\n        add_dim1 = True\n\n    ctr = planes.mean(dim=2)  # (B, P, 3)\n    normals = torch.zeros_like(ctr)\n\n    v0, v1, v2, v3 = planes.unbind(dim=2)  # 4 x (B, P, 3)\n\n    P = planes.shape[1]\n    for t in range(P):\n        ns = torch.zeros((B, 6, 3), device=planes.device)\n        ns[:, 0] = torch.cross(v0[:, t] - ctr[:, t], v1[:, t] - ctr[:, t], dim=-1)\n        ns[:, 1] = torch.cross(v0[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1)\n        ns[:, 2] = torch.cross(v0[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)\n        ns[:, 3] = torch.cross(v1[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1)\n        ns[:, 4] = torch.cross(v1[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)\n        ns[:, 5] = torch.cross(v2[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)\n        ii = torch.argmax(torch.norm(ns, dim=-1), dim=-1)\n        normals[:, t] = ns[torch.arange(B), ii]\n\n    if add_dim1:\n        ctr = ctr[:, 0]\n        normals = normals[:, 0]\n    normals = F.normalize(normals, dim=-1)\n    return ctr, normals\n"
  },
  {
    "path": "efm3d/aria/pose.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport math\nfrom typing import Dict, List, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom .tensor_wrapper import autocast, autoinit, smart_stack, TensorWrapper\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\nIdentityPose = torch.tensor(\n    [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]\n).reshape(12)\n\nPAD_VAL = -1\n\n\ndef get_T_rot_z(angle: float):\n    T_rot_z = np.array(\n        [\n            [np.cos(angle), -np.sin(angle), 0.0, 0.0],\n            [np.sin(angle), np.cos(angle), 0.0, 0.0],\n            [0.0, 0.0, 1.0, 0.0],\n        ]\n    )\n    return torch.from_numpy(T_rot_z).float()\n\n\ndef skew_symmetric(v):\n    \"\"\"Create a skew-symmetric matrix from a (batched) vector of size (..., 3).\"\"\"\n    z = torch.zeros_like(v[..., 0])\n    M = torch.stack(\n        [\n            z,\n            -v[..., 2],\n            v[..., 1],\n            v[..., 2],\n            z,\n            -v[..., 0],\n            -v[..., 1],\n            v[..., 0],\n            z,\n        ],\n        dim=-1,\n    ).reshape(v.shape[:-1] + (3, 3))\n    return M\n\n\ndef inv_skew_symmetric(V):\n    \"\"\"Create a (batched) vector from a skew-symmetric matrix of size (..., 3, 3).\"\"\"\n    # average lower and uper triangular entries in case skew symmetric matrix\n    # has numeric errors.\n    VVT = 0.5 * (V - V.transpose(-2, -1))\n    return torch.stack(\n        [\n            -VVT[..., 1, 2],\n            VVT[..., 0, 2],\n            -VVT[..., 0, 1],\n        ],\n        -1,\n    )\n\n\ndef so3exp_map(w, eps: float = 1e-7):\n    \"\"\"Compute rotation matrices from batched twists.\n    Args:\n        w: batched 3D axis-angle vectors of size (..., 3).\n    Returns:\n        A batch of rotation matrices of size (..., 3, 3).\n    \"\"\"\n    theta = w.norm(p=2, dim=-1, keepdim=True)\n    small = theta < eps\n    div = torch.where(small, torch.ones_like(theta), theta)\n    W = skew_symmetric(w / div)\n    theta = theta[..., None]  # ... x 1 x 1\n    res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta))\n    res = torch.where(small[..., None], W, res)  # first-order Taylor approx\n    return torch.eye(3).to(W) + res\n\n\ndef so3log_map(R, eps: float = 1e-7):\n    trace = torch.diagonal(R, dim1=-1, dim2=-2).sum(-1)\n    cos = torch.clamp((trace - 1.0) * 0.5, -1, 1)\n    theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1)\n    ones = torch.ones_like(theta)\n    small = theta < eps\n    # compute factors and approximate them around 0 using second order\n    # taylor expansion (from WolframAlpha)\n    theta_over_sin_theta = torch.where(\n        small,\n        ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0,\n        theta / torch.sin(theta),\n    )\n    # compute log-map W of rotation R first\n    W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2))\n    omega = inv_skew_symmetric(W)\n    return omega\n\n\ndef interpolation_boundaries_alphas(times: torch.Tensor, interp_times: torch.Tensor):\n    \"\"\"\n    find the ids in times tensor that bound each of the interp_times timestamps\n    from below (lower_ids) and above (upper_ids).\n    If interp_times are outside the interval spanned by times, upper and lower\n    ids will both point to the boundary timestamps and the returned good boolean\n    tensor will be False at those interpolation timestamps.\n\n    Also return the alphas needed to interpolate a value as:\n    interp_value = alpha * value[lower_id] + (1-alpha)* value[upper_id]\n\n    Note that because the upper and lower ids are pointing to the boundary\n    timestamps when the interpolation time is outside the time interval,\n    applying the formula above will yield the values at the boundaries as a\n    reasonable \"interpolation\". No extrapolation will be performed. Again the\n    good values can be used to check which values are at the boundaries and\n    which ones are interpolated.\n    \"\"\"\n    times = times.unsqueeze(-2)\n    interp_times = interp_times.unsqueeze(-1)\n    dt = times - interp_times\n    if dt.dtype == torch.long:\n        dt_max = torch.iinfo(type=dt.dtype).max\n    else:\n        dt_max = torch.finfo(type=dt.dtype).max\n    dt_upper = torch.where(dt < 0.0, torch.ones_like(dt) * dt_max, dt)\n    dt_lower = torch.where(dt > 0.0, torch.ones_like(dt) * dt_max, -dt)\n    upper_alpha, upper_ids = torch.min(dt_upper, dim=-1)\n    lower_alpha, lower_ids = torch.min(dt_lower, dim=-1)\n    good = torch.logical_and(lower_alpha < dt_max, upper_alpha < dt_max)\n    upper_ids = torch.where(good, upper_ids, torch.maximum(lower_ids, upper_ids))\n    lower_ids = torch.where(good, lower_ids, torch.minimum(lower_ids, upper_ids))\n    assert (lower_ids <= upper_ids).all()\n    # nan_to_num handles the case where time and interpolation time are the same\n    # and hence this is a 0/0\n    # okay to go to floats now since the critical bit is the computation of the time difference\n    alpha = torch.nan_to_num(lower_alpha.float() / (lower_alpha + upper_alpha).float())\n    alpha = torch.where(good, alpha, torch.zeros_like(alpha))\n    return lower_ids, upper_ids, alpha, good\n\n\ndef quaternion_to_matrix(quaternions_wxyz: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Convert rotations given as quaternions to rotation matrices. Input quaternions\n    should be in wxyz format, with real part first, imaginary part last.\n\n    The function is copied from `quaternion_to_matrix` in Pytorch3d:\n    https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions_wxyz, -1)\n    two_s = 2.0 / (quaternions_wxyz * quaternions_wxyz).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions_wxyz.shape[:-1] + (3, 3))\n\n\nclass PoseTW(TensorWrapper):\n    @autocast\n    @autoinit\n    def __init__(self, data: torch.Tensor = IdentityPose):\n        assert isinstance(data, torch.Tensor)\n        assert data.shape[-1] == 12\n        super().__init__(data)\n\n    @classmethod\n    @autocast\n    def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):\n        \"\"\"Pose from a rotation matrix and translation vector.\n        Accepts numpy arrays or PyTorch tensors.\n\n        Args:\n            R: rotation matrix with shape (..., 3, 3).\n            t: translation vector with shape (..., 3).\n        \"\"\"\n        assert R.shape[-2:] == (3, 3)\n        assert t.shape[-1] == 3\n        assert R.shape[:-2] == t.shape[:-1]\n        data = torch.cat([R.flatten(start_dim=-2), t], -1)\n        return cls(data)\n\n    @classmethod\n    @autocast\n    def from_qt(cls, quaternion_wxyz: torch.Tensor, t: torch.Tensor):\n        \"\"\"Pose from quaternion and translation vectors. Quaternion should\n        be wxyz format, with real part first, and imaginary part last.\n\n        Args:\n            quaternion: quaternion with shape (..., 4).\n            t: translation vector with shape (..., 3).\n        \"\"\"\n        assert quaternion_wxyz.shape[:-1] == t.shape[:-1], (\n            f\"quaternion shape {quaternion_wxyz.shape[:-1]} must match translation shape {t.shape[:-1]} expect the last dim\"\n        )\n        assert quaternion_wxyz.shape[-1] == 4, \"quaternion must be of shape (..., 4)\"\n        assert t.shape[-1] == 3, \"translation must be of shape (..., 3)\"\n\n        R = quaternion_to_matrix(quaternion_wxyz)\n        data = torch.cat([R.flatten(start_dim=-2), t], -1)\n        return cls(data)\n\n    @classmethod\n    @autocast\n    def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):\n        \"\"\"Pose from an axis-angle rotation vector and translation vector.\n        Accepts numpy arrays or PyTorch tensors.\n\n        Args:\n            aa: axis-angle rotation vector with shape (..., 3).\n            t: translation vector with shape (..., 3).\n        \"\"\"\n        assert aa.shape[-1] == 3\n        assert t.shape[-1] == 3\n        assert aa.shape[:-1] == t.shape[:-1]\n        return cls.from_Rt(so3exp_map(aa), t)\n\n    @classmethod\n    @autocast\n    def from_matrix(cls, T: torch.Tensor):\n        \"\"\"Pose from an SE(3) transformation matrix.\n        Args:\n            T: transformation matrix with shape (..., 4, 4).\n        \"\"\"\n        assert T.shape[-2:] == (4, 4)\n        R, t = T[..., :3, :3], T[..., :3, 3]\n        return cls.from_Rt(R, t)\n\n    @classmethod\n    @autocast\n    def from_matrix3x4(cls, T_3x4: torch.Tensor):\n        \"\"\"Pose from an SE(3) transformation matrix.\n        Args:\n            T: transformation matrix with shape (..., 3, 4).\n        \"\"\"\n        assert T_3x4.shape[-2:] == (3, 4)\n        R, t = T_3x4[..., :3, :3], T_3x4[..., :3, 3]\n        return cls.from_Rt(R, t)\n\n    @classmethod\n    @autocast\n    def exp(cls, u_omega: torch.Tensor, eps: float = 1e-7):\n        \"\"\"\n        Compute the SE3 exponential map from input se3 vectors u_omega [....,6] where\n        the last 3 entries are the so3 entires omega and the first 3 the entries\n        for translation.\n        \"\"\"\n        # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf\n        u = u_omega[..., :3]\n        omega = u_omega[..., 3:]\n        theta = omega.norm(p=2, dim=-1, keepdim=True).unsqueeze(-1)\n        small = theta < eps\n        R = so3exp_map(omega, eps)\n        # compute V\n        shape = [1] * len(omega.shape[:-1])\n        ones = torch.ones_like(theta)\n        # compute factors and approximate them around 0 using second order\n        # taylor expansion (from WolframAlpha)\n        b = torch.where(\n            small,\n            0.5 * ones - theta**2 / 24.0 + theta**4 / 720.0,\n            (ones - torch.cos(theta)) / theta**2,\n        )\n        c = torch.where(\n            small,\n            1.0 / 6.0 * ones - theta**2 / 120.0 + theta**4 / 5040.0,\n            (theta - torch.sin(theta)) / theta**3,\n        )\n        Identity = (\n            torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(u_omega)\n        )\n        W = skew_symmetric(omega)\n        V = Identity + b * W + c * W @ W\n        # compute t\n        t = (V @ u.unsqueeze(-1)).squeeze(-1)\n        return cls.from_Rt(R, t)\n\n    # @classmethod\n    # def from_colmap(cls, image: NamedTuple):\n    #    '''Pose from a COLMAP Image.'''\n    #    return cls.from_Rt(image.qvec2rotmat(), image.tvec)\n\n    @property\n    def R(self) -> torch.Tensor:\n        \"\"\"Underlying rotation matrix with shape (..., 3, 3).\"\"\"\n        rvec = self._data[..., :9]\n        return rvec.reshape(rvec.shape[:-1] + (3, 3))\n\n    @property\n    def t(self) -> torch.Tensor:\n        \"\"\"Underlying translation vector with shape (..., 3).\"\"\"\n        return self._data[..., -3:]\n\n    @property\n    def q(self) -> torch.Tensor:\n        \"\"\"\n        Convert rotations of shape (..., 3, 3) to a quaternion (..., 4).\n        The returned quaternions have real part first, as wxyz.\n        The function is adapted from `matrix_to_quaternion` in Pytorch3d:\n        https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py\n\n        The major difference to the original pytorch3d function is that the returned\n        quaternions are normalized and have positive real part.\n        \"\"\"\n\n        def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n            \"\"\"\n            Returns torch.sqrt(torch.max(0, x))\n            but with a zero subgradient where x is 0.\n            \"\"\"\n            ret = torch.zeros_like(x)\n            positive_mask = x > 0\n            ret[positive_mask] = torch.sqrt(x[positive_mask])\n            return ret\n\n        matrix = self.R\n        batch_dim = matrix.shape[:-2]\n        m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(\n            matrix.reshape(batch_dim + (9,)), dim=-1\n        )\n\n        q_abs = _sqrt_positive_part(\n            torch.stack(\n                [\n                    1.0 + m00 + m11 + m22,\n                    1.0 + m00 - m11 - m22,\n                    1.0 - m00 + m11 - m22,\n                    1.0 - m00 - m11 + m22,\n                ],\n                dim=-1,\n            )\n        )\n\n        # we produce the desired quaternion multiplied by each of r, i, j, k\n        quat_by_wxyz = torch.stack(\n            [\n                torch.stack(\n                    [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1\n                ),\n                torch.stack(\n                    [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1\n                ),\n                torch.stack(\n                    [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1\n                ),\n                torch.stack(\n                    [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1\n                ),\n            ],\n            dim=-2,\n        )\n\n        # We floor here at 0.1 but the exact level is not important; if q_abs is small,\n        # the candidate won't be picked.\n        flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n        quat_candidates = quat_by_wxyz / (2.0 * q_abs[..., None].max(flr))\n\n        # if not for numerical problems, quat_candidates[i] should be same (up to a sign),\n        # forall i; we pick the best-conditioned one (with the largest denominator)\n        best_quat = quat_candidates[\n            F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :\n        ].reshape(batch_dim + (4,))\n\n        # normalize quaternions and make the real part to be positive for all quaternions\n        best_quat = best_quat.reshape(-1, 4)\n        neg_ind = torch.nonzero(best_quat[:, 0] < 0).squeeze()\n        best_quat[neg_ind, :] *= -1\n        best_quat = best_quat.reshape(batch_dim + (4,))\n        best_quat_normalized = F.normalize(best_quat, p=2, dim=-1)\n        return best_quat_normalized\n\n    @property\n    def q_xyzw(self) -> torch.Tensor:\n        \"\"\"\n        Get the quaternion representation similar to self.q, but the real part\n        of the quaternion comes last rather than first. This is a handy function to increase\n        interoperability, e.g. lietorch requires xyzw quaternions.\n        \"\"\"\n        quat_wxyz = self.q\n        return torch.concat([quat_wxyz[..., 1:4], quat_wxyz[..., 0:1]], dim=-1)\n\n    @property\n    def matrix3x4(self) -> torch.Tensor:\n        \"\"\"Underlying transformation matrix with shape (..., 3, 4).\"\"\"\n        rvec = self._data[..., :9]\n        rmat = rvec.reshape(rvec.shape[:-1] + (3, 3))\n        tvec = self._data[..., -3:].unsqueeze(-1)\n        T = torch.cat([rmat, tvec], dim=-1)\n        return T\n\n    @property\n    def matrix(self) -> torch.Tensor:\n        \"\"\"Underlying transformation matrix with shape (..., 4, 4).\"\"\"\n        T_3x4 = self.matrix3x4\n        bot_row = T_3x4.new_zeros(T_3x4.shape[:-2] + (1, 4))\n        bot_row[..., 0, 3] = 1\n        return torch.cat([T_3x4, bot_row], dim=-2)\n\n    def to_euler(self, rad=True) -> torch.Tensor:\n        \"\"\"Convert the rotation matrix to Euler angles using ZYX convention.\"\"\"\n        \"\"\"Reference: http://eecs.qmul.ac.uk/~gslabaugh/publications/euler.pdf\"\"\"\n        # Test gimbal lock (ignore rotations that are all PAD_VAL).\n        is_pad = torch.all(torch.all(self.R == PAD_VAL, dim=-1), dim=-1)\n        assert (~torch.abs(self.R[~is_pad][..., 2, 0]).isclose(torch.tensor(1.0))).all()\n        Y_angle = -torch.asin(self.R[..., 2, 0])\n        euler_angles = (\n            torch.atan2(self.R[..., 2, 1], self.R[..., 2, 2]),\n            Y_angle,\n            torch.atan2(self.R[..., 1, 0], self.R[..., 0, 0]),\n        )\n        if not rad:\n            # return degree\n            return torch.stack(euler_angles, -1) * 180.0 / torch.pi\n        return torch.stack(euler_angles, -1)\n\n    def to_ypr(self, rad=True) -> torch.Tensor:\n        # yaw, pitch, roll from rotation matrix: http://lavalle.pl/planning/node103.html\n        R = self.R\n        yaw = torch.atan(R[..., 1, 0] / R[..., 0, 0])\n        pitch = torch.atan(\n            -R[..., 2, 0]\n            / torch.sqrt(R[..., 2, 1] * R[..., 2, 1] + R[..., 2, 2] * R[..., 2, 2])\n        )\n        roll = torch.atan(R[..., 2, 1] / R[..., 2, 2])\n        return yaw, pitch, roll\n\n    def inverse(self) -> \"PoseTW\":\n        \"\"\"Invert an SE(3) pose.\"\"\"\n        R = self.R.transpose(-1, -2)\n        t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)\n        return self.__class__.from_Rt(R, t)\n\n    def compose(self, other: \"PoseTW\") -> \"PoseTW\":\n        \"\"\"Chain two SE(3) poses: T_C_B.compose(T_B_A) -> T_C_A.\"\"\"\n        R = self.R @ other.R\n        t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)\n        return self.__class__.from_Rt(R, t)\n\n    @autocast\n    def transform(self, p3d: torch.Tensor) -> torch.Tensor:\n        \"\"\"Transform a set of 3D points.\n        Args:\n            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).\n        \"\"\"\n        assert p3d.shape[-1] == 3\n        # use more efficient right multiply that avoids transpose of the points\n        # according to the equality:\n        # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T\n        # where p^T = p3d, R = self.R and t = self.t\n        return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)\n\n    @autocast\n    def batch_transform(self, p3d: torch.Tensor) -> torch.Tensor:\n        \"\"\"Transform a set of 3D points each by the associated (in batch\n        dimensions) transform in this PoseTW.\n        Args:\n            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).\n        \"\"\"\n        assert p3d.shape == self.t.shape, f\"shapes of p3d {p3d.shape}, t {self.t.shape}\"\n        # bmm assumes one batch dimension\n        assert p3d.dim() == 2, f\"{p3d.shape}\"\n        assert self.ndim == 2, f\"{self.shape}\"\n        # use more efficient right multiply that avoids transpose of the points\n        # according to the equality:\n        # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T\n        # where p^T = p3d, R = self.R and t = self.t\n        return (\n            torch.bmm(p3d.unsqueeze(-2), self.R.transpose(-1, -2)).squeeze(-2) + self.t\n        )\n\n    @autocast\n    def rotate(self, p3d: torch.Tensor) -> torch.Tensor:\n        \"\"\"Rotate a set of 3D points. Useful for directional vectors which should not be translated.\n        Args:\n            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).\n        \"\"\"\n        assert p3d.shape[-1] == 3\n        # use more efficient right multiply that avoids transpose of the points\n        # according to the equality:\n        # (Rp)^T = p^T R^T where p3d = p^T and self.R = R\n        return p3d @ self.R.transpose(-1, -2)\n\n    def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:\n        \"\"\"Transform a set of 3D points: T_B_A * p3D_A -> p3D_B\"\"\"\n        return self.transform(p3D)\n\n    def __matmul__(self, other: \"PoseTW\") -> \"PoseTW\":\n        \"\"\"Chain two SE(3) poses: T_C_B @ T_B_A -> T_C_A.\"\"\"\n        return self.compose(other)\n\n    def numpy(self) -> Tuple[np.ndarray]:\n        return self.R.numpy(), self.t.numpy()\n\n    def magnitude(self, deg=True, eps=0) -> Tuple[torch.Tensor]:\n        \"\"\"Magnitude of the SE(3) transformation. The `eps` has to be\n        positive if you want to use this function as part of a training loop.\n\n        Returns:\n            dr: rotation angle in degrees (if deg=True) or in radians.\n            dt: translation distance in meters.\n        \"\"\"\n        trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)\n        cos = torch.clamp((trace - 1) / 2, min=-1.0 + eps, max=1.0 - eps)\n        dr = torch.acos(cos)\n        if deg:\n            dr = dr * 180.0 / math.pi\n        dt = torch.norm(self.t, dim=-1)\n        return dr, dt\n\n    def so3_geodesic(self, other: \"PoseTW\", deg=False) -> \"PoseTW\":\n        \"\"\"Compute the geodesic distance for rotation between this pose and another pose\"\"\"\n        pose_e = self.compose(other.inverse())\n        dr, _ = pose_e.magnitude(deg=deg, eps=1e-6)\n        return dr\n\n    def log(self, eps: float = 1e-6) -> torch.Tensor:\n        \"\"\"\n        Compute the SE3 log map for these poses.\n        Returns [...,6] where the last 3 entries are the so3 entires omega and\n        the first 3 the entries for translation.\n        \"\"\"\n        # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf\n        R, t = self.R, self.t\n        trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)\n        cos = torch.clamp((trace - 1.0) * 0.5, -1, 1)\n        theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1)\n        ones = torch.ones_like(theta)\n        small = theta < eps\n        # compute factors and approximate them around 0 using second order\n        # taylor expansion (from WolframAlpha)\n        theta_over_sin_theta = torch.where(\n            small,\n            ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0,\n            theta / torch.sin(theta),\n        )\n        c = torch.where(\n            small,\n            0.08333333 + 0.001388889 * theta**2 + 0.0000330688 * theta**4,\n            (ones - ((0.5 * theta * torch.sin(theta)) / (ones - torch.cos(theta))))\n            / theta**2,\n        )\n        # compute log-map W of rotation R first\n        W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2))\n        # compute V_inv to be able to get u\n        shape = [1] * len(R.shape[:-2])\n        Identity = (\n            torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(self._data)\n        )\n        V_inv = Identity - 0.5 * W + c * W @ W\n        u = (V_inv @ t.unsqueeze(-1)).squeeze(-1)\n        omega = inv_skew_symmetric(W)\n        return torch.cat([u, omega], -1)\n\n    def interpolate(self, times: torch.Tensor, interp_times: torch.Tensor):\n        \"\"\"\n        Return poses at the given interpolation times interp_times based on the\n        poses in this object and the provided associated timestamps times.\n\n        If interpolation timestamps are outside the interval of times, the poses\n        at the interval boundaries will be returned and the good boolean tensor\n        will indicate those boundary values with a False.\n        \"\"\"\n        assert times.shape == self._data.shape[:-1], (\n            f\"time stamps for the poses do not match poses shape {times.shape} vs {self._data.shape}\"\n        )\n\n        assert times.dim() <= 2, (\n            \"The shape of the input times should be either BxT or T.\"\n        )\n        times = times.to(self.device)\n        interp_times = interp_times.to(self.device)\n        # find the closest timestamps above and below for each interp_times in times\n        lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas(\n            times, interp_times\n        )\n        # get the bounding poses\n        upper_ids = upper_ids.unsqueeze(-1)\n        upper_ids = upper_ids.expand(*upper_ids.shape[0:-1], self._data.shape[-1])\n        lower_ids = lower_ids.unsqueeze(-1)\n        lower_ids = lower_ids.expand(*lower_ids.shape[0:-1], self._data.shape[-1])\n        T_upper = self.__class__(self._data.gather(times.dim() - 1, upper_ids))\n        T_lower = self.__class__(self._data.gather(times.dim() - 1, lower_ids))\n        # get se3 element connecting the lower and upper poses\n        dT = T_lower.inverse() @ T_upper\n        dx = dT.log()\n        # interpolate on se3\n        dT = self.exp(dx * alpha.unsqueeze(-1))\n        return T_lower @ dT, good\n\n    def align(self, other, self_times=None, other_times=None):\n        \"\"\"Align two trajectories using the method of Horn (closed-form).\n\n        Input:\n            other -- second PoseTW (Nx12) trajectory to align to\n\n        Output:\n            T_self_other -- relative SE3 transform (Nx12)\n            trans_error -- translational error per point (Nx1)\n\n        code inspired by: https://github.com/symao/vio_evaluation/blob/master/align.py#L6-L38\n        \"\"\"\n        if self.t.ndim != 2:\n            raise ValueError(\n                \"Only Nx12 Pose supported in alignment, given {self.shape}\"\n            )\n        if other.t.ndim != 2:\n            raise ValueError(\n                \"Only Nx12 Pose supported in alignment, given {other.shape}\"\n            )\n        dtype = torch.promote_types(self.dtype, other.dtype)\n\n        # Optionally interpolate other to match the size of self.\n        if self.shape[0] != other.shape[0]:\n            if self_times is None or other_times is None:\n                raise ValueError(\n                    \"Got different length PoseTW (self {self.shape} and other {other.shape}). Must provide timestamps to support interpolation\"\n                )\n            # Do interpolation on temporal intersection.\n            other, goods = other.interpolate(other_times, self_times)\n            self2 = self.clone()[goods].to(dtype)\n            other2 = other.clone()[goods].to(dtype)\n        else:\n            self2 = self.clone().to(dtype)\n            other2 = other.clone().to(dtype)\n\n        P = self2.t.transpose(0, 1)\n        Q = other2.t.transpose(0, 1)\n\n        if P.shape != Q.shape:\n            raise ValueError(\"Matrices P and Q must be of the same dimensionality\")\n\n        centroids_P = torch.mean(P, dim=1)\n        centroids_Q = torch.mean(Q, dim=1)\n        A = P - torch.outer(centroids_P, torch.ones(P.shape[1], dtype=dtype))\n        B = Q - torch.outer(centroids_Q, torch.ones(Q.shape[1], dtype=dtype))\n        C = A @ B.transpose(0, 1)\n        U, S, V = torch.linalg.svd(C)\n        R = V.transpose(0, 1) @ U.transpose(0, 1)\n        L = torch.eye(3, dtype=dtype)\n        if torch.linalg.det(R) < 0:\n            L[2][2] *= -1\n\n        R = V.transpose(0, 1) @ (L @ U.transpose(0, 1))\n        t = (-R @ centroids_P) + centroids_Q\n        T_self_other = PoseTW.from_Rt(R, t).inverse().to(dtype)\n\n        other_aligned = T_self_other @ other2\n\n        error = torch.linalg.norm(other_aligned.t - self2.t, dim=-2)\n        mean_error = error.mean(dim=-1)\n\n        return T_self_other, mean_error\n\n    def fit_to_SO3(self):\n        # Math used from quora post and this berkeley pdf.\n        # https://qr.ae/pKQaG5\n        # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf\n        assert self._data.ndim == 1\n        Q = fit_to_SO3(self.R)\n        return PoseTW.from_Rt(Q, self.t)\n\n    def __repr__(self):\n        return f\"PoseTW: {self.shape} {self.dtype} {self.device}\"\n\n\ndef interpolate_timed_poses(\n    timed_poses: Dict[\n        Union[float, int],\n        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],\n    ],\n    time: Union[float, int],\n):\n    \"\"\"\n    interpolate timed poses given as a dict[time:container[PoseTW]] to given\n    time.  The poses container indexed by time can be given plain as poses, or a\n    list or dict of poses.  If a list or dict of poses is given, the output will\n    also be a list or dict of the interpolated poses. This allows batched\n    interpolation.\n    \"\"\"\n    ts_list = list(timed_poses.keys())\n    ts = torch.from_numpy(np.array(ts_list))\n    interp_time = torch.from_numpy(np.array([time]))\n    lower_ids, upper_ids, _, _ = interpolation_boundaries_alphas(ts, interp_time)\n    t_lower = ts_list[lower_ids[0]]\n    t_upper = ts_list[upper_ids[0]]\n    poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper]\n    poses_interp = None\n    times = torch.from_numpy(np.array([t_lower, t_upper])).float()\n    if isinstance(poses_lower, PoseTW):\n        poses = PoseTW(smart_stack([poses_lower, poses_upper]))\n        if poses.dim() == 3:\n            times = times.unsqueeze(-1).repeat(1, poses.shape[1])\n        poses_interp = poses.interpolate(times, interp_time)[0].squeeze()\n    elif isinstance(poses_lower, dict):\n        keys_lower = set(poses_lower.keys())\n        keys_upper = set(poses_upper.keys())\n        keys = keys_lower & keys_upper\n        poses_interp = {}\n        for key in keys:\n            poses = PoseTW(smart_stack([poses_lower[key], poses_upper[key]]))\n            if poses.dim() == 3 and times.dim() == 1:\n                times = times.unsqueeze(-1).repeat(1, poses.shape[1])\n            poses_interp[key] = poses.interpolate(times, interp_time)[0].squeeze()\n    elif isinstance(poses_lower, list):\n        assert len(poses_lower) == len(poses_upper)\n        poses_interp = []\n        for i in range(len(poses_lower)):\n            poses = PoseTW(smart_stack([poses_lower[i], poses_upper[i]]))\n            if poses.dim() == 3 and times.dim() == 1:\n                times = times.unsqueeze(-1).repeat(1, poses.shape[1])\n            poses_interp.append(poses.interpolate(times, interp_time)[0].squeeze())\n    return poses_interp\n\n\ndef lower_timed_poses(\n    timed_poses: Dict[\n        Union[float, int],\n        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],\n    ],\n    time: Union[float, int],\n):\n    \"\"\"\n    interpolate timed poses given as a dict[time:container[PoseTW]] to given\n    time.  The poses container indexed by time can be given plain as poses, or a\n    list or dict of poses.  If a list or dict of poses is given, the output will\n    also be a list or dict of the interpolated poses. This allows batched\n    interpolation.\n    \"\"\"\n    ts_list = list(timed_poses.keys())\n    ts = torch.from_numpy(np.array(ts_list))\n    interp_time = torch.from_numpy(np.array([time]))\n    lower_ids, _, alpha, good = interpolation_boundaries_alphas(ts, interp_time)\n    t_lower = ts_list[lower_ids[0]]\n    poses_lower = timed_poses[t_lower]\n    return poses_lower, t_lower - time\n\n\ndef closest_timed_poses(\n    timed_poses: Dict[\n        Union[float, int],\n        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],\n    ],\n    time: Union[float, int],\n):\n    \"\"\"\n    interpolate timed poses given as a dict[time:container[PoseTW]] to given\n    time.  The poses container indexed by time can be given plain as poses, or a\n    list or dict of poses.  If a list or dict of poses is given, the output will\n    also be a list or dict of the interpolated poses. This allows batched\n    interpolation.\n    \"\"\"\n    ts_list = list(timed_poses.keys())\n    ts = torch.from_numpy(np.array(ts_list))\n    interp_time = torch.from_numpy(np.array([time]))\n    lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas(ts, interp_time)\n    t_lower = ts_list[lower_ids[0]]\n    t_upper = ts_list[upper_ids[0]]\n    poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper]\n    if time - t_lower < t_upper - time:\n        return poses_lower, time - t_lower\n    else:\n        return poses_upper, t_upper - time\n\n\ndef all_rot90():\n    # construct all possible 90 degree rotations\n    dirs = torch.cat([torch.eye(3), -torch.eye(3)], dim=0)\n    ids = torch.arange(0, 6).long()\n    jds = torch.arange(0, 6).long()\n    ids, jds = torch.meshgrid(ids, jds)\n    ids, jds = ids.reshape(-1), jds.reshape(-1)\n    a, b = dirs[ids, :], dirs[jds, :]\n    c = torch.cross(a, b, -1)\n    Rs = torch.cat([a.unsqueeze(2), b.unsqueeze(2), c.unsqueeze(2)], dim=2)\n    # filter to valid rotations\n    det = torch.linalg.det(Rs)\n    Rs = Rs[det > 0.99]\n    return Rs\n\n\ndef find_r90(Ta, Tb, R90s):\n    N = None\n    if Tb.ndim == 2:\n        N = Tb.shape[0]\n        # 24xNx3x3\n        R90s = R90s.unsqueeze(1).repeat(1, N, 1, 1)\n    Ra_inv, Rb = Ta.inverse().R.unsqueeze(0), Tb.R\n    # 24x(Nx)3x3\n    dR = Ra_inv @ Rb.unsqueeze(0) @ R90s\n    w = so3log_map(dR)\n    ang = torch.linalg.norm(w, 2, dim=-1)\n    ang_min, id_min = torch.min(ang, dim=0)\n    if N is None:\n        R90min = R90s[id_min]\n    else:\n        R90min = R90s[id_min, torch.arange(N)]\n    Rb = Rb @ R90min\n    Tb = PoseTW.from_Rt(Rb, Tb.t)\n    return Tb, R90min\n\n\ndef stereographic_unproject(a, axis=None):\n    \"\"\"\n    Inverse of stereographic projection: https://en.wikipedia.org/wiki/Stereographic_projection\n    This is from the paper \"On the Continuity of Rotation Representations in Neural\n    Networks\" https://arxiv.org/pdf/1812.07035.pdf, equation [8,9],\n    used in rotation_from_ortho_5d.\n    \"\"\"\n    batch = a.shape[0]\n    if axis is None:\n        axis = a.shape[1]\n    s2 = torch.pow(a, 2).sum(1)\n    ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1] + 1).to(a))\n    unproj = 2 * a / (s2 + 1).reshape(batch, 1).repeat(1, a.shape[1])\n    if axis > 0:\n        ans[:, :axis] = unproj[:, :axis]\n    ans[:, axis] = (s2 - 1) / (s2 + 1)\n    ans[:, axis + 1 :] = unproj[:, axis:]\n    return ans\n\n\ndef rotation_from_ortho_6d(ortho6d):\n    \"\"\"\n    Convert a 6-d rotation representation to rotation matrix\n\n    From the paper \"On the Continuity of Rotation Representations in Neural Networks\"\n    https://arxiv.org/pdf/1812.07035.pdf\n    \"\"\"\n    x_raw = ortho6d[..., 0:3]\n    y_raw = ortho6d[..., 3:6]\n\n    x = F.normalize(x_raw, dim=-1, eps=1e-6)\n    y = F.normalize(y_raw, dim=-1, eps=1e-6)\n\n    z = torch.cross(x, y, -1)\n    z = F.normalize(z, dim=-1, eps=1e-6)\n    y = torch.cross(z, x, -1)\n\n    x = x.reshape(-1, 3, 1)\n    y = y.reshape(-1, 3, 1)\n    z = z.reshape(-1, 3, 1)\n    matrix = torch.cat((x, y, z), 2)\n    return matrix\n\n\ndef rotation_from_ortho_5d(ortho5d):\n    \"\"\"\n    Convert a 5-d rotation representation to rotation matrix\n\n    From the paper \"On the Continuity of Rotation Representations in Neural Networks\"\n    https://arxiv.org/pdf/1812.07035.pdf\n    \"\"\"\n    batch = ortho5d.shape[0]\n    proj_scale_np = np.array([np.sqrt(2) + 1, np.sqrt(2) + 1, np.sqrt(2)])\n    proj_scale = (\n        torch.autograd.Variable(torch.FloatTensor(proj_scale_np).to(ortho5d))\n        .reshape(1, 3)\n        .repeat(batch, 1)\n    )\n\n    u = stereographic_unproject(ortho5d[:, 2:5] * proj_scale, axis=0)\n    norm = torch.sqrt(torch.pow(u[:, 1:], 2).sum(1))\n    u = u / norm.reshape(batch, 1).repeat(1, u.shape[1])\n    b = torch.cat((ortho5d[:, 0:2], u), 1)\n    matrix = rotation_from_ortho_6d(b)\n    return matrix\n\n\ndef rotation_from_euler(euler):\n    \"\"\"\n    Convert a 3-d Euler angle representation to rotation matrix\n    \"\"\"\n    batch = euler.shape[0]\n\n    c1 = torch.cos(euler[:, 0]).reshape(batch, 1)\n    s1 = torch.sin(euler[:, 0]).reshape(batch, 1)\n    c2 = torch.cos(euler[:, 2]).reshape(batch, 1)\n    s2 = torch.sin(euler[:, 2]).reshape(batch, 1)\n    c3 = torch.cos(euler[:, 1]).reshape(batch, 1)\n    s3 = torch.sin(euler[:, 1]).reshape(batch, 1)\n\n    row1 = torch.cat((c2 * c3, -s2, c2 * s3), 1).reshape(-1, 1, 3)\n    row2 = torch.cat(\n        (c1 * s2 * c3 + s1 * s3, c1 * c2, c1 * s2 * s3 - s1 * c3), 1\n    ).reshape(-1, 1, 3)\n    row3 = torch.cat(\n        (s1 * s2 * c3 - c1 * s3, s1 * c2, s1 * s2 * s3 + c1 * c3), 1\n    ).reshape(-1, 1, 3)\n\n    matrix = torch.cat((row1, row2, row3), 1)\n    return matrix\n\n\ndef fit_to_SO3(R):\n    # Math used from quora post and this berkeley pdf.\n    # https://qr.ae/pKQaG5\n    # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf\n    #\n    # Input:\n    #   R - torch 3x3 rotation matrix that is not quite orthogonal\n    # Output:\n    #   Q - torch 3x3 nearest valid rotation matrix\n    assert R.ndim == 2\n    assert R.shape[0] == 3 and R.shape[1] == 3\n    B = R\n    I = torch.eye(3)\n    Y = B.transpose(-2, -1) @ B - I\n    Q = B - B @ Y @ (\n        I / 2.0 - (3.0 * Y) / 8.0 + (5 * Y @ Y) / 16 - (35 * Y @ Y @ Y @ Y) / 128\n    )\n    return Q\n"
  },
  {
    "path": "efm3d/aria/projection_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\n\ndef sign_plus(x):\n    \"\"\"\n    return +1 for positive and for 0.0 in x. This is important for our handling\n    of z values that should never be 0.0\n    \"\"\"\n    sgn = torch.ones_like(x)\n    sgn[sgn < 0.0] = -1.0\n    return sgn\n\n\n@torch.jit.script\ndef fisheye624_project(xyz, params):\n    \"\"\"\n    Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera\n    model project() function.\n\n    Inputs:\n        xyz: Bx(T)xNx3 tensor of 3D points to be projected\n        params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this:\n                [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]\n                or Bx(T)x15 tensor of Fisheye624 parameters formatted like this:\n                [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]\n    Outputs:\n        uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane\n\n    Model for fisheye cameras with radial, tangential, and thin-prism distortion.\n    This model allows fu != fv.\n    Specifically, the model is:\n    uvDistorted = [x_r]  + tangentialDistortion  + thinPrismDistortion\n                  [y_r]\n    proj = diag(fu,fv) * uvDistorted + [cu;cv];\n    where:\n      a = x/z, b = y/z, r = (a^2+b^2)^(1/2)\n      th = atan(r)\n      cosPhi = a/r, sinPhi = b/r\n      [x_r]  = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi]\n      [y_r]                                     [sinPhi]\n      the number of terms in the series is determined by the template parameter numK.\n      tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]\n                             [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0]\n      where rd^2 = x_r^2 + y_r^2\n      thinPrismDistortion = [s0 * rd^2 + s1 rd^4]\n                            [s2 * rd^2 + s3 rd^4]\n    \"\"\"\n\n    assert (xyz.ndim == 3 and params.ndim == 2) or (\n        xyz.ndim == 4 and params.ndim == 3\n    ), f\"point dim {xyz.shape} does not match cam parameter dim {params}\"\n    assert xyz.shape[-1] == 3\n    assert params.shape[-1] == 16 or params.shape[-1] == 15, (\n        \"This model allows fx != fy\"\n    )\n    assert xyz.dtype == params.dtype, \"data type must match\"\n\n    eps = 1e-9\n    T = -1\n    if xyz.ndim == 4:\n        # has T dim\n        T, N = xyz.shape[1], xyz.shape[2]\n        xyz = xyz.reshape(-1, N, 3)  # (BxT)xNx3\n        params = params.reshape(-1, params.shape[-1])  #  (BxT)x16\n\n    B, N = xyz.shape[0], xyz.shape[1]\n\n    # Radial correction.\n    z = xyz[:, :, 2].reshape(B, N, 1)\n    # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a\n    # nan when we compute xy/z\n    z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z)\n    ab = xyz[:, :, :2] / z\n    # make sure abs are not too small or 0 otherwise gradients are nan\n    ab = torch.where(torch.abs(ab) < eps, eps * sign_plus(ab), ab)\n    r = torch.norm(ab, dim=-1, p=2, keepdim=True)\n    th = torch.atan(r)\n    th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)\n    th_k = th.reshape(B, N, 1).clone()\n    for i in range(6):\n        th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2)\n    xr_yr = th_k * th_divr\n    uv_dist = xr_yr\n\n    # Tangential correction.\n    p0 = params[:, -6].reshape(B, 1)\n    p1 = params[:, -5].reshape(B, 1)\n    xr = xr_yr[:, :, 0].reshape(B, N)\n    yr = xr_yr[:, :, 1].reshape(B, N)\n    xr_yr_sq = torch.square(xr_yr)\n    xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)\n    yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)\n    rd_sq = xr_sq + yr_sq\n    uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)\n    uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)\n    uv_dist = torch.stack(\n        [uv_dist_tu, uv_dist_tv], dim=-1\n    )  # Avoids in-place complaint.\n\n    # Thin Prism correction.\n    s0 = params[:, -4].reshape(B, 1)\n    s1 = params[:, -3].reshape(B, 1)\n    s2 = params[:, -2].reshape(B, 1)\n    s3 = params[:, -1].reshape(B, 1)\n    rd_4 = torch.square(rd_sq)\n    uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)\n    uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)\n\n    # Finally, apply standard terms: focal length and camera centers.\n    if params.shape[-1] == 15:\n        fx_fy = params[:, 0].reshape(B, 1, 1)\n        cx_cy = params[:, 1:3].reshape(B, 1, 2)\n    else:\n        fx_fy = params[:, 0:2].reshape(B, 1, 2)\n        cx_cy = params[:, 2:4].reshape(B, 1, 2)\n    result = uv_dist * fx_fy + cx_cy\n\n    if T > 0:\n        result = result.reshape(B // T, T, N, 2)\n\n    assert result.ndim == 4 or result.ndim == 3\n    assert result.shape[-1] == 2\n\n    return result\n\n\n@torch.jit.script\ndef fisheye624_unproject(uv, params, max_iters: int = 5):\n    \"\"\"\n    Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera\n    model. There is no analytical solution for the inverse of the project()\n    function so this solves an optimization problem using Newton's method to get\n    the inverse.\n\n    Inputs:\n        uv: Bx(T)xNx2 tensor of 2D pixels to be projected\n        params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this:\n                [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]\n                or Bx(T)x15 tensor of Fisheye624 parameters formatted like this:\n                [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]\n    Outputs:\n        xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1.\n\n    Model for fisheye cameras with radial, tangential, and thin-prism distortion.\n    This model assumes fu=fv. This unproject function holds that:\n\n    X = unproject(project(X))     [for X=(x,y,z) in R^3, z>0]\n\n    and\n\n    x = project(unproject(s*x))   [for s!=0 and x=(u,v) in R^2]\n    \"\"\"\n\n    assert uv.ndim == 3 or uv.ndim == 4, \"Expected batched input shaped Bx(T)xNx2\"\n    assert uv.shape[-1] == 2\n    assert params.ndim == 2 or params.ndim == 3, (\n        \"Expected batched input shaped Bx(T)x16 or Bx(T)x15\"\n    )\n    assert params.shape[-1] == 16 or params.shape[-1] == 15, (\n        \"This model allows fx != fy\"\n    )\n    assert uv.dtype == params.dtype, \"data type must match\"\n    eps = 1e-6\n\n    T = -1\n    if uv.ndim == 4:\n        # has T dim\n        T, N = uv.shape[1], uv.shape[2]\n        uv = uv.reshape(-1, N, 2)  # (BxT)xNx2\n        params = params.reshape(-1, params.shape[-1])  #  (BxT)x16\n\n    B, N = uv.shape[0], uv.shape[1]\n\n    if params.shape[-1] == 15:\n        fx_fy = params[:, 0].reshape(B, 1, 1)\n        cx_cy = params[:, 1:3].reshape(B, 1, 2)\n    else:\n        fx_fy = params[:, 0:2].reshape(B, 1, 2)\n        cx_cy = params[:, 2:4].reshape(B, 1, 2)\n\n    uv_dist = (uv - cx_cy) / fx_fy\n\n    # Compute xr_yr using Newton's method.\n    xr_yr = uv_dist.clone()  # Initial guess.\n    for _ in range(max_iters):\n        uv_dist_est = xr_yr.clone()\n        # Tangential terms.\n        p0 = params[:, -6].reshape(B, 1)\n        p1 = params[:, -5].reshape(B, 1)\n        xr = xr_yr[:, :, 0].reshape(B, N)\n        yr = xr_yr[:, :, 1].reshape(B, N)\n        xr_yr_sq = torch.square(xr_yr)\n        xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)\n        yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)\n        rd_sq = xr_sq + yr_sq\n        uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (\n            (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1\n        )\n        uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (\n            (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0\n        )\n        # Thin Prism terms.\n        s0 = params[:, -4].reshape(B, 1)\n        s1 = params[:, -3].reshape(B, 1)\n        s2 = params[:, -2].reshape(B, 1)\n        s3 = params[:, -1].reshape(B, 1)\n        rd_4 = torch.square(rd_sq)\n        uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)\n        uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)\n        # Compute the derivative of uv_dist w.r.t. xr_yr.\n        duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)\n        duv_dist_dxr_yr[:, :, 0, 0] = (\n            1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1\n        )\n        offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0)\n        duv_dist_dxr_yr[:, :, 0, 1] = offdiag\n        duv_dist_dxr_yr[:, :, 1, 0] = offdiag\n        duv_dist_dxr_yr[:, :, 1, 1] = (\n            1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0\n        )\n        xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1]\n        temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)\n        duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (\n            xr_yr[:, :, 0] * temp1\n        )\n        duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (\n            xr_yr[:, :, 1] * temp1\n        )\n        temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)\n        duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (\n            xr_yr[:, :, 0] * temp2\n        )\n        duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (\n            xr_yr[:, :, 1] * temp2\n        )\n        # Compute 2x2 inverse manually here since torch.inverse() is very slow.\n        # Because this is slow: inv = duv_dist_dxr_yr.inverse()\n        # About a 10x reduction in speed with above line.\n        mat = duv_dist_dxr_yr.reshape(-1, 2, 2)\n        a = mat[:, 0, 0].reshape(-1, 1, 1)\n        b = mat[:, 0, 1].reshape(-1, 1, 1)\n        c = mat[:, 1, 0].reshape(-1, 1, 1)\n        d = mat[:, 1, 1].reshape(-1, 1, 1)\n        det = 1.0 / ((a * d) - (b * c))\n        top = torch.cat([d, -b], dim=2)\n        bot = torch.cat([-c, a], dim=2)\n        inv = det * torch.cat([top, bot], dim=1)\n        inv = inv.reshape(B, N, 2, 2)\n        # Manually compute 2x2 @ 2x1 matrix multiply.\n        # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0]\n        diff = uv_dist - uv_dist_est\n        a = inv[:, :, 0, 0]\n        b = inv[:, :, 0, 1]\n        c = inv[:, :, 1, 0]\n        d = inv[:, :, 1, 1]\n        e = diff[:, :, 0]\n        f = diff[:, :, 1]\n        step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)\n        # Newton step.\n        xr_yr = xr_yr + step\n\n    # Compute theta using Newton's method.\n    xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)\n    th = xr_yr_norm.clone()\n    for _ in range(max_iters):\n        th_radial = uv.new_ones(B, N, 1)\n        dthd_th = uv.new_ones(B, N, 1)\n        for k in range(6):\n            r_k = params[:, -12 + k].reshape(B, 1, 1)\n            th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2))\n            dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2))\n        th_radial = th_radial * th\n        step = (xr_yr_norm - th_radial) / dthd_th\n        # handle dthd_th close to 0.\n        step = torch.where(dthd_th.abs() > eps, step, sign_plus(step) * eps * 10.0)\n        th = th + step\n    # Compute the ray direction using theta and xr_yr.\n    close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)\n    ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)\n    ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)\n    assert ray.shape[-1] == 3\n\n    if T > 0:\n        ray = ray.reshape(B // T, T, N, 3)\n\n    return ray\n\n\ndef pinhole_project(xyz, params):\n    \"\"\"\n    Batched implementation of the Pinhole (aka Linear) camera\n    model project() function.\n\n    Inputs:\n        xyz: Bx(T)xNx3 tensor of 3D points to be projected\n        params: Bx(T)x4 tensor of Pinhole parameters formatted like this:\n                [f_u f_v c_u c_v]\n    Outputs:\n        uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane\n    \"\"\"\n\n    assert (xyz.ndim == 3 and params.ndim == 2) or (xyz.ndim == 4 and params.ndim == 3)\n    assert params.shape[-1] == 4\n    eps = 1e-9\n\n    # Focal length and principal point\n    fx_fy = params[..., 0:2].reshape(*xyz.shape[:-2], 1, 2)\n    cx_cy = params[..., 2:4].reshape(*xyz.shape[:-2], 1, 2)\n    # Make sure depth is not too close to zero.\n    z = xyz[..., 2:]\n    # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a\n    # nan when we compute xy/z\n    z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z)\n    uv = (xyz[..., :2] / z) * fx_fy + cx_cy\n    return uv\n\n\ndef pinhole_unproject(uv, params, max_iters: int = 5):\n    \"\"\"\n    Batched implementation of the Pinhole (aka Linear) camera model.\n\n    Inputs:\n        uv: Bx(T)xNx2 tensor of 2D pixels to be projected\n        params: Bx(T)x4 tensor of Pinhole parameters formatted like this:\n                [f_u f_v c_u c_v]\n    Outputs:\n        xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1.\n\n    \"\"\"\n    assert uv.ndim == 3 or uv.ndim == 4, \"Expected batched input shaped Bx(T)xNx3\"\n    assert params.ndim == 2 or params.ndim == 3\n    assert params.shape[-1] == 4\n    assert uv.shape[-1] == 2\n\n    # Focal length and principal point\n    fx_fy = params[..., 0:2].reshape(*uv.shape[:-2], 1, 2)\n    cx_cy = params[..., 2:4].reshape(*uv.shape[:-2], 1, 2)\n\n    uv_dist = (uv - cx_cy) / fx_fy\n\n    ray = torch.cat([uv_dist, uv.new_ones(*uv.shape[:-1], 1)], dim=-1)\n    return ray\n"
  },
  {
    "path": "efm3d/aria/tensor_wrapper.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport inspect\nimport logging\nfrom typing import List\n\nimport numpy as np\nimport torch\nfrom torch.utils.data._utils.collate import (\n    collate,\n    collate_tensor_fn,\n    default_collate_fn_map,\n)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\ndef smart_cat(inp_arr, dim=-1):\n    devices = set()\n    for i, inp in enumerate(inp_arr):\n        if isinstance(inp, TensorWrapper):\n            inp_arr[i] = inp._data\n        else:\n            inp_arr[i] = inp\n        devices.add(inp_arr[i].device)\n    if len(devices) > 1:\n        raise RuntimeError(f\"More than one device found! {devices}\")\n    return torch.cat(inp_arr, dim=dim)\n\n\ndef smart_stack(inp_arr, dim: int = 0):\n    devices = set()\n    for i, inp in enumerate(inp_arr):\n        if isinstance(inp, TensorWrapper):\n            inp_arr[i] = inp._data\n        else:\n            inp_arr[i] = inp\n        devices.add(inp_arr[i].device)\n    if len(devices) > 1:\n        raise RuntimeError(f\"More than one device found! {devices}\")\n    return torch.stack(inp_arr, dim=dim)\n\n\ndef get_default_args(func):\n    signature = inspect.signature(func)\n    return {\n        k: v.default\n        for k, v in signature.parameters.items()\n        if v.default is not inspect.Parameter.empty\n    }\n\n\ndef get_nonempty_arg_names(func):\n    spec = inspect.getfullargspec(func)\n    signature = inspect.signature(func)\n    return [\n        k\n        for k in spec.args\n        if signature.parameters[k].default is not inspect.Parameter.empty\n    ]\n\n\ndef autocast(func):\n    \"\"\"Cast the inputs of a TensorWrapper method to PyTorch tensors\n    if they are numpy arrays. Use the device and dtype of the wrapper.\n    \"\"\"\n\n    @functools.wraps(func)\n    def wrap(self, *args):\n        device = torch.device(\"cpu\")\n        dtype = None\n        if isinstance(self, TensorWrapper):\n            if self._data is not None:\n                device = self.device\n                dtype = self.dtype\n        elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):\n            raise ValueError(self)\n\n        cast_args = []\n        for arg in args:\n            if isinstance(arg, np.ndarray):\n                arg = torch.from_numpy(arg)\n                arg = arg.to(device=device, dtype=dtype)\n            cast_args.append(arg)\n\n        return func(self, *cast_args)\n\n    return wrap\n\n\ndef autoinit(func):\n    \"\"\"\n    Helps with initialization. Will auto-reshape and auto-expand input arguments\n    to match the first argument, as well as check shapes based on default tensor sizes.\n    \"\"\"\n\n    @functools.wraps(func)\n    def wrap(self, *args, **kwargs):\n        # Combine args and kwargs.\n        arg_names = get_nonempty_arg_names(func)\n        all_args = {}\n        for i, arg in enumerate(args):\n            all_args[arg_names[i]] = arg\n        for arg_name in kwargs:\n            all_args[arg_name] = kwargs[arg_name]\n\n        # Add default values to all_args if unspecified inputs.\n        default_args = get_default_args(func)\n        extra_args = {}\n        for arg_name in default_args:\n            default_arg = default_args[arg_name]\n            if not isinstance(default_arg, (TensorWrapper, torch.Tensor)):\n                # If not TW or torch tensor, pass it through unperturbed.\n                extra_args[arg_name] = all_args.pop(arg_name)\n            else:\n                if arg_name not in all_args or all_args[arg_name] is None:\n                    all_args[arg_name] = default_arg\n\n        # Auto convert numpy,lists,floats to torch, check that shapes are good.\n        for arg_name in all_args:\n            arg = all_args[arg_name]\n            if isinstance(arg, (torch.Tensor, TensorWrapper)):\n                pass\n            elif isinstance(arg, (int, float)):\n                arg = torch.tensor(arg).reshape(1)\n            elif isinstance(arg, List):\n                arg = torch.tensor(arg)\n            elif isinstance(arg, np.ndarray):\n                arg = torch.from_numpy(arg)\n            else:\n                raise ValueError(\"Unsupported initialization type\")\n            assert isinstance(arg, (torch.Tensor, TensorWrapper))\n\n            default_arg = default_args[arg_name]\n            if isinstance(default_arg, TensorWrapper):\n                # Convert list of torch.Size to tuple of ints.\n                default_dims = tuple([da[0] for da in default_arg.shape])\n            else:\n                default_dims = (default_arg.shape[-1],)\n            if arg.shape[-1] not in default_dims:\n                # probably need a more general solution here to handle single dim inputs.\n                if default_dims[0] == 1:\n                    arg = arg.unsqueeze(-1)\n                if arg.shape[-1] not in default_dims:\n                    raise ValueError(\n                        \"Bad shape of %d for %s, should be in %s\"\n                        % (arg.shape[-1], arg_name, default_dims)\n                    )\n\n            all_args[arg_name] = arg\n\n        # Shape of all inputs is determined by first arg.\n        first_arg_name = arg_names[0]\n        batch_shape = all_args[first_arg_name].shape[:-1]\n\n        has_cuda_tensor = False\n\n        for arg_name in all_args:\n            arg = all_args[arg_name]\n            # Try to trim any extra dimensions at the beginning of arg shape.\n            while True:\n                if arg.ndim > len(batch_shape) and arg.shape[0] == 1 and arg.ndim > 1:\n                    arg = arg.squeeze(0)\n                else:\n                    break\n            arg = arg.expand(*batch_shape, arg.shape[-1])\n            all_args[arg_name] = arg\n\n            if (\n                isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper))\n                and all_args[arg_name].is_cuda\n            ):\n                has_cuda_tensor = True\n\n        if has_cuda_tensor:\n            for arg_name in all_args:\n                if (\n                    isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper))\n                    and not all_args[arg_name].is_cuda\n                ):\n                    all_args[arg_name] = all_args[arg_name].cuda()\n\n        # Add the unperturbed args back to all args.\n        all_args.update(extra_args)\n\n        return func(self, **all_args)\n\n    return wrap\n\n\ndef tensor_wrapper_collate(batch, *, collate_fn_map=None):\n    \"\"\"Simply call stack for TensorWrapper\"\"\"\n    return torch.stack(batch, 0)\n\n\ndef float_collate(batch, *, collate_fn_map=None):\n    \"\"\"Auto convert float to float32\"\"\"\n    return torch.tensor(batch, dtype=torch.float32)\n\n\ndef list_dict_collate(batch, *, collate_fn_map=None):\n    \"\"\"collate lists; handles the case where the lists in the batch are\n    expressing a dict via List[Tuple[key, value]] and returns a Dict[key, value]\n    in that case.\"\"\"\n    if len(batch) > 0:\n        list_0 = batch[0]\n        if len(list_0) > 0:\n            elem_0 = list_0[0]\n            if isinstance(elem_0, tuple) and len(elem_0) == 2:\n                # the lists in each batch sample are (key, value) pairs and we hence return a dictionary\n                for i in range(len(batch)):\n                    batch[i] = {k: v for k, v in batch[i]}\n    return batch\n\n\ndef tensor_wrapper_collate_cat(batch, *, collate_fn_map=None):\n    \"\"\"Simply call cat for TensorWrapper\"\"\"\n    return torch.cat(batch, 0)\n\n\ndef tensor_collate_cat(batch, *, collate_fn_map=None):\n    \"\"\"identical to \"collate_tensor_fn\" but replace torch.stack with torch.cat\"\"\"\n    elem = batch[0]\n    out = None\n    if torch.utils.data.get_worker_info() is not None:\n        # If we're in a background process, concatenate directly into a\n        # shared memory tensor to avoid an extra copy\n        numel = sum(x.numel() for x in batch)\n        # Note: pytorch 1.12 doesn't have the _typed_storage() interface. Need to use storage() instead.\n        # storage = elem._typed_storage()._new_shared(numel, device=elem.device)\n        storage = elem.storage()._new_shared(numel, device=elem.device)\n\n        # since we are using torch.cat, we don't need to add a new dimension here\n        dims_from_one = list(elem.size())[1:]\n        out = elem.new(storage).resize_(len(batch), *dims_from_one)\n    return torch.cat(batch, 0, out=out)  # concatenate instead of stack\n\n\ndef custom_collate_fn(batch):\n    # Get the common keys between samples. This is required when we train with\n    # multiple datasets with samples having different keys.\n    if isinstance(batch, list) and isinstance(batch[0], dict):\n        common_keys = set(batch[0].keys())\n\n        for sample in batch[1:]:\n            common_keys &= set(sample.keys())\n\n        # update the batch with new samples with only the common keys\n        new_batch = []\n        for sample in batch:\n            new_sample = {k: v for k, v in sample.items() if k in common_keys}\n            new_batch.append(new_sample)\n        batch = new_batch\n\n    \"\"\"Custom collate function for tensor wrapper\"\"\"\n    default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate\n    default_collate_fn_map[float] = float_collate\n    default_collate_fn_map[list] = list_dict_collate\n    default_collate_fn_map[torch.Tensor] = collate_tensor_fn\n    if \"already_collated\" in batch[0]:\n        # Use torch.cat instead of torch.stack\n        default_collate_fn_map[torch.Tensor] = tensor_collate_cat\n        default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate_cat\n    batch = collate(batch, collate_fn_map=default_collate_fn_map)\n    return batch\n\n\nclass TensorWrapper:\n    \"\"\"Base class for making \"smart\" tensor objects that behave like pytorch tensors\n    Inpired by Paul-Edouard Sarlin's code here in pixloc:\n    https://github.com/cvg/pixloc/blob/master/pixloc/pixlib/geometry/wrappers.py\n    Adopted and modified by Daniel DeTone.\n    \"\"\"\n\n    _data = None\n\n    @autocast\n    def __init__(self, data: torch.Tensor):\n        self._data = data\n\n    @property\n    def shape(self):\n        return self._data.shape\n\n    @property\n    def device(self):\n        return self._data.device\n\n    @property\n    def dtype(self):\n        return self._data.dtype\n\n    @property\n    def ndim(self):\n        return self._data.ndim\n\n    def dim(self):\n        return self._data.dim()\n\n    def nelement(self):\n        return self._data.nelement()\n\n    def numel(self):\n        return self._data.numel()\n\n    @property\n    def collate_fn(self):\n        return custom_collate_fn\n\n    @property\n    def is_cuda(self):\n        return self._data.is_cuda\n\n    @property\n    def is_contiguous(self):\n        return self._data.is_contiguous\n\n    @property\n    def requires_grad(self):\n        return self._data.requires_grad\n\n    @property\n    def grad(self):\n        return self._data.grad\n\n    @property\n    def grad_fn(self):\n        return self._data.grad_fn\n\n    def requires_grad_(self, requires_grad: bool = True):\n        self._data.requires_grad_(requires_grad)\n\n    def __getitem__(self, index):\n        return self.__class__(self._data[index])\n\n    def __setitem__(self, index, item):\n        self._data[index] = item._data\n\n    def to(self, *args, **kwargs):\n        return self.__class__(self._data.to(*args, **kwargs))\n\n    def reshape(self, *args, **kwargs):\n        return self.__class__(self._data.reshape(*args, **kwargs))\n\n    def repeat(self, *args, **kwargs):\n        return self.__class__(self._data.repeat(*args, **kwargs))\n\n    def expand(self, *args, **kwargs):\n        return self.__class__(self._data.expand(*args, **kwargs))\n\n    def clone(self):\n        return self.__class__(self._data.clone())\n\n    def cpu(self):\n        return self.__class__(self._data.cpu())\n\n    def cuda(self, gpu_id=0):\n        return self.__class__(self._data.cuda(gpu_id))\n\n    def contiguous(self):\n        return self.__class__(self._data.contiguous())\n\n    def pin_memory(self):\n        return self.__class__(self._data.pin_memory())\n\n    def float(self):\n        return self.__class__(self._data.float())\n\n    def double(self):\n        return self.__class__(self._data.double())\n\n    def detach(self):\n        return self.__class__(self._data.detach())\n\n    def numpy(self):\n        return self._data.numpy()\n\n    def tensor(self):\n        return self._data\n\n    def tolist(self):\n        return self._data.tolist()\n\n    def squeeze(self, dim=None):\n        assert dim != -1 and dim != self._data.dim() - 1\n        if dim is None:\n            return self.__class__(self._data.squeeze())\n        return self.__class__(self._data.squeeze(dim=dim))\n\n    def unsqueeze(self, dim=None):\n        assert dim != -1 and dim != self._data.dim()\n        return self.__class__(self._data.unsqueeze(dim=dim))\n\n    def view(self, *shape):\n        assert shape[-1] == -1 or shape[-1] == self._data.shape[-1]\n        return self.__class__(self._data.view(*shape))\n\n    def __len__(self):\n        return self._data.shape[0]\n\n    @classmethod\n    def stack(cls, objects: List, dim=0, *, out=None):\n        data = torch.stack([obj._data for obj in objects], dim=dim, out=out)\n        return cls(data)\n\n    @classmethod\n    def cat(cls, objects: List, dim=0, *, out=None):\n        data = torch.cat([obj._data for obj in objects], dim=dim, out=out)\n        return cls(data)\n\n    @classmethod\n    def allclose(\n        cls,\n        input: torch.Tensor,\n        other: torch.Tensor,\n        rtol=1e-5,\n        atol=1e-8,\n        equal_nan=False,\n    ):\n        return torch.allclose(\n            input._data, other._data, rtol=rtol, atol=atol, equal_nan=equal_nan\n        )\n\n    @classmethod\n    def take_along_dim(cls, obj, indices, dim, *, out=None):\n        data = torch.take_along_dim(obj._data, indices, dim, out=out)\n        return cls(data)\n\n    @classmethod\n    def flatten(cls, obj, start_dim=0, end_dim=-1):\n        data = torch.flatten(obj._data, start_dim=start_dim, end_dim=end_dim)\n        return cls(data)\n\n    @classmethod\n    def __torch_function__(self, func, types, args=(), kwargs=None):\n        if kwargs is None:\n            kwargs = {}\n        if func is torch.stack:\n            return self.stack(*args, **kwargs)\n        elif func is torch.cat:\n            return self.cat(*args, **kwargs)\n        elif func is torch.allclose:\n            return self.allclose(*args, **kwargs)\n        elif func is torch.take_along_dim:\n            return self.take_along_dim(*args, **kwargs)\n        elif func is torch.flatten:\n            return self.flatten(*args, **kwargs)\n        else:\n            return NotImplemented\n"
  },
  {
    "path": "efm3d/config/efm_preprocessing_conf.yaml",
    "content": "atek_config_name: \"efm\"\ncamera_temporal_subsampler:\n  main_camera_label: \"camera-rgb\"\n  time_domain: \"DEVICE_TIME\"\n  main_camera_target_freq_hz: 10.0\n  sample_length_in_num_frames: 20\n  stride_length_in_num_frames: 10\nprocessors:\n  rgb:\n    selected: true\n    sensor_label: \"camera-rgb\"\n    time_domain: \"DEVICE_TIME\"\n    tolerance_ns: 10_000_000\n    undistort_to_linear_cam: false  # if set, undistort to a linear camera model\n    target_camera_resolution: [240, 240] # if set, rescale to [image_width, image_height]\n    rescale_antialias: false # to be consistent with cv2\n    rotate_image_cw90deg: false # if set, rotate image by 90 degrees clockwise\n  slam_left:\n    selected: true\n    sensor_label: \"camera-slam-left\"\n    tolerance_ns: 10_000_000\n    time_domain: \"DEVICE_TIME\"\n    target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height]\n    rescale_antialias: false # to be consistent with cv2\n  slam_right:\n    selected: true\n    sensor_label: \"camera-slam-right\"\n    tolerance_ns: 10_000_000\n    time_domain: \"DEVICE_TIME\"\n    target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height]\n    rescale_antialias: false # to be consistent with cv2\n  mps_traj:\n    selected: true\n    tolerance_ns: 10_000_000\n  mps_semidense:\n    selected: true\n    tolerance_ns: 10_000_000\n  rgb_depth:\n    selected: true\n    sensor_stream_id: \"345-1\" # 345-1 for ADT data, 214-8 for ASE data\n    tolerance_ns: 10_000_000\n    time_domain: \"DEVICE_TIME\"\n    convert_zdepth_to_dist: false\n  efm_gt:\n    selected: true\n    tolerance_ns : 10_000_000\n    category_mapping_field_name: category # {prototype_name, category}\nwds_writer:\n  prefix_string: \"\"\n  max_samples_per_shard: 8\n"
  },
  {
    "path": "efm3d/config/evl_inf.yaml",
    "content": "_target_: efm3d.model.evl.EVL\nneck_hidden_dims: [128, 256, 512]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv\n\nvideo_backbone:\n  _target_: efm3d.model.video_backbone.VideoBackboneDinov2\n  freeze_encoder: true\n  image_tokenizer:\n    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens\n    dinov2_name: vit_base_v25\n    freeze: true\n    handle_rotated_data: true\n    dim_out: 768\n    add_lin_layer: false\n    multilayer_output: true\n    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth\n  video_streams: [rgb]\n  correct_vignette: false\n  optimize_vignette: false\nvideo_backbone3d:\n  _target_: efm3d.model.lifter.Lifter\n  in_dim: 768\n  out_dim: 64\n  patch_size: 16\n  voxel_size: [96,96,96]\n  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]\n  head_type: dpt_ori\n  streams: [rgb]\n  joint_slam_streams: false\n  joint_streams: false\n"
  },
  {
    "path": "efm3d/config/evl_inf_desktop.yaml",
    "content": "_target_: efm3d.model.evl.EVL\nneck_hidden_dims: [32, 64, 128]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv\n\nvideo_backbone:\n  _target_: efm3d.model.video_backbone.VideoBackboneDinov2\n  freeze_encoder: true\n  image_tokenizer:\n    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens\n    dinov2_name: vit_base_v25\n    freeze: true\n    handle_rotated_data: true\n    dim_out: 768\n    add_lin_layer: false\n    multilayer_output: true\n    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth\n  video_streams: [rgb]\n  correct_vignette: false\n  optimize_vignette: false\nvideo_backbone3d:\n  _target_: efm3d.model.lifter.Lifter\n  in_dim: 768\n  out_dim: 32\n  patch_size: 16\n  voxel_size: [48,48,48]\n  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]\n  head_type: dpt_ori\n  streams: [rgb]\n  joint_slam_streams: false\n  joint_streams: false\n"
  },
  {
    "path": "efm3d/config/evl_train.yaml",
    "content": "_target_: efm3d.model.evl_train.EVLTrain\nneck_hidden_dims: [128, 256, 512]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv\n\nvideo_backbone:\n  _target_: efm3d.model.video_backbone.VideoBackboneDinov2\n  freeze_encoder: true\n  image_tokenizer:\n    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens\n    dinov2_name: vit_base_v25\n    freeze: true\n    handle_rotated_data: true\n    dim_out: 768\n    add_lin_layer: false\n    multilayer_output: true\n    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth\n  video_streams: [rgb]\n  correct_vignette: false\n  optimize_vignette: false\nvideo_backbone3d:\n  _target_: efm3d.model.lifter.Lifter\n  in_dim: 768\n  out_dim: 64\n  patch_size: 16\n  voxel_size: [96,96,96]\n  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]\n  head_type: dpt_ori\n  streams: [rgb]\n  joint_slam_streams: false\n  joint_streams: false\n"
  },
  {
    "path": "efm3d/config/taxonomy/aeo_to_efm.csv",
    "content": "AEO Category Name,EFM Category Name,EFM Category Id\nChair,chair,3\nCouch,sofa,1\nTable,table,0\nBed,bed,4\nWallArt,picture_frame,21\nPlant,flower_pot,13\nWindow,window,28\nMirror,mirror,22\nLamp,lamp,16\n"
  },
  {
    "path": "efm3d/config/taxonomy/ase_sem_name_to_id.csv",
    "content": "sem_name,sem_id\r\ntable,0\r\nsofa,1\r\nshelf,2\r\nchair,3\r\nbed,4\r\nfloor_mat,5\r\nexercise_weight,6\r\ncutlery,7\r\ncontainer,8\r\nclock,9\r\ncart,10\r\nvase,11\r\ntent,12\r\nflower_pot,13\r\npillow,14\r\nmount,15\r\nlamp,16\r\nladder,17\r\nfan,18\r\ncabinet,19\r\njar,20\r\npicture_frame,21\r\nmirror,22\r\nelectronic_device,23\r\ndresser,24\r\nclothes_rack,25\r\nbattery_charger,26\r\nair_conditioner,27\r\nwindow,28\r\n"
  },
  {
    "path": "efm3d/config/taxonomy/atek_to_efm.csv",
    "content": "ATEK Category Name,EFM version ASE Category Name,EFM version ASE Category Id\r\ntable,table,0\r\nsofa,sofa,1\r\nshelves,shelf,2\r\nchair,chair,3\r\nbed,bed,4\r\nfloor mat,floor_mat,5\r\nexercise_weight,exercise_weight,6\r\ncutlery,cutlery,7\r\ncontainer,container,8\r\nclock,clock,9\r\ncart,cart,10\r\nvase,vase,11\r\ntent,tent,12\r\nplant,flower_pot,13\r\npillow,pillow,14\r\nmount,mount,15\r\nlamp,lamp,16\r\nladder,ladder,17\r\nfan,fan,18\r\ncabinet,cabinet,19\r\njar,jar,20\r\npicture,picture_frame,21\r\nmirror,mirror,22\r\nelectronic_device,electronic_device,23\r\ndresser,dresser,24\r\nclothes_rack,clothes_rack,25\r\nbattery_charger,battery_charger,26\r\nair_conditioner,air_conditioner,27\r\nwindow,window,28\r\n"
  },
  {
    "path": "efm3d/dataset/atek_vrs_dataset.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pyre-strict\n\nimport logging\nimport os\nfrom typing import Dict, List, Optional\n\nfrom atek.data_loaders.atek_wds_dataloader import select_and_remap_dict_keys\nfrom atek.data_preprocess.atek_data_sample import AtekDataSample\nfrom atek.data_preprocess.sample_builders.atek_data_paths_provider import (\n    AtekDataPathsProvider,\n)\nfrom atek.data_preprocess.sample_builders.efm_sample_builder import EfmSampleBuilder\nfrom atek.data_preprocess.subsampling_lib.temporal_subsampler import (\n    CameraTemporalSubsampler,\n)\nfrom efm3d.dataset.efm_model_adaptor import EfmModelAdaptor\nfrom omegaconf.omegaconf import DictConfig, OmegaConf\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\nclass AtekRawDataloaderAsEfm:\n    def __init__(\n        self,\n        vrs_file: str,\n        mps_files: Dict[str, str],\n        gt_files: Dict[str, str],\n        conf: DictConfig,\n        freq_hz: int,\n        snippet_length_s: float,\n        semidense_points_pad_to_num: int = 50000,\n        max_snippets=9999,\n    ) -> None:\n        self.max_snippets = max_snippets\n\n        # initialize the sample builder\n        self.sample_builder = EfmSampleBuilder(\n            conf=conf.processors,\n            vrs_file=vrs_file,\n            mps_files=mps_files,\n            gt_files=gt_files,\n            depth_vrs_file=\"\",\n            sequence_name=os.path.basename(vrs_file),\n        )\n\n        self.subsampler = CameraTemporalSubsampler(\n            vrs_file=vrs_file,\n            conf=conf.camera_temporal_subsampler,\n        )\n\n        # Create a EFM model adaptor\n        self.model_adaptor = EfmModelAdaptor(\n            freq=freq_hz,\n            snippet_length_s=snippet_length_s,\n            semidense_points_pad_to_num=semidense_points_pad_to_num,\n            atek_to_efm_taxonomy_mapping_file=f\"{os.path.dirname(__file__)}/../config/taxonomy/atek_to_efm.csv\",\n        )\n\n    def __len__(self):\n        return min(self.subsampler.get_total_num_samples(), self.max_snippets)\n\n    def get_timestamps_by_sample_index(self, index: int) -> List[int]:\n        return self.subsampler.get_timestamps_by_sample_index(index)\n\n    def get_atek_sample_at_timestamps_ns(\n        self, timestamps_ns: List[int]\n    ) -> Optional[AtekDataSample]:\n        return self.sample_builder.get_sample_by_timestamps_ns(timestamps_ns)\n\n    def get_model_specific_sample_at_timestamps_ns(\n        self, timestamps_ns: List[int]\n    ) -> Optional[Dict]:\n        atek_sample = self.get_atek_sample_at_timestamps_ns(timestamps_ns)\n        if atek_sample is None:\n            logger.warning(\n                f\"Cannot retrieve valid atek sample at timestamp {timestamps_ns}\"\n            )\n            return None\n\n        # Flatten to dict\n        atek_sample_dict = atek_sample.to_flatten_dict()\n\n        # key remapping\n        remapped_data_dict = select_and_remap_dict_keys(\n            sample_dict=atek_sample_dict,\n            key_mapping=self.model_adaptor.get_dict_key_mapping_all(),\n        )\n\n        # transform\n        model_specific_sample_gen = self.model_adaptor.atek_to_efm([remapped_data_dict])\n\n        # Obtain a dict from a generator object\n        model_specific_sample = next(model_specific_sample_gen)\n\n        return model_specific_sample\n\n    def __getitem__(self, index):\n        if index >= self.max_snippets:\n            raise StopIteration\n\n        timestamps = self.get_timestamps_by_sample_index(index)\n        maybe_sample = self.get_model_specific_sample_at_timestamps_ns(timestamps)\n\n        return maybe_sample\n\n\ndef create_atek_raw_data_loader_from_vrs_path(\n    vrs_path: str,\n    freq_hz: int,\n    snippet_length_s,\n    stride_length_s,\n    skip_begin_seconds: float = 0.0,\n    skip_end_seconds: float = 0.0,\n    semidense_points_pad_to_num=50000,\n    max_snippets=9999,\n):\n    vrs_dir = os.path.dirname(vrs_path)\n    data_path_provider = AtekDataPathsProvider(data_root_path=vrs_dir)\n    atek_data_paths = data_path_provider.get_data_paths()\n\n    conf = OmegaConf.load(\"efm3d/config/efm_preprocessing_conf.yaml\")\n\n    # Update snippet / stride length\n    conf.camera_temporal_subsampler.main_camera_target_freq_hz = float(freq_hz)\n    conf.camera_temporal_subsampler.sample_length_in_num_frames = int(\n        freq_hz * snippet_length_s\n    )\n    conf.camera_temporal_subsampler.stride_length_in_num_frames = int(\n        freq_hz * stride_length_s\n    )\n    conf.camera_temporal_subsampler.update(\n        {\n            \"skip_begin_seconds\": skip_begin_seconds,\n            \"skip_end_seconds\": skip_end_seconds,\n        }\n    )\n\n    data_loader = AtekRawDataloaderAsEfm(\n        vrs_file=atek_data_paths[\"video_vrs_file\"],\n        mps_files={\n            \"mps_closedloop_traj_file\": atek_data_paths[\"mps_closedloop_traj_file\"],\n            \"mps_semidense_points_file\": atek_data_paths[\"mps_semidense_points_file\"],\n            \"mps_semidense_observations_file\": atek_data_paths[\n                \"mps_semidense_observations_file\"\n            ],\n        },\n        gt_files={\n            \"obb3_file\": atek_data_paths[\"gt_obb3_file\"],\n            \"obb3_traj_file\": atek_data_paths[\"gt_obb3_traj_file\"],\n            \"obb2_file\": atek_data_paths[\"gt_obb2_file\"],\n            \"instance_json_file\": atek_data_paths[\"gt_instance_json_file\"],\n        },\n        conf=conf,\n        freq_hz=freq_hz,\n        snippet_length_s=snippet_length_s,\n        semidense_points_pad_to_num=semidense_points_pad_to_num,\n        max_snippets=max_snippets,\n    )\n\n    return data_loader\n"
  },
  {
    "path": "efm3d/dataset/atek_wds_dataset.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport glob\nimport tarfile\n\nimport numpy as np\nimport torch\nimport webdataset as wds\nfrom efm3d.aria import CameraTW, DEFAULT_CAM_DATA_SIZE, ObbTW, PoseTW, TensorWrapper\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_OBB_PADDED,\n    ARIA_POINTS_VOL_MAX,\n    ARIA_POINTS_VOL_MIN,\n    ARIA_POINTS_WORLD,\n    ARIA_POSE_T_SNIPPET_RIG,\n    ARIA_POSE_T_WORLD_RIG,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.dataset.efm_model_adaptor import load_atek_wds_dataset_as_efm\n\n\ndef batchify(datum, device=None):\n    # Add batch dimension\n    for key in datum:\n        if isinstance(datum[key], (torch.Tensor, TensorWrapper)):\n            datum[key] = datum[key][None, ...].to(device)\n            if device is not None:\n                datum[key] = datum[key].to(device)\n        else:\n            datum[key] = [datum[key]]\n    return datum\n\n\ndef unbatchify(datum):\n    # Remove batch dimension\n    for key in datum:\n        if isinstance(datum[key], (torch.Tensor, TensorWrapper, list)):\n            datum[key] = datum[key][0]\n    return datum\n\n\nclass AtekWdsStreamDataset:\n    \"\"\"Sample 2s/1s WDS dataset to specified snippet length and stride\"\"\"\n\n    def __init__(\n        self,\n        data_path,\n        atek_to_efm_taxonomy,\n        snippet_length_s=1.0,\n        stride_length_s=0.1,\n        wds_length_s=2.0,\n        fps=10,\n        max_snip=99999999,\n    ):\n        self.snippet_length_s = snippet_length_s\n        self.stride_length_s = stride_length_s\n        self.wds_length_s = wds_length_s\n        # wds snippets should always be generated half overlapped\n        self.wds_stride_s = wds_length_s // 2\n        self.fps = fps\n        self.max_snip = max_snip\n\n        tar_list = sorted(glob.glob(f\"{data_path}/*.tar\"))\n        sn = set()\n        with tarfile.TarFile(tar_list[0], \"r\") as tar:\n            for member in tar.getmembers():\n                sn.add(member.name.split(\".\")[0])\n        self.samples_per_tar = len(sn)\n        self.num_tars = len(tar_list)\n\n        self.dataset = load_atek_wds_dataset_as_efm(\n            urls=tar_list,\n            freq=fps,\n            snippet_length_s=wds_length_s,  # Need to use `wds_length` for model adaptor!\n            atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy,\n        )\n        self.dataloader = iter(self.dataset)\n\n        self.frames_wds = int(self.fps * self.wds_length_s)\n        self.frames_out = int(self.fps * self.snippet_length_s)\n        self.frames_stride_wds = int(self.fps * self.wds_stride_s)\n        self.frames_stride_out = int(self.fps * self.stride_length_s)\n\n        self.num_rest = int(\n            (self.wds_length_s - self.snippet_length_s) / self.stride_length_s\n        )\n        self.num_first = int(1 + self.num_rest)\n        self.num_snippets = (\n            self.num_first + (self.samples_per_tar * self.num_tars - 1) * self.num_rest\n        )\n\n        # for iteration\n        self.first = True\n        self.wds_snippet = None\n        self.snip_idx = 0\n        self.global_idx = 0\n\n    def __len__(self):\n        return min(self.num_snippets, self.max_snip)\n\n    def sample_snippet_(self, snippet, start, end):\n        # time crop\n        sample = snippet.copy()\n        for k in sample:\n            if isinstance(sample[k], (torch.Tensor, TensorWrapper)):\n                if k not in [\n                    ARIA_SNIPPET_T_WORLD_SNIPPET,\n                    ARIA_POINTS_VOL_MIN,\n                    ARIA_POINTS_VOL_MAX,\n                ]:\n                    sample[k] = sample[k][start:end, ...]\n\n        return sample\n\n    def __iter__(self):\n        return self\n\n    def if_get_next_(self):\n        if self.wds_snippet is None:\n            return True\n\n        if self.first:\n            return self.snip_idx >= self.num_first\n        else:\n            return self.snip_idx >= self.num_rest\n\n    def __next__(self):\n        if self.global_idx >= self.max_snip:\n            raise StopIteration\n\n        if self.if_get_next_():\n            if self.first and self.wds_snippet is not None:\n                self.first = False\n            self.wds_snippet = next(self.dataloader)\n            self.snip_idx = 0\n\n        if self.first:\n            start = self.snip_idx * self.frames_stride_out\n        else:\n            start = (self.snip_idx + 1) * self.frames_stride_out\n\n        end = start + self.frames_out\n        sample = self.sample_snippet_(self.wds_snippet, start, end)\n        self.snip_idx += 1\n        self.global_idx += 1\n        return sample\n"
  },
  {
    "path": "efm3d/dataset/augmentation.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom functools import partial\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torchvision\nfrom efm3d.aria.aria_constants import (\n    ARIA_IMG,\n    ARIA_POINTS_DIST_STD,\n    ARIA_POINTS_INV_DIST_STD,\n    ARIA_POINTS_WORLD,\n)\nfrom torchvision.transforms.v2._color import RandomAdjustSharpness\nfrom webdataset import WebDataset\n\nlogging.basicConfig()\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass ColorJitter:\n    \"\"\"\n    Applies photometric jitter to the images in the video sequence.\n    \"\"\"\n\n    def __init__(\n        self,\n        brightness: Union[Tuple[float], float] = 0.5,\n        contrast: Union[Tuple[float], float] = 0.3,\n        saturation: Union[Tuple[float], float] = 0.3,\n        hue: Union[Tuple[float], float] = 0.05,\n        sharpness: Union[Tuple[float], float] = 2.0,\n        snippet_jitter: bool = False,\n    ):\n        \"\"\"\n        Calls torchvision on the images independently in a video using:\n        https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html\n\n\n        brightness: how much to jitter brightness in range [0,val]\n        contrast: how much to jitter contrast in range [0,val]\n        saturation: how much to jitter contrast in range [0,val]\n        hue: how much to jitter hue in range [-val,val]\n        snippet_jitter: if true, jitter equally across the snippet\n        \"\"\"\n\n        self.transform = torchvision.transforms.ColorJitter(\n            brightness=brightness, contrast=contrast, saturation=saturation, hue=hue\n        )\n        self.snippet_jitter = snippet_jitter\n        self.sharpness = sharpness\n\n    def rnd_sharpen(self, im):\n        factor = float(self.sharpness * torch.rand(1))\n        sharp_fn = RandomAdjustSharpness(sharpness_factor=factor, p=1.0)\n        return sharp_fn(im)\n\n    def apply(self, im):\n        im = self.transform(im)\n        im = self.rnd_sharpen(im)\n        return im\n\n    def __call__(self, batch: Dict):\n        for name in ARIA_IMG:\n            if name in batch:\n                batch[name] = batch[name].clone().detach()\n                if self.snippet_jitter:\n                    batch[name] = self.apply(batch[name])\n                else:\n                    for t in range(len(batch[name])):\n                        batch[name][t] = self.apply(batch[name][t])\n        return batch\n\n\nclass PointDrop:\n    \"\"\"\n    Applies point drop augmentation based on the standard deviations of the points.\n    A standard deviation is sampled within the provided range, and points exceeding\n    the sampled standard deviation are dropped.\n    Attributes:\n        dropout_all_rate (float): The rate at which all points are dropped.\n        inv_dist_std (List[float]): Range [min, max] of inverse distance standard deviations.\n        dist_std (List[float]): Range [min, max] of distance standard deviations.\n    \"\"\"\n\n    def __init__(\n        self,\n        dropout_all_rate: float = 0.2,\n        inv_dist_std: Optional[List[float]] = None,\n        dist_std: Optional[List[float]] = None,\n    ):\n        if inv_dist_std is None:\n            inv_dist_std = [0.001, 0.03]\n        if dist_std is None:\n            dist_std = [0.01, 0.3]\n        self.dropout_all_rate = dropout_all_rate\n        self.inv_dist_std = inv_dist_std\n        self.dist_std = dist_std\n\n        assert inv_dist_std[1] >= inv_dist_std[0]\n        assert dist_std[1] >= dist_std[0]\n\n    def __call__(self, batch: Dict):\n        if ARIA_POINTS_WORLD not in batch:\n            return batch\n\n        p_drop_all = torch.rand(1).item()\n        if p_drop_all < self.dropout_all_rate:\n            # drop all points\n            batch[ARIA_POINTS_WORLD][:, :, :] = torch.nan\n        else:\n            # drop based on stds.\n            p_w = batch[ARIA_POINTS_WORLD]\n            T, N = p_w.shape[:2]\n\n            # sample inv_dist_std\n            rand_inv_dist_thres = torch.rand(1).item()\n            rand_inv_dist_thres = (\n                rand_inv_dist_thres * (self.inv_dist_std[1] - self.inv_dist_std[0])\n                + self.inv_dist_std[0]\n            )\n\n            # sample dist_std\n            rand_dist_thres = torch.rand(1).item()\n            rand_dist_thres = (\n                rand_dist_thres * (self.dist_std[1] - self.dist_std[0])\n                + self.dist_std[0]\n            )\n\n            dropped = torch.zeros(T, N, dtype=torch.bool)\n            if ARIA_POINTS_INV_DIST_STD in batch:\n                drop_inv_dist_std = (\n                    batch[ARIA_POINTS_INV_DIST_STD] > rand_inv_dist_thres\n                )\n                dropped |= drop_inv_dist_std\n\n                logger.debug(f\"drop points with max inv_dist_std {rand_inv_dist_thres}\")\n                logger.debug(f\"drop {dropped.sum()} points.\")\n            if ARIA_POINTS_DIST_STD in batch:\n                drop_dist_std = batch[ARIA_POINTS_DIST_STD] > rand_dist_thres\n                dropped |= drop_dist_std\n\n                logger.debug(f\"drop points with max dist_std {rand_dist_thres}\")\n                logger.debug(f\"drop {dropped.sum()} points.\")\n            p_w[dropped, :] = torch.nan\n            batch[ARIA_POINTS_WORLD] = p_w\n\n        return batch\n\n\nclass PointDropSimple:\n    \"\"\"\n    simple point drop augmentation.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_dropout_rate: float = 0.8,\n    ):\n        self.max_dropout_rate = max_dropout_rate\n        assert self.max_dropout_rate < 1.0 and self.max_dropout_rate > 0.0\n\n    def __call__(self, batch: Dict):\n        if ARIA_POINTS_WORLD not in batch:\n            return batch\n\n        dropout_rate = torch.rand(1).item()\n        if dropout_rate > self.max_dropout_rate:\n            return batch\n        else:\n            p_w = batch[ARIA_POINTS_WORLD]  # B, T, 3\n            T, N = p_w.shape[:2]\n            mask = torch.rand((T, N)) < dropout_rate\n            p_w[mask, :] = torch.nan\n            batch[ARIA_POINTS_WORLD] = p_w\n\n        return batch\n\n\nclass PointJitter:\n    \"\"\"\n    Applies point jitter augmentation.\n    \"\"\"\n\n    def __init__(\n        self,\n        depth_std_scale_min: float = 1.0,\n        depth_std_scale_max: float = 3.0,\n    ):\n        \"\"\"\n        Args:\n            depth_std_scale_min: min scale factor for depth jitter based on depth_std\n            depth_std_scale_max: max scale factor for depth jitter based on depth_std\n        \"\"\"\n        self.depth_std_scale_max = depth_std_scale_max\n        self.depth_std_scale_min = depth_std_scale_min\n\n    def __call__(self, batch: Dict):\n        if ARIA_POINTS_WORLD in batch and ARIA_POINTS_DIST_STD in batch:\n            p_w = batch[ARIA_POINTS_WORLD]\n            scale = (\n                torch.rand(1).item()\n                * (self.depth_std_scale_max - self.depth_std_scale_min)\n                + self.depth_std_scale_min\n            )\n            std = batch[ARIA_POINTS_DIST_STD] * scale\n            noise = torch.randn_like(p_w) * std.unsqueeze(-1)\n            batch[ARIA_POINTS_WORLD] = p_w + noise\n        return batch\n"
  },
  {
    "path": "efm3d/dataset/efm_model_adaptor.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport csv\nimport logging\nfrom functools import partial\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\nimport webdataset as wds\nfrom atek.data_loaders.atek_wds_dataloader import (\n    load_atek_wds_dataset,\n    process_wds_sample,\n    select_and_remap_dict_keys,\n)\nfrom atek.util.tensor_utils import fill_or_trim_tensor\nfrom efm3d.aria import CameraTW, ObbTW, PoseTW, TensorWrapper\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_CALIB_SNIPPET_TIME_S,\n    ARIA_OBB_PADDED,\n    ARIA_POINTS_VOL_MAX,\n    ARIA_POINTS_VOL_MIN,\n    ARIA_POSE_SNIPPET_TIME_S,\n    ARIA_POSE_T_SNIPPET_RIG,\n    ARIA_SNIPPET_LENGTH_S,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n    ARIA_SNIPPET_TIME_NS,\n)\nfrom efm3d.aria.obb import transform_obbs\nfrom efm3d.aria.tensor_wrapper import smart_stack\nfrom webdataset.filters import pipelinefilter\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_local_pose_helper(snippet_origin_time_s, batch, local_coordinate):\n    \"\"\"\n    get the local coordinate system of the snippet as the pose at the\n    snippet_origin_time_s under the specified coordinate system conventions (rig, or cam_rgb)\n    \"\"\"\n    assert (\n        ARIA_POSE_T_SNIPPET_RIG in batch.keys()\n        and ARIA_POSE_SNIPPET_TIME_S in batch.keys()\n        and ARIA_SNIPPET_T_WORLD_SNIPPET in batch.keys()\n    ), f\"keys not in batch keys {batch.keys()}\"\n\n    T_world_snippet = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]\n    Ts_world_rig = T_world_snippet @ batch[ARIA_POSE_T_SNIPPET_RIG]\n    time_s = batch[ARIA_POSE_SNIPPET_TIME_S]\n    assert Ts_world_rig.dim() in [2, 3], f\"{Ts_world_rig.shape} should be (B)xTx12\"\n\n    if local_coordinate == \"rig\":\n        T_world_local = get_snippet_cosy_from_rig(\n            Ts_world_rig=Ts_world_rig,\n            time=time_s,\n            snippet_origin_time=snippet_origin_time_s,\n        )\n    elif local_coordinate == \"cam_rgb\":\n        T_world_local = get_snippet_cosy_from_cam_rgb(\n            Ts_world_rig=Ts_world_rig,\n            time=time_s,\n            snippet_origin_time=snippet_origin_time_s,\n            cam_rgb=batch[ARIA_CALIB[0]],\n            cam_rgb_time_s=batch[ARIA_CALIB_SNIPPET_TIME_S[0]],\n        )\n    else:\n        raise NotImplementedError(\n            f\"{local_coordinate} is not a valid coordinate option\"\n        )\n\n    return T_world_local\n\n\ndef run_local_cosy(\n    batch,\n    origin_ratio=0.5,\n    local_coordinate=\"cam_rgb\",\n    align_to_gravity=False,\n    snippet_origin_time_s=None,\n):\n    new_batch = {}\n\n    if snippet_origin_time_s is None:\n        assert ARIA_SNIPPET_LENGTH_S in batch.keys()\n        # get new snippet time origin\n        snippet_length_s = batch[ARIA_SNIPPET_LENGTH_S]\n        snippet_origin_time_s = snippet_length_s * origin_ratio\n\n    # New origin time in ns.\n    snippet_origin_time_ns = (snippet_origin_time_s * 1e9).long()\n\n    # modify all time stamps to account for snippet origin change\n    new_batch[ARIA_SNIPPET_TIME_NS] = (\n        batch[ARIA_SNIPPET_TIME_NS] + snippet_origin_time_ns\n    )\n\n    # modify all snippet_time_s timestamps to account for snippet origin change\n    keys_time_s = [key for key in batch.keys() if key.endswith(\"/snippet_time_s\")]\n    for key in keys_time_s:\n        new_batch[key] = batch[key] - snippet_origin_time_s\n\n    # get new snippet pose origin\n    if (\n        ARIA_POSE_T_SNIPPET_RIG in batch\n        and ARIA_POSE_SNIPPET_TIME_S in batch\n        and ARIA_SNIPPET_TIME_NS in batch\n        and ARIA_SNIPPET_T_WORLD_SNIPPET in batch\n    ):\n        T_world_snippet = get_local_pose_helper(\n            snippet_origin_time_s,\n            batch,\n            local_coordinate,\n        )\n\n        # apply change of coordinates to snippet coordinate system\n        T_snippet_new_old = (\n            T_world_snippet.inverse() @ batch[ARIA_SNIPPET_T_WORLD_SNIPPET]\n        )\n        new_batch[ARIA_SNIPPET_T_WORLD_SNIPPET] = T_world_snippet\n        # apply the coordinate change to t_snippet_rigs\n        keys_t_snippet_rig = [\n            key for key in batch.keys() if key.endswith(\"t_snippet_rig\")\n        ]\n        for key in keys_t_snippet_rig:\n            new_batch[key] = T_snippet_new_old @ batch[key]\n\n        # transform obbs into the new snippet coordinate system as well\n        if ARIA_OBB_PADDED in batch.keys():\n            new_batch[ARIA_OBB_PADDED] = transform_obbs(\n                batch[ARIA_OBB_PADDED], T_snippet_new_old\n            )\n\n    return new_batch\n\n\ndef get_snippet_cosy_from_rig(\n    snippet_origin_time: torch.Tensor,\n    Ts_world_rig: PoseTW,\n    time: torch.Tensor,\n):\n    \"\"\"\n    simply interpolate the T_world_rig using the given time at the snippet_origin_time\n    to get T_world_rig_origin\n    \"\"\"\n    T_world_rig_origin, good = Ts_world_rig.interpolate(time, snippet_origin_time)\n    T = T_world_rig_origin.shape[-1]\n    if T > 1 and not good.all():\n        logger.warn(\n            f\"WARNING some interpolated poses were not good: {good} time_s {time} snippet_time {snippet_origin_time}\"\n        )\n    return T_world_rig_origin\n\n\ndef get_snippet_cosy_from_cam_rgb(\n    snippet_origin_time: torch.Tensor,\n    Ts_world_rig: PoseTW,\n    time: torch.Tensor,\n    cam_rgb: torch.Tensor,\n    cam_rgb_time_s: torch.Tensor,\n):\n    \"\"\"\n    interpolate T_world_rig and T_camera_rig using the given time_s at the snippet_origin_time\n    and then compose the interpolated centers to get T_world_camera_origin\n    \"\"\"\n    # interpolate T_camera_rig\n    Ts_camera_rig = cam_rgb.T_camera_rig\n    T_camera_rig_origin, good = Ts_camera_rig.interpolate(\n        cam_rgb_time_s, snippet_origin_time\n    )\n\n    T = Ts_camera_rig.shape[-1]\n    if T > 1 and not good.all():\n        logger.warn(\"WARNING: some interpolated camera extrinsics were not good:\")\n    logger.debug(\n        f\"Good: {good}\\n time_s {cam_rgb_time_s}\\n snip_center {snippet_origin_time}\"\n    )\n    T_world_rig_origin = get_snippet_cosy_from_rig(\n        Ts_world_rig=Ts_world_rig, time=time, snippet_origin_time=snippet_origin_time\n    )\n    return T_world_rig_origin @ T_camera_rig_origin.inverse()\n\n\nclass EfmModelAdaptor:\n    ATEK_CAM_LABEL_TO_EFM_CAM_LABEL: Dict[str, str] = {\n        \"camera-rgb\": \"rgb\",\n        \"camera-slam-left\": \"slaml\",\n        \"camera-slam-right\": \"slamr\",\n    }\n    EFM_CAM_LABELS = [\"rgb\", \"slaml\", \"slamr\"]\n\n    EFM_GRAVITY_IN_WORLD = [0, 0, -9.81]\n\n    def __init__(\n        self,\n        freq: int,\n        snippet_length_s: float = 2.0,\n        semidense_points_pad_to_num: int = 50000,\n        atek_to_efm_taxonomy_mapping_file: Optional[str] = None,\n    ):\n        self.freq = torch.tensor([freq], dtype=torch.int32)\n\n        # EFM samples have fields padded to a fixed shape.\n        # Obtain the fixed shape dimentions\n        self.fixed_num_frames = int(snippet_length_s * freq)\n        self.fixed_semidense_num_points = semidense_points_pad_to_num\n\n        # Load optional taxonomy mapping file\n        if atek_to_efm_taxonomy_mapping_file is not None:\n            self.atek_to_efm_category_mapping = self._load_taxonomy_mapping_file(\n                atek_to_efm_taxonomy_mapping_file\n            )\n        else:\n            self.atek_to_efm_category_mapping = None\n\n    @staticmethod\n    def get_dict_key_mapping_for_camera(atek_camera_label: str, efm_camera_label: str):\n        return {\n            f\"mfcd#{atek_camera_label}+images\": f\"{efm_camera_label}/img\",\n            f\"mfcd#{atek_camera_label}+projection_params\": f\"{efm_camera_label}/calib/projection_params\",\n            f\"mfcd#{atek_camera_label}+frame_ids\": f\"{efm_camera_label}/frame_id_in_sequence\",\n            f\"mfcd#{atek_camera_label}+capture_timestamps_ns\": f\"{efm_camera_label}/img/time_ns\",\n            f\"mfcd#{atek_camera_label}+camera_model_name\": f\"{efm_camera_label}/calib/camera_model_name\",\n            f\"mfcd#{atek_camera_label}+camera_valid_radius\": f\"{efm_camera_label}/calib/valid_radius\",\n            f\"mfcd#{atek_camera_label}+exposure_durations_s\": f\"{efm_camera_label}/calib/exposure\",\n            f\"mfcd#{atek_camera_label}+gains\": f\"{efm_camera_label}/calib/gain\",\n            f\"mfcd#{atek_camera_label}+t_device_camera\": f\"{efm_camera_label}/calib/t_device_camera\",\n        }\n\n    @staticmethod\n    def get_dict_key_mapping_all():\n        dict_key_mapping = {\n            # mps data mappings\n            \"mtd#ts_world_device\": \"pose/t_world_rig\",\n            \"mtd#capture_timestamps_ns\": \"pose/time_ns\",\n            \"mtd#gravity_in_world\": \"pose/gravity_in_world\",\n            \"msdpd#points_world\": \"points/p3s_world\",\n            \"msdpd#points_inv_dist_std\": \"points/inv_dist_std\",\n            \"msdpd#points_dist_std\": \"points/dist_std\",\n            \"msdpd#capture_timestamps_ns\": \"points/time_ns\",\n            \"msdpd#points_volumn_min\": ARIA_POINTS_VOL_MIN,\n            \"msdpd#points_volumn_max\": ARIA_POINTS_VOL_MAX,\n            \"msdpd#points\": \"points/time_ns\",\n            \"mfcd#camera-rgb-depth+images\": \"rgb/distance_m\",\n            # gt mappings\n            \"gt_data\": \"gt_data\",\n        }\n        # camera data related mappings\n        for (\n            atek_cam_label,\n            efm_cam_label,\n        ) in EfmModelAdaptor.ATEK_CAM_LABEL_TO_EFM_CAM_LABEL.items():\n            dict_key_mapping.update(\n                EfmModelAdaptor.get_dict_key_mapping_for_camera(\n                    atek_camera_label=atek_cam_label, efm_camera_label=efm_cam_label\n                )\n            )\n\n        return dict_key_mapping\n\n    def _get_pose_to_align_gravity(self, sample_dict: Dict) -> Optional[PoseTW]:\n        \"\"\"\n        A helper function to return a T_newWorld_oldWorld transformation to align world gravity to the EFM convention.\n        This pose needs to be later applied to all poses that include world.\n        \"\"\"\n        efm_gravity_in_world = torch.tensor(\n            self.EFM_GRAVITY_IN_WORLD, dtype=torch.float32\n        )\n        current_gravity_in_world = sample_dict[\"pose/gravity_in_world\"]\n        if torch.allclose(efm_gravity_in_world, current_gravity_in_world, atol=1e-3):\n            # print(\"gravity convention is already aligned.\")\n            return None\n        else:\n            if torch.allclose(current_gravity_in_world, torch.tensor([0, -9.81, 0])):\n                return PoseTW.from_Rt(\n                    torch.tensor(\n                        [[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=torch.float32\n                    ),\n                    torch.tensor([0, 0, 0], dtype=torch.float32),\n                )\n            else:\n                raise ValueError(\n                    f\"unsupported gravity direction to align: {current_gravity_in_world}\"\n                )\n\n    def _load_taxonomy_mapping_file(self, filename: str) -> Dict:\n        \"\"\"\n        Load a taxonomy mapping csv file in the format of:\n        ATEK_category_name, efm_category_name, efm_category_id\n\n        returns a dict of {atek_cat_name -> (efm_cat_name, efm_cat_id)}\n        \"\"\"\n        atek_to_efm_category_mapping = {}\n        with open(filename, \"r\") as f:\n            csv_reader = csv.reader(f)\n            next(csv_reader)\n\n            for row in csv_reader:\n                atek_name = row[0]\n                value = (row[1], int(row[2]))  # Convert category id to an integer\n                atek_to_efm_category_mapping[atek_name] = value\n\n        return atek_to_efm_category_mapping\n\n    def _fill_dict_with_freq(self, sample_dict: Dict) -> Dict:\n        fields_to_fill = [\n            \"pose/hz\",\n            \"points/hz\",\n            \"rgb/img/hz\",\n            \"slaml/img/hz\",\n            \"slamr/img/hz\",\n        ]\n\n        # only fill obb frequency if GT exists\n        if self.gt_exists_flag:\n            fields_to_fill += [\"obbs/hz\"]\n\n        for field in fields_to_fill:\n            sample_dict[field] = self.freq\n        return sample_dict\n\n    def _convert_to_batched_camera_tw(\n        self, sample_dict: Dict, cam_label: str\n    ) -> CameraTW:\n        \"\"\"\n        A helper function to convert ATEK camera calibration to EFM camera tensor wrapper, where calibration params are replicated x `num_frames`.\n        \"\"\"\n        # calibrations are replicated by `fixed_num_frames`\n        batched_size = torch.Size((self.fixed_num_frames, 1))\n        camera_tw = CameraTW.from_surreal(\n            width=torch.full(\n                size=batched_size, fill_value=sample_dict[f\"{cam_label}/img\"].shape[3]\n            ),\n            height=torch.full(\n                size=batched_size, fill_value=sample_dict[f\"{cam_label}/img\"].shape[2]\n            ),\n            type_str=sample_dict[f\"{cam_label}/calib/camera_model_name\"],\n            params=sample_dict[f\"{cam_label}/calib/projection_params\"].unsqueeze(\n                0\n            ),  # make tensor shape [1, 15], so that it can be expanded to [num_frames, 15]\n            gain=fill_or_trim_tensor(\n                tensor=sample_dict[f\"{cam_label}/calib/gain\"],\n                dim_size=self.fixed_num_frames,\n                dim=0,\n            ),\n            exposure_s=fill_or_trim_tensor(\n                tensor=sample_dict[f\"{cam_label}/calib/exposure\"],\n                dim_size=self.fixed_num_frames,\n                dim=0,\n            ),\n            valid_radius=sample_dict[f\"{cam_label}/calib/valid_radius\"],\n            T_camera_rig=PoseTW.from_matrix3x4(\n                sample_dict[f\"{cam_label}/calib/t_device_camera\"]\n            ).inverse(),\n        )\n\n        return camera_tw.float()\n\n    def _update_efm_obb_gt(self, atek_gt_dict: Dict) -> Dict:\n        \"\"\"\n        Helper function to convert ATEK obb gt to EFM obb gt.\n        \"\"\"\n        efm_sub_dict = {}\n\n        # loop over all timestamps\n        timestamp_list = []\n        efm_obb_all_timestamps = []\n        semantic_id_to_name = {}\n        for timestamp, obb3_dict in atek_gt_dict[\"efm_gt\"].items():\n            timestamp_list.append(int(timestamp))\n\n            # Create a hash map to query which instance is visible in which camera.\n            # The resulting map will look like: {\n            #    \"instance_1\": {\n            #    \"cam_0\": index_in_cam0,\n            #    \"cam_1\": index_in_cam1,\n            #       ...\n            # },\n            #    \"instance_2\":  {\n            #    ...\n            # }\n            #    ...\n            instance_visible_map = {}\n            for camera_label, per_cam_dict in obb3_dict.items():\n                for i in range(len(per_cam_dict[\"instance_ids\"])):\n                    instance_id = per_cam_dict[\"instance_ids\"][i].item()\n                    if instance_id not in instance_visible_map:\n                        instance_visible_map[instance_id] = {}\n                    instance_visible_map[instance_id][camera_label] = i\n\n            efm_obb_tw_list = []\n            # Loop over all instances from all cameras\n            for instance_id, instance_mapping_info in instance_visible_map.items():\n                # Create a ObbTW for this instance\n                # get obb3 info from any visible camera\n                cam_label_0, cam_index_0 = next(iter(instance_mapping_info.items()))\n                atek_single_bb3_dict = obb3_dict[cam_label_0]\n                bb3_dim = atek_single_bb3_dict[\"object_dimensions\"][\n                    cam_index_0\n                ]  # tensor [3]\n                object_half_sizes = bb3_dim / 2.0\n                bb3_object = torch.tensor(\n                    [\n                        -object_half_sizes[0],\n                        object_half_sizes[0],\n                        -object_half_sizes[1],\n                        object_half_sizes[1],\n                        -object_half_sizes[2],\n                        object_half_sizes[2],\n                    ],\n                    dtype=torch.float32,\n                )\n                T_world_object = PoseTW.from_matrix3x4(\n                    atek_single_bb3_dict[\"ts_world_object\"][cam_index_0]\n                )\n                inst_id = atek_single_bb3_dict[\"instance_ids\"][cam_index_0]\n\n                # perform taxonomy remapping if needed, but skip \"other\"\n                sem_id = atek_single_bb3_dict[\"category_ids\"][cam_index_0].item()\n                category_name = atek_single_bb3_dict[\"category_names\"][cam_index_0]\n                if category_name == \"other\":\n                    continue\n                if self.atek_to_efm_category_mapping is not None:\n                    category_name, sem_id = self.atek_to_efm_category_mapping[\n                        category_name\n                    ]\n\n                # Also keep track of a sem_id_to_name mapping\n                if sem_id not in semantic_id_to_name:\n                    semantic_id_to_name[sem_id] = category_name\n\n                bb2_rgb = -1 * torch.ones(4)\n                bb2_slaml = -1 * torch.ones(4)\n                bb2_slamr = -1 * torch.ones(4)\n                # Commenting off because obb2 are not needed\n                \"\"\"\n                if \"camera-rgb\" in instance_mapping_info:\n                    cam_label = \"camera-rgb\"\n                    cam_index = instance_mapping_info[cam_label]\n                    bb2_rgb = atek_gt_dict[\"obb2\"][cam_label][\"bbox_ranges\"][cam_index]\n                \n                if \"camera-slam-left\" in instance_mapping_info:\n                    cam_label = \"camera-slam-left\"\n                    cam_index = instance_mapping_info[cam_label]\n                    bb2_slaml = atek_gt_dict[\"obb2\"][cam_label][\"bbox_ranges\"][cam_index]\n                \n                if \"camera-slam-right\" in instance_mapping_info:\n                    cam_label = \"camera-slam-right\"\n                    cam_index = instance_mapping_info[cam_label]\n                    bb2_slamr = atek_gt_dict[\"obb2\"][cam_label][\"bbox_ranges\"][cam_index]\n                \"\"\"\n\n                # Fill in padded obbs in EFM format\n                efm_obb_tw_list.append(\n                    ObbTW.from_lmc(\n                        bb3_object=bb3_object,\n                        bb2_rgb=bb2_rgb,\n                        bb2_slaml=bb2_slaml,\n                        bb2_slamr=bb2_slamr,\n                        T_world_object=T_world_object,\n                        sem_id=torch.tensor([sem_id], dtype=torch.int64),\n                        inst_id=torch.tensor([inst_id], dtype=torch.int64),\n                    )\n                )\n            # end for instance_id\n\n            if len(efm_obb_tw_list) == 0:\n                efm_obb_tw = ObbTW()\n            else:\n                efm_obb_tw = ObbTW(smart_stack(efm_obb_tw_list, dim=0))\n            efm_obb_tw = efm_obb_tw.add_padding(max_elts=128)\n            efm_obb_all_timestamps.append(efm_obb_tw)\n\n        efm_sub_dict[\"obbs/padded_snippet\"] = ObbTW(\n            smart_stack(efm_obb_all_timestamps, dim=0)\n        )\n        efm_sub_dict[\"obbs/time_ns\"] = torch.tensor(timestamp_list, dtype=torch.int64)\n        efm_sub_dict[\"obbs/sem_id_to_name\"] = semantic_id_to_name\n        return efm_sub_dict\n\n    def _pad_semidense_data(self, sample_dict: Dict) -> Dict:\n        \"\"\"\n        A helper function to pad semidense data from List[Tensor, (K, 3 or 1)] to fixed shape of [numFrames, num_semidense_points, 3 or 1]\n        \"\"\"\n        result_dict = {}\n\n        fields_to_pad = [\"points/p3s_world\", \"points/dist_std\", \"points/inv_dist_std\"]\n        for field in fields_to_pad:\n            tensor_list = sample_dict[field]\n            for i in range(len(tensor_list)):\n                # First, pad each tensor in the list to fixed num points\n                tensor_list[i] = fill_or_trim_tensor(\n                    tensor=tensor_list[i],\n                    dim_size=self.fixed_semidense_num_points,\n                    dim=0,\n                    fill_value=float(\"nan\"),\n                )\n\n            # then stack\n            stacked_tensor = torch.stack(tensor_list, dim=0)\n\n            # then pad over frames\n            result_dict[field] = fill_or_trim_tensor(\n                tensor=stacked_tensor, dim_size=self.fixed_num_frames, dim=0\n            )\n\n        return result_dict\n\n    def _pad_over_frames(self, sample_dict: Dict, fields_to_pad: List[str]) -> Dict:\n        \"\"\"\n        A helper function to pad data over frames, by repeating the last element over frames.\n        \"\"\"\n        result_dict = {}\n        for field in fields_to_pad:\n            result_dict[field] = fill_or_trim_tensor(\n                tensor=sample_dict[field],\n                dim_size=self.fixed_num_frames,\n                dim=0,\n            )\n        return result_dict\n\n    def _split_pose_over_snippet(self, sample_dict: Dict) -> Dict:\n        \"\"\"\n        A helper function to split T_world_rig into T_world_snippet and T_snippet_rig.\n        In the meantime, Align gravity to [0, 0, -9.81]\n        \"\"\"\n        result_dict = {}\n\n        # check if world coordinates needs to be re-aligned\n        maybe_T_newWorld_oldWorld = self._get_pose_to_align_gravity(sample_dict)\n        # maybe_T_newWorld_oldWorld = None\n\n        Ts_world_rig = PoseTW.from_matrix3x4(sample_dict[\"pose/t_world_rig\"])\n        if maybe_T_newWorld_oldWorld:\n            Ts_world_rig = maybe_T_newWorld_oldWorld @ Ts_world_rig\n\n        result_dict[\"pose/t_world_rig\"] = Ts_world_rig\n\n        T_world_snippet = Ts_world_rig.clone()[0]\n        T_world_snippet = T_world_snippet.unsqueeze(0)\n        result_dict[\"snippet/t_world_snippet\"] = T_world_snippet.clone()\n        result_dict[\"pose/t_snippet_rig\"] = Ts_world_rig[0].inverse() @ Ts_world_rig\n\n        for camera_label in EfmModelAdaptor.EFM_CAM_LABELS:\n            result_dict[f\"{camera_label}/t_snippet_rig\"] = result_dict[\n                \"pose/t_snippet_rig\"\n            ].clone()\n\n        # Transform obbs poses, from old_world -> new_world -> snippet\n        if ARIA_OBB_PADDED in sample_dict:\n            if maybe_T_newWorld_oldWorld:\n                T_snippet_world = T_world_snippet.inverse() @ maybe_T_newWorld_oldWorld\n            else:\n                T_snippet_world = T_world_snippet.inverse()\n\n            result_dict[ARIA_OBB_PADDED] = transform_obbs(\n                sample_dict[ARIA_OBB_PADDED], T_snippet_world\n            )\n\n        # Also transform semidense points\n        if maybe_T_newWorld_oldWorld:\n            result_dict[\"points/p3s_world\"] = (\n                maybe_T_newWorld_oldWorld * sample_dict[\"points/p3s_world\"]\n            )\n\n        return result_dict\n\n    def _split_timestamps_over_snippet(self, sample_dict: Dict) -> Dict:\n        \"\"\"\n        A helper function to split capture_timestamps_ns into snippet/time_ns and */snippet_time_s\n        \"\"\"\n        dict_keys_to_split_timestamps = [\n            \"pose/\",\n            \"points/\",\n        ] + [f\"{label}/img/\" for label in EfmModelAdaptor.EFM_CAM_LABELS]\n\n        # Also split obbs timestamps, if gt exists\n        if self.gt_exists_flag:\n            dict_keys_to_split_timestamps += [\"obbs/\"]\n\n        result_dict = {}\n\n        result_dict[\"snippet/time_ns\"] = sample_dict[\"rgb/img/time_ns\"][0].unsqueeze(0)\n        for key in dict_keys_to_split_timestamps:\n            result_dict[key + \"snippet_time_s\"] = (\n                sample_dict[key + \"time_ns\"] - result_dict[\"snippet/time_ns\"]\n            ) / torch.tensor(1e9, dtype=torch.float32)\n\n        return result_dict\n\n    def atek_to_efm(self, data, train=False):\n        \"\"\"\n        A helper data transform function to convert a ATEK webdataset data sample built by EfmSampleBuilder to EFM unbatched\n        samples. Yield one unbatched sample a time to use the collation and batching mechanism in\n        the webdataset properly.\n        \"\"\"\n        for atek_wds_sample in data:\n            efm_sample = atek_wds_sample\n\n            # Check if GT exists in the sample. If not, all obb related operations will be skipped\n            self.gt_exists_flag = (\n                \"gt_data\" in atek_wds_sample and len(atek_wds_sample[\"gt_data\"]) > 0\n            )\n\n            # Fill frequenze data from conf\n            efm_sample = self._fill_dict_with_freq(efm_sample)\n\n            # Pad semidense data, which requires 2-dim padding\n            padded_dict = self._pad_semidense_data(efm_sample)\n            efm_sample.update(padded_dict)\n\n            # Convert ATEK calibration to EFM camera calibration, where calibration params are replicated x `num_frames`,\n            # except gains and exposure_s which is per-frame.\n            for cam_label in EfmModelAdaptor.EFM_CAM_LABELS:\n                efm_sample[f\"{cam_label}/calib\"] = self._convert_to_batched_camera_tw(\n                    efm_sample, cam_label\n                )\n\n            # Convert ATEK GT to EFM GT\n            if self.gt_exists_flag:\n                result_dict = self._update_efm_obb_gt(atek_wds_sample[\"gt_data\"])\n                efm_sample.update(result_dict)\n\n            # split T_world_rig into T_world_snippet and T_snippet_rig\n            result_dict = self._split_pose_over_snippet(efm_sample)\n            efm_sample.update(result_dict)\n\n            # split capture_timestamps_ns into snippet/time_ns and */snippet_time_s\n            result_dict = self._split_timestamps_over_snippet(efm_sample)\n            efm_sample.update(result_dict)\n\n            # Pad some data over frames by repeating last element\n            fields_to_pad = []\n            fields_to_skip_padding = [\"snippet/t_world_snippet\"]\n            for key, value in efm_sample.items():\n                if key in fields_to_skip_padding:\n                    continue\n                if isinstance(value, torch.Tensor) or isinstance(value, TensorWrapper):\n                    if value.shape[0] < self.fixed_num_frames:\n                        # pad timestamp tensors, but not other 1-dim tensors\n                        if (\n                            key.endswith(\"time_ns\")\n                            or key.endswith(\"time_s\")\n                            or value.ndim > 1\n                        ):\n                            fields_to_pad.append(key)\n            result_dict = self._pad_over_frames(efm_sample, fields_to_pad=fields_to_pad)\n            efm_sample.update(result_dict)\n\n            # Duplicate `camera/img/time` to `camera/calib/time`\n            for camera_name in EfmModelAdaptor.EFM_CAM_LABELS:\n                efm_sample[f\"{camera_name}/calib/time_ns\"] = efm_sample[\n                    f\"{camera_name}/img/time_ns\"\n                ]\n                efm_sample[f\"{camera_name}/calib/snippet_time_s\"] = efm_sample[\n                    f\"{camera_name}/img/snippet_time_s\"\n                ]\n\n            # Convert data types from int to float32\n            fields_to_conv2float32 = [\n                f\"{label}/img\" for label in EfmModelAdaptor.EFM_CAM_LABELS\n            ] + [\n                f\"{label}/frame_id_in_sequence\"\n                for label in EfmModelAdaptor.EFM_CAM_LABELS\n            ]\n            for field in fields_to_conv2float32:\n                efm_sample[field] = efm_sample[field].to(torch.float32)\n                if field.endswith(\"img\"):\n                    # normalize\n                    efm_sample[field] = efm_sample[field] / 255.0\n                if field == \"rgb/img\":\n                    # swap channels from [RGB] -> [BGR]\n                    # efm_sample[field] = efm_sample[field][:, [2, 1, 0], :, :]\n                    pass\n\n            # Run local cosy to shift the origin\n            # For testing only: patch snippet lenths\n            efm_sample[ARIA_SNIPPET_LENGTH_S] = torch.tensor([2.0], dtype=torch.float32)\n            result = run_local_cosy(batch=efm_sample, origin_ratio=0.5)\n            efm_sample.update(result)\n\n            # delete useless data\n            if train:\n                # keep only tensors\n                remove_keys = []\n                for key in efm_sample:\n                    if not isinstance(efm_sample[key], (torch.Tensor, TensorWrapper)):\n                        remove_keys.append(key)\n                for k in remove_keys:\n                    efm_sample.pop(k)\n\n            yield efm_sample\n\n\ndef load_atek_wds_dataset_as_efm(\n    urls: List,\n    freq=10,\n    snippet_length_s=2.0,\n    semidense_points_pad_to_num=50000,\n    atek_to_efm_taxonomy_mapping_file: Optional[str] = None,\n    batch_size: Optional[int] = None,\n    collation_fn: Optional[Callable] = None,\n):\n    efm_model_adaptor = EfmModelAdaptor(\n        freq=freq,\n        snippet_length_s=snippet_length_s,\n        semidense_points_pad_to_num=semidense_points_pad_to_num,\n        atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy_mapping_file,\n    )\n\n    return load_atek_wds_dataset(\n        urls,\n        dict_key_mapping=EfmModelAdaptor.get_dict_key_mapping_all(),\n        data_transform_fn=pipelinefilter(efm_model_adaptor.atek_to_efm)(\n            train=collation_fn is not None\n        ),\n        batch_size=batch_size,\n        collation_fn=collation_fn,\n    )\n\n\ndef load_atek_wds_dataset_as_efm_train(\n    urls: List,\n    freq=10,\n    snippet_length_s=2.0,\n    semidense_points_pad_to_num=50000,\n    atek_to_efm_taxonomy_mapping_file: Optional[str] = None,\n    batch_size: Optional[int] = None,\n    collation_fn: Optional[Callable] = None,\n):\n    efm_model_adaptor = EfmModelAdaptor(\n        freq=freq,\n        snippet_length_s=snippet_length_s,\n        semidense_points_pad_to_num=semidense_points_pad_to_num,\n        atek_to_efm_taxonomy_mapping_file=atek_to_efm_taxonomy_mapping_file,\n    )\n\n    wds_dataset = (\n        wds.WebDataset(urls, nodesplitter=None, resampled=True, repeat=True)\n        .decode(wds.imagehandler(\"torchrgb8\"))\n        .map(process_wds_sample)\n    )\n    wds_dataset = wds_dataset.map(\n        partial(\n            select_and_remap_dict_keys,\n            key_mapping=EfmModelAdaptor.get_dict_key_mapping_all(),\n        )\n    )\n    wds_dataset = wds_dataset.compose(\n        pipelinefilter(efm_model_adaptor.atek_to_efm)(train=collation_fn is not None)\n    )\n    wds_dataset = wds_dataset.batched(batch_size, collation_fn=collation_fn)\n\n    return wds_dataset\n"
  },
  {
    "path": "efm3d/dataset/vrs_dataset.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nimport os\nimport random\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport pyvrs\nimport torch\nimport torch.nn.functional as F\nfrom efm3d.aria import CameraTW, ObbTW, PoseTW, smart_stack, transform_obbs\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_CALIB_SNIPPET_TIME_S,\n    ARIA_CALIB_TIME_NS,\n    ARIA_CAM_INFO,\n    ARIA_FRAME_ID,\n    ARIA_IMG,\n    ARIA_IMG_SNIPPET_TIME_S,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_IMG_TIME_NS,\n    ARIA_OBB_BB2,\n    ARIA_OBB_PADDED,\n    ARIA_OBB_SEM_ID_TO_NAME,\n    ARIA_OBB_SNIPPET_TIME_S,\n    ARIA_OBB_TIME_NS,\n    ARIA_POINTS_SNIPPET_TIME_S,\n    ARIA_POINTS_TIME_NS,\n    ARIA_POINTS_VOL_MAX,\n    ARIA_POINTS_VOL_MIN,\n    ARIA_POINTS_WORLD,\n    ARIA_POSE_SNIPPET_TIME_S,\n    ARIA_POSE_T_SNIPPET_RIG,\n    ARIA_POSE_T_WORLD_RIG,\n    ARIA_POSE_TIME_NS,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.utils.file_utils import (\n    exists_nonzero_path,\n    get_timestamp_list_ns,\n    load_factory_calib,\n    load_global_points_csv,\n    load_obbs_gt,\n    load_semidense_observations,\n    load_trajectory,\n    load_trajectory_adt,\n    load_trajectory_aeo,\n    read_image_snippet_from_vrs,\n    sample_from_range,\n    sample_times,\n)\nfrom efm3d.utils.obb_io import get_instance_id_in_frameset, next_obb_observations\nfrom efm3d.utils.rescale import rescale_obb_tw\nfrom torch.utils.data import Dataset\n\n\n# gravity direction in ADT conventions\nGRAVITY_DIRECTION_ADT = np.array([0.0, -1.0, 0.0], np.float32)\n\n\ndef is_adt(vrs_path):\n    # get folder name\n    if vrs_path.endswith(\".vrs\"):\n        vrs_path = os.path.split(vrs_path)[0]\n    if os.path.exists(os.path.join(vrs_path, \"aria_trajectory.csv\")):\n        return True\n    folder_name = os.path.basename(vrs_path)\n    return \"optitrack_release_work_seq\" in folder_name\n\n\ndef is_aeo(vrs_path):\n    return \"aeo_\" in vrs_path\n\n\ndef get_transform_to_vio_gravity_convention(gravity_direction: np.array):\n    \"\"\"\n    Get transformation to map gravity_direction to (0,0,-1) as per our (and\n    VIO/Temple) convention.\n    \"\"\"\n    # gravity_direction = (d1, d2, d3) (0,0,-1)^T; d1, d2, d3 column vectors of rotation matrix R_gravity_vio\n    # -d3 = gravity_direction\n    d3 = -gravity_direction.copy()\n    # now construct an orthonormal basis for the rotation matrix\n    # d1 is a vector thats orthogonal to gravity_direction by construction\n    d1 = np.array(\n        [\n            gravity_direction[2] - gravity_direction[1],\n            gravity_direction[0],\n            -gravity_direction[0],\n        ]\n    )\n    # get d2 via orthogonal direction vector to d3 and d1\n    d2 = np.cross(d3, d1)\n    # get rotation matrix\n    R_gravity_vio = np.concatenate(\n        [d1[:, np.newaxis], d2[:, np.newaxis], d3[:, np.newaxis]], 1\n    )\n    assert (np.linalg.det(R_gravity_vio) - 1.0) < 1e-5\n    assert (((R_gravity_vio @ R_gravity_vio.transpose()) - np.eye(3)) < 1e-5).all()\n    R_gravity_vio = torch.from_numpy(R_gravity_vio)\n    # normalize to unit length\n    R_gravity_vio = F.normalize(R_gravity_vio, p=2, dim=-2)\n    R_vio_gravity = R_gravity_vio.transpose(1, 0)\n    T_vio_gravity = PoseTW.from_Rt(R_vio_gravity, torch.zeros(3))\n    return T_vio_gravity\n\n\ndef compute_time_intersection(time_lists):\n    min_time = -math.inf\n    max_time = math.inf\n    for ts in time_lists:\n        ts = np.array(ts)\n        min_time = max(min_time, ts.min())\n        max_time = min(max_time, ts.max())\n\n    # add an offset to the timestamp\n    safety_margin = 3_000_000  # 3ms\n    min_time = min_time - safety_margin\n    max_time = max_time - safety_margin\n\n    return min_time, max_time\n\n\ndef preprocess_inference(batch):\n    # tensor wrapper\n    for k in batch:\n        if not isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)):\n            continue\n\n        if k in [\n            ARIA_SNIPPET_T_WORLD_SNIPPET,\n            ARIA_POSE_T_WORLD_RIG,\n            ARIA_POSE_T_SNIPPET_RIG,\n        ] + ARIA_IMG_T_SNIPPET_RIG and not isinstance(batch[k], PoseTW):\n            batch[k] = PoseTW(batch[k])\n        elif k in ARIA_CALIB and not isinstance(batch[k], CameraTW):\n            batch[k] = CameraTW(batch[k])\n        elif k == ARIA_OBB_PADDED and not isinstance(batch[k], ObbTW):\n            batch[k] = ObbTW(batch[k])\n\n    return batch\n\n\ndef preprocess(\n    batch,\n    device,\n    subsample: int = 10,\n    aug_funcs: Optional[Union[Callable, List[Callable]]] = None,\n):\n    # tensor wrapper\n    for k in batch:\n        if not isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)):\n            continue\n\n        if k in [\n            ARIA_SNIPPET_T_WORLD_SNIPPET,\n            ARIA_POSE_T_WORLD_RIG,\n            ARIA_POSE_T_SNIPPET_RIG,\n        ] + ARIA_IMG_T_SNIPPET_RIG and not isinstance(batch[k], PoseTW):\n            batch[k] = PoseTW(batch[k])\n        elif k in ARIA_CALIB and not isinstance(batch[k], CameraTW):\n            batch[k] = CameraTW(batch[k])\n        elif k == ARIA_OBB_PADDED and not isinstance(batch[k], ObbTW):\n            batch[k] = ObbTW(batch[k])\n\n    # time crop\n    T = batch[ARIA_IMG[0]].shape[1]\n    if subsample != T:\n        s = random.randint(0, T - subsample - 1)\n        for k in batch:\n            if (\n                isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW))\n                and batch[k].shape[1] == T\n            ):\n                batch[k] = batch[k][:, s : s + subsample, ...]\n\n    # move to GPU\n    for k in batch:\n        if isinstance(batch[k], (torch.Tensor, PoseTW, CameraTW, ObbTW)):\n            batch[k] = batch[k].to(device)\n\n    # data augmentations\n    if aug_funcs is not None:\n        if isinstance(aug_funcs, Callable):\n            aug_funcs = [aug_funcs]\n        for aug in aug_funcs:\n            batch = aug(batch)\n\n    return batch\n\n\ndef tensor_unify(tensor, dim_size: int, dim: int = 0):\n    \"\"\"Fill or trim a torch or numpy tensor to the given `dim_size`, along the given dim\n\n    Inputs:\n        tensor (torch or np.array): input tensor\n        dim_size (int): the size to fill or trim to (e.g. predefined batch size)\n        dim (int): the dimension to fill or trim\n\n    Returns:\n        tensor2 (a torch or np.array): output tensor with the dim size = `dim_size`.\n    \"\"\"\n    assert tensor.shape[dim] > 0, \"Input tensor must have at least 1 element\"\n\n    if isinstance(tensor, list):\n        tensor = np.array(tensor)\n    if isinstance(tensor, np.ndarray):\n        np_tensor = tensor\n        tensor_bs = np_tensor.shape[dim]\n        if tensor_bs > dim_size:\n            tensor2 = np.take(np_tensor, indices=np.arange(dim_size), axis=dim)\n        elif tensor_bs < dim_size:\n            last = np.take(np_tensor, tensor_bs - 1, dim)\n            fill = np.expand_dims(last, axis=dim)  # fill with last element\n            fill = np.repeat(fill, dim_size - tensor_bs, axis=dim)\n            tensor2 = np.concatenate([np_tensor, fill], axis=dim)\n        else:\n            tensor2 = tensor\n    else:\n        tensor_bs = tensor.shape[dim]\n        if tensor_bs > dim_size:\n            indices = torch.arange(dim_size)\n            for i in range(tensor.ndim):\n                if i != dim:\n                    indices = indices.unsqueeze(i)\n            tensor2 = torch.take_along_dim(tensor, indices, dim)\n        elif tensor_bs < dim_size:\n            shape = [1 for _ in range(tensor.ndim)]\n            indices = torch.ones(shape).long()\n            indices[0] = tensor_bs - 1\n            last = torch.take_along_dim(tensor, indices, dim)\n            fill_shape = shape\n            fill_shape[dim] = dim_size - tensor_bs\n            fill = last.repeat(fill_shape)\n            tensor2 = torch.cat([tensor, fill], dim=dim)\n        else:\n            tensor2 = tensor\n    return tensor2\n\n\ndef run_sensor_poses(batch, num_notified=-1, max_notified=10):\n    if (\n        ARIA_POSE_T_SNIPPET_RIG in batch.keys()\n        and ARIA_POSE_SNIPPET_TIME_S in batch.keys()\n    ):\n        new_batch = {}\n        Ts_snippet_rig = batch[ARIA_POSE_T_SNIPPET_RIG]\n        ts = batch[ARIA_POSE_SNIPPET_TIME_S]\n        assert Ts_snippet_rig.dim() in [\n            2,\n            3,\n        ], f\"need to be of shape (B) x T x 12 but are {Ts_snippet_rig.shape}\"\n        for i, img_time_key in enumerate(ARIA_IMG_SNIPPET_TIME_S):\n            if (\n                img_time_key in batch.keys()\n                and ARIA_IMG_T_SNIPPET_RIG[i] not in batch.keys()\n            ):\n                ts_interp = batch[img_time_key]\n                Ts_world_rig_i, good = Ts_snippet_rig.interpolate(ts, ts_interp)\n                new_batch[ARIA_IMG_T_SNIPPET_RIG[i]] = Ts_world_rig_i\n                if not good.all():\n                    counts = good.sum(dim=-1).squeeze()\n                    if num_notified > 0 and num_notified < max_notified:\n                        print(\n                            f\"some interpolated poses were bad (fraction good per batch: {counts / good.shape[-1]}); likely because tried to interpolated past given input timed poses.\"\n                        )\n        return new_batch\n\n\nclass VrsSequenceDataset(Dataset):\n    def __init__(\n        self,\n        vrs_path,\n        frame_rate,\n        sdi,\n        snippet_length_s,\n        stride_length_s,\n        max_snippets=9999,\n        skip_snippets=0,\n        preprocess=None,\n    ):\n        self.frame_rate = frame_rate\n        self.vrs_path = vrs_path\n        self.vrs_folder = os.path.split(vrs_path)[0]\n        self.reader = pyvrs.SyncVRSReader(\n            vrs_path, auto_read_configuration_records=True\n        )\n        self.max_snippets = max_snippets\n        self.sdi = sdi\n        self.preprocess = preprocess\n        self.cam_calib = load_factory_calib(self.reader)\n        self.is_adt = is_adt(vrs_path)\n        self.is_aeo = is_aeo(vrs_path)\n        self.max_objects_per_frameset = 128\n\n        fps = self.cam_calib[\"fps\"]\n        self.fps = [fps[\"rgb\"], fps[\"slaml\"], fps[\"slamr\"]]\n\n        ts_lists = []\n        # Add images\n        for idx in range(3):\n            img_ts_list = get_timestamp_list_ns(self.reader, ARIA_CAM_INFO[\"id\"][idx])\n            ts_lists.append(img_ts_list)\n\n        # Add poses\n        timed_Ts_world_rig = self.load_poses(self.vrs_folder, subsample=1)\n        pose_times_ns = list(timed_Ts_world_rig.keys())\n        pose_freq = int(1.0 / (1e-9 * (pose_times_ns[1] - pose_times_ns[0])))\n        pose_subsample = int(pose_freq / frame_rate)\n        pose_times_ns = pose_times_ns[::pose_subsample]\n\n        self.T_world_rig_time_ns = pose_times_ns\n        self.Ts_world_rig = torch.stack(\n            [timed_Ts_world_rig[key] for key in pose_times_ns]\n        )\n        ts_lists.append(pose_times_ns)\n\n        # Add obbs GT if available\n        self.obs = None\n        self.obs = self.load_objects()\n        if self.obs is not None:\n            obb_freq = int(1.0 / (1e-9 * (self.obb_times[1] - self.obb_times[0])))\n            obb_subsample = max(1, int(obb_freq / frame_rate))\n            self.obb_times = self.obb_times[::obb_subsample]\n\n        # Add points\n        self.load_semidense(self.vrs_folder)\n\n        # intersect all data modalities\n        min_time, max_time = compute_time_intersection(ts_lists)\n\n        play_times_ns = get_timestamp_list_ns(self.reader, ARIA_CAM_INFO[\"id\"][idx])\n        play_times_ns = [\n            ts for ts in play_times_ns if (ts > min_time and ts < max_time)\n        ]\n        play_times_ns = np.unique(play_times_ns).tolist()\n\n        # compute snippets start and end time\n        seq_start_time = play_times_ns[0]\n        seq_end_time = play_times_ns[-1]\n        snip_start = seq_start_time\n        snip_end = snip_start + snippet_length_s * 1e9\n        self.snippet_times = []\n\n        while snip_end < seq_end_time:\n            self.snippet_times.append((snip_start, snip_end))\n            snip_start += stride_length_s * 1e9\n            snip_end = snip_start + snippet_length_s * 1e9\n\n        if skip_snippets > 0:\n            self.snippet_times = self.snippet_times[skip_snippets:]\n\n    def load_objects(self):\n        self.obs = load_obbs_gt(\n            self.vrs_folder,\n            load_2d_bbs=True,\n            filter_outside_2d_bbs=True,\n            rgb_only=False,\n        )\n        if len(self.obs) == 0:\n            return None\n\n        # inverse map from proto to a linear id and filter the interested objects if given.\n        instance2proto = self.obs[\"inst2proto\"]\n        unique_proto_names = np.unique(list(instance2proto.values())).tolist()\n        self.obs[\"proto2id\"] = {name: i for i, name in enumerate(unique_proto_names)}\n\n        if self.is_aeo:\n            aeo_to_efm = (\n                f\"{os.path.dirname(__file__)}/../config/taxonomy/aeo_to_efm.csv\"\n            )\n            self.global_name_to_id = {}\n            with open(aeo_to_efm, \"r\") as f:\n                lines = f.readlines()\n            for li in lines[1:]:\n                ori_name, class_name, class_id = li.strip().split(\",\")\n                self.global_name_to_id[str(ori_name)] = (str(class_name), int(class_id))\n\n            filtered_proto_names = set(self.global_name_to_id.keys()).intersection(\n                set(unique_proto_names)\n            )\n\n            # remap the proto names and semantic ids given the taxonomy mapping\n            self.obs[\"proto2id\"] = {\n                self.global_name_to_id[name][0]: self.global_name_to_id[name][1]\n                for name in filtered_proto_names\n            }\n            self.obs[\"inst2proto\"] = {\n                inst: self.global_name_to_id[name][0]\n                for inst, name in instance2proto.items()\n                if name in filtered_proto_names\n            }\n        else:\n            # use the class name to id mapping in the sequence\n            self.obs[\"proto2id\"] = {\n                name: i for i, name in enumerate(unique_proto_names)\n            }\n\n        # compute inverse map\n        self.obs[\"id2proto\"] = {id: name for name, id in self.obs[\"proto2id\"].items()}\n\n        timedTs_world_object = self.obs[\"timedTs_world_object\"]\n        static_Ts_world_object = {}\n        assert len(timedTs_world_object) != 0, (\n            \"Warning: no observations found for entire sequence\"\n        )\n        # timedTs_world_object captures static object at the -1 timestamp\n        if -1 in timedTs_world_object.keys():\n            static_Ts_world_object = timedTs_world_object[-1]\n        self.obs[\"static_Ts_world_object\"] = static_Ts_world_object\n        self.obb_times = sorted(set(self.obs[ARIA_OBB_BB2[0]].keys()))\n\n        if self.is_adt:\n            T_vio_gravity = get_transform_to_vio_gravity_convention(\n                GRAVITY_DIRECTION_ADT\n            )\n            for time, idT_wo in self.obs[\"timedTs_world_object\"].items():\n                for inst, T_wo in idT_wo.items():\n                    # we go from gravity world coordinate system to the new one that follows vio conventions\n                    self.obs[\"timedTs_world_object\"][time][inst] = (\n                        T_vio_gravity @ T_wo.float()\n                    )\n\n        return self.obs\n\n    def load_semidense(self, vrs_path, max_inv_depth_std=0.005, max_depth_std=0.05):\n        possible_global_points_paths = [\n            os.path.join(vrs_path, \"multi_global_points.csv.gz\"),\n            os.path.join(vrs_path, \"multi_global_points.csv\"),\n            os.path.join(vrs_path, \"global_points.csv.gz\"),\n            os.path.join(vrs_path, \"global_points.csv\"),\n            os.path.join(vrs_path, \"semidense_points.csv.gz\"),\n            os.path.join(vrs_path, \"maps/maps_v1/globalcloud_GT.csv\"),  # ASE\n            os.path.join(vrs_path, \"mps/slam/semidense_points.csv.gz\"),  # ADT\n        ]\n        possible_obs_paths = [\n            os.path.join(vrs_path, \"semidense_observations.csv.gz\"),\n            os.path.join(vrs_path, \"semidense_observations.csv\"),\n            os.path.join(vrs_path, \"maps/maps_v1/observations.csv\"),  # ASE\n            os.path.join(vrs_path, \"semidense_points.csv\"),\n            os.path.join(vrs_path, \"mps/slam/semidense_observations.csv.gz\"),  # ADT\n        ]\n        global_points_path = exists_nonzero_path(possible_global_points_paths)\n        self.uid_to_p3, self.uid_to_inv_dist_std, self.uid_to_dist_std = (\n            load_global_points_csv(global_points_path, max_inv_depth_std, max_depth_std)\n        )\n\n        if self.is_adt:\n            T_vio_gravity = get_transform_to_vio_gravity_convention(\n                GRAVITY_DIRECTION_ADT\n            ).double()\n            for uid, p3 in self.uid_to_p3.items():\n                self.uid_to_p3[uid] = (T_vio_gravity * p3).reshape(-1)\n\n        semidense_obs_path = exists_nonzero_path(possible_obs_paths)\n        self.time_to_uids, self.uid_to_times = load_semidense_observations(\n            semidense_obs_path\n        )\n\n        if self.time_to_uids is not None:\n            self.pts_times_ns = sorted(self.time_to_uids.keys())\n            (\n                self.time_to_pc,\n                self.time_to_dist_std,\n                self.time_to_inv_dist_std,\n                no_points_times,\n            ) = ({}, {}, {}, [])\n            for time in self.pts_times_ns:\n                uids = self.time_to_uids[time]\n                p3s = [self.uid_to_p3[uid] for uid in uids if uid in self.uid_to_p3]\n                if len(p3s) > 0:\n                    # sort by inv dist std to make any cropping later use the best points\n                    inv_dist_std = [\n                        self.uid_to_inv_dist_std[uid]\n                        for uid in uids\n                        if uid in self.uid_to_inv_dist_std\n                    ]\n                    inv_dist_std = np.array(inv_dist_std)\n                    dist_std = [\n                        self.uid_to_dist_std[uid]\n                        for uid in uids\n                        if uid in self.uid_to_dist_std\n                    ]\n                    dist_std = np.array(dist_std)\n                    ids = np.argsort(inv_dist_std)\n                    p3s = [p3s[i] for i in ids]\n                    p3s = torch.stack(p3s)\n                    inv_dist_std = torch.from_numpy(inv_dist_std[ids])\n                    dist_std = torch.from_numpy(dist_std[ids])\n                else:\n                    no_points_times.append(time)\n                    p3s = torch.zeros((0, 3), dtype=torch.float32)\n                    inv_dist_std = torch.zeros((0), dtype=torch.float32)\n                    dist_std = torch.zeros((0), dtype=torch.float32)\n                self.time_to_pc[time] = p3s\n                self.time_to_dist_std[time] = dist_std\n                self.time_to_inv_dist_std[time] = inv_dist_std\n        print(\n            f\"Found {len(self.uid_to_p3)} semidense points; time range {min(self.pts_times_ns) / 1e9}s-{max(self.pts_times_ns) / 1e9}s\"\n        )\n\n        # aggregate all the points\n        all_p3s = [self.uid_to_p3[uid] for uid in self.uid_to_p3]\n        all_inv_dist_std = [\n            self.uid_to_inv_dist_std[uid] for uid in self.uid_to_inv_dist_std\n        ]\n        ids = np.argsort(all_inv_dist_std)\n        # ranked by inverse depth std\n        self.all_p3s = torch.stack([all_p3s[i] for i in ids])  # [N, 3]\n        assert self.all_p3s.shape[0] > 0, \"no points loaded\"\n\n        # compute a [q, 1-q] percentile as the global range\n        q = 0.001\n        self.vol_min = torch.quantile(self.all_p3s, q, dim=0)\n        self.vol_max = torch.quantile(self.all_p3s, 1 - q, dim=0)\n        self.vol_min = self.vol_min.detach()\n        self.vol_max = self.vol_max.detach()\n\n    def load_poses(self, vrs_path, subsample):\n        timed_Ts_world_rig = None\n        # ADT sequences\n        timed_Ts_world_rig = load_trajectory_adt(vrs_path, subsample=subsample)\n        if timed_Ts_world_rig is not None:\n            # handle ADT sequence gravity rotation\n            T_vio_gravity = get_transform_to_vio_gravity_convention(\n                GRAVITY_DIRECTION_ADT\n            ).double()\n            for k, T_wr in timed_Ts_world_rig.items():\n                timed_Ts_world_rig[k] = T_vio_gravity @ T_wr\n            return timed_Ts_world_rig\n\n        # AEO sequences\n        timed_Ts_world_rig = load_trajectory_aeo(\n            vrs_path,\n            time_in_secs=False,\n            load_torch=True,\n            subsample=subsample,\n        )\n        if timed_Ts_world_rig is not None:\n            if self.is_adt:\n                T_vio_gravity = get_transform_to_vio_gravity_convention(\n                    GRAVITY_DIRECTION_ADT\n                ).double()\n                for k, T_wr in timed_Ts_world_rig.items():\n                    timed_Ts_world_rig[k] = T_vio_gravity @ T_wr\n            return timed_Ts_world_rig\n\n        # Other sequences\n        timed_Ts_world_rig = load_trajectory(\n            vrs_path,\n            time_in_secs=False,\n            load_torch=True,\n            subsample=subsample,\n        )\n\n        return timed_Ts_world_rig\n\n    def load_snippet_pose(self, start, end):\n        idx_i, idx_j = sample_times(self.T_world_rig_time_ns, start, end)\n        Ts_wr = self.Ts_world_rig[idx_i:idx_j, :]\n        pose_times_ns = torch.LongTensor(self.T_world_rig_time_ns[idx_i:idx_j])\n\n        T_ws = Ts_wr[0].clone().unsqueeze(0)\n        Ts_sr = T_ws.inverse() @ Ts_wr\n        pose_times_s = (\n            pose_times_ns - torch.tensor(start, dtype=torch.long)\n        ).float() * 1e-9\n        return T_ws, Ts_wr, Ts_sr, pose_times_ns, pose_times_s\n\n    def load_snippet_semidense(self, start, end, max_size=20000):\n        idx_i, idx_j = sample_times(self.pts_times_ns, start, end)\n        points_times_ns = self.pts_times_ns[idx_i:idx_j]\n        points_world = [self.time_to_pc[time] for time in points_times_ns]\n\n        for idx, ps in enumerate(points_world):\n            ps = ps[:max_size, :]\n            pad_num = max_size - ps.shape[0]\n            assert pad_num >= 0, f\"padding must be non-negative, but got {pad_num}\"\n\n            points_world[idx] = F.pad(\n                ps,\n                (0, 0, 0, pad_num),\n                \"constant\",\n                float(\"nan\"),\n            )\n        points_world = torch.stack(points_world)\n        points_times_ns = torch.LongTensor(points_times_ns)\n        points_times_s = (\n            points_times_ns - torch.tensor(start, dtype=torch.long)\n        ).float() * 1e-9\n        return points_world, points_times_ns, points_times_s\n\n    def load_snippet_objects(self, start, end):\n        def get_obbs_for_time(t: int, inst_ids: List):\n            (\n                bb2s_rgb,\n                bb2s_slaml,\n                bb2s_slamr,\n                bb3s,\n                Ts_world_object,\n                sem_ids,\n                inst_ids,\n            ) = next_obb_observations(\n                obs=self.obs,\n                time=t,\n                inst_ids=inst_ids,\n                cam_names=[\"rgb\", \"slaml\", \"slamr\"],\n                load_dynamic_objects=True,\n                interpolate_poses=True,\n                dt_threshold_ns=10_000_000,\n            )\n            obbs = ObbTW.from_lmc(\n                bb3s,\n                bb2s_rgb,\n                bb2s_slaml,\n                bb2s_slamr,\n                Ts_world_object,\n                sem_ids,\n                inst_ids,\n            )\n            # scale 2d bbs to image size\n            obbs = rescale_obb_tw(\n                obbs,\n                cam_size_before_rgb=[1408, 1408, 3],  # Aria rgb size\n                cam_size_before_slam=[480, 640, 1],  # Aria slam size\n                down_scale=self.sdi,\n                wh_multiple_of=16,\n            )\n            # center object bounding box in the object coordinate system\n            # T_world_object so that origin is the center of the object\n            obbs = obbs.center()\n\n            # get object sem_id to name mapping\n            sem_id_to_name = {\n                self.obs[\"proto2id\"][self.obs[\"inst2proto\"][iid.item()]]: self.obs[\n                    \"inst2proto\"\n                ][iid.item()]\n                for iid in inst_ids\n            }\n            return obbs, sem_id_to_name\n\n        obbs_snippet, sem_id_to_name, snippet_times = [], {}, []\n        probably_snippet_times = [t for t in self.obb_times if start < t and t <= end]\n        for t in probably_snippet_times:\n            # we get only the instances that are visibile as indicated by them having 2d bb annotations\n            inst_ids = get_instance_id_in_frameset(\n                self.obs,\n                t,\n                load_dynamic_objects=True,\n                interpolate_poses=True,\n                dt_threshold_ns=10_000_000,\n            )\n            snippet_times.append(t)\n            if len(inst_ids) == 0:\n                obbs_snippet.append(\n                    ObbTW(-1 * torch.ones(self.max_objects_per_frameset, 34))\n                )\n                continue\n            obbs, sem2names = get_obbs_for_time(t, inst_ids)\n            obbs_snippet.append(obbs.add_padding(self.max_objects_per_frameset))\n            sem_id_to_name.update(sem2names)\n\n        if len(obbs_snippet) > 0:\n            obbs_padded = ObbTW(smart_stack(obbs_snippet))\n        else:\n            obbs_padded = ObbTW(-1 * torch.ones((0, self.max_objects_per_frameset, 34)))\n            print(f\"could not find obbs for snippet times {snippet_times}\")\n        obbs_time_ns = torch.LongTensor(snippet_times)\n        obbs_time_s = (\n            obbs_time_ns - torch.tensor(start, dtype=torch.long)\n        ).float() * 1e-9\n\n        # subsample\n        obj_idxs = sample_from_range(\n            0, len(obbs_padded), sample_rate=1, add_random=False\n        )\n        obbs_padded = obbs_padded[obj_idxs].contiguous()\n        obbs_time_ns = obbs_time_ns[obj_idxs].contiguous()\n        obbs_time_s = obbs_time_s[obj_idxs].contiguous()\n\n        return obbs_padded, sem_id_to_name, obbs_time_ns, obbs_time_s\n\n    def __len__(self):\n        return min(len(self.snippet_times), self.max_snippets)\n\n    def __getitem__(self, index):\n        if index >= self.max_snippets:\n            raise StopIteration\n\n        sample = {}\n        start, end = self.snippet_times[index]\n\n        rgb_calib = {key: self.cam_calib[key][\"rgb\"] for key in self.cam_calib}\n\n        # img\n        for i in range(3):\n            subsample = int(self.fps[i] / self.frame_rate)\n            imgs, img_times_ns, cam_tws, frame_ids = read_image_snippet_from_vrs(\n                self.reader,\n                ARIA_CAM_INFO[\"id\"][i],\n                start,\n                end,\n                rgb_calib,\n                subsample=subsample,\n                scale_down_images=self.sdi,\n            )\n            img_times_s = (\n                img_times_ns - torch.tensor(start, dtype=torch.long).float()\n            ) * 1e-9\n\n            sample.update(\n                {\n                    ARIA_IMG[i]: imgs,\n                    ARIA_IMG_TIME_NS[i]: img_times_ns,\n                    ARIA_IMG_SNIPPET_TIME_S[i]: img_times_s,\n                    ARIA_FRAME_ID[i]: frame_ids,\n                    ARIA_CALIB[i]: cam_tws,\n                    ARIA_CALIB_TIME_NS[i]: img_times_ns,\n                    ARIA_CALIB_SNIPPET_TIME_S[i]: img_times_s,\n                }\n            )\n\n        # pose\n        T_ws, Ts_wr, Ts_sr, pose_times_ns, pose_times_s = self.load_snippet_pose(\n            start, end\n        )\n        sample.update(\n            {\n                ARIA_SNIPPET_T_WORLD_SNIPPET: T_ws,\n                ARIA_POSE_T_WORLD_RIG: Ts_wr,\n                ARIA_POSE_T_SNIPPET_RIG: Ts_sr,\n                ARIA_POSE_TIME_NS: pose_times_ns,\n                ARIA_POSE_SNIPPET_TIME_S: pose_times_s,\n            }\n        )\n\n        # interpolate slam poses to get img poses\n        sample.update(run_sensor_poses(sample))\n\n        # semidense points\n        pts_world, pts_times_ns, pts_times_s = self.load_snippet_semidense(start, end)\n        sample.update(\n            {\n                ARIA_POINTS_WORLD: pts_world,\n                ARIA_POINTS_TIME_NS: pts_times_ns,\n                ARIA_POINTS_SNIPPET_TIME_S: pts_times_s,\n                ARIA_POINTS_VOL_MIN: self.vol_min,\n                ARIA_POINTS_VOL_MAX: self.vol_max,\n            }\n        )\n\n        # objects\n        if self.obs:\n            obbs_padded, sem_id_to_name, obbs_time_ns, obbs_time_s = (\n                self.load_snippet_objects(start, end)\n            )\n            # transform obbs into snippet coordinate system\n            obbs_padded = transform_obbs(obbs_padded, T_ws.float().inverse())\n            sample.update(\n                {\n                    ARIA_OBB_PADDED: obbs_padded,\n                    ARIA_OBB_SEM_ID_TO_NAME: sem_id_to_name,\n                    ARIA_OBB_TIME_NS: obbs_time_ns,\n                    ARIA_OBB_SNIPPET_TIME_S: obbs_time_s,\n                }\n            )\n\n        for key in sample:\n            if isinstance(sample[key], (PoseTW, CameraTW, ObbTW)):\n                sample[key] = sample[key].tensor()\n\n            if isinstance(sample[key], torch.Tensor):\n                sample[key] = sample[key].float()\n\n            if key not in [\n                ARIA_SNIPPET_T_WORLD_SNIPPET,\n                ARIA_POINTS_VOL_MIN,\n                ARIA_POINTS_VOL_MAX,\n                ARIA_OBB_SEM_ID_TO_NAME,\n            ]:\n                if isinstance(sample[key], torch.Tensor) and sample[key].shape[0] == 0:\n                    continue\n                sample[key] = tensor_unify(sample[key], self.frame_rate)\n\n        if self.preprocess:\n            sample = self.preprocess(sample)\n\n        return sample\n"
  },
  {
    "path": "efm3d/dataset/wds_dataset.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport glob\nimport tarfile\n\nimport numpy as np\nimport torch\nimport webdataset as wds\nfrom efm3d.aria import CameraTW, DEFAULT_CAM_DATA_SIZE, ObbTW, PoseTW, TensorWrapper\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_OBB_PADDED,\n    ARIA_POINTS_VOL_MAX,\n    ARIA_POINTS_VOL_MIN,\n    ARIA_POINTS_WORLD,\n    ARIA_POSE_T_SNIPPET_RIG,\n    ARIA_POSE_T_WORLD_RIG,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\n\n\ndef convert_to_aria_multimodal_dataset(sample):\n    \"\"\"\n    Convert a data sample from Aria multimodal data in webdataset format\n    to training/validation sample format.\n    \"\"\"\n\n    def to_mm_key(k, end_separator=\".\"):\n        k = k[: k.rfind(end_separator)]  # remove suffix\n        # move keys back to the \"/\" convention from the \"-\" separator needed for webdataset paths.\n        k = k.replace(\"-\", \"/\")\n        return k\n\n    image_snippets = {}\n    mm_sample = {}\n    for k, v in sample.items():\n        # Compose images to image snippet\n        if k.endswith(\".jpg\"):\n            img_key = to_mm_key(k, \"_\")\n            if img_key not in image_snippets:\n                image_snippets[img_key] = [v]\n            else:\n                image_snippets[img_key].append(v)\n\n        # np.float32 tensors\n        elif k.endswith(\".pyd\"):\n            k = to_mm_key(k, \".\")\n            if k in [\n                ARIA_POSE_T_SNIPPET_RIG,\n                ARIA_POSE_T_WORLD_RIG,\n                ARIA_SNIPPET_T_WORLD_SNIPPET,\n                ARIA_IMG_T_SNIPPET_RIG[0],\n                ARIA_IMG_T_SNIPPET_RIG[1],\n                ARIA_IMG_T_SNIPPET_RIG[2],\n            ]:\n                mm_sample[k] = PoseTW.from_matrix3x4(v.float())\n            elif k in ARIA_CALIB:\n                assert v.shape[-1] == DEFAULT_CAM_DATA_SIZE, (\n                    \"only allow Fisheye624 cameras\"\n                )\n                mm_sample[k] = CameraTW(v)\n            elif k == ARIA_OBB_PADDED:\n                mm_sample[ARIA_OBB_PADDED] = ObbTW(v)\n            elif k == ARIA_POINTS_WORLD:\n                # load as float32\n                mm_sample[ARIA_POINTS_WORLD] = v.float()\n            elif isinstance(v, dict):\n                # store dicts as (key, datum) lists in order to be able to collate them\n                mm_sample[k] = [(kv, vv) for kv, vv in v.items()]\n            else:\n                mm_sample[k] = v\n\n        # str\n        elif k.endswith(\".txt\"):\n            k = to_mm_key(k, \".\")\n            mm_sample[k] = v\n\n        # int\n        elif k.endswith(\".cls\"):\n            k = to_mm_key(k, \".\")\n            mm_sample[k] = v\n\n        else:\n            pass  # silently ignore data field not used for training\n\n    # images to image snippets\n    for k, v in image_snippets.items():\n        mm_sample[k] = np.transpose(np.stack(v, axis=0), (0, 3, 1, 2))\n        # convert to one-channel for SLAM images\n        if k == ARIA_IMG[1] or k == ARIA_IMG[2]:\n            mm_sample[k] = mm_sample[k][:, :1, :, :]\n        mm_sample[k] = torch.from_numpy(mm_sample[k])\n\n    for key in mm_sample:\n        if \"time_s\" in key:\n            if isinstance(mm_sample[key], np.ndarray):\n                mm_sample[key] = torch.from_numpy(mm_sample[key])\n            assert mm_sample[key].dtype == torch.float32\n            mm_sample[key] = mm_sample[key]\n        if \"time_ns\" in key:\n            if isinstance(mm_sample[key], np.ndarray):\n                mm_sample[key] = torch.from_numpy(mm_sample[key])\n            assert mm_sample[key].dtype == torch.int64\n            mm_sample[key] = mm_sample[key]\n    return mm_sample\n\n\ndef batchify(datum, device=None):\n    # Add batch dimension\n    for key in datum:\n        if isinstance(datum[key], (torch.Tensor, TensorWrapper)):\n            datum[key] = datum[key][None, ...].to(device)\n            if device is not None:\n                datum[key] = datum[key].to(device)\n        else:\n            datum[key] = [datum[key]]\n    return datum\n\n\ndef unbatchify(datum):\n    # Remove batch dimension\n    for key in datum:\n        if isinstance(datum[key], (torch.Tensor, TensorWrapper, list)):\n            datum[key] = datum[key][0]\n    return datum\n\n\ndef get_tar_sample_num(tar_file):\n    sn = set()\n    with tarfile.TarFile(tar_file, \"r\") as tar:\n        for member in tar.getmembers():\n            sn.add(member.name.split(\".\")[0])\n    return len(sn)\n\n\nclass WdsStreamDataset:\n    \"\"\"Sample 2s/1s WDS dataset to specified snippet length and stride\"\"\"\n\n    def __init__(\n        self,\n        data_path,\n        snippet_length_s=1.0,\n        stride_length_s=0.1,\n        wds_length_s=2.0,\n        fps=10,\n        max_snip=99999999,\n    ):\n        self.snippet_length_s = snippet_length_s\n        self.stride_length_s = stride_length_s\n        self.wds_length_s = wds_length_s\n        # wds snippets should always be generated half overlapped\n        self.wds_stride_s = wds_length_s // 2\n        self.fps = fps\n        self.max_snip = max_snip\n\n        tar_list = sorted(glob.glob(f\"{data_path}/*.tar\"))\n        self.samples_per_tar = get_tar_sample_num(tar_list[0])\n        self.num_tars = len(tar_list)\n\n        self.dataset = wds.DataPipeline(\n            wds.SimpleShardList(tar_list),\n            wds.tarfile_to_samples(),\n            wds.decode(\"rgb\"),\n            wds.map(convert_to_aria_multimodal_dataset),\n        )\n        self.dataloader = iter(self.dataset)\n\n        self.frames_wds = int(self.fps * self.wds_length_s)\n        self.frames_out = int(self.fps * self.snippet_length_s)\n        self.frames_stride_wds = int(self.fps * self.wds_stride_s)\n        self.frames_stride_out = int(self.fps * self.stride_length_s)\n\n        self.num_rest = int(\n            (self.wds_length_s - self.snippet_length_s) / self.stride_length_s\n        )\n        self.num_first = int(1 + self.num_rest)\n        self.num_snippets = (\n            self.num_first + (self.samples_per_tar * self.num_tars - 1) * self.num_rest\n        )\n\n        # for iteration\n        self.first = True\n        self.wds_snippet = None\n        self.snip_idx = 0\n        self.global_idx = 0\n\n    def __len__(self):\n        return min(self.num_snippets, self.max_snip)\n\n    def sample_snippet_(self, snippet, start, end):\n        # time crop\n        sample = snippet.copy()\n        for k in sample:\n            if isinstance(sample[k], (torch.Tensor, TensorWrapper)):\n                if k not in [\n                    ARIA_SNIPPET_T_WORLD_SNIPPET,\n                    ARIA_POINTS_VOL_MIN,\n                    ARIA_POINTS_VOL_MAX,\n                ]:\n                    sample[k] = sample[k][start:end, ...]\n\n        return sample\n\n    def __iter__(self):\n        return self\n\n    def if_get_next_(self):\n        if self.wds_snippet is None:\n            return True\n\n        if self.first:\n            return self.snip_idx >= self.num_first\n        else:\n            return self.snip_idx >= self.num_rest\n\n    def __next__(self):\n        if self.global_idx >= self.max_snip:\n            raise StopIteration\n\n        if self.if_get_next_():\n            if self.first and self.wds_snippet is not None:\n                self.first = False\n            self.wds_snippet = next(self.dataloader)\n            self.snip_idx = 0\n\n        if self.first:\n            start = self.snip_idx * self.frames_stride_out\n        else:\n            start = (self.snip_idx + 1) * self.frames_stride_out\n\n        end = start + self.frames_out\n        sample = self.sample_snippet_(self.wds_snippet, start, end)\n        self.snip_idx += 1\n        self.global_idx += 1\n        return sample\n"
  },
  {
    "path": "efm3d/inference/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/inference/eval.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport numpy as np\nimport torch\nfrom efm3d.utils.obb_csv_writer import ObbCsvReader\nfrom efm3d.utils.obb_metrics import ObbMetrics\nfrom efm3d.utils.obb_utils import (\n    draw_prec_recall_curve,\n    prec_recall_bb3,\n    prec_recall_curve,\n)\n\n\ndef check_sem_id_conflict(ids_pred, ids_gt):\n    all_sem_ids = set(list(ids_pred.keys()) + list(ids_gt.keys()))\n    for sem_id in all_sem_ids:\n        if sem_id in ids_pred and sem_id in ids_gt:\n            assert ids_pred[sem_id] == ids_gt[sem_id], (\n                f\"Mismatch id to name for sem id {sem_id}, {ids_pred[sem_id]} in pred but {ids_gt[sem_id]} in GT\"\n            )\n        elif sem_id not in ids_pred:\n            print(f\"sem_id {sem_id} not found in pred\")\n        else:\n            print(f\"sem_id {sem_id} not found in GT\")\n\n\ndef evaluate_obb_csv(\n    pred_csv: str,\n    gt_csv: str,\n    iou: float = 0.2,\n    pr_curve: bool = False,\n):\n    pred_reader = ObbCsvReader(pred_csv)\n    gt_reader = ObbCsvReader(gt_csv)\n    pred_obbs = pred_reader.obbs\n    gt_obbs = gt_reader.obbs\n\n    sem_id_to_name = pred_reader.sem_ids_to_names.copy()\n    sem_id_to_name_gt = gt_reader.sem_ids_to_names.copy()\n    check_sem_id_conflict(sem_id_to_name, sem_id_to_name_gt)\n    sem_id_to_name.update(sem_id_to_name_gt)\n\n    result = {}\n    mAP = ObbMetrics(\n        cam_ids=[0],\n        cam_names=[\"rgb\"],\n        class_metrics=True,\n        eval_2d=False,\n        eval_3d=True,\n        global_name_to_id={\n            name: int(sem_id) for sem_id, name in sem_id_to_name.items()\n        },\n    )\n\n    ts = list(pred_obbs.keys()) + list(gt_obbs.keys())\n    ts = list(set(ts))\n    ts.sort()\n\n    gt_ts_miss = 0\n    pred_ts_miss = 0\n    for t in ts:\n        if t not in pred_obbs:\n            print(f\"pred obbs not found for {t}\")\n            pred_ts_miss += 1\n            continue\n        if t not in gt_obbs:\n            print(f\"gt obbs not found for {t}\")\n            gt_ts_miss += 1\n            continue\n\n        # we should not have any paddings\n        assert pred_obbs[t].shape[0] == pred_obbs[t].remove_padding().shape[0]\n        assert gt_obbs[t].shape[0] == gt_obbs[t].remove_padding().shape[0]\n\n        # always do precision recall calculation\n        prec, rec, match_mat, ious, per_class_results = prec_recall_bb3(\n            pred_obbs[t],\n            gt_obbs[t],\n            iou_thres=iou,\n            return_ious=True,\n            per_class=True,\n        )\n        tps = match_mat.any(-1)\n        fps = (~match_mat).all(-1)\n        result[f\"precision@IoU{iou}\"] = float(prec)\n        result[f\"recall@IoU{iou}\"] = float(rec)\n        result[f\"num_true_positives@IoU{iou}\"] = int(tps.sum())\n        result[\"num_dets\"] = match_mat.shape[0]\n        result[\"num_gts\"] = match_mat.shape[1]\n        for sem_id, per_class_result in per_class_results.items():\n            result[f\"precision@IoU{iou}@Class_{sem_id_to_name[sem_id.item()]}\"] = float(\n                per_class_result[\"precision\"]\n            )\n            result[f\"recall@IoU{iou}@Class_{sem_id_to_name[sem_id.item()]}\"] = float(\n                per_class_result[\"recall\"]\n            )\n        # check if the preds contain probabilities\n        prob = pred_obbs[t].prob.squeeze()\n        assert not torch.all(prob.eq(-1.0)), (\n            \"the obbs don't contain valid probabilities for mAP calculation.\"\n        )\n        # add pred/gt pair to mAP calculator.\n        mAP.update(pred_obbs[t], gt_obbs[t])\n\n        output_dir = os.path.dirname(pred_csv)\n        if pr_curve and len(ts) == 1:\n            precs, recalls, probs = prec_recall_curve([(pred_obbs[t], gt_obbs[t])])\n            draw_prec_recall_curve(\n                precs, recalls, save_folder=output_dir, iou_thres=iou\n            )\n\n    result[\"num_timestamps\"] = len(ts)\n    result[\"num_timestamp_miss_pred\"] = pred_ts_miss\n    result[\"num_timestamp_miss_gt\"] = gt_ts_miss\n\n    result_map = mAP.compute()\n    # ignore average recall\n    result_map = {\n        k: v.item() for k, v in result_map.items() if not k.startswith(\"rgb/mar_\")\n    }\n    result.update(result_map)\n    return result\n\n\ndef obb_eval_dataset(input_folder: str, iou: float = 0.2):\n    \"\"\"\n    Obb eval at dataset-level\n    \"\"\"\n\n    GT_OBB_FILENAME = \"gt_scene_obbs.csv\"\n    PRED_OBB_FILENAME = \"tracked_scene_obbs.csv\"\n\n    # get all the pred and gt csv files\n    pred_csv_paths, gt_csv_paths = [], []\n    filenames = os.listdir(input_folder)\n    dirs = [os.path.join(input_folder, f) for f in filenames]\n    dirs = [d for d in dirs if os.path.isdir(d)]\n    for d in dirs:\n        pred_csv = os.path.join(d, PRED_OBB_FILENAME)\n        gt_csv = os.path.join(d, GT_OBB_FILENAME)\n        if os.path.exists(gt_csv) and os.path.exists(pred_csv):\n            pred_csv_paths.append(pred_csv)\n            gt_csv_paths.append(gt_csv)\n\n    result = {}\n    result[\"num_seqs\"] = len(pred_csv_paths)\n    if len(pred_csv_paths) == 0 or len(gt_csv_paths) == 0:\n        return result\n\n    pred_obbs, gt_obbs = [], []\n    sem_id_to_name = {}\n\n    for pred_csv, gt_csv in zip(pred_csv_paths, gt_csv_paths):\n        pred_reader = ObbCsvReader(pred_csv)\n        gt_reader = ObbCsvReader(gt_csv)\n        p_obbs = pred_reader.obbs\n        g_obbs = gt_reader.obbs\n        # p_obbs, g_obbs are single-item dicts\n        p_obbs = next(iter(p_obbs.values()))\n        g_obbs = next(iter(g_obbs.values()))\n        pred_obbs.append(p_obbs)\n        gt_obbs.append(g_obbs)\n\n        sem_id_to_name_pred = pred_reader.sem_ids_to_names.copy()\n        sem_id_to_name_gt = gt_reader.sem_ids_to_names.copy()\n        check_sem_id_conflict(sem_id_to_name_pred, sem_id_to_name_gt)\n        sem_id_to_name.update(sem_id_to_name_gt)\n\n    mAP = ObbMetrics(\n        cam_ids=[0],\n        cam_names=[\"rgb\"],\n        class_metrics=True,\n        eval_2d=False,\n        eval_3d=True,\n        global_name_to_id={\n            name: int(sem_id) for sem_id, name in sem_id_to_name.items()\n        },\n    )\n\n    precs, recs = [], []\n    for p_obbs, g_obbs in zip(pred_obbs, gt_obbs):\n        prec, rec, match_mat, ious, per_class_results = prec_recall_bb3(\n            p_obbs,\n            g_obbs,\n            iou_thres=iou,\n            return_ious=True,\n            per_class=True,\n        )\n        precs.append(prec)\n        recs.append(rec)\n        mAP.update(p_obbs, g_obbs)\n    result[f\"precision@IoU{iou}\"] = np.mean(precs)\n    result[f\"recall@IoU{iou}\"] = np.mean(recs)\n\n    precs, recalls, probs = prec_recall_curve(\n        [(p_obbs, g_obbs) for p_obbs, g_obbs in zip(pred_obbs, gt_obbs)]\n    )\n\n    # save precision-recall curve to png\n    save_dir = input_folder\n    draw_prec_recall_curve(precs, recalls, save_folder=save_dir, iou_thres=iou)\n\n    result_map = mAP.compute()\n    # ignore average recall (e.g. \"rgb/mar_220_3D\")\n    result_map = {\n        k: v.item() for k, v in result_map.items() if not k.startswith(\"rgb/mar_\")\n    }\n    result.update(result_map)\n    return result\n\n\ndef main():\n    import argparse\n    import json\n\n    parser = argparse.ArgumentParser(description=\"Run EFM eval pipeline\")\n    parser.add_argument(\n        \"--input_folder\",\n        type=str,\n        help=\"The input folder that contains the gt and pred obbs csv files. If this is provided, the eval will be done at dataset-level\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--pred_csv\",\n        type=str,\n        help=\"The prediction obbs csv file, can be snippet-level snippet_obbs.csv or scene-level tracked_scene_obbs.csv\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--gt_csv\",\n        type=str,\n        help=\"The ground truth obbs csv file, can be snippet-level gt_obbs.csv or scene-level gt_scene_obbs.csv\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--iou\",\n        type=float,\n        default=0.2,\n    )\n    parser.add_argument(\n        \"--pr_curve\",\n        action=\"store_true\",\n        help=\"Whether to draw precision recall curve\",\n    )\n    args = parser.parse_args()\n\n    if args.input_folder:\n        metrics = obb_eval_dataset(args.input_folder)\n        print(json.dumps(metrics, indent=2, sort_keys=True))\n    else:\n        assert args.pred_csv is not None, \"pred_csv is required\"\n        assert args.gt_csv is not None, \"gt_csv is required\"\n\n        metrics = evaluate_obb_csv(args.pred_csv, args.gt_csv, args.iou, args.pr_curve)\n        output_dir = os.path.dirname(args.pred_csv)\n        print(json.dumps(metrics, indent=2, sort_keys=True))\n        with open(os.path.join(output_dir, \"metrics.json\"), \"w\") as f:\n            json.dump(metrics, f, indent=2, sort_keys=True)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "efm3d/inference/fuse.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport glob\nimport logging\nimport os\nfrom typing import List\n\nimport numpy as np\nimport torch\nimport tqdm\nimport trimesh\nfrom efm3d.aria.pose import PoseTW\nfrom efm3d.utils.marching_cubes import marching_cubes_scaled\nfrom efm3d.utils.reconstruction import pc_to_vox, sample_voxels\nfrom efm3d.utils.voxel import create_voxel_grid\n\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\ndef set_boundary_value(x, val, thickness):\n    if thickness == 0:\n        return x\n    x[..., :thickness, :, :] = val\n    x[..., -thickness:, :, :] = val\n    x[..., :, :thickness, :] = val\n    x[..., :, -thickness:, :] = val\n    x[..., :, :, :thickness] = val\n    x[..., :, :, -thickness:] = val\n    return x\n\n\ndef load_tensor(fname, device):\n    data = torch.load(fname, map_location=device)\n    if \"_8b\" in fname:\n        data = data.dequantize()\n    return data\n\n\nclass VolumeFusion:\n    def __init__(\n        self,\n        voxel_size: List[float],\n        voxel_extent: List[float],\n        device: str = \"cuda\",\n        dtype=torch.float32,\n        w_min: float = 5.0,\n        w_max: float = 100.0,\n        init_value: float = 0.0,\n        surface_thres: float = 0.99,\n        boundary_thres: int = 1,\n    ):\n        self.voxel_size = voxel_size  # D x H x W\n        self.voxel_extent = voxel_extent  # W x H x D\n        self.vD, self.vH, self.vW = self.voxel_size\n        self.vD = int(self.vD)\n        self.vH = int(self.vH)\n        self.vW = int(self.vW)\n        self.w_max = w_max\n        self.w_min = w_min\n        self.surface_thres = surface_thres\n        self.boundary_thres = boundary_thres\n\n        self.global_volume = torch.ones(\n            self.vD, self.vH, self.vW, device=device, dtype=dtype\n        )  # D H W\n        self.global_volume = self.global_volume * init_value\n        self.global_volume_weights = torch.zeros_like(self.global_volume)  # D H W\n        self.global_volume_points = create_voxel_grid(\n            self.vW, self.vH, self.vD, self.voxel_extent, device\n        ).to(dtype=dtype)  # W, H, D, 3\n        # reshaping\n        self.global_volume_points = self.global_volume_points.permute(\n            2, 1, 0, 3\n        )  # D H W 3\n        self.global_volume_points = self.global_volume_points.reshape(\n            -1, 3\n        )  # (D*H*W) x 3\n        self.global_volume_weights = self.global_volume_weights.reshape(-1)  # D*H*W\n        self.global_volume = self.global_volume.reshape(-1)  # D*H*W\n\n        self.device = device\n\n    def set_boundary_mask(self, mask):\n        thickness = self.boundary_thres\n        mask[:thickness] = False  # Set the first 'thickness' layers in Height to zero\n        mask[-thickness:] = False  # Set the last 'thickness' layers in Height to zero\n        mask[:, :thickness] = False  # Set the first 'thickness' layers in Width to zero\n        mask[:, -thickness:] = False  # Set the last 'thickness' layers in Width to zero\n        mask[:, :, :thickness] = (\n            False  # Set the first 'thickness' layers in Depth to zero\n        )\n        mask[:, :, -thickness:] = (\n            False  # Set the last 'thickness' layers in Depth to zero\n        )\n        return mask\n\n    def fuse(\n        self,\n        local_volume: torch.Tensor,\n        local_extent: List[float],\n        T_l_w: PoseTW,\n        new_obs_w=1.5,\n        visiblity_mask=None,\n    ):\n        local_volume = local_volume.to(self.global_volume.device)\n        T_l_w = T_l_w.to(self.global_volume.device)\n\n        vD, vH, vW = local_volume.shape\n        # transform global_volume to local volume\n        global_volume_l = T_l_w * self.global_volume_points\n        global_volume_l_coord, valid_global_points = pc_to_vox(\n            global_volume_l, vW, vH, vD, local_extent\n        )\n        local_samples, valid_samples = sample_voxels(\n            local_volume.unsqueeze(0).unsqueeze(0).float(),\n            global_volume_l_coord.view(1, -1, 3).float(),\n        )\n        local_samples = (\n            local_samples.squeeze(0).squeeze(0).to(dtype=self.global_volume.dtype)\n        )\n        valid_samples = valid_samples.squeeze(0)\n\n        # making a mask\n        surface_mask = local_volume < self.surface_thres\n        if visiblity_mask is not None:\n            surface_mask &= visiblity_mask.to(surface_mask)\n        # we don't trust the boundary voxels from CNNS\n        if self.boundary_thres > 0:\n            surface_mask = self.set_boundary_mask(surface_mask)\n        surface_mask_f = surface_mask.float()\n        surface_mask_f[~surface_mask] = torch.nan\n        # sample the mask\n        surface_mask_samples, _ = sample_voxels(\n            surface_mask_f.unsqueeze(0).unsqueeze(0).float(),\n            global_volume_l_coord.view(1, -1, 3).float(),\n        )\n        surface_mask = ~surface_mask_samples.isnan()\n        valid_samples = valid_samples & surface_mask\n        mask = valid_samples & valid_global_points\n        mask = mask.squeeze()\n        w = self.global_volume_weights[mask]\n\n        self.global_volume[mask] = (\n            self.global_volume[mask] * w + local_samples[mask] * 2.0\n        ) / (w + 2.0)\n\n        # update weights\n        self.global_volume_weights[mask] = w + new_obs_w\n        self.global_volume_weights[mask] = self.global_volume_weights[mask].clamp(\n            max=self.w_max\n        )\n\n    def get_volume(self, reshape=True):\n        if reshape:\n            return self.global_volume.reshape(self.vD, self.vH, self.vW)\n        else:\n            return self.global_volume\n\n    def get_weights(self, reshape=True):\n        if reshape:\n            return self.global_volume_weights.reshape(self.vD, self.vH, self.vW)\n        else:\n            self.global_volume_weights\n\n    def get_mask(self, reshape=True):\n        mask = self.global_volume_weights >= self.w_min\n        if reshape:\n            return mask.reshape(self.vD, self.vH, self.vW)\n        else:\n            mask\n\n    def get_trimesh(self, iso_level=0.5):\n        global_vol = self.get_volume()\n        mask = self.get_mask()\n        verts_w, faces, _ = marching_cubes_scaled(\n            global_vol.cpu().detach().float(),\n            iso_level,\n            self.voxel_extent,\n            mask,\n        )\n        sem_rgb = None\n        mesh = trimesh.Trimesh(verts_w, faces, vertex_colors=sem_rgb)\n\n        return mesh\n\n\nclass VolumetricFusion:\n    def __init__(\n        self,\n        input_folder,\n        w_min=5.0,\n        w_max=9999999.0,\n        voxel_res=0.04,\n        device=\"cuda\",\n    ):\n        self.input_folder = input_folder\n        self.per_snip_folder = os.path.join(input_folder, \"per_snip\")\n        f_vol_min = os.path.join(self.per_snip_folder, \"scene_vol_min.pt\")\n        f_vol_max = os.path.join(self.per_snip_folder, \"scene_vol_max.pt\")\n        assert os.path.exists(f_vol_min) and os.path.exists(f_vol_max), (\n            \"missing scene volume info\"\n        )\n        self.vol_min = load_tensor(f_vol_min, \"cpu\").numpy()\n        self.vol_max = load_tensor(f_vol_max, \"cpu\").numpy()\n        self.w_min = w_min\n        self.w_max = w_max\n        self.voxel_res = voxel_res\n        self.device = device\n\n        self.vis_norm_grad_occ_thr = 0.2\n        # we remove a 1 voxel wide boundary on the volumes to remove cnn artifacts\n        self.boundary_thresh = 1\n\n        self.f_occ_preds = sorted(\n            glob.glob(os.path.join(self.per_snip_folder, \"occ_pr*.pt\"))\n        )\n        Ts_wv_pt = os.path.join(self.per_snip_folder, \"Ts_wv.pt\")\n        self.Ts_wv = torch.load(Ts_wv_pt, map_location=\"cpu\")  # need to be on cpu\n        assert len(self.f_occ_preds) == self.Ts_wv.shape[0], (\n            f\"occ snippets {len(self.f_occ_preds)} should match with Ts_wv {self.Ts_wv.shape[0]}\"\n        )\n\n        # load voxel extent for initialization\n        ve_path = os.path.join(self.per_snip_folder, \"voxel_extent.pt\")\n        self.local_extent = torch.load(ve_path).cpu()\n        if self.local_extent.ndim == 2:\n            self.local_extent = self.local_extent.squeeze(0)\n        self.local_extent = self.local_extent.tolist()\n        self.global_vol = None\n\n        self.init_from_range(self.vol_min, self.vol_max)\n\n    def reinit(self):\n        # reinit with the same voxel extent\n        if self.global_vol is not None:\n            del self.global_vol\n        self.init_from_range(self.vol_min, self.vol_max)\n\n    def init_from_range(self, xyz_min, xyz_max):\n        # Add a little buffer around the bounds.\n        xyz_min -= 2 * self.voxel_res\n        xyz_max += 2 * self.voxel_res\n        if xyz_min.ndim == 2:\n            xyz_min = xyz_min[0]\n        if xyz_max.ndim == 2:\n            xyz_max = xyz_max[0]\n\n        global_extent = [\n            xyz_min[0],\n            xyz_max[0],\n            xyz_min[1],\n            xyz_max[1],\n            xyz_min[2],\n            xyz_max[2],\n        ]\n        voxel_size = np.ceil((xyz_max - xyz_min) / self.voxel_res).tolist()\n        voxel_size.reverse()  # change to DxHxW\n        self.global_vol = VolumeFusion(\n            voxel_size,\n            global_extent,\n            device=self.device,\n            w_min=self.w_min,\n            w_max=self.w_max,\n            init_value=1.0,\n            surface_thres=0.99,\n        )\n\n    def get_trimesh(self):\n        return self.global_vol.get_trimesh()\n\n    def run_step(self, i):\n        # run one step of volume fusion\n        if i >= len(self.f_occ_preds):\n            logger.info(\n                f\"{i}-th snippet exceeding the number of snippets {len(self.f_occ_preds)}\"\n            )\n            return\n        T_wv = self.Ts_wv[i]\n        occ_pred = load_tensor(self.f_occ_preds[i], self.device)  # [1, 1, D, H, W]\n        occ_pred = occ_pred[0][0]  # [D, H, W]\n\n        self.global_vol.fuse(\n            local_volume=occ_pred,\n            local_extent=self.local_extent,\n            T_l_w=T_wv.inverse(),\n        )\n\n    def run(self):\n        logger.info(\"Fusing voxel occupancy using volume fusion...\")\n        for i, _ in tqdm.tqdm(enumerate(self.f_occ_preds), total=len(self.f_occ_preds)):\n            self.run_step(i)\n"
  },
  {
    "path": "efm3d/inference/model.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\nimport time\n\nimport torch\nimport tqdm\nfrom efm3d.aria.aria_constants import (\n    ARIA_IMG_TIME_NS,\n    ARIA_OBB_PADDED,\n    ARIA_OBB_PRED_SEM_ID_TO_NAME,\n    ARIA_OBB_PRED_VIZ,\n    ARIA_OBB_SEM_ID_TO_NAME,\n    ARIA_POINTS_VOL_MAX,\n    ARIA_POINTS_VOL_MIN,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.aria.obb import obb_time_union\nfrom efm3d.dataset.wds_dataset import batchify\nfrom efm3d.utils.obb_csv_writer import ObbCsvWriter\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass EfmInference:\n    def __init__(self, streamer, model, output_dir, device, zip, obb_only=False):\n        self.streamer = streamer\n        self.model = model\n        self.output_dir = output_dir\n        self.device = device\n        self.zip = zip\n        self.obb_only = obb_only\n        self.metadata_saved = False\n        self.Ts_wv = []  # all T_world_voxel as one tensor\n\n        self.obb_csv_path = os.path.join(output_dir, \"snippet_obbs.csv\")\n        self.obb_writer = None\n        self.per_snip_dir = os.path.join(output_dir, \"per_snip\")\n        shutil.rmtree(self.per_snip_dir, ignore_errors=True)\n        os.makedirs(self.per_snip_dir, exist_ok=True)\n\n        # obb GT\n        self.gt_obb_csv_path = os.path.join(output_dir, \"gt_obbs.csv\")\n        self.gt_obb_writer = None\n        self.scene_gt_obbs_w = []\n\n    def __del__(self):\n        if not self.zip:\n            return\n\n        # compress the output folder\n        if os.path.exists(self.output_dir) and os.listdir(self.output_dir):\n            logger.info(f\"zipping file to {self.output_dir}.zip\")\n            shutil.make_archive(\n                self.output_dir.rstrip(\"/\"), \"zip\", self.output_dir, verbose=True\n            )\n        logger.info(f\"zip file saved to {self.output_dir}.zip\")\n\n    def save_tensor(self, tensor, key, idx=None, output_dir=\"\"):\n        if idx is not None:\n            pt_name = os.path.join(output_dir, f\"{key}_{idx:06}.pt\")\n        else:\n            pt_name = os.path.join(output_dir, f\"{key}.pt\")\n        torch.save(tensor.cpu(), pt_name)\n\n    def save_output(self, data, idx, output_dir):\n        \"\"\"\n        Save per-snippet 3D obb output and occupancy tensor to disk.\n        \"\"\"\n        # assuming single sample batch\n        bid = 0\n\n        # 3d obb predictions\n        if ARIA_OBB_PRED_VIZ in data:\n            obb_preds_s = data[ARIA_OBB_PRED_VIZ][bid].remove_padding()\n            T_ws = data[ARIA_SNIPPET_T_WORLD_SNIPPET][bid]\n            obb_preds_w = obb_preds_s.transform(T_ws)\n            first_rgb_time_ns = data[ARIA_IMG_TIME_NS[0]][bid, 0].item()\n            if self.obb_writer is None:\n                self.obb_writer = ObbCsvWriter(self.obb_csv_path)\n            self.obb_writer.write(\n                obb_preds_w, first_rgb_time_ns, data[ARIA_OBB_PRED_SEM_ID_TO_NAME]\n            )\n\n            if ARIA_OBB_PADDED in data and ARIA_OBB_SEM_ID_TO_NAME in data:\n                gt_obbs_s = obb_time_union(data[ARIA_OBB_PADDED])[bid].remove_padding()\n                gt_obbs_w = gt_obbs_s.transform(T_ws)\n                self.scene_gt_obbs_w.append(gt_obbs_w.add_padding(128))\n\n                if self.gt_obb_writer is None:\n                    self.gt_obb_writer = ObbCsvWriter(self.gt_obb_csv_path)\n\n                gt_sem_id_to_name = {}\n                gt_sem_id_to_name.update(data[ARIA_OBB_SEM_ID_TO_NAME][bid])\n                self.gt_obb_writer.write(\n                    gt_obbs_w,\n                    first_rgb_time_ns,\n                    sem_id_to_name=gt_sem_id_to_name,\n                )\n\n        # occupancy predictions (skipped in obb_only mode)\n        if (\n            not self.obb_only\n            and \"occ_pr\" in data\n            and ARIA_POINTS_VOL_MIN in data\n            and ARIA_POINTS_VOL_MAX in data\n        ):\n            if not self.metadata_saved:\n                self.save_tensor(\n                    data[\"voxel_extent\"],\n                    \"voxel_extent\",\n                    idx=None,\n                    output_dir=output_dir,\n                )\n                self.metadata_saved = True\n                self.save_tensor(\n                    data[ARIA_POINTS_VOL_MIN][0],  # tensor(3)\n                    \"scene_vol_min\",\n                    idx=None,\n                    output_dir=output_dir,\n                )\n                self.save_tensor(\n                    data[ARIA_POINTS_VOL_MAX][0],  # tensor(3)\n                    \"scene_vol_max\",\n                    idx=None,\n                    output_dir=output_dir,\n                )\n\n            self.save_tensor(data[\"occ_pr\"], \"occ_pr\", idx, output_dir)\n            self.Ts_wv.append(data[\"voxel/T_world_voxel\"][0])\n\n    def run(self):\n        # feed the per-snippet data to the model\n        gt_sem_id = {}\n        idx = 0\n\n        start = time.time()\n        for batch in tqdm.tqdm(self.streamer, total=len(self.streamer)):\n            # convert single sample to batch and move to GPU\n            batchify(batch, device=self.device)\n\n            with torch.no_grad():\n                output = self.model(batch, obb_only=self.obb_only)\n                batch.update(output)\n                self.save_output(batch, idx, self.per_snip_dir)\n            if ARIA_OBB_SEM_ID_TO_NAME in batch:\n                gt_sem_id.update(batch[ARIA_OBB_SEM_ID_TO_NAME][0])\n            idx += 1\n\n        print(f\"\\ninference speed {idx / (time.time() - start):.02f} sample/s\")\n\n        # save all T_wv as one tensor to avoid writing small files\n        if len(self.Ts_wv) > 0:\n            Ts_wv = torch.stack(self.Ts_wv, dim=0)\n            self.save_tensor(Ts_wv, \"Ts_wv\", None, self.per_snip_dir)\n\n        # write scene-level obbs\n        if len(self.scene_gt_obbs_w) > 0:\n            max_obbs = 512\n            merged_gts = torch.stack(self.scene_gt_obbs_w, dim=0)\n            merged_gts = obb_time_union(merged_gts.unsqueeze(0), pad_size=max_obbs)\n            merged_gts = merged_gts[0].remove_padding()\n\n            gt_scene_obb_csv_path = os.path.join(self.output_dir, \"gt_scene_obbs.csv\")\n            gt_scene_obb_writer = ObbCsvWriter(gt_scene_obb_csv_path)\n            gt_scene_obb_writer.write(merged_gts, -1, gt_sem_id)\n"
  },
  {
    "path": "efm3d/inference/pipeline.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport math\nimport os\nimport shutil\n\nimport hydra\nimport numpy as np\nimport omegaconf\nimport torch\nimport trimesh\nfrom efm3d.dataset.vrs_dataset import preprocess_inference, VrsSequenceDataset\nfrom efm3d.inference.fuse import VolumetricFusion\nfrom efm3d.inference.model import EfmInference\nfrom efm3d.inference.viz import generate_video\nfrom efm3d.utils.gravity import correct_adt_mesh_gravity\nfrom efm3d.utils.mesh_utils import eval_mesh_to_mesh\n\n\ndef get_gt_mesh_ply(data_path):\n    \"\"\"\n    Return ASE or ADT GT mesh path. If not exist, return empty str.\n    \"\"\"\n    if data_path.endswith(\".vrs\"):\n        seq_name = os.path.basename(os.path.dirname(data_path))\n    else:\n        seq_name = os.path.basename(data_path.strip(\"/\"))\n\n    adt_mesh_ply = f\"./data/adt_mesh/{seq_name}/gt_mesh.ply\"\n    ase_mesh_ply = f\"./data/ase_mesh/scene_ply_{seq_name}.ply\"\n    if os.path.exists(adt_mesh_ply):\n        return adt_mesh_ply\n    elif os.path.exists(ase_mesh_ply):\n        return ase_mesh_ply\n    return \"\"\n\n\ndef compute_avg_metrics(paths):\n    \"\"\"\n    Given metrics path list, compute the average metrics\n    Note that simply averaging is not a good way to compute mAP metrics.\n    \"\"\"\n    avg_ret = {}\n\n    for path in paths:\n        with open(path, \"r\") as f:\n            metrics = json.load(f)\n            for k, v in metrics.items():\n                if k not in avg_ret:\n                    avg_ret[k] = [v]\n                else:\n                    avg_ret[k].append(v)\n    for k, v in avg_ret.items():\n        avg_ret[k] = np.mean(v)\n    return avg_ret\n\n\ndef create_streamer(\n    data_path, snippet_length_s, stride_length_s, max_snip, skip_snips=0\n):\n    # infer data type\n    def is_atek_wds_input(data_path):\n        ATEK_WDS_TAR = \"shards-0000.tar\"\n        first_tar = os.path.join(data_path, ATEK_WDS_TAR)\n        return os.path.exists(first_tar)\n\n    if is_atek_wds_input(data_path):\n        from efm3d.dataset.atek_wds_dataset import AtekWdsStreamDataset\n\n        streamer = AtekWdsStreamDataset(\n            data_path,\n            atek_to_efm_taxonomy=f\"{os.path.dirname(__file__)}/../config/taxonomy/atek_to_efm.csv\",\n            snippet_length_s=snippet_length_s,\n            stride_length_s=stride_length_s,\n            max_snip=max_snip,\n        )\n    elif data_path.endswith(\".vrs\"):\n        # Use the native vrs sequence processor\n        streamer = VrsSequenceDataset(\n            data_path,\n            frame_rate=10,\n            sdi=2,\n            snippet_length_s=snippet_length_s,\n            stride_length_s=stride_length_s,\n            max_snippets=max_snip,\n            skip_snippets=skip_snips,\n            preprocess=preprocess_inference,\n        )\n\n        # (optional) use the ATEK data loader If it is installed\n        # from efm3d.dataset.atek_vrs_dataset import create_atek_raw_data_loader_from_vrs_path\n        # streamer = create_atek_raw_data_loader_from_vrs_path(\n        #     vrs_path=data_path,\n        #     freq_hz=10,\n        #     snippet_length_s=snippet_length_s,\n        #     stride_length_s=stride_length_s,\n        #     skip_begin_seconds=20.0,\n        #     skip_end_seconds=5.0,\n        #     max_snippets=max_snip,\n        # )\n    else:\n        print(\n            f\"Input error {data_path}, expect the input to be a folder to WDS tars or a .vrs file\"\n        )\n        exit(-1)\n    return streamer\n\n\ndef create_output_dir(output_dir, model_ckpt, data_path):\n    # create output path from model ckpt and data path\n    # e.g. result will be output to <output_dir>/<ckpt_name>/<seq_name>\n    model_name = os.path.basename(os.path.splitext(model_ckpt)[0])\n    seq_name = data_path\n    if data_path.endswith(\".vrs\"):\n        seq_name = os.path.basename(os.path.dirname(data_path))\n    else:\n        seq_name = os.path.basename(data_path.strip(\"/\"))\n    output_dir = os.path.join(output_dir, f\"{model_name}\", f\"{seq_name}\")\n    if not os.path.exists(output_dir):\n        os.makedirs(output_dir, exist_ok=True)\n    return output_dir\n\n\ndef run_one(\n    data_path,\n    model_ckpt,\n    model_cfg,\n    max_snip=9999,\n    snip_stride=0.1,\n    voxel_res=0.04,\n    output_dir=\"./output\",\n    obb_only=False,\n    skip_video=False,\n    skip_snips=0,\n):\n    output_dir = create_output_dir(output_dir, model_ckpt, data_path)\n\n    # create model\n    if torch.cuda.is_available():\n        device = \"cuda\"\n    elif torch.backends.mps.is_available():\n        device = \"mps\"\n    else:\n        device = \"cpu\"\n\n    checkpoint = torch.load(model_ckpt, weights_only=True, map_location=device)\n    model_config = omegaconf.OmegaConf.load(model_cfg)\n    model = hydra.utils.instantiate(model_config)\n    model.load_state_dict(checkpoint[\"state_dict\"], strict=True)\n    model.to(device)\n    model.eval()\n    print(\"model init done\")\n\n    # create dataset\n    streamer = create_streamer(\n        data_path,\n        snippet_length_s=1.0,\n        stride_length_s=snip_stride,\n        max_snip=max_snip,\n        skip_snips=skip_snips,\n    )\n\n    # per-snippet inference\n    efm_inf = EfmInference(\n        streamer, model, output_dir, device=device, zip=False, obb_only=obb_only\n    )\n    efm_inf.run()\n    del efm_inf\n    del model\n\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    # track obbs\n    try:\n        from efm3d.inference.track import track_obbs\n\n        track_obbs(output_dir)\n    except:\n        print(f\"Skip tracking obb due to missing dependency, please see INSTALL.md\")\n\n    # eval obb\n    metrics = {}\n    pred_csv = os.path.join(output_dir, \"tracked_scene_obbs.csv\")\n    gt_csv = os.path.join(output_dir, \"gt_scene_obbs.csv\")\n    if os.path.exists(pred_csv) and os.path.exists(gt_csv):\n        try:\n            from efm3d.inference.eval import evaluate_obb_csv\n\n            obb_metrics = evaluate_obb_csv(pred_csv=pred_csv, gt_csv=gt_csv, iou=0.2)\n            metrics.update(obb_metrics)\n        except:\n            print(\n                f\"Skip obb evaluation due to missing dependency, please see INSTALL.md\"\n            )\n\n    vol_fusion = None\n    if not obb_only:\n        # fuse mesh\n        vol_fusion = VolumetricFusion(output_dir, voxel_res=voxel_res, device=device)\n        vol_fusion.run()\n        fused_mesh = vol_fusion.get_trimesh()\n        pred_mesh_ply = os.path.join(output_dir, \"fused_mesh.ply\")\n        if fused_mesh.vertices.shape[0] > 0 and fused_mesh.faces.shape[0] > 0:\n            fused_mesh.export(pred_mesh_ply)\n\n        # eval mesh\n        gt_mesh_ply = get_gt_mesh_ply(data_path)\n        if os.path.exists(pred_mesh_ply) and os.path.exists(gt_mesh_ply):\n            pred_trimesh = trimesh.load(pred_mesh_ply)\n            gt_trimesh = trimesh.load(gt_mesh_ply)\n            if \"adt\" in gt_mesh_ply:\n                gt_trimesh = correct_adt_mesh_gravity(gt_trimesh)\n\n            mesh_metrics, _, _ = eval_mesh_to_mesh(\n                pred=pred_trimesh,\n                gt=gt_trimesh,\n                sample_num=1000,\n            )\n            metrics.update(mesh_metrics)\n    else:\n        print(\"Skipping volume fusion (--obb_only)\")\n\n    # write metrics\n    if len(metrics) > 0:\n        with open(os.path.join(output_dir, \"metrics.json\"), \"w\") as f:\n            json.dump(metrics, f, indent=2, sort_keys=True)\n        print(json.dumps(metrics, indent=2, sort_keys=True))\n\n    # viz\n    if not skip_video:\n        streamer = create_streamer(\n            data_path=data_path,\n            snippet_length_s=1.0,\n            stride_length_s=1.0,\n            max_snip=math.ceil((max_snip - 1) * snip_stride),\n            skip_snips=int(skip_snips * snip_stride),\n        )\n        if vol_fusion is not None:\n            vol_fusion.reinit()\n        viz_path = generate_video(\n            streamer, output_dir=output_dir, vol_fusion=vol_fusion, stride_s=snip_stride\n        )\n        print(f\"output viz file to {os.path.abspath(viz_path)}\")\n\n    # rm per-snippet occupancy tensors\n    per_snip_dir = os.path.join(output_dir, \"per_snip\")\n    if os.path.exists(per_snip_dir):\n        shutil.rmtree(per_snip_dir)\n"
  },
  {
    "path": "efm3d/inference/track.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\n\nimport torch\nfrom efm3d.utils.obb_csv_writer import ObbCsvReader, ObbCsvWriter\nfrom efm3d.utils.obb_trackers import ObbTracker\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\ndef track_obbs(input_path, prob_inst_thr=0.3, prob_assoc_thr=0.25):\n    \"\"\"\n    Run ObbTracker on input csv file.\n\n    input_path: path to input folder or obbs csv file. if folder, will look for 'snippet_obbs.csv'\n    as the input obbs csv file.\n    prob_inst_thr: minimum probability threshold for instantiating a new world obb\n    prob_assoc_thr: minimum probability threshold for associating a new obb with existing world obbs\n    \"\"\"\n    if not os.path.exists(input_path):\n        logger.error(f\"Input folder {input_path} does not exist\")\n        return\n\n    if input_path.endswith(\".csv\"):\n        obb_csv_path = input_path\n        obb_folder = os.path.dirname(input_path)\n    else:\n        obb_csv_path = os.path.join(input_path, \"snippet_obbs.csv\")\n        obb_folder = input_path\n    assert os.path.exists(obb_csv_path), f\"No obb csv file found {obb_csv_path}\"\n\n    tracked_obbs_path = os.path.join(obb_folder, \"tracked_obbs.csv\")\n    reader = ObbCsvReader(obb_csv_path)\n    writer = ObbCsvWriter(tracked_obbs_path)\n    tracker = ObbTracker(\n        track_best=False,\n        track_running_average=True,\n        max_assoc_dist=0.1,\n        max_assoc_iou2=0.0,  # disabled\n        max_assoc_iou3=0.2,\n        prob_inst_thr=prob_inst_thr,\n        prob_assoc_thr=prob_assoc_thr,\n        nms_iou3_thr=0.1,\n        nms_iou2_thr=0.0,  # disabled\n        w_max=30,\n        w_min=5,\n        dt_max_inst=1.0,\n        dt_max_occ=999999.0,  # never delete\n    )\n\n    # write snippet-level tracked obbs\n    for t_ns, obbs in reader:\n        tracked_obbs, unviz_obbs = tracker.track(obbs)\n        # seq_obb_eval use both tracked and unviz obbs\n        all_tracked_obbs = torch.cat([tracked_obbs, unviz_obbs], dim=-2)\n        writer.write(all_tracked_obbs, t_ns, reader.sem_ids_to_names)\n\n    # write scene-level tracked obbs\n    tracked_scene_obbs_path = os.path.join(obb_folder, \"tracked_scene_obbs.csv\")\n    scene_writer = ObbCsvWriter(tracked_scene_obbs_path)\n    final_scene_obbs, unviz_obbs = tracker.obbs_world\n    final_scene_obbs_all = torch.cat([final_scene_obbs, unviz_obbs], dim=-2)\n    scene_writer.write(final_scene_obbs_all, -1, reader.sem_ids_to_names)\n    logger.info(f\"Wrote scene-level tracked obbs to {tracked_scene_obbs_path}\")\n\n\ndef main():\n    import argparse\n\n    parser = argparse.ArgumentParser(description=\"Run Obb tracker on obbs csv file.\")\n    parser.add_argument(\n        \"--input\",\n        type=str,\n        help=\"The input folder to look for the per-snippet obbs csv file\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--prob_inst_thr\",\n        type=float,\n        default=0.3,\n        help=\"minimum probability threshold for instantiating a new world obb\",\n    )\n    parser.add_argument(\n        \"--prob_assoc_thr\",\n        type=float,\n        default=0.25,\n        help=\"minimum probability threshold for associating a new obb with existing world obbs\",\n    )\n\n    args = parser.parse_args()\n    track_obbs(input_path=args.input)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "efm3d/inference/viz.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom bisect import bisect_left\nfrom typing import Optional\n\nimport cv2\nimport numpy as np\nimport torch\nimport tqdm\nimport trimesh\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_IMG_TIME_NS,\n    ARIA_MESH_FACES,\n    ARIA_MESH_VERT_NORMS_W,\n    ARIA_MESH_VERTS_W,\n    ARIA_OBB_PADDED,\n    ARIA_OBB_PRED_VIZ,\n    ARIA_OBB_TRACKED,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.aria.obb import ObbTW\nfrom efm3d.inference.fuse import VolumetricFusion\nfrom efm3d.utils.image import put_text, smart_resize, torch2cv2\nfrom efm3d.utils.obb_csv_writer import ObbCsvReader\nfrom efm3d.utils.render import draw_obbs_snippet\nfrom efm3d.utils.viz import draw_snippet_scene_3d, SceneView\n\n\nVIZ_RGB = \"RGB/GT\"\nVIZ_SLAM = \"SLAM\"\nVIZ_PRED_OBB = \"Snippet Prediction\"\nVIZ_TRACKED_OBB = \"Tracked Prediction\"\nVIZ_GT_OBB = \"Ground Truth\"\n\n\ndef find_nearest(array, value):\n    \"\"\"Find the index of the nearest value in an array.\"\"\"\n    idx = bisect_left(array, value)\n    if idx == len(array):\n        return idx - 1\n    if idx == 0:\n        return 0\n    before = array[idx - 1]\n    after = array[idx]\n    if after - value < value - before:\n        return idx\n    return idx - 1\n\n\ndef fill_obbs_to_snippet(obbs, rgb_ts, T_ws):\n    obbs_out = []\n    obbs_ts = sorted(obbs.keys())\n    for ts in rgb_ts:\n        if ts in obbs:\n            obbs_out.append(obbs[ts].add_padding(128))\n        elif len(obbs_ts) == 0:\n            obbs_out.append(ObbTW().add_padding(128))\n        else:\n            # find the nearest timestamp within 1s\n            nidx = find_nearest(obbs_ts, ts)\n            if abs(obbs_ts[nidx] - ts) / 1e9 < 1:\n                obbs_out.append(obbs[obbs_ts[nidx]].add_padding(128))\n            else:\n                obbs_out.append(ObbTW().add_padding(128))\n    obbs_w = torch.stack(obbs_out, dim=0)\n    obbs_s = obbs_w.transform(T_ws.inverse())\n    return obbs_s\n\n\ndef compose_views(view_dict, keys, vertical=True):\n    \"\"\"stack snippet images into a single image, vertical or horizontal\"\"\"\n    keys = [k for k in keys if k in view_dict]\n    if len(keys) == 0:\n        return None\n    if len(keys) == 1:\n        return view_dict[keys[0]]\n\n    output_imgs = []\n    T = len(view_dict[keys[0]])\n    for i in range(T):\n        img_list = [view_dict[key][i] for key in keys]\n        axis = 0 if vertical else 1\n        combine_img = np.concatenate(img_list, axis=axis)\n        output_imgs.append(combine_img)\n    return output_imgs\n\n\ndef draw_scene_with_mesh_and_obbs(\n    snippet,\n    w,\n    h,\n    scene,\n    snip_obbs=None,\n    tracked_obbs=None,\n    gt_obbs=None,\n    mesh=None,\n    sem_ids_to_names=None,\n):\n    \"\"\"\n    Draw 3d scene view of a snippet, with optionally obbs and mesh.\n    \"\"\"\n    # put pred obbs into the snippet\n    rgb_ts = snippet[ARIA_IMG_TIME_NS[0]]\n    rgb_ts = [ts.item() for ts in rgb_ts]\n    T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET]\n\n    if snip_obbs is not None:\n        snippet[ARIA_OBB_PRED_VIZ] = fill_obbs_to_snippet(snip_obbs, rgb_ts, T_ws)\n    if tracked_obbs is not None:\n        snippet[ARIA_OBB_TRACKED] = fill_obbs_to_snippet(tracked_obbs, rgb_ts, T_ws)\n    if gt_obbs is not None:\n        snippet[ARIA_OBB_PADDED] = fill_obbs_to_snippet(gt_obbs, rgb_ts, T_ws)\n    if mesh is not None and mesh.vertices.shape[0] > 0 and mesh.faces.shape[0] > 0:\n        snippet[ARIA_MESH_VERTS_W] = torch.tensor(mesh.vertices)\n        snippet[ARIA_MESH_FACES] = torch.tensor(mesh.faces)\n        # normals for pred should be minus due to marching cube\n        snippet[ARIA_MESH_VERT_NORMS_W] = -torch.tensor(mesh.vertex_normals)\n\n    scene_imgs = draw_snippet_scene_3d(\n        snippet, sem_ids_to_names=sem_ids_to_names, width=w, height=h, scene=scene\n    )\n    return scene_imgs\n\n\ndef render_views(snippet, h, w, pred_sem_ids_to_names, gt_sem_ids_to_names):\n    Ts_sr = snippet[ARIA_IMG_T_SNIPPET_RIG[0]]\n    cams = snippet[ARIA_CALIB[0]]\n    T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET]\n    Ts_wr = T_ws @ Ts_sr\n    rgb_ts = snippet[ARIA_IMG_TIME_NS[0]]\n    time_s = [f\"{ts.item() * 1e-9:.02f}s\" for ts in rgb_ts]\n\n    imgs = {}\n    # RGB and SLAM\n    rgb_imgs = snippet[ARIA_IMG[0]].clone().numpy()\n    rgb_imgs = [\n        torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False) for im in rgb_imgs\n    ]\n    imgs[VIZ_RGB] = rgb_imgs\n\n    if ARIA_IMG[1] in snippet and ARIA_IMG[2] in snippet:\n        slaml_imgs = snippet[ARIA_IMG[1]].clone().numpy()\n        slamr_imgs = snippet[ARIA_IMG[2]].clone().numpy()\n        slaml_imgs = [\n            torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False)\n            for im in slaml_imgs\n        ]\n        slamr_imgs = [\n            torch2cv2(im, rotate=True, ensure_rgb=True, rgb2bgr=False)\n            for im in slamr_imgs\n        ]\n        imgs[VIZ_SLAM] = []\n        for iml, imr in zip(slaml_imgs, slamr_imgs):\n            imgs[VIZ_SLAM].append(np.concatenate([iml, imr], axis=1))\n\n    if ARIA_OBB_PRED_VIZ in snippet:\n        imgs[VIZ_PRED_OBB] = draw_obbs_snippet(\n            snippet[ARIA_IMG[0]].clone(),\n            snippet[ARIA_OBB_PRED_VIZ].transform(T_ws),\n            Ts_wr,\n            cams,\n            rgb2bgr=False,\n            draw_cosy=False,\n            white_backing_line=False,\n            draw_bb2=False,\n            sem_id_to_name_mapping=pred_sem_ids_to_names,\n            draw_label=True,\n            draw_score=True,\n            prob_threshold=0.001,  # keep this very low, obbs are already thresholded.\n        )\n\n    if ARIA_OBB_TRACKED in snippet:\n        imgs[VIZ_TRACKED_OBB] = draw_obbs_snippet(\n            snippet[ARIA_IMG[0]].clone(),\n            snippet[ARIA_OBB_TRACKED].transform(T_ws),\n            Ts_wr,\n            cams,\n            rgb2bgr=False,\n            draw_cosy=False,\n            white_backing_line=False,\n            draw_bb2=False,\n            sem_id_to_name_mapping=pred_sem_ids_to_names,\n            draw_label=True,\n            draw_score=True,\n            prob_threshold=0.001,  # keep this very low, obbs are already thresholded.\n        )\n\n    if ARIA_OBB_PADDED in snippet:\n        # if gt obb (VIZ_GT_OBB) is present, overlay it on top of the RGB view\n        imgs[VIZ_RGB] = draw_obbs_snippet(\n            snippet[ARIA_IMG[0]].clone(),\n            snippet[ARIA_OBB_PADDED].transform(T_ws),\n            Ts_wr,\n            cams,\n            rgb2bgr=False,\n            draw_cosy=False,\n            white_backing_line=False,\n            draw_bb2=False,\n            sem_id_to_name_mapping=gt_sem_ids_to_names,\n            draw_label=True,\n            draw_inst_id=True,\n            draw_score=True,\n        )\n\n    # add text to the images\n    for text, grid_imgs in imgs.items():\n        for i, img in enumerate(grid_imgs):\n            img = smart_resize(img, h, w, pad_image=True)\n            img = put_text(img, text)\n            imgs[text][i] = put_text(img, time_s[i], line=-1)\n    return imgs\n\n\ndef generate_video(\n    streamer,\n    output_dir,\n    fps=10,\n    vol_fusion: Optional[VolumetricFusion] = None,\n    stride_s: float = 0.1,\n):\n    \"\"\"\n    streamer: the data iterator, assuming input snippets are 1s at 10 FPS.\n    output_dir: the output folder for the video, will also load obbs and per_snip artifacts from the same folder\n    fps: the output video fps\n    vol_fusion: A volumetric fusion class instance. If not None, will use it to show the incremental mesh, updated as 1s frame rate.\n    \"\"\"\n\n    # read snippet obbs\n    snip_obbs_csv = os.path.join(output_dir, \"snippet_obbs.csv\")\n    snip_obbs = None\n    sem_ids_to_names = None\n    if os.path.exists(snip_obbs_csv):\n        snip_obb_reader = ObbCsvReader(snip_obbs_csv)\n        snip_obbs = snip_obb_reader.obbs\n        sem_ids_to_names = snip_obb_reader.sem_ids_to_names\n\n    # read tracked obbs\n    tracked_obbs_csv = os.path.join(output_dir, \"tracked_obbs.csv\")\n    tracked_obbs = None\n    if os.path.exists(tracked_obbs_csv):\n        tracked_obb_reader = ObbCsvReader(tracked_obbs_csv)\n        tracked_obbs = tracked_obb_reader.obbs\n\n    # read GT obbs\n    gt_obbs_csv = os.path.join(output_dir, \"gt_obbs.csv\")\n    gt_obbs = None\n    gt_sem_ids_to_names = None\n    if os.path.exists(gt_obbs_csv):\n        gt_obb_reader = ObbCsvReader(gt_obbs_csv)\n        gt_obbs = gt_obb_reader.obbs\n        gt_sem_ids_to_names = gt_obb_reader.sem_ids_to_names\n\n    # read fused mesh\n    fused_mesh = os.path.join(output_dir, \"fused_mesh.ply\")\n    pred_mesh = None\n    if os.path.exists(fused_mesh):\n        pred_mesh = trimesh.load(fused_mesh)\n\n    # write video\n    fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n    output_path = os.path.join(output_dir, \"video.mp4\")\n\n    # two columns for 2d views (RGB+SLAM, output), 1 column for 3d scene\n    gW, gH = 360, 360  # 2d grid size\n    sH = 2 * gH\n    sW = sH\n    W = sW + 2 * gW\n    H = sH\n\n    out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))\n    scene = SceneView(width=sW, height=sH)\n    num_snip_per_s = int(1.0 / stride_s)\n    for idx, snippet in tqdm.tqdm(enumerate(streamer), total=len(streamer)):\n        # show incremental fusion if vol_fusion is given\n        if vol_fusion is not None:\n            for i in range(num_snip_per_s):\n                vol_fusion.run_step(idx * num_snip_per_s + i)\n            pred_mesh = vol_fusion.get_trimesh()\n\n        scene_imgs = draw_scene_with_mesh_and_obbs(\n            snippet,\n            w=sW,\n            h=sH,\n            scene=scene,\n            snip_obbs=snip_obbs,\n            tracked_obbs=tracked_obbs,\n            gt_obbs=gt_obbs,\n            mesh=pred_mesh,\n            sem_ids_to_names=sem_ids_to_names,\n        )\n        view_imgs = render_views(\n            snippet,\n            gH,\n            gW,\n            pred_sem_ids_to_names=sem_ids_to_names,\n            gt_sem_ids_to_names=gt_sem_ids_to_names,\n        )\n\n        input_col = compose_views(view_imgs, [VIZ_RGB, VIZ_SLAM])\n        output_col = compose_views(view_imgs, [VIZ_PRED_OBB, VIZ_TRACKED_OBB])\n\n        for i, scene_img in enumerate(scene_imgs):\n            final_img = np.zeros((H, W, 3), dtype=np.uint8)  # black background\n            h, w = input_col[i].shape[:2]\n            final_img[:h, :w] = input_col[i]\n            final_img[:sH, gW : gW + sW, :] = scene_img\n            if output_col is not None:\n                h, w = output_col[i].shape[:2]\n                final_img[:h, gW + sW : gW + sW + w] = output_col[i]\n\n            out.write(final_img[:, :, ::-1])  # convert rgb to bgr before writing\n    out.release()\n    return output_path\n"
  },
  {
    "path": "efm3d/model/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/model/cnn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport einops\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef cnn_weight_initialization(modules):\n    for m in modules:\n        if isinstance(m, nn.Conv2d):\n            nn.init.kaiming_uniform_(m.weight.data, nonlinearity=\"relu\")\n            if m.bias is not None:\n                nn.init.constant_(m.bias.data, 0)\n        elif isinstance(m, nn.BatchNorm2d):\n            nn.init.constant_(m.weight.data, 1)\n            nn.init.constant_(m.bias.data, 0)\n        elif isinstance(m, nn.Linear):\n            nn.init.kaiming_uniform_(m.weight.data)\n            nn.init.constant_(m.bias.data, 0)\n\n\nclass GELU(nn.Module):\n    def forward(self, x):\n        return F.gelu(x)\n\n\nclass LayerNorm2d(nn.LayerNorm):\n    \"\"\"LayerNorm for channels of '2D' spatial NCHW tensors, taken from\n    https://github.com/huggingface/pytorch-image-models/blob/d7b55a9429f3d56a991e604cbc2e9fdf1901612f/timm/models/layers/norm.py#L26\n    \"\"\"\n\n    def __init__(self, num_channels, eps=1e-6, affine=True):\n        super().__init__(num_channels, eps=eps, elementwise_affine=affine)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.layer_norm(\n            x.permute(0, 2, 3, 1),\n            self.normalized_shape,\n            self.weight,\n            self.bias,\n            self.eps,\n        ).permute(0, 3, 1, 2)\n\n\nclass UpsampleCNN(nn.Module):\n    def __init__(\n        self,\n        input_dim: int = 3,\n        first_hidden_dim: int = 32,\n        final_dim: int = 1,\n        upsample_power: int = 4,\n        fix_hidden_dim: bool = True,\n    ):\n        \"\"\"\n        Upsample a feature map by a given factor = 2^upsample_power\n\n        Args:\n            input_dim (int): number of input channels\n            first_hidden_dim (int): the first hidden layer output dimension. If set to -1, we use the input dimension.\n            final_dim (int): number of output channels\n            upsample_power (int): 2^upsample_power is the factor of image resolution upsampling\n            fix_hidden_dim (bool): if True, all layers have the same hidden dims. Otherwise, hidden dims are subsequently halved by 2x starting from first_hidden_dim\n        \"\"\"\n        super(UpsampleCNN, self).__init__()\n        assert upsample_power <= 4, \"only upsampling power <= 4 is supported\"\n\n        if fix_hidden_dim:\n            # all layers have the same hidden dims\n            c = [first_hidden_dim] * (upsample_power + 1)\n        else:\n            first_hidden_dim = first_hidden_dim if first_hidden_dim > 0 else input_dim\n            assert first_hidden_dim // 2 ** (upsample_power) >= 1, (\n                f\"first_hidden_dim must be at least {2 ** (upsample_power)}, but got {first_hidden_dim}.\"\n            )\n            # subsequently halve the hidden dim by 2x\n            c = [first_hidden_dim] + [\n                first_hidden_dim // 2 ** (i + 1) for i in range(upsample_power)\n            ]\n\n        self.conv1 = nn.Conv2d(input_dim, c[0], kernel_size=3, stride=1, padding=1)\n        self.bn1 = nn.BatchNorm2d(c[0])\n\n        if upsample_power >= 1:\n            self.conv1u = nn.Conv2d(c[0], c[1], kernel_size=3, stride=1, padding=1)\n            self.bn1u = nn.BatchNorm2d(c[1])\n        if upsample_power >= 2:\n            self.conv2u = nn.Conv2d(c[1], c[2], kernel_size=3, stride=1, padding=1)\n            self.bn2u = nn.BatchNorm2d(c[2])\n        if upsample_power >= 3:\n            self.conv3u = nn.Conv2d(c[2], c[3], kernel_size=3, stride=1, padding=1)\n            self.bn3u = nn.BatchNorm2d(c[3])\n        if upsample_power >= 4:\n            self.conv4u = nn.Conv2d(c[3], c[4], kernel_size=3, stride=1, padding=1)\n            self.bn4u = nn.BatchNorm2d(c[4])\n        self.conv_final = nn.Conv2d(\n            c[-1], final_dim, kernel_size=1, stride=1, padding=0\n        )\n        self.relu = nn.ReLU(inplace=True)\n        self.upsample = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n        self.upsample_power = upsample_power\n        cnn_weight_initialization(self.modules())\n\n        print(f\"==> [UpsampleCNN]: intialized with hidden layers: {c}\")\n\n    def forward(self, x, force_hw=None):\n        \"\"\"\n        Inputs:\n            x : torch.Tensor : Bx(T)xCxhxw tensor\n            force_hw: (int, int) : tuple of ints of height and width to be forced upsampled to\n        Returns:\n            x : torch.Tensor: Upsampled to Bx(T)xCxHxW, where H = h*(upsample_power**2) and W = w*(upsample_power**2)\n        \"\"\"\n        ndim = x.ndim\n        if ndim == 5:\n            T = x.shape[1]\n            x = einops.rearrange(x, \"b t c h w -> (b t) c h w\")\n        x = self.relu(self.bn1(self.conv1(x)))\n        if self.upsample_power >= 1:\n            x = self.upsample(x)\n            x = self.relu(self.bn1u(self.conv1u(x)))\n        if self.upsample_power >= 2:\n            x = self.upsample(x)\n            x = self.relu(self.bn2u(self.conv2u(x)))\n        if self.upsample_power >= 3:\n            x = self.upsample(x)\n            x = self.relu(self.bn3u(self.conv3u(x)))\n        if self.upsample_power >= 4:\n            x = self.upsample(x)\n            x = self.relu(self.bn4u(self.conv4u(x)))\n\n        # Force upsampling, useful for patch_size=14 ViTs for example.\n        if force_hw is not None and (\n            x.shape[-2] != force_hw[0] or x.shape[-1] != force_hw[1]\n        ):\n            x = torch.nn.functional.interpolate(x, size=force_hw, mode=\"bilinear\")\n\n        x = self.conv_final(x)\n\n        if ndim == 5:\n            x = einops.rearrange(x, \"(b t) c h w -> b t c h w\", t=T)\n        return x\n\n\nclass LayerNorm3d(nn.LayerNorm):\n    \"\"\"LayerNorm for channels of '3D' spatial NCDHW tensors, taken from\n    https://github.com/huggingface/pytorch-image-models/blob/d7b55a9429f3d56a991e604cbc2e9fdf1901612f/timm/models/layers/norm.py#L26\n    \"\"\"\n\n    def __init__(self, num_channels, eps=1e-6, affine=True):\n        super().__init__(num_channels, eps=eps, elementwise_affine=affine)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return F.layer_norm(\n            x.permute(0, 2, 3, 4, 1),  # NCHW\n            self.normalized_shape,\n            self.weight,\n            self.bias,\n            self.eps,\n        ).permute(0, 4, 1, 2, 3)\n\n\nclass UpConv3d(torch.nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n\n        self.upsample = torch.nn.Upsample(\n            scale_factor=2, mode=\"trilinear\", align_corners=True\n        )\n        self.cnn_up = torch.nn.Conv3d(dim_in, dim_out, 3, stride=1, padding=1)\n\n        self.norm = torch.nn.BatchNorm3d(dim_out)\n        cnn_weight_initialization(self.modules())\n\n    def forward(self, x_up):\n        assert x_up.shape[1] == self.dim_in, f\"{x_up.shape}, {self.dim_in}\"\n        x_up = self.upsample(x_up)\n        return self.norm(self.cnn_up(x_up))\n\n\nclass FpnUpConv3d(torch.nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n\n        self.upsample = torch.nn.Upsample(\n            scale_factor=2, mode=\"trilinear\", align_corners=True\n        )\n        self.cnn_up = torch.nn.Conv3d(dim_in, dim_out, 3, stride=1, padding=1)\n        self.cnn_lat = torch.nn.Conv3d(dim_out, dim_in, 1)\n\n        self.norm = torch.nn.BatchNorm3d(dim_out)\n        cnn_weight_initialization(self.modules())\n\n    def forward(self, x_up, x_lat):\n        assert x_up.shape[1] == self.dim_in, f\"{x_up.shape}, {self.dim_in}\"\n        assert x_lat.shape[1] == self.dim_out, f\"{x_lat.shape}, {self.dim_out}\"\n        x_up = self.upsample(x_up)\n        x_lat = self.cnn_lat(x_lat)\n        return self.norm(self.cnn_up(x_up + x_lat))\n\n\nclass InvBottleNeck3d(torch.nn.Module):\n    def __init__(self, dim_in, dim_out, stride: int = 1, expansion: float = 1.0):\n        super().__init__()\n        self.dim_hidden = int(math.floor(dim_in * expansion))\n        self.stride = stride\n        self.dim_in = dim_in\n        self.dim_out = dim_out\n\n        self.relu = torch.nn.ReLU()\n        self.norm = torch.nn.BatchNorm3d(self.dim_out)\n\n        self.cnn1 = torch.nn.Conv3d(dim_in, self.dim_hidden, 1)\n        self.cnn2 = torch.nn.Conv3d(\n            self.dim_hidden, self.dim_hidden, 3, stride=stride, padding=1\n        )\n        self.cnn3 = torch.nn.Conv3d(self.dim_hidden, dim_out, 1)\n        cnn_weight_initialization(self.modules())\n\n    def forward(self, x):\n        y = self.relu(self.cnn1(x))\n        y = self.relu(self.cnn2(y))\n        y = self.cnn3(y)\n        if self.stride != 1 or self.dim_in != self.dim_out:\n            return self.norm(y)\n        return self.norm(y + x)\n\n\nclass InvResnetBlock3d(torch.nn.Module):\n    def __init__(\n        self, dim_in, dim_out, num_bottles, in_stride: int = 1, expansion: float = 1.0\n    ):\n        super().__init__()\n        self.inv_bottles = torch.nn.ModuleList(\n            [InvBottleNeck3d(dim_in, dim_out, in_stride, expansion)]\n        )\n        for _ in range(1, num_bottles):\n            self.inv_bottles.append(InvBottleNeck3d(dim_out, dim_out, 1, expansion))\n\n        self.num_bottles = num_bottles\n\n    def forward(self, x):\n        for i in range(self.num_bottles):\n            x = self.inv_bottles[i](x)\n        return x\n\n\nclass InvResnetFpn3d(torch.nn.Module):\n    def __init__(self, dims, num_bottles, strides, expansions, freeze=False):\n        super().__init__()\n\n        assert len(dims) == len(num_bottles) + 1\n        assert len(dims) == len(strides) + 1\n        assert len(dims) == len(expansions) + 1\n        assert strides[0] == 1\n        assert all([s == 2 for s in strides[1:]])\n\n        self.block1 = InvResnetBlock3d(\n            dims[0], dims[1], num_bottles[0], strides[0], expansions[0]\n        )\n        self.block2 = InvResnetBlock3d(\n            dims[1], dims[2], num_bottles[1], strides[1], expansions[1]\n        )\n        self.block3 = InvResnetBlock3d(\n            dims[2], dims[3], num_bottles[2], strides[2], expansions[2]\n        )\n        self.block4 = InvResnetBlock3d(\n            dims[3], dims[4], num_bottles[3], strides[3], expansions[3]\n        )\n        self.fpn1 = FpnUpConv3d(dims[2], dims[1])\n        self.fpn2 = FpnUpConv3d(dims[3], dims[2])\n        self.fpn3 = FpnUpConv3d(dims[4], dims[3])\n\n        if freeze:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.eval()\n\n    def forward(self, x):\n        x1 = self.block1(x)\n        x2 = self.block2(x1)\n        x3 = self.block3(x2)\n        x = self.block4(x3)\n\n        x = self.fpn3(x, x3)\n        del x3\n        x = self.fpn2(x, x2)\n        del x2\n        x = self.fpn1(x, x1)\n        del x1\n        return x\n\n\nclass VolumeCNN(nn.Module):\n    \"\"\"A 3d UNet structure with take in a `hidden_dims` vector (e.g. [c0, c1, c2, c3],\n    c0 <= c1 <= c2 <= c3). It outputs a shared feature layer with ReLU and BN applied.\n    The shape on the channel dimension looks like c0->c1->c2->c3->c2->c1->c0.\n    \"\"\"\n\n    def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False):\n        super(VolumeCNN, self).__init__()\n        self.relu = nn.ReLU(inplace=True)\n        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.upsample = nn.Upsample(\n            scale_factor=2, mode=\"trilinear\", align_corners=True\n        )\n\n        c0, c1, c2, c3 = tuple(hidden_dims)\n        self.conv1 = conv3(c0, c1, kernel_size=3, stride=1, padding=1)\n        self.conv2 = conv3(c1, c2, kernel_size=3, stride=1, padding=1)\n        self.conv3 = conv3(c2, c3, kernel_size=3, stride=1, padding=1)\n        self.conv2u = conv3(c2 + c3, c2, kernel_size=3, stride=1, padding=1)\n        self.conv1u = conv3(c1 + c2, c1, kernel_size=3, stride=1, padding=1)\n        self.bn1 = nn.BatchNorm3d(c1)\n        self.bn2 = nn.BatchNorm3d(c2)\n        self.bn3 = nn.BatchNorm3d(c3)\n        self.bn2u = nn.BatchNorm3d(c2)\n        self.bn1u = nn.BatchNorm3d(c1)\n\n        cnn_weight_initialization(self.modules())\n        self.out_dim = c1\n\n        if freeze:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.eval()\n\n    def forward(self, x):\n        # Simple U-Net like structure.\n        conv1 = self.relu(self.bn1(self.conv1(x)))\n        x = self.pool(conv1)\n        conv2 = self.relu(self.bn2(self.conv2(x)))\n        x = self.pool(conv2)\n        x = self.relu(self.bn3(self.conv3(x)))\n        x = self.upsample(x)\n        x = torch.cat([x, conv2], dim=1)\n        x = self.relu(self.bn2u(self.conv2u(x)))\n        x = self.upsample(x)\n        x = torch.cat([x, conv1], dim=1)\n        x = self.relu(self.bn1u(self.conv1u(x)))\n        return x\n\n\nclass VolumeCNNHead(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        hidden_dim,\n        final_dim,\n        num_layers=2,\n        name=\"\",\n        bias=None,\n        freeze=False,\n    ):\n        super(VolumeCNNHead, self).__init__()\n        self.num_layers = num_layers\n        self.relu = nn.ReLU(inplace=True)\n\n        assert num_layers in [2, 3, 4], f\"num_layers {num_layers} must be 2, 3, or 4\"\n\n        # first conv layer is the same for all num_layers = {2,3,4}\n        self.conv1 = torch.nn.Conv3d(\n            input_dim, hidden_dim, kernel_size=3, stride=1, padding=1\n        )\n        self.bn1 = nn.BatchNorm3d(hidden_dim)\n\n        if num_layers == 2:\n            self.conv2 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1)\n        elif num_layers == 3:\n            self.conv2 = torch.nn.Conv3d(\n                hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1\n            )\n            self.conv3 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1)\n            self.bn2 = nn.BatchNorm3d(hidden_dim)\n        elif num_layers == 4:\n            self.conv2 = torch.nn.Conv3d(\n                hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1\n            )\n            self.conv3 = torch.nn.Conv3d(\n                hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1\n            )\n            self.conv4 = torch.nn.Conv3d(hidden_dim, final_dim, kernel_size=1)\n            self.bn2 = nn.BatchNorm3d(hidden_dim)\n            self.bn3 = nn.BatchNorm3d(hidden_dim)\n\n        cnn_weight_initialization(self.modules())\n        model_msg = f\"==> Init {num_layers}-layer 3DCNN with {hidden_dim} hidden dims and {final_dim} outputs\"\n        if name:\n            model_msg += f\", with name {name}\"\n        print(model_msg)\n\n        if bias:\n            print(\"overwriting bias to %f\" % bias)\n            self.conv2.bias.data.fill_(bias)\n\n        if freeze:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.eval()\n\n    def forward(self, x):\n        x = self.relu(self.bn1(self.conv1(x)))\n\n        if self.num_layers == 2:\n            x = self.conv2(x)\n        elif self.num_layers == 3:\n            x = self.relu(self.bn2(self.conv2(x)))\n            x = self.conv3(x)\n        elif self.num_layers == 4:\n            x = self.relu(self.bn2(self.conv2(x)))\n            x = self.relu(self.bn3(self.conv3(x)))\n            x = self.conv4(x)\n        return x\n\n\nclass ResidualConvUnit3d(nn.Module):\n    # From \"Vision Transformers for Dense Prediction\": https://arxiv.org/abs/2103.13413\n    # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py\n    def __init__(self, features, kernel_size):\n        super().__init__()\n        assert kernel_size % 1 == 0, \"Kernel size needs to be odd\"\n        padding = kernel_size // 2\n        self.conv = nn.Sequential(\n            nn.Conv3d(features, features, kernel_size, padding=padding),\n            nn.ReLU(True),\n            nn.Conv3d(features, features, kernel_size, padding=padding),\n            nn.ReLU(True),\n        )\n\n    def forward(self, x):\n        return self.conv(x) + x\n\n\nclass FeatureFusionBlock3d(nn.Module):\n    # Fro \"Vision Transformers for Dense Prediction\": https://arxiv.org/abs/2103.13413\n    # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py\n    def __init__(self, features, kernel_size, with_skip=True):\n        super().__init__()\n        self.with_skip = with_skip\n        if self.with_skip:\n            self.resConfUnit1 = ResidualConvUnit3d(features, kernel_size)\n\n        self.resConfUnit2 = ResidualConvUnit3d(features, kernel_size)\n\n    def forward(self, x, skip_x=None):\n        if skip_x is not None:\n            assert self.with_skip and skip_x.shape == x.shape\n            x = self.resConfUnit1(x) + skip_x\n\n        x = self.resConfUnit2(x)\n        return x\n\n\nclass VolumeResnet(nn.Module):\n    def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False):\n        super(VolumeResnet, self).__init__()\n        self.relu = nn.ReLU(inplace=True)\n        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)\n        self.upsample = nn.Upsample(\n            scale_factor=2, mode=\"trilinear\", align_corners=True\n        )\n\n        c0, c1, c2, c3 = tuple(hidden_dims)\n        self.resconv1 = ResidualConvUnit3d(c0, kernel_size=3)\n        self.conv1 = conv3(c0, c1, kernel_size=3, stride=1, padding=1)\n\n        self.resconv2 = ResidualConvUnit3d(c1, kernel_size=3)\n        self.conv2 = conv3(c1, c2, kernel_size=3, stride=1, padding=1)\n\n        self.resconv3 = ResidualConvUnit3d(c2, kernel_size=3)\n        self.conv3 = conv3(c2, c3, kernel_size=3, stride=1, padding=1)\n\n        self.conv2u = conv3(c2 + c3, c2, kernel_size=3, stride=1, padding=1)\n        self.conv1u = conv3(c1 + c2, c1, kernel_size=3, stride=1, padding=1)\n        self.bn1 = nn.BatchNorm3d(c1)\n        self.bn2 = nn.BatchNorm3d(c2)\n        self.bn3 = nn.BatchNorm3d(c3)\n        self.bn2u = nn.BatchNorm3d(c2)\n        self.bn1u = nn.BatchNorm3d(c1)\n\n        cnn_weight_initialization(self.modules())\n        self.out_dim = c1\n\n        if freeze:\n            for param in self.parameters():\n                param.requires_grad = False\n            self.eval()\n\n    def forward(self, x):\n        # Simple U-Net like structure.\n        conv1 = self.relu(self.bn1(self.conv1(self.resconv1(x))))\n        x = self.pool(conv1)\n        conv2 = self.relu(self.bn2(self.conv2(self.resconv2(x))))\n        x = self.pool(conv2)\n        x = self.relu(self.bn3(self.conv3(self.resconv3(x))))\n        x = self.upsample(x)\n        x = torch.cat([x, conv2], dim=1)\n        x = self.relu(self.bn2u(self.conv2u(x)))\n        x = self.upsample(x)\n        x = torch.cat([x, conv1], dim=1)\n        x = self.relu(self.bn1u(self.conv1u(x)))\n        return x\n"
  },
  {
    "path": "efm3d/model/dinov2_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n#   https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py\n\nimport logging\nimport math\nimport os\nfrom functools import partial\nfrom typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms as T\nfrom torch import Tensor\nfrom torch.nn.init import trunc_normal_\nfrom torch.nn.utils import weight_norm\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ntry:\n    from xformers.ops import (\n        fmha,\n        index_select_cat,\n        memory_efficient_attention,\n        scaled_index_add,\n        unbind,\n    )\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    logger.warning(\"xFormers not available\")\n    XFORMERS_AVAILABLE = False\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        B, N, C = x.shape\n        qkv = (\n            self.qkv(x)\n            .reshape(B, N, 3, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n\n        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]\n        attn = q @ k.transpose(-2, -1)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            assert attn_bias is None, \"xFormers is required for nested tensors usage\"\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass CrossAttention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor:\n        B, N, C = x_kv.shape\n        kv = (\n            self.kv(x_kv)\n            .reshape(B, N, 2, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n        B, N, C = x_q.shape\n        q = (\n            self.q(x_q)\n            .reshape(B, N, 1, self.num_heads, C // self.num_heads)\n            .permute(2, 0, 3, 1, 4)\n        )\n\n        q, k, v = q[0] * self.scale, kv[0], kv[1]\n        attn = q @ k.transpose(-2, -1)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffCrossAttention(CrossAttention):\n    def forward(self, x_kv: Tensor, x_q: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            assert attn_bias is None, \"xFormers is required for nested tensors usage\"\n            return super().forward(x_kv, x_q)\n\n        B, N, C = x_kv.shape\n        kv = self.kv(x_kv).reshape(B, N, 2, self.num_heads, C // self.num_heads)\n        k, v = unbind(kv, 2)\n\n        B, N, C = x_q.shape\n        q = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n    ) -> None:\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        self.ls1 = (\n            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        )\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = (\n            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        )\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x: Tensor) -> Tensor:\n        def attn_residual_func(x: Tensor) -> Tensor:\n            return self.ls1(self.attn(self.norm1(x)))\n\n        def ffn_residual_func(x: Tensor) -> Tensor:\n            return self.ls2(self.mlp(self.norm2(x)))\n\n        if self.training and self.sample_drop_ratio > 0.1:\n            # the overhead is compensated only for a drop path rate larger than 0.1\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n        elif self.training and self.sample_drop_ratio > 0.0:\n            x = x + self.drop_path1(attn_residual_func(x))\n            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2\n        else:\n            x = x + attn_residual_func(x)\n            x = x + ffn_residual_func(x)\n        return x\n\n\nclass CrossBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n    ) -> None:\n        super().__init__()\n        self.norm1_q = norm_layer(dim)\n        self.norm1_kv = norm_layer(dim)\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        self.ls1 = (\n            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        )\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = (\n            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        )\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor:\n        def attn_residual_func(x_kv: Tensor, x_q: Tensor) -> Tensor:\n            return self.ls1(self.attn(self.norm1_kv(x_kv), self.norm1_q(x_q)))\n\n        def ffn_residual_func(x_q: Tensor) -> Tensor:\n            return self.ls2(self.mlp(self.norm2(x_q)))\n\n        if self.training and self.sample_drop_ratio > 0.0:\n            x_q = x_q + self.drop_path1(attn_residual_func(x_kv, x_q))\n            x_q = x_q + self.drop_path2(ffn_residual_func(x_q))\n        else:\n            x_q = x_q + attn_residual_func(x_kv, x_q)\n            x_q = x_q + ffn_residual_func(x_q)\n        return x_q\n\n\ndef drop_add_residual_stochastic_depth(\n    x: Tensor,\n    residual_func: Callable[[Tensor], Tensor],\n    sample_drop_ratio: float = 0.0,\n) -> Tensor:\n    # 1) extract subset using permutation\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    x_subset = x[brange]\n\n    # 2) apply residual_func to get residual\n    residual = residual_func(x_subset)\n\n    x_flat = x.flatten(1)\n    residual = residual.flatten(1)\n\n    residual_scale_factor = b / sample_subset_size\n\n    # 3) add the residual\n    x_plus_residual = torch.index_add(\n        x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor\n    )\n    return x_plus_residual.view_as(x)\n\n\ndef get_branges_scales(x, sample_drop_ratio=0.0):\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    residual_scale_factor = b / sample_subset_size\n    return brange, residual_scale_factor\n\n\ndef add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):\n    if scaling_vector is None:\n        x_flat = x.flatten(1)\n        residual = residual.flatten(1)\n        x_plus_residual = torch.index_add(\n            x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor\n        )\n    else:\n        x_plus_residual = scaled_index_add(\n            x,\n            brange,\n            residual.to(dtype=x.dtype),\n            scaling=scaling_vector,\n            alpha=residual_scale_factor,\n        )\n    return x_plus_residual\n\n\nattn_bias_cache: Dict[Tuple, Any] = {}\n\n\ndef get_attn_bias_and_cat(x_list, branges=None):\n    \"\"\"\n    this will perform the index select, cat the tensors, and provide the attn_bias from cache\n    \"\"\"\n    batch_sizes = (\n        [b.shape[0] for b in branges]\n        if branges is not None\n        else [x.shape[0] for x in x_list]\n    )\n    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))\n    if all_shapes not in attn_bias_cache.keys():\n        seqlens = []\n        for b, x in zip(batch_sizes, x_list):\n            for _ in range(b):\n                seqlens.append(x.shape[1])\n        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)\n        attn_bias._batch_sizes = batch_sizes\n        attn_bias_cache[all_shapes] = attn_bias\n\n    if branges is not None:\n        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(\n            1, -1, x_list[0].shape[-1]\n        )\n    else:\n        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)\n        cat_tensors = torch.cat(tensors_bs1, dim=1)\n\n    return attn_bias_cache[all_shapes], cat_tensors\n\n\ndef drop_add_residual_stochastic_depth_list(\n    x_list: List[Tensor],\n    residual_func: Callable[[Tensor, Any], Tensor],\n    sample_drop_ratio: float = 0.0,\n    scaling_vector=None,\n) -> Tensor:\n    # 1) generate random set of indices for dropping samples in the batch\n    branges_scales = [\n        get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list\n    ]\n    branges = [s[0] for s in branges_scales]\n    residual_scale_factors = [s[1] for s in branges_scales]\n\n    # 2) get attention bias and index+concat the tensors\n    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)\n\n    # 3) apply residual_func to get residual, and split the result\n    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore\n\n    outputs = []\n    for x, brange, residual, residual_scale_factor in zip(\n        x_list, branges, residual_list, residual_scale_factors\n    ):\n        outputs.append(\n            add_residual(\n                x, brange, residual, residual_scale_factor, scaling_vector\n            ).view_as(x)\n        )\n    return outputs\n\n\nclass NestedTensorBlock(Block):\n    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:\n        \"\"\"\n        x_list contains a list of tensors to nest together and run\n        \"\"\"\n        assert isinstance(self.attn, MemEffAttention)\n\n        if self.training and self.sample_drop_ratio > 0.0:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.attn(self.norm1(x), attn_bias=attn_bias)\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.mlp(self.norm2(x))\n\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=(\n                    self.ls1.gamma if isinstance(self.ls1, LayerScale) else None\n                ),\n            )\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=(\n                    self.ls2.gamma if isinstance(self.ls1, LayerScale) else None\n                ),\n            )\n            return x_list\n        else:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls2(self.mlp(self.norm2(x)))\n\n            attn_bias, x = get_attn_bias_and_cat(x_list)\n            x = x + attn_residual_func(x, attn_bias=attn_bias)\n            x = x + ffn_residual_func(x)\n            return attn_bias.split(x)\n\n    def forward(self, x_or_x_list):\n        if isinstance(x_or_x_list, Tensor):\n            return super().forward(x_or_x_list)\n        elif isinstance(x_or_x_list, list):\n            assert XFORMERS_AVAILABLE, (\n                \"Please install xFormers for nested tensors usage\"\n            )\n            return self.forward_nested(x_or_x_list)\n        else:\n            raise AssertionError\n\n\nclass DINOHead(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        use_bn=False,\n        nlayers=3,\n        hidden_dim=2048,\n        bottleneck_dim=256,\n        mlp_bias=True,\n    ):\n        super().__init__()\n        nlayers = max(nlayers, 1)\n        self.mlp = _build_mlp(\n            nlayers,\n            in_dim,\n            bottleneck_dim,\n            hidden_dim=hidden_dim,\n            use_bn=use_bn,\n            bias=mlp_bias,\n        )\n        self.apply(self._init_weights)\n        self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))\n        self.last_layer.weight_g.data.fill_(1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.mlp(x)\n        eps = 1e-6 if x.dtype == torch.float16 else 1e-12\n        x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)\n        x = self.last_layer(x)\n        return x\n\n\ndef _build_mlp(\n    nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True\n):\n    if nlayers == 1:\n        return nn.Linear(in_dim, bottleneck_dim, bias=bias)\n    else:\n        layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]\n        if use_bn:\n            layers.append(nn.BatchNorm1d(hidden_dim))\n        layers.append(nn.GELU())\n        for _ in range(nlayers - 2):\n            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))\n            if use_bn:\n                layers.append(nn.BatchNorm1d(hidden_dim))\n            layers.append(nn.GELU())\n        layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))\n        return nn.Sequential(*layers)\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (\n        x.ndim - 1\n    )  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0:\n        random_tensor.div_(keep_prob)\n    output = x * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n\nclass LayerScale(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        init_values: Union[float, Tensor] = 1e-5,\n        inplace: bool = False,\n    ) -> None:\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x: Tensor) -> Tensor:\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\ndef make_2tuple(x):\n    if isinstance(x, tuple):\n        assert len(x) == 2\n        return x\n\n    assert isinstance(x, int)\n    return (x, x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    2D image to patch embedding: (B,C,H,W) -> (B,N,D)\n\n    Args:\n        img_size: Image size.\n        patch_size: Patch token size.\n        in_chans: Number of input image channels.\n        embed_dim: Number of linear projection output channels.\n        norm_layer: Normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Union[int, Tuple[int, int]] = 224,\n        patch_size: Union[int, Tuple[int, int]] = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        norm_layer: Optional[Callable] = None,\n        flatten_embedding: bool = True,\n    ) -> None:\n        super().__init__()\n\n        image_HW = make_2tuple(img_size)\n        patch_HW = make_2tuple(patch_size)\n        patch_grid_size = (\n            image_HW[0] // patch_HW[0],\n            image_HW[1] // patch_HW[1],\n        )\n\n        self.img_size = image_HW\n        self.patch_size = patch_HW\n        self.patches_resolution = patch_grid_size\n        self.num_patches = patch_grid_size[0] * patch_grid_size[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.flatten_embedding = flatten_embedding\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW\n        )\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x: Tensor) -> Tensor:\n        _, _, H, W = x.shape\n        patch_H, patch_W = self.patch_size\n\n        if H % patch_H > 0 or W % patch_W > 0:\n            H_new = math.ceil(H / patch_H) * patch_H\n            W_new = math.ceil(W / patch_W) * patch_W\n            x = F.interpolate(\n                x, size=(H_new, W_new), mode=\"bilinear\", align_corners=False\n            )\n\n        x = self.proj(x)  # B C H W\n        H, W = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)  # B HW C\n        x = self.norm(x)\n        if not self.flatten_embedding:\n            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C\n        return x\n\n    def flops(self) -> float:\n        Ho, Wo = self.patches_resolution\n        flops = (\n            Ho\n            * Wo\n            * self.embed_dim\n            * self.in_chans\n            * (self.patch_size[0] * self.patch_size[1])\n        )\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass SwiGLUFFN(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)\n        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x12 = self.w12(x)\n        x1, x2 = x12.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n        return self.w3(hidden)\n\n\ntry:\n    from xformers.ops import SwiGLU\n\n    XFORMERS_AVAILABLE = True\nexcept ImportError:\n    SwiGLU = SwiGLUFFN\n    XFORMERS_AVAILABLE = False\n\n\nclass SwiGLUFFNFused(SwiGLU):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8\n        super().__init__(\n            in_features=in_features,\n            hidden_features=hidden_features,\n            out_features=out_features,\n            bias=bias,\n        )\n\n\ndef named_apply(\n    fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False\n) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(\n            fn=fn,\n            module=child_module,\n            name=child_name,\n            depth_first=depth_first,\n            include_root=True,\n        )\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass BlockChunk(nn.ModuleList):\n    def forward(self, x):\n        for b in self:\n            x = b(x)\n        return x\n\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        ffn_bias=True,\n        proj_bias=True,\n        drop_path_rate=0.0,\n        drop_path_uniform=False,\n        init_values=None,  # for layerscale: None or 0 => no layerscale\n        embed_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        block_fn=Block,\n        ffn_layer=\"mlp\",\n        block_chunks=1,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            proj_bias (bool): enable bias for proj in attn if True\n            ffn_bias (bool): enable bias for ffn if True\n            drop_path_rate (float): stochastic depth rate\n            drop_path_uniform (bool): apply uniform drop rate across blocks\n            weight_init (str): weight init scheme\n            init_values (float): layer-scale init values\n            embed_layer (nn.Module): patch embedding layer\n            act_layer (nn.Module): MLP activation layer\n            block_fn (nn.Module): transformer block class\n            ffn_layer (str): \"mlp\", \"swiglu\", \"swiglufused\" or \"identity\"\n            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap\n            num_register_tokens: (int) number of extra cls tokens (so-called \"registers\")\n            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings\n            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings\n        \"\"\"\n        super().__init__()\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n\n        self.num_features = self.embed_dim = (\n            embed_dim  # num_features for consistency with other models\n        )\n        self.num_tokens = 1\n        self.n_blocks = depth\n        self.num_heads = num_heads\n        self.patch_size = patch_size\n        self.num_register_tokens = num_register_tokens\n        self.interpolate_antialias = interpolate_antialias\n        self.interpolate_offset = interpolate_offset\n\n        self.patch_embed = embed_layer(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n        )\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(\n            torch.zeros(1, num_patches + self.num_tokens, embed_dim)\n        )\n        assert num_register_tokens >= 0\n        self.register_tokens = (\n            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))\n            if num_register_tokens\n            else None\n        )\n\n        if drop_path_uniform is True:\n            dpr = [drop_path_rate] * depth\n        else:\n            dpr = [\n                x.item() for x in torch.linspace(0, drop_path_rate, depth)\n            ]  # stochastic depth decay rule\n\n        if ffn_layer == \"mlp\":\n            logger.info(\"using MLP layer as FFN\")\n            ffn_layer = Mlp\n        elif ffn_layer == \"swiglufused\" or ffn_layer == \"swiglu\":\n            logger.info(\"using SwiGLU layer as FFN\")\n            ffn_layer = SwiGLUFFNFused\n        elif ffn_layer == \"identity\":\n            logger.info(\"using Identity layer as FFN\")\n\n            def f(*args, **kwargs):\n                return nn.Identity()\n\n            ffn_layer = f\n        else:\n            raise NotImplementedError\n\n        blocks_list = [\n            block_fn(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                proj_bias=proj_bias,\n                ffn_bias=ffn_bias,\n                drop_path=dpr[i],\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                ffn_layer=ffn_layer,\n                init_values=init_values,\n            )\n            for i in range(depth)\n        ]\n        if block_chunks > 0:\n            self.chunked_blocks = True\n            chunked_blocks = []\n            chunksize = depth // block_chunks\n            for i in range(0, depth, chunksize):\n                # this is to keep the block index consistent if we chunk the block list\n                chunked_blocks.append(\n                    [nn.Identity()] * i + blocks_list[i : i + chunksize]\n                )\n            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])\n        else:\n            self.chunked_blocks = False\n            self.blocks = nn.ModuleList(blocks_list)\n\n        self.norm = norm_layer(embed_dim)\n        self.head = nn.Identity()\n\n        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))\n\n        self.init_weights()\n\n    def init_weights(self):\n        trunc_normal_(self.pos_embed, std=0.02)\n        nn.init.normal_(self.cls_token, std=1e-6)\n        if self.register_tokens is not None:\n            nn.init.normal_(self.register_tokens, std=1e-6)\n        named_apply(init_weights_vit_timm, self)\n\n    def interpolate_pos_encoding(self, x, w, h):\n        previous_dtype = x.dtype\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        pos_embed = self.pos_embed.float()\n        class_pos_embed = pos_embed[:, 0]\n        patch_pos_embed = pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset\n\n        sqrt_N = math.sqrt(N)\n        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(\n                0, 3, 1, 2\n            ),\n            scale_factor=(sx, sy),\n            mode=\"bicubic\",\n            antialias=self.interpolate_antialias,\n        )\n\n        assert int(w0) == patch_pos_embed.shape[-2]\n        assert int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(\n            previous_dtype\n        )\n\n    def prepare_tokens_with_masks(self, x, masks=None):\n        B, nc, w, h = x.shape\n        x = self.patch_embed(x)\n        if masks is not None:\n            x = torch.where(\n                masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x\n            )\n\n        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        if self.register_tokens is not None:\n            x = torch.cat(\n                (\n                    x[:, :1],\n                    self.register_tokens.expand(x.shape[0], -1, -1),\n                    x[:, 1:],\n                ),\n                dim=1,\n            )\n\n        return x\n\n    def forward_features_list(self, x_list, masks_list):\n        x = [\n            self.prepare_tokens_with_masks(x, masks)\n            for x, masks in zip(x_list, masks_list)\n        ]\n        for blk in self.blocks:\n            x = blk(x)\n\n        all_x = x\n        output = []\n        for x, masks in zip(all_x, masks_list):\n            x_norm = self.norm(x)\n            output.append(\n                {\n                    \"x_norm_clstoken\": x_norm[:, 0],\n                    \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n                    \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n                    \"x_prenorm\": x,\n                    \"masks\": masks,\n                }\n            )\n        return output\n\n    def forward_features(self, x, masks=None):\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x_norm = self.norm(x)\n        return {\n            \"x_norm_clstoken\": x_norm[:, 0],\n            \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n            \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n            \"x_prenorm\": x,\n            \"masks\": masks,\n        }\n\n    def forward_features_multi(self, x, masks=None):\n        \"\"\"\n        Extract multilayer features from the model for dense prediction.\n        Fixing the number of layers to 4 following \"Vision Transformers for Dense Prediction\"\n        https://arxiv.org/abs/2103.13413.\n        \"\"\"\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n\n        feats = []\n        num_layers = len(self.blocks)\n        feat_layers = [\n            num_layers // 4 - 1,\n            num_layers // 2 - 1,\n            num_layers // 4 * 3 - 1,\n            num_layers - 1,\n        ]\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in feat_layers:\n                feats.append(self.norm(x))\n\n        return {\n            \"x_norm_clstoken\": [feat[:, 0] for feat in feats],\n            \"x_norm_regtokens\": [\n                feat[:, 1 : self.num_register_tokens + 1] for feat in feats\n            ],\n            \"x_norm_patchtokens\": [\n                feat[:, self.num_register_tokens + 1 :] for feat in feats\n            ],\n            \"masks\": masks,\n        }\n\n    def _get_intermediate_layers_not_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        # If n is an int, take the n last blocks. If it's a list, take them\n        output, total_block_len = [], len(self.blocks)\n        blocks_to_take = (\n            range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        )\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in blocks_to_take:\n                output.append(x)\n        assert len(output) == len(blocks_to_take), (\n            f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        )\n        return output\n\n    def _get_intermediate_layers_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        output, i, total_block_len = [], 0, len(self.blocks[-1])\n        # If n is an int, take the n last blocks. If it's a list, take them\n        blocks_to_take = (\n            range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        )\n        for block_chunk in self.blocks:\n            for blk in block_chunk[i:]:  # Passing the nn.Identity()\n                x = blk(x)\n                if i in blocks_to_take:\n                    output.append(x)\n                i += 1\n        assert len(output) == len(blocks_to_take), (\n            f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        )\n        return output\n\n    def get_intermediate_layers(\n        self,\n        x: torch.Tensor,\n        n: Union[int, Sequence] = 1,  # Layers or n last layers to take\n        reshape: bool = False,\n        return_class_token: bool = False,\n        norm=True,\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:\n        if self.chunked_blocks:\n            outputs = self._get_intermediate_layers_chunked(x, n)\n        else:\n            outputs = self._get_intermediate_layers_not_chunked(x, n)\n        if norm:\n            outputs = [self.norm(out) for out in outputs]\n        class_tokens = [out[:, 0] for out in outputs]\n        outputs = [out[:, 1:] for out in outputs]\n        if reshape:\n            B, _, w, h = x.shape\n            outputs = [\n                out.reshape(B, w // self.patch_size, h // self.patch_size, -1)\n                .permute(0, 3, 1, 2)\n                .contiguous()\n                for out in outputs\n            ]\n        if return_class_token:\n            return tuple(zip(outputs, class_tokens))\n        return tuple(outputs)\n\n    def forward(self, *args, is_training=False, **kwargs):\n        ret = self.forward_features(*args, **kwargs)\n        if is_training:\n            return ret\n        else:\n            return self.head(ret[\"x_norm_clstoken\"])\n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = \"\"):\n    \"\"\"ViT weight initialization, original timm impl (for reproducibility)\"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=0.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n\n\ndef vit_small(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        # block_fn=partial(Block, attn_class=MemEffAttention),\n        block_fn=partial(Block, attn_class=Attention),\n        **kwargs,\n    )\n    return model\n\n\ndef vit_small_reg(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        # block_fn=partial(Block, attn_class=MemEffAttention),\n        block_fn=partial(Block, attn_class=Attention),\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_base(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        **kwargs,\n    )\n    return model\n\n\ndef vit_base_reg(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_large(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        **kwargs,\n    )\n    return model\n\n\ndef vit_large_reg(patch_size, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_giant2(patch_size, **kwargs):\n    \"\"\"\n    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64\n    \"\"\"\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1536,\n        depth=40,\n        num_heads=24,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        ffn_layer=\"swiglufused\",\n        **kwargs,\n    )\n    return model\n\n\ndino_name_mappings = {\n    \"vit_small\": {\n        \"weights\": \"dinov2_vits14_pretrain.pth\",\n        \"feats\": 384,\n        \"func\": vit_small,\n    },\n    \"vit_base\": {\n        \"weights\": \"dinov2_vitb14_pretrain.pth\",\n        \"feats\": 768,\n        \"func\": vit_base,\n    },\n    \"vit_large\": {\n        \"weights\": \"dinov2_vitl14_pretrain.pth\",\n        \"feats\": 1024,\n        \"func\": vit_large,\n    },\n    \"vit_giant2\": {\n        \"weights\": \"dinov2_vitg14_pretrain.pth\",\n        \"feats\": 1536,\n        \"func\": vit_giant2,\n    },\n    # v2.5 models\n    \"vit_small_v25\": {\n        \"weights\": \"dinov2_vits14_reg4_pretrain.pth\",\n        \"feats\": 384,\n        \"func\": vit_small_reg,\n    },\n    \"vit_base_v25\": {\n        \"weights\": \"dinov2_vitb14_reg4_pretrain.pth\",\n        \"feats\": 768,\n        \"func\": vit_base_reg,\n    },\n    \"vit_large_v25\": {\n        \"weights\": \"dinov2_vitl14_reg4_pretrain.pth\",\n        \"feats\": 1024,\n        \"func\": vit_large_reg,\n    },\n}\n\n\nclass DinoV2Wrapper(torch.nn.Module):\n    \"\"\"\n    runs DinoV2 on input images\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str = \"vit_small\",\n        img_size: Optional[Union[int, Tuple[int, int]]] = None,\n        multilayer_output: bool = False,\n        ckpt_path: str = \"\",\n    ):\n        super().__init__()\n\n        assert name in dino_name_mappings.keys(), (\n            f\"Dino model name should be one of {dino_name_mappings.keys()}\"\n        )\n\n        assert os.path.exists(ckpt_path), f\"Missing DinoV2 checkpoint path {ckpt_path}\"\n        print(f\"Use the provided DinoV2 checkpoint path {ckpt_path}\")\n\n        # If no image size is provided, use the recommended image size from DinoV2 models.\n        if img_size is None:\n            img_size = 518\n\n        patch_size = 14\n        feat_dim = dino_name_mappings[name][\"feats\"]\n        model_constructor = dino_name_mappings[name][\"func\"]\n        # reference: https://github.com/facebookresearch/dinov2/blob/9a4564ce5ebfe66a37fd16c6a233fb04ffb0a752/dinov2/hub/backbones.py#L21\n        model = model_constructor(\n            patch_size=patch_size,\n            img_size=img_size,\n            block_chunks=0,\n            init_values=1.0,\n        )\n        print(f\"Contructed DinoV2 model {name}\")\n        checkpoint = torch.load(ckpt_path, weights_only=True)\n        # Dino models should all be loaded with strict=True.\n        model.load_state_dict(checkpoint, strict=True)\n\n        normalize_fn = T.Normalize(\n            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)\n        )\n\n        self.model = model\n        self.feat_dim = feat_dim\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.normalize_fn = normalize_fn\n        if multilayer_output:\n            self.forward_fn = self.model.forward_features_multi\n        else:\n            self.forward_fn = self.model.forward_features\n\n    def forward(self, img):\n        \"\"\"\n        Input:\n            img : torch.Tensor : batch of images shaped BxCxHxW of type float32 in range [0,1],\n                                 where C can be 1 or 3, they get resized to a fixed size\n        Returns:\n            feats : torch.Tensor : shaped BxDxpHxpW where D is feature dim, pH & pW = img_size / patch_size,\n        \"\"\"\n        B, C, H, W = img.shape\n        assert C in [1, 3], \"must be either 1 or 3 channel input (BxCxHxW)\"\n        if C == 1:\n            # Fake RGB by repeating gray channel.\n            img = img.repeat(1, 3, 1, 1)\n        img = self.normalize_fn(img)  # Apply imagenet normalization.\n        feats = self.forward_fn(img)[\"x_norm_patchtokens\"]\n        H, W = img.shape[-2:]\n        assert H % self.patch_size == 0 and W % self.patch_size == 0, (\n            \"Resize the images to a multiple of patch size\"\n        )\n        pH = H // self.patch_size\n        pW = W // self.patch_size\n        if isinstance(feats, List):\n            feats = [f.reshape(-1, pH, pW, self.feat_dim) for f in feats]\n            feats = [f.permute(0, 3, 1, 2) for f in feats]  # BxHxWxD => BxDxHxW\n        else:\n            feats = feats.reshape(-1, pH, pW, self.feat_dim)\n            feats = feats.permute(0, 3, 1, 2)  # BxHxWxD => BxDxHxW\n        return feats\n"
  },
  {
    "path": "efm3d/model/dpt.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport einops\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.functional import interpolate\n\n\nclass ResidualConvUnit(nn.Module):\n    # From \"Vision Transformers for Dense Prediction\": https://arxiv.org/abs/2103.13413\n    # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py\n    def __init__(self, features, kernel_size):\n        super().__init__()\n        assert kernel_size % 1 == 0, \"Kernel size needs to be odd\"\n        padding = kernel_size // 2\n        self.conv = nn.Sequential(\n            nn.Conv2d(features, features, kernel_size, padding=padding),\n            nn.ReLU(True),\n            nn.Conv2d(features, features, kernel_size, padding=padding),\n            nn.ReLU(True),\n        )\n\n    def forward(self, x):\n        return self.conv(x) + x\n\n\nclass FeatureFusionBlock(nn.Module):\n    # Fro \"Vision Transformers for Dense Prediction\": https://arxiv.org/abs/2103.13413\n    # adapted from https://github.com/isl-org/DPT/blob/main/dpt/blocks.py\n    def __init__(self, features, kernel_size, with_skip=True):\n        super().__init__()\n        self.with_skip = with_skip\n        if self.with_skip:\n            self.resConfUnit1 = ResidualConvUnit(features, kernel_size)\n\n        self.resConfUnit2 = ResidualConvUnit(features, kernel_size)\n\n    def forward(self, x, skip_x=None):\n        if skip_x is not None:\n            assert self.with_skip, \"Must init with with_skip=True\"\n            assert skip_x.shape == x.shape, (\n                f\"skip {skip_x.shape} and x {x.shape} shape mismatch\"\n            )\n            x = self.resConfUnit1(x) + skip_x\n\n        x = self.resConfUnit2(x)\n        return x\n\n\nclass Interpolate(nn.Module):\n    \"\"\"\n    Interpolation module. https://github.com/isl-org/DPT/blob/main/dpt/blocks.py#L138\n    \"\"\"\n\n    def __init__(self, scale_factor, mode, align_corners=False):\n        \"\"\"Init.\n\n        Args:\n            scale_factor (float): scaling\n            mode (str): interpolation mode\n        \"\"\"\n        super(Interpolate, self).__init__()\n\n        self.interp = nn.functional.interpolate\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: interpolated data\n        \"\"\"\n\n        x = self.interp(\n            x,\n            scale_factor=self.scale_factor,\n            mode=self.mode,\n            align_corners=self.align_corners,\n        )\n        return x\n\n\nclass DPTOri(nn.Module):\n    \"\"\"\n    Implementation of DPT according to the paper description https://arxiv.org/pdf/2103.13413\n    \"\"\"\n\n    def __init__(self, input_dim, hidden_dim=256, output_dim=256, depth=False):\n        \"\"\"\n        input_dim: dimension of the DinoV2 tokens (384/768/...)\n        hidden_dim: dense feature dimension(=256, D^{hat} in the paper) in DPT\n        output_dim: final output feature dimension\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        if self.depth:\n            # DPT depth head https://github.com/isl-org/DPT/blob/main/dpt/models.py#L89\n            self.depth_head = nn.Sequential(\n                nn.Conv2d(\n                    hidden_dim, hidden_dim // 2, kernel_size=3, stride=1, padding=1\n                ),\n                Interpolate(scale_factor=2, mode=\"bilinear\", align_corners=True),\n                nn.Conv2d(hidden_dim // 2, 32, kernel_size=3, stride=1, padding=1),\n                nn.ReLU(True),\n                nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),\n                nn.ReLU(True),  # require depth to be non-negative\n                nn.Identity(),\n            )\n            output_dim = output_dim - 1  # last dim is depth\n\n        # 1x1 convs to map (H/p x W/p x D) -> (H/s x W/s x D^{hat})\n        self.conv_0 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0)\n        self.conv_1 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0)\n        self.conv_2 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0)\n        self.conv_3 = nn.Conv2d(input_dim, hidden_dim, 1, padding=0)\n\n        # (strided) convs for upsampling (feat_0/1/2) and downsample (feat_3)\n        # image - WxW, padding - P, kernel - FxF, stride - S\n        # conv size - (W-F+2P) / S + 1\n        # transpose conv size - (H-1)*S+F-2P\n        self.resample_conv0 = nn.ConvTranspose2d(\n            hidden_dim, hidden_dim, 3, stride=4, padding=0\n        )\n        self.resample_conv1 = nn.ConvTranspose2d(\n            hidden_dim, hidden_dim, 3, stride=2, padding=1\n        )\n        self.resample_conv2 = nn.Conv2d(hidden_dim, hidden_dim, 3, stride=1, padding=1)\n        self.resample_conv3 = nn.Conv2d(hidden_dim, hidden_dim, 3, stride=2, padding=1)\n\n        # fusion blocks\n        self.ref_0 = FeatureFusionBlock(hidden_dim, 3)\n        self.ref_1 = FeatureFusionBlock(hidden_dim, 3)\n        self.ref_2 = FeatureFusionBlock(hidden_dim, 3)\n        self.ref_3 = FeatureFusionBlock(hidden_dim, 3, with_skip=False)\n\n        # final upsample head\n        self.conv_up1 = nn.Conv2d(\n            hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1\n        )\n        self.conv_final = nn.Conv2d(\n            hidden_dim, output_dim, kernel_size=1, stride=1, padding=0\n        )\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, feats):\n        \"\"\"\n        feats: tokens from multi-layers, for ViT-base these are [3,6,9,12] (starting from 1, not 0)\n        \"\"\"\n        assert len(feats) == 4, (\n            \"feats must be multi-level as a list of 4 tensors, probably set model.video_backbone.image_tokenizer.multilayer_output=True\"\n        )\n        ndim = feats[0].ndim\n        if ndim == 5:\n            T = feats[0].shape[1]\n            feats = [einops.rearrange(f, \"b t c h w -> (b t) c h w\") for f in feats]\n\n        # [T, D, H/p, W/p]\n        feats[0] = self.conv_0(feats[0])\n        feats[1] = self.conv_1(feats[1])\n        feats[2] = self.conv_2(feats[2])\n        feats[3] = self.conv_3(feats[3])\n\n        # add single-side padding here after feat0 and feat1 to make upsampling 4x and 2x the token map size\n        padding = (0, 1, 0, 1)  # left, right, top, bottom\n        feats[0] = self.resample_conv0(feats[0])\n        feats[0] = F.pad(feats[0], padding, mode=\"constant\", value=0)\n        feats[1] = self.resample_conv1(feats[1])\n        feats[1] = F.pad(feats[1], padding, mode=\"constant\", value=0)\n        feats[2] = self.resample_conv2(feats[2])\n        feats[3] = self.resample_conv3(feats[3])\n\n        out = self.ref_3(feats[3], None)\n        out = interpolate(\n            out, size=feats[2].shape[-2:], mode=\"bilinear\", align_corners=True\n        )\n        out = self.ref_2(feats[2], out)\n        out = interpolate(\n            out, size=feats[1].shape[-2:], mode=\"bilinear\", align_corners=True\n        )\n        out = self.ref_1(feats[1], out)\n        out = interpolate(\n            out, size=feats[0].shape[-2:], mode=\"bilinear\", align_corners=True\n        )\n        out = self.ref_0(feats[0], out)\n        h, w = feats[0].shape[-2:]\n        feat = interpolate(\n            out, size=(h * 2, w * 2), mode=\"bilinear\", align_corners=True\n        )\n\n        # upsample by 2x (In the paper DPT outputs 1/2 original size feature maps)\n        out = self.relu(self.conv_up1(feat))\n        h, w = out.shape[-2:]\n        out = interpolate(out, size=(h * 2, w * 2), mode=\"bilinear\", align_corners=True)\n        out = self.conv_final(out)\n\n        if self.depth:\n            inv_depth = self.depth_head(feat) + 1e-3  # predict inv depth, add epsilon\n            out = torch.cat([out, inv_depth], dim=1)\n\n        if ndim == 5:\n            out = einops.rearrange(out, \"(b t) c h w -> b t c h w\", t=T)\n        return out\n"
  },
  {
    "path": "efm3d/model/evl.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Union\n\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_OBB_PRED,\n    ARIA_OBB_PRED_PROBS_FULL,\n    ARIA_OBB_PRED_PROBS_FULL_VIZ,\n    ARIA_OBB_PRED_SEM_ID_TO_NAME,\n    ARIA_OBB_PRED_VIZ,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.model.cnn import InvResnetFpn3d, VolumeCNNHead\nfrom efm3d.model.lifter import VideoBackbone3d\nfrom efm3d.model.video_backbone import VideoBackbone\nfrom efm3d.utils.detection_utils import simple_nms3d, voxel2obb\nfrom efm3d.utils.file_utils import parse_global_name_to_id_csv\nfrom hydra.utils import instantiate\nfrom omegaconf import DictConfig\n\n\nclass EVL(torch.nn.Module):\n    def __init__(\n        self,\n        video_backbone: Union[VideoBackbone, DictConfig],\n        video_backbone3d: Union[VideoBackbone3d, DictConfig],\n        neck_hidden_dims: Optional[List] = None,\n        head_hidden_dim: int = 128,\n        head_layers: int = 2,\n        taxonomy_file: Optional[str] = None,\n        det_thresh: float = 0.2,\n        yaw_max: float = 1.6,\n    ):\n        \"\"\"\n        Args:\n            video_backbone: 2D backbone to extract features from images.\n            video_backbone3d: 3D backbone to lift 2d to 3d voxels.\n            neck_hidden_dims: hidden dims of the 3D CNN neck.\n            head_hidden_dim: hidden dim of the 3D CNN head.\n\n            # obb params\n            det_thresh: Detection threshold for NMS.\n            yaw_max: Maximum yaw angle for object orientation.\n        \"\"\"\n        super().__init__()\n\n        if neck_hidden_dims is None:\n            neck_hidden_dims = [64, 128, 256]\n\n        self.backbone2d = video_backbone\n        self.backbone3d = video_backbone3d\n        self.head_layers = head_layers\n\n        if isinstance(video_backbone, DictConfig):\n            self.backbone2d = instantiate(video_backbone)\n        if isinstance(video_backbone3d, DictConfig):\n            self.backbone3d = instantiate(video_backbone3d)\n\n        backbone3d_out_dim = self.backbone3d.output_dim()\n\n        # 3d U-Net\n        c = backbone3d_out_dim  # c = 66 (64 + 1 + 1)\n        dims = [c, 64, 96, 128, 160]\n        neck_final = dims[1]\n        print(f\"==> Init 3D InvResnetFpn3d neck with hidden layers: {dims}\")\n        self.neck = InvResnetFpn3d(\n            dims=dims,\n            num_bottles=[2, 2, 2, 2],\n            strides=[1, 2, 2, 2],\n            expansions=[2.0, 2.0, 2.0, 2.0],\n        )\n\n        print(\n            f\"==> Init 3D CNN Head with final dim = {neck_final}, hidden dim = {head_hidden_dim}\"\n        )\n        # occpuancy head\n        self.occ_head = VolumeCNNHead(\n            neck_final,\n            head_hidden_dim,\n            final_dim=1,\n            num_layers=self.head_layers,\n            name=\"Occupancy\",\n        )\n\n        # obb part\n        if taxonomy_file is not None:\n            taxonomy = parse_global_name_to_id_csv(taxonomy_file)\n            self.sem2name = {int(sem_id): name for name, sem_id in taxonomy.items()}\n        self.num_class = len(self.sem2name)\n\n        # Centerness head (center of the bounding box).\n        self.cent_head = VolumeCNNHead(\n            neck_final,\n            head_hidden_dim,\n            final_dim=1,\n            name=\"Centerness\",\n            bias=-5,\n        )\n        # Box size head (height, width, depth, offset_h, offset_w, offset_d, yaw rotation of box).\n        self.bbox_head = VolumeCNNHead(\n            neck_final,\n            head_hidden_dim,\n            final_dim=7,\n            name=\"BoundingBox\",\n        )\n        self.clas_head = VolumeCNNHead(\n            neck_final,\n            head_hidden_dim,\n            final_dim=self.num_class,\n            name=\"Class\",\n        )\n\n        self.det_thresh = det_thresh\n        self.bbox_min = 0.1  # Min bbox dim\n        self.bbox_max = 6.0  # Max bbox dim\n        # Scale the bbox offset max based on voxel size in meters.\n        self.offset_max = 2 * self.backbone3d.voxel_meters\n        self.splat_sigma = max(1, int(0.12 / self.backbone3d.voxel_meters))\n        self.iou_thres = 0.2\n        self.ve = self.backbone3d.voxel_extent  # voxel extent\n        self.yaw_max = yaw_max\n        self.scene = None\n\n    def post_process(self, batch, out):\n        cent_pr = out[\"cent_pr\"]\n        bbox_pr = out[\"bbox_pr\"]\n        clas_pr = out[\"clas_pr\"]\n        # Run NMS + convert voxel outputs to ObbTW.\n        with torch.no_grad():\n            # First NMS is a simple heatmap suppression.\n            cent_pr_nms = simple_nms3d(cent_pr, nms_radius=self.splat_sigma + 1)\n            vD, vH, vW = cent_pr.shape[-3:]\n            # Convert dense predicitions to sparse ObbTW predictions.\n            obbs_pr_nms, _, clas_prob_nms = voxel2obb(\n                cent_pr_nms,\n                bbox_pr,\n                clas_pr,\n                self.ve,\n                top_k=128,\n                thresh=self.det_thresh,\n                return_full_prob=True,\n            )\n            out[\"obbs_pr_nms\"] = obbs_pr_nms\n            out[\"cent_pr_nms\"] = cent_pr_nms\n\n        # obb tracker expects ARIA_OBB_PRED and ARIA_OBB_PRED_VIZ to be in snippet coord system\n        obbs_pr_nms_s = obbs_pr_nms.clone()\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]  # B x 1 x 12\n        T_wv = out[\"voxel/T_world_voxel\"].unsqueeze(1)  # B x 1 x 12\n        T_sv = T_ws.inverse() @ T_wv\n        obbs_pr_nms_s = obbs_pr_nms_s.transform(T_sv)  # transform to snippet coords\n        out[ARIA_OBB_PRED_SEM_ID_TO_NAME] = self.sem2name\n        out[ARIA_OBB_PRED] = obbs_pr_nms_s\n        out[ARIA_OBB_PRED_VIZ] = obbs_pr_nms_s\n        out[ARIA_OBB_PRED_PROBS_FULL] = [item for item in clas_prob_nms]\n        out[ARIA_OBB_PRED_PROBS_FULL_VIZ] = out[ARIA_OBB_PRED_PROBS_FULL]\n        return out\n\n    def forward(self, batch, obb_only=False):\n        out = {}\n        # Run 2D backbone on images to get on 2D feature map per image.\n        backbone2d_out_all = self.backbone2d(batch)\n        for stream in [\"rgb\", \"slaml\", \"slamr\"]:\n            if stream in backbone2d_out_all:\n                # add to batch for lifter\n                batch[f\"{stream}/feat\"] = backbone2d_out_all[stream]\n\n        # Run explicit 3D backbone to lift 2D features to a 3D voxel grid.\n        backbone3d_out = self.backbone3d(batch)\n        voxel_feats = backbone3d_out[\"voxel/feat\"]\n\n        # Run 3D encoder-decoder CNN, acting as a \"neck\" to the heads.\n        neck_feats1 = self.neck(voxel_feats)\n        neck_feats2 = neck_feats1\n\n        # ---------- Run the occ head ------------\n        if not obb_only:\n            occ_logits = self.occ_head(neck_feats1)\n            occ_pr = torch.sigmoid(occ_logits)  # logits => prob.\n            out[\"occ_pr\"] = occ_pr\n            out[\"voxel_extent\"] = torch.tensor(self.ve).to(neck_feats1)\n\n        # ---------- Run the obb head ------------\n        # Run the centerness head.\n        cent_logits = self.cent_head(neck_feats2)\n        cent_pr = torch.sigmoid(cent_logits)  # logits => prob.\n\n        # Run the box size head.\n        bbox_pr = self.bbox_head(neck_feats2)\n        bbox_pr[:, 0:3] = (self.bbox_max - self.bbox_min) * torch.sigmoid(\n            bbox_pr[:, :3]\n        ) + self.bbox_min\n        bbox_pr[:, 3:6] = self.offset_max * torch.tanh(bbox_pr[:, 3:6])\n        bbox_pr[:, 6] = self.yaw_max * torch.tanh(bbox_pr[:, 6])\n\n        # Run the classification head.\n        clas_pr = self.clas_head(neck_feats2)\n        clas_pr = torch.nn.functional.softmax(clas_pr, dim=1)\n\n        out.update(backbone3d_out)\n        out[\"neck/occ_feat\"] = neck_feats1\n        out[\"neck/obb_feat\"] = neck_feats2\n        #  Copy data from head outputs.\n        out[\"cent_pr\"] = cent_pr\n        out[\"bbox_pr\"] = bbox_pr\n        out[\"clas_pr\"] = clas_pr\n\n        out = self.post_process(batch, out)\n\n        return out\n"
  },
  {
    "path": "efm3d/model/evl_train.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_OBB_PADDED,\n    ARIA_OBB_PRED,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.aria.obb import ObbTW\nfrom efm3d.aria.pose import PoseTW\nfrom efm3d.model.evl import EVL\nfrom efm3d.model.lifter import VideoBackbone3d\nfrom efm3d.model.video_backbone import VideoBackbone\nfrom efm3d.utils.evl_loss import compute_obb_losses, compute_occ_losses, get_gt_obbs\nfrom efm3d.utils.image import put_text\nfrom efm3d.utils.marching_cubes import marching_cubes_scaled\nfrom efm3d.utils.obb_utils import prec_recall_bb3\nfrom efm3d.utils.pointcloud import (\n    get_points_world,\n    pointcloud_occupancy_samples,\n    pointcloud_to_occupancy_snippet,\n)\nfrom efm3d.utils.render import draw_obbs_snippet, get_colors_from_sem_map\nfrom efm3d.utils.viz import (\n    render_cosy,\n    render_frustum,\n    render_linestrip,\n    render_obb_line,\n    render_obbs_line,\n    render_points,\n    render_rgb_tri_mesh,\n    render_scalar_field_points,\n    render_tri_mesh,\n    SceneView,\n)\nfrom efm3d.utils.voxel import erode_voxel_mask\nfrom efm3d.utils.voxel_sampling import pc_to_vox, sample_voxels\nfrom omegaconf import DictConfig\n\n\nclass EVLTrain(EVL):\n    def __init__(\n        self,\n        video_backbone: Union[VideoBackbone, DictConfig],\n        video_backbone3d: Union[VideoBackbone3d, DictConfig],\n        neck_hidden_dims: Optional[List] = None,\n        head_hidden_dim: int = 128,\n        head_layers: int = 2,\n        taxonomy_file: Optional[str] = None,\n        det_thresh: float = 0.2,\n        yaw_max: float = 1.6,\n    ):\n        super().__init__(\n            video_backbone,\n            video_backbone3d,\n            neck_hidden_dims,\n            head_hidden_dim,\n            head_layers,\n            taxonomy_file,\n            det_thresh,\n            yaw_max,\n        )\n\n    def compute_losses(self, outputs, batch):\n        total_loss = 0\n        losses = {\"rgb\": {}}\n\n        self.occ_weight = 10.0\n        self.tv_weight = 0.01\n        occ_losses, occ_total_loss = compute_occ_losses(\n            outputs,\n            batch,\n            self.ve,\n            occ_weight=self.occ_weight,\n            tv_weight=self.tv_weight,\n        )\n        for k in losses:  # for ['rgb', 'slaml', 'slamr']\n            losses[k].update(occ_losses[k])\n        total_loss += occ_total_loss\n\n        self.cent_weight = 10.0\n        self.bbox_weight = 0.0\n        self.cham_weight = 0.0\n        self.clas_weight = 0.1\n        self.iou_weight = 0.5\n        obb_losses, obb_total_loss = compute_obb_losses(\n            outputs,\n            batch,\n            self.ve,\n            self.num_class,\n            self.splat_sigma,\n            cent_weight=self.cent_weight,\n            clas_weight=self.clas_weight,\n            iou_weight=self.iou_weight,\n            bbox_weight=self.bbox_weight,\n            cham_weight=self.cham_weight,\n        )\n        for k in losses:  # for ['rgb', 'slaml', 'slamr']\n            losses[k].update(obb_losses[k])\n        total_loss += obb_total_loss\n\n        return losses, total_loss\n\n    def render2d(self, imgs, obbs, Ts_wr, cams):\n        \"\"\"Render a 2D visualization overlaid on the RGB image of the given obbs.\"\"\"\n        # Draw the 3D bb overlaid on the image.\n        obb_img = draw_obbs_snippet(\n            imgs.clone(),\n            obbs,\n            Ts_wr,\n            cams,\n            rgb2bgr=False,\n            draw_cosy=True,\n            white_backing_line=False,\n            draw_bb2=False,\n            sem_id_to_name_mapping=self.sem2name,\n            draw_label=True,\n            draw_score=True,\n            prob_threshold=0.001,  # keep this very low, obbs are already thresholded.\n        )\n        return np.array(obb_img)\n\n    def log_single_obb(self, batch, outputs, batch_idx):\n        \"\"\"Log a single element from the batch based on \"batch_idx\".\"\"\"\n        log_ims = {}\n\n        # Get stuff.\n        rgb_img = batch[ARIA_IMG[0]][batch_idx].cpu().detach()\n        T = rgb_img.shape[0]\n        cams = batch[ARIA_CALIB[0]][batch_idx].cpu().detach()\n        Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][batch_idx].cpu().detach()\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][batch_idx].cpu().detach()\n        voxel_w = outputs[\"voxel/pts_world\"][batch_idx].cpu().detach()\n        T_wv = outputs[\"voxel/T_world_voxel\"][batch_idx]\n        obbs_gt = get_gt_obbs(batch, self.ve, T_wv)\n        obbs_gt = obbs_gt[batch_idx].cpu()\n        T_wv = T_wv.cpu().detach()\n        cent_pr = outputs[\"cent_pr\"][batch_idx].cpu().detach()\n        obbs_pr = outputs[\"obbs_pr_nms\"][batch_idx].cpu().detach()\n        occ_input = None\n\n        # Get some convenience transforms.\n        Ts_wr = T_ws @ Ts_sr\n\n        # Transform Objects to world coords.\n        obbs_pr = obbs_pr.transform(T_wv)  # T_wo = T_wv @ T_vo\n\n        # Transform lifter volume obb to world coordinates.\n        extent = torch.tensor(self.ve).to(T_wv._data)\n        voxel_obb = ObbTW()[0]\n        voxel_obb.set_bb3_object(extent, use_mask=False)\n        voxel_obb.set_T_world_object(T_wv)\n\n        occ_input = outputs[\"voxel/occ_input\"][batch_idx].cpu().detach().reshape(-1)\n        mask = occ_input > 1e-4\n        log_ims[\"voxel/occ_input\"] = self.render3d_obb(\n            occ_input[mask],\n            obbs_pr,\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_w[mask],\n            voxel_obb,\n            view=\"follow\",\n            alpha_min=0.1,\n        )\n\n        # compute precision and recall and add text to the pred 2d\n        log_ims[\"rgb_pred\"] = self.render2d(rgb_img, obbs_pr, Ts_wr, cams)\n        if ARIA_OBB_PADDED in batch:\n            obbs_pr_nms = outputs[ARIA_OBB_PRED][batch_idx].cpu()\n            prec, rec, match_mat = prec_recall_bb3(\n                obbs_pr_nms.remove_padding(),\n                obbs_gt.remove_padding(),\n                iou_thres=self.iou_thres,\n            )\n            if match_mat is not None:\n                num_tp = match_mat.any(-1).sum().item()\n                num_pred = match_mat.shape[0]\n                num_gt = match_mat.shape[1]\n                precision = f\"Prec@{self.iou_thres}: {prec:.2f} ({num_tp}/{num_pred})\"\n                recall = f\"Recall@{self.iou_thres}: {rec:.2f} ({num_tp}/{num_gt})\"\n            else:\n                precision = f\"Prec@{self.iou_thres}: {prec:.2f}\"\n                recall = f\"Recall@{self.iou_thres}: {rec:.2f}\"\n            imgs_pred = log_ims[\"rgb_pred\"]\n            imgs_pred = [put_text(img, precision, line=-2) for img in imgs_pred]\n            imgs_pred = [put_text(img, recall, line=-1) for img in imgs_pred]\n            log_ims[\"rgb_pred\"] = np.array(imgs_pred)\n\n        log_ims[\"3D_pred\"] = self.render3d_obb(\n            cent_pr,\n            obbs_pr,\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_w,\n            voxel_obb,\n            view=\"follow\",\n            alpha_min=0.1,\n        )\n\n        if \"cent_gt\" in outputs:\n            self.compute_losses(outputs, batch)\n\n            obbs_gt = obbs_gt[~obbs_gt.get_padding_mask()]\n            obbs_gt = obbs_gt.transform(T_ws)  # T_wo = T_ws @ T_so\n            log_ims[\"rgb_gt\"] = self.render2d(rgb_img, obbs_gt, Ts_wr, cams)\n\n            cent_gt = outputs[\"cent_gt\"][batch_idx].cpu().reshape(-1)\n            log_ims[\"3D_gt\"] = self.render3d_obb(\n                cent_gt,\n                obbs_gt,\n                Ts_wr,\n                T_ws,\n                cams,\n                voxel_w,\n                voxel_obb,\n                alpha_min=0.1,\n            )\n\n        return log_ims\n\n    def log_single(self, batch, outputs, batch_idx):\n        \"\"\"Log a single element from the batch based on \"batch_idx\".\"\"\"\n        log_ims = self.log_single_obb(batch, outputs, batch_idx)\n\n        cams = batch[ARIA_CALIB[0]][batch_idx].cpu().detach()\n        Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][batch_idx].cpu().detach()\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][batch_idx].cpu().detach()\n        voxel_w = outputs[\"voxel/pts_world\"][batch_idx].cpu().detach()\n        T_wv = outputs[\"voxel/T_world_voxel\"][batch_idx].cpu().detach()\n        Ts_wr = T_ws @ Ts_sr\n        T = cams.shape[0]\n\n        occ = outputs[\"occ_pr\"].squeeze(1)\n        voxel_counts = outputs[\"voxel/counts\"][batch_idx].cpu().detach()\n\n        B, D, H, W = occ.shape\n        Df, Hf, Wf = voxel_counts.shape\n        if D != Df or H != Hf or W != Wf:\n            resize = torch.nn.Upsample(size=(D, H, W))\n            voxel_w = voxel_w.view(Df, Hf, Wf, 3).permute(3, 0, 1, 2)\n            voxel_w = resize(voxel_w.unsqueeze(0)).squeeze(0)\n            voxel_w = voxel_w.permute(1, 2, 3, 0).view(-1, 3)\n            voxel_counts = resize(voxel_counts.unsqueeze(0).unsqueeze(0).float())\n            voxel_counts = voxel_counts.squeeze(0).squeeze(0)\n\n        visible = voxel_counts > 0\n        visible = erode_voxel_mask(visible.unsqueeze(0)).squeeze(0)\n\n        # Get some convenience transforms.\n        Ts_wr = T_ws @ Ts_sr\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()\n\n        # Transform lifter volume obb to world coordinates.\n        extent = torch.tensor(self.ve).to(T_wv._data)\n        voxel_obb = ObbTW()[0]\n        voxel_obb.set_bb3_object(extent, use_mask=False)\n        voxel_obb.set_T_world_object(T_wv)\n\n        # -------------------- draw occ -----------------------\n        occ_pr = outputs[\"occ_pr\"][batch_idx].cpu().detach().squeeze(0)\n        alpha_min = 0.5 if self.occ_weight > 0.0 else 0.04\n        log_ims[\"occ/mesh_pred\"] = self.render3d_mesh(\n            occ_pr,\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_obb,\n            view=\"follow\",\n            alpha_min=alpha_min,\n            T_wv=T_wv,\n            voxel_mask=visible,\n        )\n\n        log_ims[\"occ/occ_pred\"] = self.render3d_occ(\n            occ_pr,\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_w,\n            voxel_obb,\n            view=\"follow\",\n            alpha_min=alpha_min,\n            voxel_mask=visible,\n        )\n\n        vD, vH, vW = occ_pr.shape\n        pc_w = get_points_world(batch, batch_idx)[0].cpu().detach()\n        (\n            p3s_occ_w,\n            p3s_surf_w,\n            p3s_free_w,\n            valid,\n        ) = pointcloud_occupancy_samples(\n            pc_w.unsqueeze(0),\n            Ts_wc.unsqueeze(0),\n            cams.unsqueeze(0),\n            vW,\n            vH,\n            vD,\n            self.ve,\n            S=1,\n            T_wv=T_wv,\n        )\n        p3s_occ_w[~valid] = float(\"nan\")\n        p3s_surf_w[~valid] = float(\"nan\")\n        p3s_free_w[~valid] = float(\"nan\")\n\n        log_ims[\"occ/occ_gt_samples\"] = self.render3d_points(\n            p3s_surf_w.squeeze(0),\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_obb,\n            view=\"follow\",\n            more_p3s_w=p3s_free_w.squeeze(0),\n            more2_p3s_w=p3s_occ_w.squeeze(0),\n        )\n\n        # get occ gt\n        occ_gt, mask = pointcloud_to_occupancy_snippet(\n            pc_w,\n            Ts_wc,\n            cams,\n            T_wv,\n            vW,\n            vH,\n            vD,\n            self.ve,\n            S=1,\n        )\n        mask = torch.logical_and(mask.bool(), visible)\n        log_ims[\"occ/mesh_gt\"] = self.render3d_mesh(\n            occ_gt,\n            Ts_wr,\n            T_ws,\n            cams,\n            voxel_obb,\n            view=\"follow\",\n            alpha_min=alpha_min,\n            T_wv=T_wv,\n            voxel_mask=mask,\n        )\n\n        return log_ims\n\n    @torch.no_grad()\n    def render3d_mesh(\n        self,\n        voxel_vals,\n        Ts_wr,\n        T_ws,\n        cams,\n        voxel_obb,\n        view=\"follow\",\n        alpha_min=0.5,\n        T_wv=None,\n        voxel_mask=None,\n        volume_feat=None,\n    ):\n        \"\"\"Render a 3D visualization of the given voxel values and obbs.\"\"\"\n        if self.scene is None:\n            self.scene = SceneView(width=320, height=320)\n        scene = self.scene\n        lifter_imgs = []\n        verts_v, faces, normals_v = marching_cubes_scaled(\n            voxel_vals.cpu().detach().float(),\n            alpha_min,\n            self.ve,\n            voxel_mask,\n        )\n        feats = torch.tensor([])\n        if volume_feat is not None and len(verts_v) > 0:\n            vD, vH, vW = voxel_vals.shape\n            p3s_surf_vox, _ = pc_to_vox(verts_v, vW, vH, vD, self.ve)\n            feats, _ = sample_voxels(\n                volume_feat.unsqueeze(0), p3s_surf_vox.unsqueeze(0)\n            )\n            feats = feats.squeeze(0).permute(1, 0)\n            print(\"[WARN] No PCA compressor provided. Take the first 3 channels.\")\n            rgb = feats[:, :3]\n            maxs = rgb.max(dim=-1, keepdim=True)[0]\n            mins = rgb.min(dim=-1, keepdim=True)[0]\n            rgb = (rgb - mins) / (maxs - mins + 1e-4)\n        black = (0.0, 0.0, 0.0, 1.0)\n        green = (0.0, 1.0, 0.0, 1.0)\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()\n        for T_wr, T_wc, cam in zip(Ts_wr, Ts_wc, cams):\n            scene.clear()\n            if view == \"follow\":\n                scene.set_follow_view(T_wc, zoom_factor=4)\n            elif view == \"bird\":\n                scene.set_birds_eye_view(T_wc, zoom_factor=8)\n            else:\n                raise ValueError(\"bad option for 3d view style\")\n            if len(verts_v) > 0:\n                verts_w = T_wv * verts_v.to(T_wv.device)\n                normals_w = T_wv.rotate(normals_v.to(T_wv.device))\n                if volume_feat is not None:\n                    render_rgb_tri_mesh(\n                        verts_w,\n                        -normals_w,\n                        faces,\n                        rgb=rgb,\n                        prog=scene.prog_mesh_rgb,\n                        ctx=scene.ctx,\n                    )\n                else:\n                    render_tri_mesh(\n                        verts_w,\n                        normals_w,\n                        faces,\n                        prog=scene.prog_mesh,\n                        ctx=scene.ctx,\n                    )\n\n            # draw voxel bounding volume\n            render_obb_line(\n                voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True\n            )\n            # draw trajectory.\n            render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            # Render snippet origin.\n            render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n            # Render world origin.\n            render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n            img = scene.finish()\n            lifter_imgs.append(np.array(img))\n        lifter_imgs = np.array(lifter_imgs)\n        if volume_feat is None:\n            return lifter_imgs\n        else:\n            return lifter_imgs, feats, verts_v, faces, normals_v\n\n    @torch.no_grad()\n    def render3d_points(\n        self,\n        p3s_w,\n        Ts_wr,\n        T_ws,\n        cams,\n        voxel_obb,\n        view=\"follow\",\n        values=None,\n        alpha_min=0.01,\n        mask=None,\n        more_p3s_w=None,\n        more2_p3s_w=None,\n    ):\n        \"\"\"Render a 3D visualization of the given voxel values and obbs.\"\"\"\n        if self.scene is None:\n            self.scene = SceneView(width=320, height=320)\n        scene = self.scene\n        lifter_imgs = []\n        black = (0.0, 0.0, 0.0, 1.0)\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()\n\n        for t, (T_wr, T_wc, cam) in enumerate(zip(Ts_wr, Ts_wc, cams)):\n            scene.clear()\n            if view == \"follow\":\n                scene.set_follow_view(T_wc, zoom_factor=8)\n            elif view == \"bird\":\n                scene.set_birds_eye_view(T_wc, zoom_factor=12)\n            else:\n                raise ValueError(\"bad option for 3d view style\")\n\n            if values is not None:\n                alphas = torch.ones_like(values[t])\n                if alpha_min is not None:\n                    alphas[values[t] < alpha_min] = 0\n                else:\n                    alpha_min = 0.0\n            if mask is not None:\n                p3s_wt = p3s_w[t][mask[t]]\n                if values is not None:\n                    values_t = values[t][mask[t]]\n                    alphas_t = alphas[mask[t]]\n            else:\n                p3s_wt = p3s_w[t]\n                if values is not None:\n                    values_t = values[t]\n                    alphas_t = alphas\n\n            if values is not None:\n                render_scalar_field_points(\n                    p3s_wt,\n                    values_t.float(),\n                    prog=scene.prog_scalar_field,\n                    ctx=scene.ctx,\n                    point_size=1.0,\n                    alphas=alphas_t,\n                    val_min=alpha_min,\n                )\n            else:\n                render_points(p3s_wt, (1.0, 0, 0, 1.0), scene.prog, scene.ctx, 1.0)\n\n            if more_p3s_w is not None:\n                render_points(\n                    more_p3s_w[t], (0.0, 1.0, 0, 0.5), scene.prog, scene.ctx, 1.0\n                )\n            if more2_p3s_w is not None:\n                render_points(\n                    more2_p3s_w[t], (0.0, 0.0, 1.0, 1.0), scene.prog, scene.ctx, 1.0\n                )\n            # draw voxel bounding volume\n            render_obb_line(\n                voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True\n            )\n            # draw trajectory.\n            render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            # Render snippet origin.\n            render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n            # Render world origin.\n            render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n            img = scene.finish()\n            lifter_imgs.append(np.array(img))\n        lifter_imgs = np.array(lifter_imgs)\n        return lifter_imgs\n\n    @torch.no_grad()\n    def render3d_obb(\n        self,\n        voxel_vals,\n        obb,\n        Ts_wr,\n        T_ws,\n        cams,\n        voxel_w,\n        voxel_obb,\n        view=\"follow\",\n        alpha_min=None,\n    ):\n        \"\"\"Render a 3D visualization of the given voxel values and obbs.\"\"\"\n        if self.scene is None:\n            self.scene = SceneView(width=320, height=320)\n        scene = self.scene\n        lifter_imgs = []\n        alphas = torch.ones_like(voxel_vals)\n        if alpha_min is not None:\n            alphas[voxel_vals < alpha_min] = 0\n        blue = (0.1, 0.1, 1.0, 1.0)\n        black = (0.0, 0.0, 0.0, 1.0)\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()\n        for T_wr, T_wc, cam in zip(Ts_wr, Ts_wc, cams):\n            scene.clear()\n            if view == \"follow\":\n                scene.set_follow_view(T_wc, zoom_factor=8)\n            elif view == \"bird\":\n                scene.set_birds_eye_view(T_wc, zoom_factor=8)\n            else:\n                raise ValueError(\"bad option for 3d view style\")\n            # draw obbs\n            if obb:\n                colors = get_colors_from_sem_map(self.sem2name, scale_to_255=False)\n                render_obbs_line(\n                    obb,\n                    scene.prog,\n                    scene.ctx,\n                    line_width=2,\n                    colors=colors,\n                    draw_cosy=True,\n                )\n            # draw voxel bounding volume\n            render_obb_line(\n                voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True\n            )\n            # draw trajectory.\n            render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            # Render snippet origin.\n            render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n            # Render world origin.\n            render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n            # \"scalar_field_points\" supports colored point cloud, will rescale based on min/max.\n            render_scalar_field_points(\n                voxel_w,\n                voxel_vals,\n                prog=scene.prog_scalar_field,\n                ctx=scene.ctx,\n                point_size=3,\n                alphas=alphas,\n            )\n            img = scene.finish()\n            lifter_imgs.append(np.array(img))\n        lifter_imgs = np.array(lifter_imgs)\n        return lifter_imgs\n\n    @torch.no_grad()\n    def render3d_occ(\n        self,\n        voxel_vals,\n        Ts_wr,\n        T_ws,\n        cams,\n        voxel_w,\n        voxel_obb,\n        view=\"follow\",\n        alpha_min=None,\n        voxel_mask=None,\n    ):\n        \"\"\"Render a 3D visualization of the given voxel values and obbs.\"\"\"\n        if self.scene is None:\n            self.scene = SceneView(width=320, height=320)\n        scene = self.scene\n        lifter_imgs = []\n        black = (0.0, 0.0, 0.0, 1.0)\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()\n        for t, (T_wr, T_wc, cam) in enumerate(zip(Ts_wr, Ts_wc, cams)):\n            if voxel_vals.ndim == 4:\n                v_vals = voxel_vals[t]\n            else:\n                v_vals = voxel_vals\n            alp = torch.ones_like(v_vals)\n            if alpha_min is not None:\n                if isinstance(alpha_min, torch.Tensor):\n                    alp_min = alpha_min[t]\n                else:\n                    alp_min = alpha_min\n                alp[v_vals < alp_min] = 0\n            else:\n                alp_min = 0.0\n\n            scene.clear()\n            if view == \"follow\":\n                scene.set_follow_view(T_wc, zoom_factor=4)\n            elif view == \"bird\":\n                scene.set_birds_eye_view(T_wc, zoom_factor=8)\n            else:\n                raise ValueError(\"bad option for 3d view style\")\n            # \"scalar_field_points\" supports colored point cloud, will rescale based on min/max.\n            if voxel_mask is not None:\n                render_scalar_field_points(\n                    voxel_w[voxel_mask.view(-1)],\n                    v_vals[voxel_mask].float(),\n                    prog=scene.prog_scalar_field,\n                    ctx=scene.ctx,\n                    point_size=3,\n                    alphas=alp[voxel_mask].float(),\n                    val_min=alp_min,\n                )\n            else:\n                render_scalar_field_points(\n                    voxel_w,\n                    v_vals.float(),\n                    prog=scene.prog_scalar_field,\n                    ctx=scene.ctx,\n                    point_size=3,\n                    alphas=alp.float(),\n                    val_min=alp_min,\n                )\n            # draw voxel bounding volume\n            render_obb_line(\n                voxel_obb, scene.prog, scene.ctx, rgba=black, draw_cosy=True\n            )\n            # draw trajectory.\n            render_linestrip(Ts_wr.t, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            render_frustum(T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=black)\n            # Render snippet origin.\n            render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n            # Render world origin.\n            render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n            img = scene.finish()\n            lifter_imgs.append(np.array(img))\n        lifter_imgs = np.array(lifter_imgs)\n        return lifter_imgs\n\n    def get_log_images(self, batch, outputs):\n        B = len(batch[\"rgb/img\"])\n        with torch.no_grad():\n            # Visualize one random element from the batch.\n            batch_idx = torch.randint(low=0, high=B, size=(1,)).item()\n            log_ims = self.log_single(batch, outputs, batch_idx=batch_idx)\n        return log_ims\n\n    def reset_metrics(self):\n        self.metrics = {}\n\n        # obb\n        self.metrics[f\"precision@{self.iou_thres}\"] = []\n        self.metrics[f\"recall@{self.iou_thres}\"] = []\n\n        # occ\n        self.metrics[\"mesh/acc\"] = []\n        self.metrics[\"mesh/comp\"] = []\n        self.metrics[\"mesh/prec\"] = []\n        self.metrics[\"mesh/recall\"] = []\n\n    def update_metrics(self, outputs, batch):\n        # don't compute metrics on training since it takes long to compute.\n        if self.training:\n            return\n\n        obbs_pred = outputs[ARIA_OBB_PRED]\n        T_wv = outputs[\"voxel/T_world_voxel\"]\n        obbs_gt = get_gt_obbs(batch, self.ve, T_wv)\n        precs, recs = [], []\n        for obbs_pred_s, obbs_gt_s in zip(obbs_pred, obbs_gt):\n            prec, rec, _ = prec_recall_bb3(\n                obbs_pred_s.remove_padding(),\n                obbs_gt_s.remove_padding(),\n                iou_thres=self.iou_thres,\n            )\n\n            if prec != -1.0 and rec != -1.0:\n                precs.append(prec)\n                recs.append(rec)\n        self.metrics[f\"precision@{self.iou_thres}\"].extend(precs)\n        self.metrics[f\"recall@{self.iou_thres}\"].extend(recs)\n\n    def compute_metrics(self):\n        metrics = {}\n        if self.training:\n            return metrics\n\n        metrics[\"rgb\"] = {}\n        metrics[\"rgb\"][\"metrics\"] = {}\n        for key in self.metrics:\n            val = torch.tensor(self.metrics[key]).mean()\n            metrics[\"rgb\"][\"metrics\"][key] = val\n        return metrics\n"
  },
  {
    "path": "efm3d/model/image_tokenizer.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport math\nfrom typing import List\n\nimport einops\nimport torch\nimport torch.nn.functional as F\nfrom efm3d.model.dinov2_utils import dino_name_mappings, DinoV2Wrapper\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass ImageToDinoV2Tokens(torch.nn.Module):\n    \"\"\"\n    Tokenize an image snippet using DinoV2.\n    \"\"\"\n\n    def __init__(\n        self,\n        dinov2_name: str = \"vit_small\",\n        freeze: bool = False,\n        handle_rotated_data: bool = True,\n        dim_out: int = 768,  # ignored if add_linear_layer = False\n        add_lin_layer: bool = False,  # add a linear layer to get to any output dim\n        out_patch_size: int = 14,  # 14 is default but can set to 16 to get resampled into a more compatible feature size\n        multilayer_output: bool = False,  # if True, return a list of features\n        ckpt_path: str = \"\",  # if not empty, load the pretrained weights from the given path\n    ):\n        super().__init__()\n        assert dinov2_name in dino_name_mappings.keys()\n        self.freeze = freeze\n        self.handle_rotated_data = handle_rotated_data\n\n        self.model = DinoV2Wrapper(\n            dinov2_name, multilayer_output=multilayer_output, ckpt_path=ckpt_path\n        )\n\n        if self.freeze:\n            for param in self.model.parameters():\n                param.requires_grad_(False)\n            self.model.eval()\n\n        self.lin = None\n        if not add_lin_layer:\n            assert dim_out == self.model.feat_dim, (\n                f\"dim_out must match DinoV2 feature dim if not adding linear layer, but get dim_out: {dim_out} and feat_dim: {self.model.feat_dim}.\"\n            )\n        else:\n            self.lin = torch.nn.Linear(self.model.feat_dim, dim_out)\n            print(\n                f\"Add linear layer to project features from {self.model.feat_dim} to {dim_out}\"\n            )\n        self.dim_out = dim_out\n\n        logger.info(\n            f\"DinoV2 InputTokenizer {dinov2_name}, is frozen {freeze}, dim_out of {self.dim_out}\"\n        )\n        self.out_patch_size = out_patch_size\n\n    def feat_dim(self):\n        return self.dim_out\n\n    def patch_size(self):\n        return self.out_patch_size\n\n    def post_process(self, feats, B, T, out_size=None):\n        \"\"\"\n        Post processing to convert Dino features, e.g. feature interpolation to the desired size,\n        handling Aria image rotation, the linear mapping to increase feature dimension.\n\n        Args:\n            feats: [B x T x C x H x W]\n            B: batch size\n            T: number of frames\n            out_size: (h, w) token feature map output size, if None, don't resize the feature map size.\n        \"\"\"\n        if out_size is not None:\n            # resize to desired size\n            feats = F.interpolate(feats, out_size, mode=\"bilinear\")\n        if self.handle_rotated_data:\n            feats = torch.rot90(feats, 1, [-2, -1])\n        # to token sequence BxNxC\n        feats = einops.rearrange(feats, \"(b t) c h w -> b t h w c\", b=B, t=T)\n        if self.lin is not None:\n            # increase feature dimension to desired output dimension\n            feats = self.lin(feats)\n        return feats\n\n    def forward_resize(self, img: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Return the round-up image size to match a multiple of patch size, which will be used as the input size\n        to the DinoV2 model.\n\n        Args:\n            img: [..., H, W] image tensor\n        \"\"\"\n        H_ori, W_ori = img.shape[-2:]\n        # Dino models have a fixed patch size of 14\n        H_new = math.ceil(H_ori / 14) * 14\n        W_new = math.ceil(W_ori / 14) * 14\n        if H_new != H_ori or W_new != W_ori:\n            img = F.interpolate(\n                img, size=(H_new, W_new), mode=\"bilinear\", align_corners=False\n            )\n        return img\n\n    def forward(self, img: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            img: [B x T x C x H x W] A sequence / snippet of Image Frames (typically used for Pose Regression)\n        \"\"\"\n        assert img.dim() == 5, f\"expecting BxTxCxHxW but got {img.shape}\"\n        B, T, C, H, W = img.shape\n        if self.handle_rotated_data:\n            # rotate image 90 degrees clockwise to give it expected upright\n            # orientation for pretrained resnet\n            img = torch.rot90(img, 1, [-1, -2])\n        # get batch image for resnet\n        img = einops.rearrange(img, \"b t c h w -> (b t) c h w\")\n\n        H_ori, W_ori = img.shape[-2:]\n        img = self.forward_resize(img)\n        feats = self.model.forward(img)\n\n        out_size = None\n        # if output_patch_size is not 14, then we need to resize the feature map to the desired size\n        if self.patch_size() != 14:\n            out_size = H_ori // self.patch_size(), W_ori // self.patch_size()\n            if (\n                out_size[0] * self.patch_size() != H_ori\n                or out_size[1] * self.patch_size() != W_ori\n            ):\n                logger.warning(\n                    f\"Image size {(H_ori, W_ori)} not divisible by output patch size {self.patch_size()}\"\n                )\n\n        if isinstance(feats, List):\n            feats = [self.post_process(f, B, T, out_size) for f in feats]\n        else:\n            feats = self.post_process(feats, B, T, out_size)\n        return feats\n"
  },
  {
    "path": "efm3d/model/lifter.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport math\nfrom abc import ABC\nfrom typing import List, Literal, Optional\n\nimport numpy as np\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_POINTS_WORLD,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.model.cnn import UpsampleCNN\nfrom efm3d.model.dpt import DPTOri\nfrom efm3d.utils.gravity import gravity_align_T_world_cam, GRAVITY_DIRECTION_VIO\nfrom efm3d.utils.image_sampling import sample_images\nfrom efm3d.utils.pointcloud import pointcloud_to_voxel_counts\nfrom efm3d.utils.ray import sample_depths_in_grid, transform_rays\nfrom efm3d.utils.voxel import create_voxel_grid\nfrom torch.nn import functional as F\n\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass VideoBackbone3d(torch.nn.Module, ABC):\n    \"\"\"\n    Abstract Video Backbone that creates an explicit 3D feature volume from a video stream.\n    \"\"\"\n\n    def __init__(\n        self,\n        feat_dim: int,\n    ):\n        \"\"\"\n        Args:\n            feat_dim: number of channels in voxel grid, the C in BxCxDxHxW\n        \"\"\"\n        super().__init__()\n        self._feat_dim = feat_dim\n\n    @property\n    def feat_dim(self):\n        return self._feat_dim\n\n    def forward_impl(self, batch):\n        pass\n\n    def forward(self, batch):\n        out = {}\n\n        assert \"rgb/feat\" in batch, \"must run 2d backbone to get rgb feature maps first\"\n\n        out.update(self.forward_impl(batch))\n\n        # Shaped B x C x D x H x W\n        assert \"voxel/feat\" in out, \"3d backbone must output voxel features\"\n\n        # Shaped B x N x W (where N=D*H*W)\n        assert \"voxel/pts_world\" in out, \"3d backbone must output voxel positions\"\n\n        # Shaped B x 12 (PoseTW object)\n        assert \"voxel/T_world_voxel\" in out, \"3d backbone must output voxel coord frame\"\n\n        return out\n\n\nclass Lifter(VideoBackbone3d):\n    \"\"\"\n    Abstract Video Backbone that creates an explicit 3D feature volume from a set of 2D features.\n    \"\"\"\n\n    def __init__(\n        self,\n        in_dim: int,\n        out_dim: int,\n        patch_size: int,\n        voxel_size: List[float],\n        voxel_extent: List[float],\n        head_type: Literal[\"none\", \"dpt_ori\", \"cnn\"] = \"cnn\",\n        streams: Optional[List[str]] = None,  # default is just rgb stream\n        joint_slam_streams: bool = False,\n        joint_streams: bool = False,  # joint all streams\n    ):\n        \"\"\"\n        Args:\n            in_dim: input feature dimension (the 2d image or feature image channel dim)\n            out_dim: output feature dimension (in 3d volume - FPN in 2D is used to get to that dim)\n            patch_size: size of the patch to use for upsampling\n            voxel_size: size of the voxel grid (D H W)\n            voxel_extent: extent of the voxel grid (x_min, x_max y_min, y_max, z_min, z_max)\n            streams: list of streams to use for the 2D features (\"rgb\", \"slaml\", \"slamr\"). Lifting gets run per stream (unless joint_slam_streams is True) and then concatenated in 3d.\n            joint_slam_streams: if True, use the slaml and slamr streams as a single stream (i.e. dont concatenate lifted volumes from the two slam streams - lift them as if they were one camera)\n        \"\"\"\n\n        super().__init__(in_dim)\n        self.streams = streams\n        if streams is None:\n            self.streams = [\"rgb\"]  # default is just rgb stream\n        self.stream2id = {\"rgb\": 0, \"slaml\": 1, \"slamr\": 2}\n        # feature map upsampling network\n        final_dim = out_dim\n\n        if head_type == \"none\":\n            self.head = None\n            self.out_dim = in_dim\n        elif head_type == \"cnn\":\n            assert patch_size > 0, f\"{patch_size} should be > 0 for UpsampleCNN\"\n            upsample_power = np.sqrt(patch_size)\n            logger.info(\"True upsample_power: %f\" % upsample_power)\n            upsample_power = int(round(upsample_power))\n            logger.info(\"Rounded upsample_power: %d\" % upsample_power)\n            self.head = UpsampleCNN(\n                input_dim=in_dim,\n                first_hidden_dim=-1,\n                final_dim=final_dim,\n                upsample_power=upsample_power,\n                fix_hidden_dim=False,\n            )\n            self.out_dim = out_dim\n        elif head_type == \"dpt_ori\":\n            self.head = DPTOri(\n                input_dim=in_dim,\n                output_dim=final_dim,\n                depth=False,\n            )\n            self.out_dim = out_dim\n        else:\n            raise ValueError(f\"{head_type} is not supported\")\n\n        self.voxel_size = voxel_size  # D x H x W\n        self.voxel_extent = list(voxel_extent)  # W x H x D\n        self.joint_streams = joint_streams\n        self.joint_slam_streams = (\n            joint_slam_streams and \"slaml\" in self.streams and \"slamr\" in self.streams\n        )\n\n        x_meters = (voxel_extent[1] - voxel_extent[0]) / self.voxel_size[2]\n        y_meters = (voxel_extent[3] - voxel_extent[2]) / self.voxel_size[1]\n        z_meters = (voxel_extent[5] - voxel_extent[4]) / self.voxel_size[0]\n        assert abs(x_meters - y_meters) < 1e-5 and abs(x_meters - z_meters) < 1e-5, (\n            f\"Voxels should be cubes {x_meters}x{y_meters}x{z_meters}\"\n        )\n        self.voxel_meters = x_meters\n        self.num_free_samples = 16\n\n    def output_dim(self):\n        num_streams = len(self.streams)\n        if self.joint_slam_streams:\n            num_streams -= 1\n        if self.joint_streams:\n            num_streams = 1\n        out_dim = 0\n        out_dim = self.out_dim * num_streams\n\n        out_dim += 1  # point mask\n        out_dim += 1  # freespace token\n        return out_dim\n\n    def get_freespace_world(self, batch, batch_idx, T_wv, vW, vH, vD, S=1):\n        \"\"\"\n        Get points (semi-dense or GT points) of a snippet in the batch.\n        \"\"\"\n        cams = batch[ARIA_CALIB[0]][batch_idx]\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][\n            batch_idx\n        ]  # T_world_rig (one per snippet)\n        Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][\n            batch_idx\n        ]  # Ts_snippet_rig (T per snippet)\n        Ts_wr = T_ws @ Ts_sr\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()  # Ts_world_cam\n\n        # compute rays and max depths\n        p_w = batch[ARIA_POINTS_WORLD][batch_idx]  # TxNx3\n        T, N = p_w.shape[:2]\n        p0_w = Ts_wc.t.unsqueeze(1)  # Tx1x3\n        diff_w = p_w - p0_w\n        ds = torch.norm(diff_w, 2.0, dim=-1)\n        dir_w = F.normalize(diff_w, 2.0, dim=-1)\n        # filter out nans\n        good = ~p_w.isnan().any(dim=-1)\n        p0_w = p0_w.repeat(1, N, 1)[good]\n        ds = ds[good]\n        dir_w = dir_w[good]\n        rays_w = torch.cat([p0_w, dir_w], dim=-1)\n        rays_v = transform_rays(rays_w, T_wv.inverse())\n\n        x_min, x_max, y_min, y_max, z_min, z_max = self.voxel_extent\n        dW = (x_max - x_min) / vW\n        dH = (y_max - y_min) / vH\n        dD = (z_max - z_min) / vD\n        diag = math.sqrt(dW**2 + dH**2 + dD**2)\n        # subtract diagonal of voxel size to not label the occupied voxel as free\n        ds = ds - diag\n        # sample depths that lie within the feature volume grid (same function as used for nerf3d!)\n        depths, _, _ = sample_depths_in_grid(\n            rays_v.view(1, 1, -1, 6),\n            ds.view(1, 1, -1),\n            self.voxel_extent,\n            vW,\n            vH,\n            vD,\n            S,\n        )\n        depths = depths.view(-1, S)\n        rays_v = rays_v.view(-1, 1, 6)\n        pts_v = rays_v[..., :3] + depths.unsqueeze(-1) * rays_v[..., 3:]\n        pts_v = pts_v.view(-1, 3)\n        return T_wv * pts_v\n\n    def get_points_world(self, batch, batch_idx, keep_T=False):\n        \"\"\"\n        Get points (semi-dense or GT points) of a snippet in the batch.\n        \"\"\"\n\n        def filter_points(p_w):\n            p_w = p_w.reshape(-1, 3)\n            # filter out nans\n            bad = p_w.isnan().any(dim=-1)\n            p_w = p_w[~bad]\n            # filter out duplicates from the collapsing of the time dimension\n            p_w = torch.unique(p_w, dim=0)\n            return p_w\n\n        p_w_Ts = []\n        p_w = batch[ARIA_POINTS_WORLD][batch_idx]\n        if not keep_T:\n            p_w = filter_points(p_w)\n        else:\n            T = p_w.shape[0]\n            for t in range(T):\n                p_w_t = p_w[t, ...]\n                p_w_t = filter_points(p_w_t)\n                p_w_Ts.append(p_w_t)\n\n        if keep_T:\n            return p_w_Ts\n        else:\n            return p_w\n\n    def get_freespace_counts(\n        self,\n        batch,\n        T_wv,\n        vW,\n        vH,\n        vD,\n        MAX_NUM_POINTS_VOXEL=50,\n        return_mask=False,\n    ):\n        \"\"\"\n        Get points as voxel grid where each voxel is assigned a count of how many points are inside it.\n        If return_mask is trued the function returns the binary occupancy instead of point counts.\n        \"\"\"\n        B, T, _, H, W = batch[ARIA_IMG[0]].shape\n        point_counts = []\n        for b in range(B):\n            p_w = self.get_freespace_world(\n                batch, b, T_wv[b], vW, vH, vD, self.num_free_samples\n            )\n            # transform points into voxel coordinate.\n            p_v = T_wv[b].inverse() * p_w\n            point_count = pointcloud_to_voxel_counts(p_v, self.voxel_extent, vW, vH, vD)\n            point_counts.append(point_count)\n        point_counts = torch.stack(point_counts, dim=0)  # B x 1 x vD, vH, vW\n        # Normalize\n        point_counts = (\n            point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL\n        )\n        if return_mask:\n            # Only use as a mask. Comment out if want to use real point counts.\n            point_counts[point_counts > 1e-4] = 1.0\n\n        return point_counts\n\n    def get_points_counts(\n        self,\n        batch,\n        T_wv,\n        vW,\n        vH,\n        vD,\n        MAX_NUM_POINTS_VOXEL=50,\n        return_mask=False,\n        keep_T=False,\n    ):\n        \"\"\"\n        Get points as voxel grid where each voxel is assigned a count of how many points are inside it.\n        If return_mask is trued the function returns the binary occupancy instead of point counts.\n        \"\"\"\n        B, T, _, H, W = batch[ARIA_IMG[0]].shape\n        point_counts = []\n        for b in range(B):\n            p_w = self.get_points_world(batch, b, keep_T)\n            if not keep_T:\n                assert isinstance(p_w, torch.Tensor)\n                # transform points into voxel coordinate.\n                p_v = T_wv[b].inverse() * p_w\n                point_count = pointcloud_to_voxel_counts(\n                    p_v, self.voxel_extent, vW, vH, vD\n                )\n            else:\n                assert isinstance(p_w, list)\n                point_count = []\n                for p_w_t in p_w:\n                    p_v_t = T_wv[b].inverse() * p_w_t\n                    point_count_t = pointcloud_to_voxel_counts(\n                        p_v_t, self.voxel_extent, vW, vH, vD\n                    )\n                    point_count.append(point_count_t)\n                point_count = torch.cat(point_count, dim=0)\n            point_counts.append(point_count)\n        point_counts = torch.stack(point_counts, dim=0)  # B x 1 x vD, vH, vW\n        # Normalize\n        point_counts = (\n            point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL\n        )\n        if return_mask:\n            # Only use as a mask. Comment out if want to use real point counts.\n            point_counts[point_counts > 1e-4] = 1.0\n\n        return point_counts\n\n    def get_voxelgrid_pose(self, cams, T_ws, Ts_sr):\n        B, T = cams.shape[:2]\n        Ts_wr = T_ws @ Ts_sr\n        Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()  # Ts_world_cam\n        # Select last frame in snippet.\n        selectT = torch.tensor(T - 1).repeat(B).long()\n\n        # Create the voxel grid by aligning selected frame with gravity.\n        T_wc_select = Ts_wc[torch.arange(B), selectT, :]\n        T_wv = gravity_align_T_world_cam(\n            T_wc_select, gravity_w=GRAVITY_DIRECTION_VIO, z_grav=True\n        )\n        # T_wv should only have yaw value\n        rpy = T_wv.to_euler()\n        assert torch.allclose(torch.tensor(0.0), rpy[:, :2], atol=1e-4)\n        return T_wv, selectT\n\n    def lift(self, feats2d, vox_w, cam, Ts_wr, vD, vH, vW):\n        B, T = cam.shape[:2]\n        F = feats2d.shape[2]\n        Ts_wc = Ts_wr @ cam.T_camera_rig.inverse()  # Ts_world_cam\n        vox_w = torch.flatten(vox_w, 0, 1)\n        cam = torch.flatten(cam, 0, 1)\n        Ts_wc = torch.flatten(Ts_wc, 0, 1)\n        feats2d = torch.flatten(feats2d, 0, 1)\n\n        vox_cam = Ts_wc.inverse() * vox_w\n        vox_feats, vox_valid = sample_images(\n            feats2d, vox_cam, cam, n_by_c=False, warn=False, single_channel_mask=True\n        )\n        vox_feats = vox_feats.reshape(B, T, F, vD, vH, vW)\n        vox_valid = vox_valid.reshape(B, T, 1, vD, vH, vW)\n        return vox_feats, vox_valid\n\n    def aggregate(self, vox_feats, vox_valid):\n        def basic_mean(x, dim, valid, keepdim=False):\n            count = torch.sum(valid, dim=dim, keepdim=True)  # B 1 C D H W\n            invalid = (~valid).expand_as(x)\n            x[invalid] = 0.0\n            x_sum = torch.sum(x, dim=dim, keepdim=True)\n            count[count == 0] = 1.0  # just so we dont divide by zero\n            mean = x_sum / count\n            del x_sum\n            mean[count.expand_as(mean) < 1] = 0.0\n            if not keepdim:\n                return mean.squeeze(dim), count.squeeze(dim)\n            return mean, count\n\n        vox_feats, count_feats_m = basic_mean(\n            vox_feats, 1, valid=vox_valid, keepdim=False\n        )\n        return vox_feats, count_feats_m[:, [0]]\n\n    def lift_aggregate_centers(self, batch, feats2d, vox_w, Ts_wr, T_wv=None):\n        vD, vH, vW = self.voxel_size\n        B, T = batch[ARIA_IMG[0]].shape[:2]\n        # Lift to 3D. Project 3D voxel centers into each image and sample.\n        vox_w = vox_w.reshape(B, 1, -1, 3).repeat(1, T, 1, 1)\n        vox_feats, vox_valid, stream2pos = [], [], {}\n        for stream in self.streams:\n            stream_id = self.stream2id[stream]\n            cam = batch[ARIA_CALIB[stream_id]]\n            _vox_feats, _vox_valid = self.lift(\n                feats2d[stream], vox_w, cam, Ts_wr, vD, vH, vW\n            )\n            stream2pos[stream] = len(vox_feats)\n            vox_feats.append(_vox_feats)\n            vox_valid.append(_vox_valid)\n        if self.joint_slam_streams:\n            vox_feats_rgb, vox_valid_rgb = None, None\n            if \"rgb\" in stream2pos:\n                i = stream2pos[\"rgb\"]\n                vox_feats_rgb, vox_valid_rgb = vox_feats[i], vox_valid[i]\n            vox_feats_slam = [\n                vox_feats[stream2pos[stream]] for stream in [\"slaml\", \"slamr\"]\n            ]\n            vox_valid_slam = [\n                vox_valid[stream2pos[stream]] for stream in [\"slaml\", \"slamr\"]\n            ]\n            vox_feats_slam = torch.cat(vox_feats_slam, 1)\n            vox_valid_slam = torch.cat(vox_valid_slam, 1)\n            count_feats = torch.sum(vox_valid_slam, dim=1, keepdim=True)  # B 1 C D H W\n            if vox_valid_rgb is not None:\n                count_feats = count_feats + torch.sum(\n                    vox_valid_slam, dim=1, keepdim=True\n                )\n            vox_feats_m, vox_valid_m = vox_feats_slam, vox_valid_slam\n            vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m)\n            if vox_valid_rgb is not None:\n                vox_feats_m, vox_valid_m = vox_feats_rgb, vox_valid_rgb\n                vox_feats_rgb, count_feats_rgb_m = self.aggregate(\n                    vox_feats_m, vox_valid_m\n                )\n                vox_feats = torch.cat([vox_feats, vox_feats_rgb], 1)\n                count_feats_m = count_feats_m + count_feats_rgb_m\n        elif self.joint_streams:\n            vox_feats = torch.cat(vox_feats, 1)\n            vox_valid = torch.cat(vox_valid, 1)\n            # Sum up number of valid projections into each camera for each voxel.\n            count_feats = torch.sum(vox_valid, dim=1, keepdim=True)  # B 1 C D H W\n            vox_feats_m, vox_valid_m = vox_feats, vox_valid\n            vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m)\n        else:\n            # concat lifted volumes for all selected video streams\n            vox_feats = torch.cat(vox_feats, 2)\n            vox_valid = torch.cat(vox_valid, 2)  # B T C D H W\n            # Sum up number of valid projections into each camera for each voxel.\n            count_feats = torch.sum(vox_valid, dim=1, keepdim=True)  # B 1 C D H W\n            vox_feats_m, vox_valid_m = vox_feats, vox_valid\n            vox_feats, count_feats_m = self.aggregate(vox_feats_m, vox_valid_m)\n        count_feats = count_feats[:, :, 0]\n        assert count_feats.shape == (B, 1, vD, vH, vW), f\"{count_feats.shape}\"\n        assert count_feats_m.shape == (B, 1, vD, vH, vW), f\"{count_feats_m.shape}\"\n        return vox_feats, count_feats, count_feats_m\n\n    def forward(self, batch):\n        B, T, _, H, W = batch[ARIA_IMG[0]].shape\n\n        # Run CNN on EFM features to features back up to full resolution.\n        feats2d = {}\n        tokens2d = {}\n        for stream in self.streams:\n            feats2d[stream] = batch[f\"{stream}/feat\"]\n            # for visualizations\n            if not isinstance(feats2d[stream], list):\n                tokens2d[stream] = feats2d[stream].detach().cpu()\n            else:\n                # multi-layer 2d features. Needed by DPT head in Lifter\n                tokens2d[stream] = [f.detach().cpu() for f in feats2d[stream]]\n            if self.head:\n                feats2d[stream] = self.head.forward(feats2d[stream])\n\n        # Compute voxel grid pose.\n        cams = batch[ARIA_CALIB[0]]\n        device = cams.device\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]  # T_world_rig (one per snippet)\n        Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]]  # Ts_snippet_rig (T per snippet)\n        Ts_wr = T_ws @ Ts_sr\n        T_wv, selectT = self.get_voxelgrid_pose(cams, T_ws, Ts_sr)\n\n        # Generate voxel grid.\n        vD, vH, vW = self.voxel_size\n        point_info = []\n        point_masks = self.get_points_counts(batch, T_wv, vW, vH, vD, return_mask=True)\n        point_info.append(point_masks)\n        free_masks = self.get_freespace_counts(\n            batch, T_wv, vW, vH, vD, return_mask=True\n        )\n        point_info.append(free_masks)\n        vox_v_orig = create_voxel_grid(vW, vH, vD, self.voxel_extent, device)\n        vox_v_orig = vox_v_orig.permute(2, 1, 0, 3)  # D H W 3\n        vox_v = vox_v_orig.reshape(-1, 3)\n        vox_v = vox_v.unsqueeze(0).repeat(B, 1, 1)\n        vox_w = T_wv * vox_v\n        vox_w = vox_w.reshape(B, vD, vH, vW, 3)\n        vox_w = vox_w.reshape(B, -1, 3)  # B DHW 3\n\n        if len(feats2d) > 0:\n            # Lift image features to 3D. Project 3D voxel centers into each\n            # image and sample.\n            vox_feats, count_feats, count_feats_m = self.lift_aggregate_centers(\n                batch,\n                feats2d,\n                vox_w,\n                Ts_wr,\n                T_wv,\n            )\n            vox_feats = torch.concatenate([vox_feats] + point_info, dim=1)\n        else:\n            vox_feats = torch.concatenate(point_info, dim=1)\n            count_feats = torch.ones(B, 1, vD, vH, vW, device=device)\n            count_feats_m = torch.ones(B, 1, vD, vH, vW, device=device)\n        out = {}\n\n        # Don't use the masked out versions (_m) because loss functions later on need these.\n        for stream, feat2d in feats2d.items():\n            out[f\"{stream}/feat2d_upsampled\"] = feat2d\n        for stream, token2d in tokens2d.items():\n            out[f\"{stream}/token2d\"] = token2d\n\n        out[\"voxel/feat\"] = vox_feats  # B x F x D x H x W\n        out[\"voxel/counts\"] = count_feats[:, 0]  # B x D x H x W\n        # Pass the masked version of counts for debugging.\n        out[\"voxel/counts_m\"] = count_feats_m[:, 0]  # B x D x H x W\n        # We don't need the repeat across time anymore.\n        vox_w = vox_w.reshape(B, vD * vH * vW, 3)\n        out[\"voxel/pts_world\"] = vox_w  # B x N x 3 (N=D*H*W)\n        out[\"voxel/T_world_voxel\"] = T_wv  # B x 12\n        out[\"voxel/selectT\"] = selectT  # B x 1 (frame that voxel grid is anchored to)\n        out[\"voxel/occ_input\"] = point_info[0]\n        return out\n"
  },
  {
    "path": "efm3d/model/video_backbone.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom abc import ABC, abstractproperty\nfrom typing import Dict, List, Optional\n\nimport einops\nimport torch\nimport torch.nn as nn\nfrom efm3d.aria.aria_constants import ARIA_IMG\nfrom hydra.utils import instantiate\nfrom omegaconf import DictConfig\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass VideoBackbone(torch.nn.Module, ABC):\n    \"\"\"\n    Snippet Feature Backbone runs image feature extractors for video snippets.\n    This lets us easily try out various different backbones.\n    \"\"\"\n\n    def __init__(\n        self,\n        video_streams: Optional[List[str]] = None,\n        pass_batch: bool = True,\n        feat_dim: Optional[int] = None,\n        correct_vignette: bool = False,\n        optimize_vignette: bool = False,\n        ensure_rgb: bool = False,\n    ):\n        \"\"\"\n        Args:\n            video_streams: a list of video streams to extract features for.\n                Supported is \"rgb\", \"slaml\", \"slamr\".\n            pass_batch: pass whole batch dict to the forward_impl if set to\n                true. Otherwise passing the image tensors associated with the stream,\n                instead of passing a dictionary of batch.\n            correct_vignette: correct vignette for the image streams.\n            optimize_vignette: optimize vignette correction for the image streams. This enables backpropagating into the vignettes.\n            ensure_rgb: if set to true, will ensure that the output streams are all 3 channels.\n        \"\"\"\n        super().__init__()\n        self.ensure_rgb = ensure_rgb\n        self._feat_dim = -1\n        if feat_dim is not None:\n            # Note that FPN will be constructed if feat_dim is passed in by construction (and fpn_levels > 0).\n            self.feat_dim = feat_dim\n        self.video_streams = video_streams\n        if self.video_streams is None:\n            self.video_streams = [\"rgb\"]\n        self.pass_batch = pass_batch\n        self.stream_to_id = {\"rgb\": 0, \"slaml\": 1, \"slamr\": 2}\n        assert set(self.video_streams).issubset(set(self.stream_to_id.keys())), (\n            f\"{self.video_streams} are not all valid (need to be a subset of {self.stream_to_id.keys()})\"\n        )\n\n        self.vignette_correction = {}\n        self.vignette_correction = nn.ModuleDict(self.vignette_correction)\n\n    @property\n    def feat_dim(self):\n        return self._feat_dim\n\n    @feat_dim.setter\n    def feat_dim(self, _feat_dim: int):\n        self._feat_dim = _feat_dim\n\n    @abstractproperty\n    def patch_size(self):\n        pass\n\n    def forward_impl(self, img, stream) -> Dict[str, torch.Tensor]:\n        \"\"\"\n        forward_impl should return a dict with keys of the desired streams mapping to the extracted feature images.\n        Other additional outputs can be added as well as needed. A suggested way\n        to return additional outputs is to nest their keys under the\n        corresponding streams such as: \"rgb/feature_scale2\" for additional\n        feature outputs for the rgb stream.\n        \"\"\"\n        pass\n\n    def forward(self, batch):\n        out = {}\n        if self.pass_batch:\n            out = self.forward_impl(batch, self.video_streams)\n        else:\n            for stream in self.video_streams:\n                # if we have a batch dictionary retrieve the corresponding video. If not assume that we are just\n                key = ARIA_IMG[self.stream_to_id[stream]]\n                if isinstance(batch, dict) and key in batch:\n                    im = batch[key]\n                elif isinstance(batch, torch.Tensor) and len(self.video_streams) == 1:\n                    im = batch\n                else:\n                    raise ValueError(\n                        f\"batch not passed correctly {type(batch)} for video streams {self.video_streams}, {key}\"\n                    )\n                if self.ensure_rgb and stream in [\"slaml\", \"slamr\"]:\n                    # greyscale -> rgb\n                    im = torch.cat([im, im, im], 2)\n                # correct vignette if desired\n                if stream in self.vignette_correction:\n                    im = self.vignette_correction[stream](im)\n                # accumulate updates into one flat dict\n                out.update(self.forward_impl(im, stream))\n\n        assert isinstance(out, dict), (\n            f\"Output of forward must be of type dict, got {type(out)}\"\n        )\n        assert set(self.video_streams).issubset(set(out.keys()))\n        return out\n\n\nclass VideoBackboneDinov2(VideoBackbone):\n    \"\"\"\n    Get a snippet feature extractor from Dino v2.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_tokenizer: DictConfig,\n        video_streams: Optional[List[str]] = None,\n        freeze_encoder: bool = False,\n        correct_vignette: bool = False,\n        optimize_vignette: bool = False,\n    ):\n        super().__init__(\n            video_streams=video_streams,\n            pass_batch=False,\n            correct_vignette=correct_vignette,\n            optimize_vignette=optimize_vignette,\n        )\n        self.image_tokenizer = image_tokenizer\n        if isinstance(image_tokenizer, DictConfig):\n            self.image_tokenizer = instantiate(self.image_tokenizer)\n\n        # assert freeze_encoder == self.image_tokenizer.freeze\n\n        # get feature dimension\n        self.feat_dim = self.image_tokenizer.feat_dim()\n        self._patch_size = self.image_tokenizer.patch_size()\n        logging.info(\"feature dim is %d\" % self.feat_dim)\n        logging.info(\"down_scale factor is %d\" % self.patch_size)\n\n    @property\n    def patch_size(self):\n        return self._patch_size\n\n    def forward_impl(self, img, stream):\n        # Run tokenizer. handles SLAM images internally.\n        img_tokens = self.image_tokenizer.forward(img)\n        # BxTxHxWxC -> B, T, C, H, W\n\n        if isinstance(img_tokens, List):\n            return {\n                stream: [\n                    einops.rearrange(t, \"b t h w c -> b t c h w\") for t in img_tokens\n                ]\n            }\n\n        return {stream: einops.rearrange(img_tokens, \"b t h w c -> b t c h w\")}\n"
  },
  {
    "path": "efm3d/thirdparty/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/LICENSE",
    "content": "Copyright 2018-2019 Open-MMLab. All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2018-2019 Open-MMLab.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/cuda_utils.h",
    "content": "// @lint-ignore-every LICENSELINT\n\n#ifndef _CUDA_UTILS_H\n#define _CUDA_UTILS_H\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cmath>\n#include <vector>\n\n#define TOTAL_THREADS 512\n\ninline int opt_n_thread(int work_size) {\n  const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);\n  return max(min(1 << pow_2, TOTAL_THREADS), 1);\n}\n\ninline dim3 opt_block_config(int x, int y) {\n  const int x_thread = opt_n_thread(x);\n  const int y_thread = max(min(opt_n_thread(y), TOTAL_THREADS / x_thread), 1);\n  dim3 block_config(x_thread, y_thread, 1);\n\n  return block_config;\n}\n\n#define CUDA_CHECK_ERRORS()                              \\\n  do {                                                   \\\n    cudaError_t err = cudaGetLastError();                \\\n    if (cudaSuccess != err) {                            \\\n      fprintf(                                           \\\n          stderr,                                        \\\n          \"CUDA kernel failed : %s\\n%s at L:%d in %s\\n\", \\\n          cudaGetErrorString(err),                       \\\n          __PRETTY_FUNCTION__,                           \\\n          __LINE__,                                      \\\n          __FILE__);                                     \\\n      exit(-1);                                          \\\n    }                                                    \\\n  } while (0)\n\n#endif\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d.cpp",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp\n// License under Apache 2.0\n// https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE\n\n#include <cuda.h>\n#include <cuda_runtime_api.h>\n#include <torch/extension.h>\n#include <torch/serialize/tensor.h>\n\n#include <vector>\n\n#define CHECK_CUDA(x) \\\n  TORCH_CHECK(x.device().is_cuda(), #x, \" must be a CUDAtensor \")\n#define CHECK_CONTIGUOUS(x) \\\n  TORCH_CHECK(x.is_contiguous(), #x, \" must be contiguous \")\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n\n#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))\n\n#define CHECK_ERROR(ans) \\\n  { gpuAssert((ans), __FILE__, __LINE__); }\ninline void\ngpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {\n  if (code != cudaSuccess) {\n    fprintf(\n        stderr, \"GPUassert: %s %s %d\\n\", cudaGetErrorString(code), file, line);\n    if (abort)\n      exit(code);\n  }\n}\n\nconst int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;\n\nvoid boxesoverlapLauncher(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_overlap);\nvoid boxesioubevLauncher(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_iou);\nvoid nmsLauncher(\n    const float* boxes,\n    unsigned long long* mask,\n    int boxes_num,\n    float nms_overlap_thresh);\nvoid nmsNormalLauncher(\n    const float* boxes,\n    unsigned long long* mask,\n    int boxes_num,\n    float nms_overlap_thresh);\n\nint boxes_overlap_bev_gpu(\n    at::Tensor boxes_a,\n    at::Tensor boxes_b,\n    at::Tensor ans_overlap) {\n  // params boxes_a: (N, 5) [x1, y1, x2, y2, ry]\n  // params boxes_b: (M, 5)\n  // params ans_overlap: (N, M)\n\n  CHECK_INPUT(boxes_a);\n  CHECK_INPUT(boxes_b);\n  CHECK_INPUT(ans_overlap);\n\n  int num_a = boxes_a.size(0);\n  int num_b = boxes_b.size(0);\n\n  const float* boxes_a_data = boxes_a.data_ptr<float>();\n  const float* boxes_b_data = boxes_b.data_ptr<float>();\n  float* ans_overlap_data = ans_overlap.data_ptr<float>();\n\n  boxesoverlapLauncher(\n      num_a, boxes_a_data, num_b, boxes_b_data, ans_overlap_data);\n\n  return 1;\n}\n\nint boxes_iou_bev_gpu(\n    at::Tensor boxes_a,\n    at::Tensor boxes_b,\n    at::Tensor ans_iou) {\n  // params boxes_a: (N, 5) [x1, y1, x2, y2, ry]\n  // params boxes_b: (M, 5)\n  // params ans_overlap: (N, M)\n\n  CHECK_INPUT(boxes_a);\n  CHECK_INPUT(boxes_b);\n  CHECK_INPUT(ans_iou);\n\n  int num_a = boxes_a.size(0);\n  int num_b = boxes_b.size(0);\n\n  const float* boxes_a_data = boxes_a.data_ptr<float>();\n  const float* boxes_b_data = boxes_b.data_ptr<float>();\n  float* ans_iou_data = ans_iou.data_ptr<float>();\n\n  boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data);\n\n  return 1;\n}\n\nint nms_gpu(\n    at::Tensor boxes,\n    at::Tensor keep,\n    float nms_overlap_thresh,\n    int device_id) {\n  // params boxes: (N, 5) [x1, y1, x2, y2, ry]\n  // params keep: (N)\n\n  CHECK_INPUT(boxes);\n  CHECK_CONTIGUOUS(keep);\n  cudaSetDevice(device_id);\n\n  int boxes_num = boxes.size(0);\n  const float* boxes_data = boxes.data_ptr<float>();\n  long* keep_data = keep.data_ptr<long>();\n\n  const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);\n\n  unsigned long long* mask_data = NULL;\n  CHECK_ERROR(cudaMalloc(\n      (void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long)));\n  nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);\n\n  // unsigned long long mask_cpu[boxes_num * col_blocks];\n  // unsigned long long *mask_cpu = new unsigned long long [boxes_num *\n  // col_blocks];\n  std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);\n\n  //    printf(\"boxes_num=%d, col_blocks=%d\\n\", boxes_num, col_blocks);\n  CHECK_ERROR(cudaMemcpy(\n      &mask_cpu[0],\n      mask_data,\n      boxes_num * col_blocks * sizeof(unsigned long long),\n      cudaMemcpyDeviceToHost));\n\n  cudaFree(mask_data);\n\n  unsigned long long remv_cpu[col_blocks];\n  memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));\n\n  int num_to_keep = 0;\n\n  for (int i = 0; i < boxes_num; i++) {\n    int nblock = i / THREADS_PER_BLOCK_NMS;\n    int inblock = i % THREADS_PER_BLOCK_NMS;\n\n    if (!(remv_cpu[nblock] & (1ULL << inblock))) {\n      keep_data[num_to_keep++] = i;\n      unsigned long long* p = &mask_cpu[0] + i * col_blocks;\n      for (int j = nblock; j < col_blocks; j++) {\n        remv_cpu[j] |= p[j];\n      }\n    }\n  }\n  if (cudaSuccess != cudaGetLastError())\n    printf(\"Error!\\n\");\n\n  return num_to_keep;\n}\n\nint nms_normal_gpu(\n    at::Tensor boxes,\n    at::Tensor keep,\n    float nms_overlap_thresh,\n    int device_id) {\n  // params boxes: (N, 5) [x1, y1, x2, y2, ry]\n  // params keep: (N)\n\n  CHECK_INPUT(boxes);\n  CHECK_CONTIGUOUS(keep);\n  cudaSetDevice(device_id);\n\n  int boxes_num = boxes.size(0);\n  const float* boxes_data = boxes.data_ptr<float>();\n  long* keep_data = keep.data_ptr<long>();\n\n  const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);\n\n  unsigned long long* mask_data = NULL;\n  CHECK_ERROR(cudaMalloc(\n      (void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long)));\n  nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);\n\n  // unsigned long long mask_cpu[boxes_num * col_blocks];\n  // unsigned long long *mask_cpu = new unsigned long long [boxes_num *\n  // col_blocks];\n  std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);\n\n  //    printf(\"boxes_num=%d, col_blocks=%d\\n\", boxes_num, col_blocks);\n  CHECK_ERROR(cudaMemcpy(\n      &mask_cpu[0],\n      mask_data,\n      boxes_num * col_blocks * sizeof(unsigned long long),\n      cudaMemcpyDeviceToHost));\n\n  cudaFree(mask_data);\n\n  unsigned long long remv_cpu[col_blocks];\n  memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));\n\n  int num_to_keep = 0;\n\n  for (int i = 0; i < boxes_num; i++) {\n    int nblock = i / THREADS_PER_BLOCK_NMS;\n    int inblock = i % THREADS_PER_BLOCK_NMS;\n\n    if (!(remv_cpu[nblock] & (1ULL << inblock))) {\n      keep_data[num_to_keep++] = i;\n      unsigned long long* p = &mask_cpu[0] + i * col_blocks;\n      for (int j = nblock; j < col_blocks; j++) {\n        remv_cpu[j] |= p[j];\n      }\n    }\n  }\n  if (cudaSuccess != cudaGetLastError())\n    printf(\"Error!\\n\");\n\n  return num_to_keep;\n}\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d.h",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.h\n// License under Apache 2.0\n// https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE\n\n#pragma once\n#include <torch/extension.h> // @manual=//caffe2:torch-cpp\n\nint boxes_overlap_bev_gpu(\n    at::Tensor boxes_a,\n    at::Tensor boxes_b,\n    at::Tensor ans_overlap);\n\nint boxes_iou_bev_gpu(\n    at::Tensor boxes_a,\n    at::Tensor boxes_b,\n    at::Tensor ans_iou);\n\nint nms_gpu(\n    at::Tensor boxes,\n    at::Tensor keep,\n    float nms_overlap_thresh,\n    int device_id);\n\nint nms_normal_gpu(\n    at::Tensor boxes,\n    at::Tensor keep,\n    float nms_overlap_thresh,\n    int device_id);\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d_kernel.cu",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu\n// License under Apache 2.0\n// https://github.com/open-mmlab/OpenPCDet/blob/master/LICENSE\n\n#include <stdio.h>\n#define THREADS_PER_BLOCK 16\n#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))\n\n// #define DEBUG\nconst int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;\nconst float EPS = 1e-8;\nstruct Point {\n  float x, y;\n  __device__ Point() {}\n  __device__ Point(double _x, double _y) {\n    x = _x, y = _y;\n  }\n\n  __device__ void set(float _x, float _y) {\n    x = _x;\n    y = _y;\n  }\n\n  __device__ Point operator+(const Point& b) const {\n    return Point(x + b.x, y + b.y);\n  }\n\n  __device__ Point operator-(const Point& b) const {\n    return Point(x - b.x, y - b.y);\n  }\n};\n\n__device__ inline float cross(const Point& a, const Point& b) {\n  return a.x * b.y - a.y * b.x;\n}\n\n__device__ inline float\ncross(const Point& p1, const Point& p2, const Point& p0) {\n  return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);\n}\n\n__device__ int check_rect_cross(\n    const Point& p1,\n    const Point& p2,\n    const Point& q1,\n    const Point& q2) {\n  int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&\n      min(q1.x, q2.x) <= max(p1.x, p2.x) &&\n      min(p1.y, p2.y) <= max(q1.y, q2.y) && min(q1.y, q2.y) <= max(p1.y, p2.y);\n  return ret;\n}\n\n__device__ inline int check_in_box2d(const float* box, const Point& p) {\n  // params: box (5) [x1, y1, x2, y2, angle]\n  const float MARGIN = 1e-5;\n\n  float center_x = (box[0] + box[2]) / 2;\n  float center_y = (box[1] + box[3]) / 2;\n  float angle_cos = cos(-box[4]),\n        angle_sin =\n            sin(-box[4]); // rotate the point in the opposite direction of box\n  float rot_x =\n      (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x;\n  float rot_y =\n      -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;\n#ifdef DEBUG\n  printf(\n      \"box: (%.3f, %.3f, %.3f, %.3f, %.3f)\\n\",\n      box[0],\n      box[1],\n      box[2],\n      box[3],\n      box[4]);\n  printf(\n      \"center: (%.3f, %.3f), cossin(%.3f, %.3f), src(%.3f, %.3f), rot(%.3f, \"\n      \"%.3f)\\n\",\n      center_x,\n      center_y,\n      angle_cos,\n      angle_sin,\n      p.x,\n      p.y,\n      rot_x,\n      rot_y);\n#endif\n  return (\n      rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&\n      rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);\n}\n\n__device__ inline int intersection(\n    const Point& p1,\n    const Point& p0,\n    const Point& q1,\n    const Point& q0,\n    Point& ans) {\n  // fast exclusion\n  if (check_rect_cross(p0, p1, q0, q1) == 0)\n    return 0;\n\n  // check cross standing\n  float s1 = cross(q0, p1, p0);\n  float s2 = cross(p1, q1, p0);\n  float s3 = cross(p0, q1, q0);\n  float s4 = cross(q1, p1, q0);\n\n  if (!(s1 * s2 > 0 && s3 * s4 > 0))\n    return 0;\n\n  // calculate intersection of two lines\n  float s5 = cross(q1, p1, p0);\n  if (fabs(s5 - s1) > EPS) {\n    ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);\n    ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);\n\n  } else {\n    float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;\n    float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;\n    float D = a0 * b1 - a1 * b0;\n\n    ans.x = (b0 * c1 - b1 * c0) / D;\n    ans.y = (a1 * c0 - a0 * c1) / D;\n  }\n\n  return 1;\n}\n\n__device__ inline void rotate_around_center(\n    const Point& center,\n    const float angle_cos,\n    const float angle_sin,\n    Point& p) {\n  float new_x =\n      (p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x;\n  float new_y =\n      -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;\n  p.set(new_x, new_y);\n}\n\n__device__ inline int\npoint_cmp(const Point& a, const Point& b, const Point& center) {\n  return atan2(a.y - center.y, a.x - center.x) >\n      atan2(b.y - center.y, b.x - center.x);\n}\n\n__device__ inline float box_overlap(const float* box_a, const float* box_b) {\n  // params: box_a (5) [x1, y1, x2, y2, angle]\n  // params: box_b (5) [x1, y1, x2, y2, angle]\n\n  float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],\n        a_angle = box_a[4];\n  float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],\n        b_angle = box_b[4];\n\n  Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);\n  Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);\n#ifdef DEBUG\n  printf(\n      \"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\\n\",\n      a_x1,\n      a_y1,\n      a_x2,\n      a_y2,\n      a_angle,\n      b_x1,\n      b_y1,\n      b_x2,\n      b_y2,\n      b_angle);\n  printf(\n      \"center a: (%.3f, %.3f), b: (%.3f, %.3f)\\n\",\n      center_a.x,\n      center_a.y,\n      center_b.x,\n      center_b.y);\n#endif\n\n  Point box_a_corners[5];\n  box_a_corners[0].set(a_x1, a_y1);\n  box_a_corners[1].set(a_x2, a_y1);\n  box_a_corners[2].set(a_x2, a_y2);\n  box_a_corners[3].set(a_x1, a_y2);\n\n  Point box_b_corners[5];\n  box_b_corners[0].set(b_x1, b_y1);\n  box_b_corners[1].set(b_x2, b_y1);\n  box_b_corners[2].set(b_x2, b_y2);\n  box_b_corners[3].set(b_x1, b_y2);\n\n  // get oriented corners\n  float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);\n  float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);\n\n  for (int k = 0; k < 4; k++) {\n#ifdef DEBUG\n    printf(\n        \"before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \\n\",\n        k,\n        box_a_corners[k].x,\n        box_a_corners[k].y,\n        box_b_corners[k].x,\n        box_b_corners[k].y);\n#endif\n    rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);\n    rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);\n#ifdef DEBUG\n    printf(\n        \"corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \\n\",\n        k,\n        box_a_corners[k].x,\n        box_a_corners[k].y,\n        box_b_corners[k].x,\n        box_b_corners[k].y);\n#endif\n  }\n\n  box_a_corners[4] = box_a_corners[0];\n  box_b_corners[4] = box_b_corners[0];\n\n  // get intersection of lines\n  Point cross_points[16];\n  Point poly_center;\n  int cnt = 0, flag = 0;\n\n  poly_center.set(0, 0);\n  for (int i = 0; i < 4; i++) {\n    for (int j = 0; j < 4; j++) {\n      flag = intersection(\n          box_a_corners[i + 1],\n          box_a_corners[i],\n          box_b_corners[j + 1],\n          box_b_corners[j],\n          cross_points[cnt]);\n      if (flag) {\n        poly_center = poly_center + cross_points[cnt];\n        cnt++;\n      }\n    }\n  }\n\n  // check corners\n  for (int k = 0; k < 4; k++) {\n    if (check_in_box2d(box_a, box_b_corners[k])) {\n      poly_center = poly_center + box_b_corners[k];\n      cross_points[cnt] = box_b_corners[k];\n      cnt++;\n    }\n    if (check_in_box2d(box_b, box_a_corners[k])) {\n      poly_center = poly_center + box_a_corners[k];\n      cross_points[cnt] = box_a_corners[k];\n      cnt++;\n    }\n  }\n\n  poly_center.x /= cnt;\n  poly_center.y /= cnt;\n\n  // sort the points of polygon\n  Point temp;\n  for (int j = 0; j < cnt - 1; j++) {\n    for (int i = 0; i < cnt - j - 1; i++) {\n      if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {\n        temp = cross_points[i];\n        cross_points[i] = cross_points[i + 1];\n        cross_points[i + 1] = temp;\n      }\n    }\n  }\n\n#ifdef DEBUG\n  printf(\"cnt=%d\\n\", cnt);\n  for (int i = 0; i < cnt; i++) {\n    printf(\n        \"All cross point %d: (%.3f, %.3f)\\n\",\n        i,\n        cross_points[i].x,\n        cross_points[i].y);\n  }\n#endif\n\n  // get the overlap areas\n  float area = 0;\n  for (int k = 0; k < cnt - 1; k++) {\n    area += cross(\n        cross_points[k] - cross_points[0],\n        cross_points[k + 1] - cross_points[0]);\n  }\n\n  return fabs(area) / 2.0;\n}\n\n__device__ inline float iou_bev(const float* box_a, const float* box_b) {\n  // params: box_a (5) [x1, y1, x2, y2, angle]\n  // params: box_b (5) [x1, y1, x2, y2, angle]\n  float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);\n  float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);\n  float s_overlap = box_overlap(box_a, box_b);\n  return s_overlap / fmaxf(sa + sb - s_overlap, EPS);\n}\n\n__global__ void boxes_overlap_kernel(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_overlap) {\n  const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;\n  const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;\n\n  if (a_idx >= num_a || b_idx >= num_b) {\n    return;\n  }\n  const float* cur_box_a = boxes_a + a_idx * 5;\n  const float* cur_box_b = boxes_b + b_idx * 5;\n  float s_overlap = box_overlap(cur_box_a, cur_box_b);\n  ans_overlap[a_idx * num_b + b_idx] = s_overlap;\n}\n\n__global__ void boxes_iou_bev_kernel(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_iou) {\n  const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;\n  const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;\n\n  if (a_idx >= num_a || b_idx >= num_b) {\n    return;\n  }\n\n  const float* cur_box_a = boxes_a + a_idx * 5;\n  const float* cur_box_b = boxes_b + b_idx * 5;\n  float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);\n  ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;\n}\n\n__global__ void nms_kernel(\n    const int boxes_num,\n    const float nms_overlap_thresh,\n    const float* boxes,\n    unsigned long long* mask) {\n  // params: boxes (N, 5) [x1, y1, x2, y2, ry]\n  // params: mask (N, N/THREADS_PER_BLOCK_NMS)\n\n  const int row_start = blockIdx.y;\n  const int col_start = blockIdx.x;\n\n  // if (row_start > col_start) return;\n\n  const int row_size = fminf(\n      boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);\n  const int col_size = fminf(\n      boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);\n\n  __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];\n\n  if (threadIdx.x < col_size) {\n    block_boxes[threadIdx.x * 5 + 0] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];\n    block_boxes[threadIdx.x * 5 + 1] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];\n    block_boxes[threadIdx.x * 5 + 2] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];\n    block_boxes[threadIdx.x * 5 + 3] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];\n    block_boxes[threadIdx.x * 5 + 4] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];\n  }\n  __syncthreads();\n\n  if (threadIdx.x < row_size) {\n    const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;\n    const float* cur_box = boxes + cur_box_idx * 5;\n\n    int i = 0;\n    unsigned long long t = 0;\n    int start = 0;\n    if (row_start == col_start) {\n      start = threadIdx.x + 1;\n    }\n    for (i = start; i < col_size; i++) {\n      if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {\n        t |= 1ULL << i;\n      }\n    }\n    const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);\n    mask[cur_box_idx * col_blocks + col_start] = t;\n  }\n}\n\n__device__ inline float iou_normal(float const* const a, float const* const b) {\n  float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);\n  float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);\n  float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);\n  float interS = width * height;\n  float Sa = (a[2] - a[0]) * (a[3] - a[1]);\n  float Sb = (b[2] - b[0]) * (b[3] - b[1]);\n  return interS / fmaxf(Sa + Sb - interS, EPS);\n}\n\n__global__ void nms_normal_kernel(\n    const int boxes_num,\n    const float nms_overlap_thresh,\n    const float* boxes,\n    unsigned long long* mask) {\n  // params: boxes (N, 5) [x1, y1, x2, y2, ry]\n  // params: mask (N, N/THREADS_PER_BLOCK_NMS)\n\n  const int row_start = blockIdx.y;\n  const int col_start = blockIdx.x;\n\n  // if (row_start > col_start) return;\n\n  const int row_size = fminf(\n      boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);\n  const int col_size = fminf(\n      boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);\n\n  __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];\n\n  if (threadIdx.x < col_size) {\n    block_boxes[threadIdx.x * 5 + 0] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];\n    block_boxes[threadIdx.x * 5 + 1] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];\n    block_boxes[threadIdx.x * 5 + 2] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];\n    block_boxes[threadIdx.x * 5 + 3] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];\n    block_boxes[threadIdx.x * 5 + 4] =\n        boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];\n  }\n  __syncthreads();\n\n  if (threadIdx.x < row_size) {\n    const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;\n    const float* cur_box = boxes + cur_box_idx * 5;\n\n    int i = 0;\n    unsigned long long t = 0;\n    int start = 0;\n    if (row_start == col_start) {\n      start = threadIdx.x + 1;\n    }\n    for (i = start; i < col_size; i++) {\n      if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {\n        t |= 1ULL << i;\n      }\n    }\n    const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);\n    mask[cur_box_idx * col_blocks + col_start] = t;\n  }\n}\n\nvoid boxesoverlapLauncher(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_overlap) {\n  dim3 blocks(\n      DIVUP(num_b, THREADS_PER_BLOCK),\n      DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)\n  dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);\n\n  boxes_overlap_kernel<<<blocks, threads>>>(\n      num_a, boxes_a, num_b, boxes_b, ans_overlap);\n#ifdef DEBUG\n  cudaDeviceSynchronize(); // for using printf in kernel function\n#endif\n}\n\nvoid boxesioubevLauncher(\n    const int num_a,\n    const float* boxes_a,\n    const int num_b,\n    const float* boxes_b,\n    float* ans_iou) {\n  dim3 blocks(\n      DIVUP(num_b, THREADS_PER_BLOCK),\n      DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)\n  dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);\n\n  boxes_iou_bev_kernel<<<blocks, threads>>>(\n      num_a, boxes_a, num_b, boxes_b, ans_iou);\n}\n\nvoid nmsLauncher(\n    const float* boxes,\n    unsigned long long* mask,\n    int boxes_num,\n    float nms_overlap_thresh) {\n  dim3 blocks(\n      DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),\n      DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));\n  dim3 threads(THREADS_PER_BLOCK_NMS);\n  nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask);\n}\n\nvoid nmsNormalLauncher(\n    const float* boxes,\n    unsigned long long* mask,\n    int boxes_num,\n    float nms_overlap_thresh) {\n  dim3 blocks(\n      DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),\n      DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));\n  dim3 threads(THREADS_PER_BLOCK_NMS);\n  nms_normal_kernel<<<blocks, threads>>>(\n      boxes_num, nms_overlap_thresh, boxes, mask);\n}\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/setup.py",
    "content": "# @lint-ignore-every LICENSELINT\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name=\"mmdet_iou3d\",\n    ext_modules=[\n        CUDAExtension(\n            \"mmdet_iou3d\",\n            [\n                \"iou3d_kernel.cu\",\n                \"iou3d.cpp\",\n                \"sort_vert_kernel.cu\",\n                \"sort_vert.cpp\",\n            ],\n        )\n    ],\n    headers=[\n        \"iou3d.h\",\n        \"sort_vert.h\",\n        \"cuda_utils.h\",\n        \"utils.h\",\n    ],\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert.cpp",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert.cpp\n// License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE\n\n#include \"sort_vert.h\"\n#include \"iou3d.h\"\n#include \"utils.h\"\n\nvoid sort_vertices_wrapper(\n    int b,\n    int n,\n    int m,\n    const float* vertices,\n    const bool* mask,\n    const int* num_valid,\n    int* idx);\n\nat::Tensor\nsort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid) {\n  CHECK_CONTIGUOUS(vertices);\n  CHECK_CONTIGUOUS(mask);\n  CHECK_CONTIGUOUS(num_valid);\n  CHECK_CUDA(vertices);\n  CHECK_CUDA(mask);\n  CHECK_CUDA(num_valid);\n  CHECK_IS_FLOAT(vertices);\n  CHECK_IS_BOOL(mask);\n  CHECK_IS_INT(num_valid);\n\n  int b = vertices.size(0);\n  int n = vertices.size(1);\n  int m = vertices.size(2);\n  at::Tensor idx = torch::zeros(\n      {b, n, MAX_NUM_VERT_IDX},\n      at::device(vertices.device()).dtype(at::ScalarType::Int));\n\n  // fix issue with multi-gpu (kernel only works for cuda:0)\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(idx));\n\n  sort_vertices_wrapper(\n      b,\n      n,\n      m,\n      vertices.data_ptr<float>(),\n      mask.data_ptr<bool>(),\n      num_valid.data_ptr<int>(),\n      idx.data_ptr<int>());\n\n  return idx;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\n      \"sort_vertices_forward\",\n      &sort_vertices,\n      \"sort vertices of a convex polygon. forward only\");\n  m.def(\n      \"boxes_overlap_bev_gpu\",\n      &boxes_overlap_bev_gpu,\n      \"oriented boxes overlap\");\n  m.def(\"boxes_iou_bev_gpu\", &boxes_iou_bev_gpu, \"oriented boxes iou\");\n  m.def(\"nms_gpu\", &nms_gpu, \"oriented nms gpu\");\n  m.def(\"nms_normal_gpu\", &nms_normal_gpu, \"nms gpu\");\n}\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert.h",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert.h\n// License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE\n\n#pragma once\n#include <torch/extension.h> // @manual=//caffe2:torch-cpp\n\n#define MAX_NUM_VERT_IDX 9\n\nat::Tensor\nsort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid);\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert_kernel.cu",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort_vert_kernel.cu\n// License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE\n\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include \"cuda_utils.h\"\n\n#define MAX_NUM_VERT_IDX 9\n#define INTERSECTION_OFFSET 8\n#define EPSILON 1e-8\n\n/*\ncompare normalized vertices (vertices around (0,0))\nif vertex1 < vertex2 return ture.\norder: minimum at x-aixs, become larger in anti-clockwise direction\n*/\n__device__ bool compare_vertices(float x1, float y1, float x2, float y2) {\n  if (fabs(x1 - x2) < EPSILON && fabs(y2 - y1) < EPSILON)\n    return false; // if equal, return false\n\n  if (y1 > 0 && y2 < 0)\n    return true;\n  if (y1 < 0 && y2 > 0)\n    return false;\n\n  float n1 = x1 * x1 + y1 * y1 + EPSILON;\n  float n2 = x2 * x2 + y2 * y2 + EPSILON;\n\n  if (y1 > 0 && y2 > 0) {\n    if (fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2 > EPSILON)\n      return true;\n    else\n      return false;\n  }\n  if (y1 < 0 && y2 < 0) {\n    if (fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2 < EPSILON)\n      return true;\n    else\n      return false;\n  }\n}\n\n__global__ void sort_vertices_kernel(\n    int b,\n    int n,\n    int m,\n    const float* __restrict__ vertices,\n    const bool* __restrict__ mask,\n    const int* __restrict__ num_valid,\n    int* __restrict__ idx) {\n  int batch_idx = blockIdx.x;\n  vertices += batch_idx * n * m * 2;\n  mask += batch_idx * n * m;\n  num_valid += batch_idx * n;\n  idx += batch_idx * n * MAX_NUM_VERT_IDX;\n\n  int index = threadIdx.x; // index of polygon\n  int stride = blockDim.x;\n  for (int i = index; i < n; i += stride) {\n    int pad; // index of arbitrary invalid intersection point (not box corner!)\n    for (int j = INTERSECTION_OFFSET; j < m; ++j) {\n      if (!mask[i * m + j]) {\n        pad = j;\n        break;\n      }\n    }\n    if (num_valid[i] < 3) {\n      // not enough vertices, take an invalid intersection point\n      // (zero padding)\n      for (int j = 0; j < MAX_NUM_VERT_IDX; ++j) {\n        idx[i * MAX_NUM_VERT_IDX + j] = pad;\n      }\n    } else {\n      // sort the valid vertices\n      // note the number of valid vertices is known\n      for (int j = 0; j < num_valid[i]; ++j) {\n        // initilize with a \"big\" value\n        float x_min = 1;\n        float y_min = -EPSILON;\n        int i_take = 0;\n        for (int k = 0; k < m; ++k) {\n          float x = vertices[i * m * 2 + k * 2 + 0];\n          float y = vertices[i * m * 2 + k * 2 + 1];\n          if (j == 0) {\n            if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min)) {\n              x_min = x;\n              y_min = y;\n              i_take = k;\n            }\n          } else {\n            int i2 = idx[i * MAX_NUM_VERT_IDX + j - 1];\n            float x2 = vertices[i * m * 2 + i2 * 2 + 0];\n            float y2 = vertices[i * m * 2 + i2 * 2 + 1];\n            if (mask[i * m + k] && compare_vertices(x, y, x_min, y_min) &&\n                compare_vertices(x2, y2, x, y)) {\n              x_min = x;\n              y_min = y;\n              i_take = k;\n            }\n          }\n          idx[i * MAX_NUM_VERT_IDX + j] = i_take;\n        }\n      }\n      // duplicate the first idx\n      idx[i * MAX_NUM_VERT_IDX + num_valid[i]] = idx[i * MAX_NUM_VERT_IDX + 0];\n\n      // pad zeros\n      for (int j = num_valid[i] + 1; j < MAX_NUM_VERT_IDX; ++j) {\n        idx[i * MAX_NUM_VERT_IDX + j] = pad;\n      }\n\n      // for corner case: the two boxes are exactly the same.\n      // in this case, idx would have duplicate elements, which makes the\n      // shoelace formula broken because of the definition, the duplicate\n      // elements only appear in the first 8 positions (they are \"corners in\n      // box\", not \"intersection of edges\")\n      if (num_valid[i] == 8) {\n        int counter = 0;\n        for (int j = 0; j < 4; ++j) {\n          int check = idx[i * MAX_NUM_VERT_IDX + j];\n          for (int k = 4; k < INTERSECTION_OFFSET; ++k) {\n            if (idx[i * MAX_NUM_VERT_IDX + k] == check)\n              counter++;\n          }\n        }\n        if (counter == 4) {\n          idx[i * MAX_NUM_VERT_IDX + 4] = idx[i * MAX_NUM_VERT_IDX + 0];\n          for (int j = 5; j < MAX_NUM_VERT_IDX; ++j) {\n            idx[i * MAX_NUM_VERT_IDX + j] = pad;\n          }\n        }\n      }\n    }\n  }\n}\n\nvoid sort_vertices_wrapper(\n    int b,\n    int n,\n    int m,\n    const float* vertices,\n    const bool* mask,\n    const int* num_valid,\n    int* idx) {\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  sort_vertices_kernel<<<b, opt_n_thread(n)>>>(\n      b, n, m, vertices, mask, num_valid, idx);\n  CUDA_CHECK_ERRORS();\n}\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/utils.h",
    "content": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/utils.h\n// License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE\n\n#pragma once\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <torch/extension.h> // @manual=//caffe2:torch-cpp\n\n#define CHECK_CUDA(x)                                      \\\n  do {                                                     \\\n    TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\"); \\\n  } while (0)\n\n#define CHECK_CONTIGUOUS(x)                                            \\\n  do {                                                                 \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must ne a contiguous tensor\"); \\\n  } while (0)\n\n#define CHECK_IS_INT(x)                                                      \\\n  do {                                                                       \\\n    TORCH_CHECK(                                                             \\\n        x.scalar_type() == at::ScalarType::Int, #x \" must be a int tensor\"); \\\n  } while (0)\n\n#define CHECK_IS_FLOAT(x)                         \\\n  do {                                            \\\n    TORCH_CHECK(                                  \\\n        x.scalar_type() == at::ScalarType::Float, \\\n        #x \" must be a float tensor\");            \\\n  } while (0)\n\n#define CHECK_IS_BOOL(x)                                                       \\\n  do {                                                                         \\\n    TORCH_CHECK(                                                               \\\n        x.scalar_type() == at::ScalarType::Bool, #x \" must be a bool tensor\"); \\\n  } while (0)\n"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/iou3d.py",
    "content": "# Copyright (c) OpenMMLab. All rights reserved.\n# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py  # noqa\n# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py  # noqa\n# License https://github.com/lilanxiao/Rotated_IoU/blob/master/LICENSE\n\nfrom typing import Tuple\n\nimport mmdet_iou3d\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import Function\n\nEPSILON = 1e-8\n\n\nclass SortVertices(Function):\n    @staticmethod\n    def forward(ctx, vertices, mask, num_valid):\n        idx = mmdet_iou3d.sort_vertices_forward(vertices, mask, num_valid)\n        ctx.mark_non_differentiable(idx)\n        return idx\n\n    @staticmethod\n    def backward(ctx, gradout):\n        return ()\n\n\ndef box_intersection(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]:\n    \"\"\"Find intersection points of rectangles.\n    Convention: if two edges are collinear, there is no intersection point.\n\n    Args:\n        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.\n        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.\n\n    Returns:\n        Tuple:\n         - Tensor: (B, N, 4, 4, 2) Intersections.\n         - Tensor: (B, N, 4, 4) Valid intersections mask.\n    \"\"\"\n    # build edges from corners\n    # B, N, 4, 4: Batch, Box, edge, point\n    line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3)\n    line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3)\n    # duplicate data to pair each edges from the boxes\n    # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point\n    line1_ext = line1.unsqueeze(3)\n    line2_ext = line2.unsqueeze(2)\n    x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1)\n    x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1)\n    # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection\n    numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)\n    denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)\n    t = denumerator_t / numerator\n    t[numerator == 0.0] = -1.0\n    mask_t = (t > 0) & (t < 1)  # intersection on line segment 1\n    denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)\n    u = -denumerator_u / numerator\n    u[numerator == 0.0] = -1.0\n    mask_u = (u > 0) & (u < 1)  # intersection on line segment 2\n    mask = mask_t * mask_u\n    # overwrite with EPSILON. otherwise numerically unstable\n    t = denumerator_t / (numerator + EPSILON)\n    intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], dim=-1)\n    intersections = intersections * mask.float().unsqueeze(-1)\n    return intersections, mask\n\n\ndef box1_in_box2(corners1: Tensor, corners2: Tensor) -> Tensor:\n    \"\"\"Check if corners of box1 lie in box2.\n    Convention: if a corner is exactly on the edge of the other box,\n    it's also a valid point.\n\n    Args:\n        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.\n        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.\n\n    Returns:\n        Tensor: (B, N, 4) Intersection.\n    \"\"\"\n    # a, b, c, d - 4 vertices of box2\n    a = corners2[:, :, 0:1, :]  # (B, N, 1, 2)\n    b = corners2[:, :, 1:2, :]  # (B, N, 1, 2)\n    d = corners2[:, :, 3:4, :]  # (B, N, 1, 2)\n    # ab, am, ad - vectors between corresponding vertices\n    ab = b - a  # (B, N, 1, 2)\n    am = corners1 - a  # (B, N, 4, 2)\n    ad = d - a  # (B, N, 1, 2)\n    prod_ab = torch.sum(ab * am, dim=-1)  # (B, N, 4)\n    norm_ab = torch.sum(ab * ab, dim=-1)  # (B, N, 1)\n    prod_ad = torch.sum(ad * am, dim=-1)  # (B, N, 4)\n    norm_ad = torch.sum(ad * ad, dim=-1)  # (B, N, 1)\n    # NOTE: the expression looks ugly but is stable if the two boxes\n    # are exactly the same also stable with different scale of bboxes\n    cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6)  # (B, N, 4)\n    cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6)  # (B, N, 4)\n    return cond1 * cond2\n\n\ndef box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tensor]:\n    \"\"\"Check if corners of two boxes lie in each other.\n\n    Args:\n        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.\n        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.\n\n    Returns:\n        Tuple:\n         - Tensor: (B, N, 4) True if i-th corner of box1 is in box2.\n         - Tensor: (B, N, 4) True if i-th corner of box2 is in box1.\n    \"\"\"\n    c1_in_2 = box1_in_box2(corners1, corners2)\n    c2_in_1 = box1_in_box2(corners2, corners1)\n    return c1_in_2, c2_in_1\n\n\ndef build_vertices(\n    corners1: Tensor,\n    corners2: Tensor,\n    c1_in_2: Tensor,\n    c2_in_1: Tensor,\n    intersections: Tensor,\n    valid_mask: Tensor,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Find vertices of intersection area.\n\n    Args:\n        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.\n        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.\n        c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2.\n        c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1.\n        intersections (Tensor): (B, N, 4, 4, 2) Intersections.\n        valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask.\n\n    Returns:\n        Tuple:\n         - Tensor: (B, N, 24, 2) Vertices of intersection area;\n               only some elements are valid.\n         - Tensor: (B, N, 24) Mask of valid elements in vertices.\n    \"\"\"\n    # NOTE: inter has elements equals zero and has zeros gradient\n    # (masked by multiplying with 0); can be used as trick\n    B = corners1.size()[0]\n    N = corners1.size()[1]\n    # (B, N, 4 + 4 + 16, 2)\n    vertices = torch.cat([corners1, corners2, intersections.view([B, N, -1, 2])], dim=2)\n    # Bool (B, N, 4 + 4 + 16)\n    mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2)\n    return vertices, mask\n\n\ndef sort_indices(vertices: Tensor, mask: Tensor) -> Tensor:\n    \"\"\"Sort indices.\n    Note:\n        why 9? the polygon has maximal 8 vertices.\n        +1 to duplicate the first element.\n        the index should have following structure:\n            (A, B, C, ... , A, X, X, X)\n        and X indicates the index of arbitrary elements in the last\n        16 (intersections not corners) with value 0 and mask False.\n        (cause they have zero value and zero gradient)\n\n    Args:\n        vertices (Tensor): (B, N, 24, 2) Box vertices.\n        mask (Tensor): (B, N, 24) Mask.\n\n    Returns:\n        Tensor: (B, N, 9) Sorted indices.\n\n    \"\"\"\n    num_valid = torch.sum(mask.int(), dim=2).int()  # (B, N)\n    mean = torch.sum(\n        vertices * mask.float().unsqueeze(-1), dim=2, keepdim=True\n    ) / num_valid.unsqueeze(-1).unsqueeze(-1)\n    vertices_normalized = vertices - mean  # normalization makes sorting easier\n    return SortVertices.apply(vertices_normalized, mask, num_valid).long()\n\n\ndef calculate_area(idx_sorted: Tensor, vertices: Tensor) -> Tuple[Tensor, Tensor]:\n    \"\"\"Calculate area of intersection.\n\n    Args:\n        idx_sorted (Tensor): (B, N, 9) Sorted vertex ids.\n        vertices (Tensor): (B, N, 24, 2) Vertices.\n\n    Returns:\n        Tuple:\n         - Tensor (B, N): Area of intersection.\n         - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding.\n    \"\"\"\n    idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2])\n    selected = torch.gather(vertices, 2, idx_ext)\n    total = (\n        selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1]\n        - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0]\n    )\n    total = torch.sum(total, dim=2)\n    area = torch.abs(total) / 2\n    return area, selected\n\n\ndef oriented_box_intersection_2d(\n    corners1: Tensor, corners2: Tensor\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Calculate intersection area of 2d rotated boxes.\n\n    Args:\n        corners1 (Tensor): (B, N, 4, 2) First batch of boxes.\n        corners2 (Tensor): (B, N, 4, 2) Second batch of boxes.\n\n    Returns:\n        Tuple:\n         - Tensor (B, N): Area of intersection.\n         - Tensor (B, N, 9, 2): Vertices of polygon with zero padding.\n    \"\"\"\n    intersections, valid_mask = box_intersection(corners1, corners2)\n    c12, c21 = box_in_box(corners1, corners2)\n    vertices, mask = build_vertices(\n        corners1, corners2, c12, c21, intersections, valid_mask\n    )\n    sorted_indices = sort_indices(vertices, mask)\n    return calculate_area(sorted_indices, vertices)\n\n\ndef box2corners(box: Tensor) -> Tensor:\n    \"\"\"Convert rotated 2d box coordinate to corners.\n\n    Args:\n        box (Tensor): (B, N, 5) with x, y, w, h, alpha.\n\n    Returns:\n        Tensor: (B, N, 4, 2) Corners.\n    \"\"\"\n    B = box.size()[0]\n    x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1)\n    x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device)\n    x4 = x4 * w  # (B, N, 4)\n    y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device)\n    y4 = y4 * h  # (B, N, 4)\n    corners = torch.stack([x4, y4], dim=-1)  # (B, N, 4, 2)\n    sin = torch.sin(alpha)\n    cos = torch.cos(alpha)\n    row1 = torch.cat([cos, sin], dim=-1)\n    row2 = torch.cat([-sin, cos], dim=-1)  # (B, N, 2)\n    rot_T = torch.stack([row1, row2], dim=-2)  # (B, N, 2, 2)\n    rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2]))\n    rotated = rotated.view([B, -1, 4, 2])  # (B * N, 4, 2) -> (B, N, 4, 2)\n    rotated[..., 0] += x\n    rotated[..., 1] += y\n    return rotated\n\n\ndef diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor:\n    \"\"\"Calculate differentiable iou of rotated 2d boxes.\n\n    Args:\n        box1 (Tensor): (B, N, 5) First box.\n        box2 (Tensor): (B, N, 5) Second box.\n\n    Returns:\n        Tensor: (B, N) IoU.\n    \"\"\"\n    corners1 = box2corners(box1)\n    corners2 = box2corners(box2)\n    intersection, _ = oriented_box_intersection_2d(corners1, corners2)  # (B, N)\n    area1 = box1[:, :, 2] * box1[:, :, 3]\n    area2 = box2[:, :, 2] * box2[:, :, 3]\n    union = area1 + area2 - intersection\n    iou = intersection / union\n    return iou\n\n\ndef diff_iou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor:\n    \"\"\"Calculate differentiable iou of rotated 3d boxes.\n\n    Args:\n        box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha).\n        box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha).\n\n    Returns:\n        Tensor: (B, N) IoU.\n    \"\"\"\n    box1 = box3d1[..., [0, 1, 3, 4, 6]]  # 2d box\n    box2 = box3d2[..., [0, 1, 3, 4, 6]]\n    corners1 = box2corners(box1)\n    corners2 = box2corners(box2)\n    intersection, _ = oriented_box_intersection_2d(corners1, corners2)\n    zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5\n    zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5\n    zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5\n    zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5\n    z_overlap = (torch.min(zmax1, zmax2) - torch.max(zmin1, zmin2)).clamp_(min=0.0)\n    intersection_3d = intersection * z_overlap\n    volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5]\n    volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5]\n    union_3d = volume1 + volume2 - intersection_3d\n    return intersection_3d / union_3d\n\n\ndef rotated_iou_3d_loss(pred, target):\n    \"\"\"Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes.\n    Note that predictions and targets are one-to-one corresponded.\n\n    Args:\n        pred (torch.Tensor): Bbox predictions with shape [N, 7]\n            (x, y, z, w, l, h, alpha).\n        target (torch.Tensor): Bbox targets (gt) with shape [N, 7]\n            (x, y, z, w, l, h, alpha).\n\n    Returns:\n        torch.Tensor: IoU loss between predictions and targets.\n    \"\"\"\n    iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0), target.unsqueeze(0))[0]\n    return iou_loss\n\n\nclass RotatedIoU3DLoss(torch.nn.Module):\n    \"\"\"Calculate the IoU loss (1-IoU) of rotated bounding boxes.\n\n    Args:\n        loss_weight (float, optional): Weight of loss. Defaults to 1.0.\n    \"\"\"\n\n    def __init__(self, loss_weight=1.0):\n        super().__init__()\n        self.loss_weight = loss_weight\n\n    def forward(\n        self,\n        pred,\n        target,\n    ):\n        \"\"\"Forward function of loss calculation.\n\n        Args:\n            pred (torch.Tensor): Bbox predictions with shape [..., 7]\n                (x, y, z, w, l, h, alpha).\n            target (torch.Tensor): Bbox targets (gt) with shape [..., 7]\n                (x, y, z, w, l, h, alpha).\n\n        Returns:\n            torch.Tensor: IoU loss between predictions and targets.\n        \"\"\"\n        # print(pred.shape, target.shape)\n        if pred.shape[0] == 0 or target.shape[0] == 0:\n            return 0.0 * pred.sum()\n        loss = self.loss_weight * rotated_iou_3d_loss(pred, target)\n        return loss\n\n\ndef boxes_iou_bev(boxes_a, boxes_b):\n    \"\"\"Calculate boxes IoU in the bird view.\n\n    Args:\n        boxes_a (torch.Tensor): Input boxes a with shape (M, 5).\n        boxes_b (torch.Tensor): Input boxes b with shape (N, 5).\n\n    Returns:\n        ans_iou (torch.Tensor): IoU result with shape (M, N).\n    \"\"\"\n    ans_iou = boxes_a.new_zeros(torch.Size((boxes_a.shape[0], boxes_b.shape[0])))\n\n    mmdet_iou3d.boxes_iou_bev_gpu(boxes_a.contiguous(), boxes_b.contiguous(), ans_iou)\n\n    return ans_iou\n\n\ndef nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):\n    \"\"\"Nms function with gpu implementation.\n\n    Args:\n        boxes (torch.Tensor): Input boxes with the shape of [N, 5]\n            ([x1, y1, x2, y2, ry]).\n        scores (torch.Tensor): Scores of boxes with the shape of [N].\n        thresh (int): Threshold.\n        pre_maxsize (int): Max size of boxes before nms. Default: None.\n        post_maxsize (int): Max size of boxes after nms. Default: None.\n\n    Returns:\n        torch.Tensor: Indexes after nms.\n    \"\"\"\n    order = scores.sort(0, descending=True)[1]\n\n    if pre_maxsize is not None:\n        order = order[:pre_maxsize]\n    boxes = boxes[order].contiguous()\n\n    keep = torch.zeros(boxes.size(0), dtype=torch.long)\n    num_out = mmdet_iou3d.nms_gpu(boxes, keep, thresh, boxes.device.index)\n    keep = order[keep[:num_out].cuda(boxes.device)].contiguous()\n    if post_max_size is not None:\n        keep = keep[:post_max_size]\n    return keep\n\n\ndef nms_normal_gpu(boxes, scores, thresh):\n    \"\"\"Normal non maximum suppression on GPU.\n\n    Args:\n        boxes (torch.Tensor): Input boxes with shape (N, 5).\n        scores (torch.Tensor): Scores of predicted boxes with shape (N).\n        thresh (torch.Tensor): Threshold of non maximum suppression.\n\n    Returns:\n        torch.Tensor: Remaining indices with scores in descending order.\n    \"\"\"\n    order = scores.sort(0, descending=True)[1]\n\n    boxes = boxes[order].contiguous()\n\n    keep = torch.zeros(boxes.size(0), dtype=torch.long)\n    num_out = mmdet_iou3d.nms_normal_gpu(boxes, keep, thresh, boxes.device.index)\n    return order[keep[:num_out].cuda(boxes.device)].contiguous()\n"
  },
  {
    "path": "efm3d/utils/__init__.py",
    "content": ""
  },
  {
    "path": "efm3d/utils/common.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\n\n\ndef sample_nearest(value_a, value_b, array_b):\n    array_b_at_a = []\n    for v_a in value_a:\n        idx = find_nearest(value_b, v_a, return_index=True)\n        array_b_at_a.append(array_b[idx])\n    return torch.stack(array_b_at_a)\n\n\ndef find_nearest(array, value, return_index=False):\n    array = np.asarray(array)\n    idx = (np.abs(array - value)).argmin()\n    if return_index:\n        return idx\n    else:\n        return array[idx]\n"
  },
  {
    "path": "efm3d/utils/depth.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom efm3d.utils.ray import ray_grid\n\n\ndef dist_im_to_point_cloud_im(dist_m, cams):\n    B, T = None, None\n    if cams.ndim == 3:\n        B, T, _ = cams.shape\n        cams = cams.view(B * T, -1)\n        dist_m = dist_m.flatten(0, 1)\n    elif cams.ndim == 2:\n        B, _ = cams.shape\n    elif cams.ndim == 1:\n        cams = cams.view(1, -1)\n        H, W = dist_m.shape\n        dist_m = dist_m.view(1, H, W)\n    BT, H, W = dist_m.shape\n    rays_rig, valids = ray_grid(cams)\n    p3s_rig = rays_rig[..., :3] + rays_rig[..., 3:] * dist_m.unsqueeze(-1)\n    p3s_c = cams.T_camera_rig * p3s_rig.view(BT, -1, 3)\n    # distances > 0.0 are valid\n    valids = torch.logical_and(valids, dist_m > 0.0)\n\n    if T is not None:\n        p3s_c = p3s_c.view(B, T, H, W, 3)\n        valids = valids.view(B, T, H, W)\n    elif B is not None:\n        p3s_c = p3s_c.view(B, H, W, 3)\n        valids = valids.view(B, H, W)\n    else:\n        p3s_c = p3s_c.view(H, W, 3)\n        valids = valids.view(H, W)\n    return p3s_c, valids\n"
  },
  {
    "path": "efm3d/utils/detection_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom efm3d.aria.obb import ObbTW\nfrom efm3d.aria.pose import PAD_VAL, PoseTW, rotation_from_euler\n\n\ndef norm2ind(norm_xyz, vD, vH, vW):\n    \"\"\"Converts normalized xyz coords [-1,1] to DxHxW indices.\"\"\"\n    if isinstance(norm_xyz, np.ndarray):\n        inds_dhw = norm_xyz.copy()\n    else:\n        inds_dhw = norm_xyz.clone()\n    inds_dhw[..., 0] = torch.ceil((norm_xyz[..., 2] + 1.0) * vD / 2.0) - 1\n    inds_dhw[..., 1] = torch.ceil((norm_xyz[..., 1] + 1.0) * vH / 2.0) - 1\n    inds_dhw[..., 2] = torch.ceil((norm_xyz[..., 0] + 1.0) * vW / 2.0) - 1\n\n    inds_dhw = inds_dhw.round()\n    outside = (\n        (inds_dhw[..., 0] <= 0)\n        | (inds_dhw[..., 0] >= (vD - 1))\n        | (inds_dhw[..., 1] <= 0)\n        | (inds_dhw[..., 1] >= (vH - 1))\n        | (inds_dhw[..., 2] <= 0)\n        | (inds_dhw[..., 2] >= (vW - 1))\n    )\n    inside = ~outside\n    if isinstance(inds_dhw, np.ndarray):\n        inds_dhw = inds_dhw.astype(int)\n    else:\n        inds_dhw = inds_dhw.int()\n    return inds_dhw, inside\n\n\ndef ind2norm(inds_dhw, vD, vH, vW):\n    \"\"\"Converts DxHxW indices to normalized xyz coords [-1,1].\"\"\"\n    if isinstance(inds_dhw, np.ndarray):\n        norm_xyz = inds_dhw.copy().astype(float)\n    else:\n        norm_xyz = inds_dhw.clone().float()\n    norm_xyz[..., 0] = 2.0 * (inds_dhw[..., 2] + 0.5) / vW - 1.0\n    norm_xyz[..., 1] = 2.0 * (inds_dhw[..., 1] + 0.5) / vH - 1.0\n    norm_xyz[..., 2] = 2.0 * (inds_dhw[..., 0] + 0.5) / vD - 1.0\n\n    return norm_xyz\n\n\ndef normalize_coord3d(xyz, extent):\n    if isinstance(xyz, np.ndarray):\n        xyz_n = xyz.copy()\n    else:\n        xyz_n = xyz.clone()\n    x_min, x_max, y_min, y_max, z_min, z_max = extent\n    xyz_n[..., 0] = ((xyz[..., 0] - x_min) / ((x_max - x_min) / 2.0)) - 1.0\n    xyz_n[..., 1] = ((xyz[..., 1] - y_min) / ((y_max - y_min) / 2.0)) - 1.0\n    xyz_n[..., 2] = ((xyz[..., 2] - z_min) / ((z_max - z_min) / 2.0)) - 1.0\n    return xyz_n\n\n\ndef unnormalize_coord3d(xyz_n, extent):\n    if isinstance(xyz_n, np.ndarray):\n        xyz = xyz_n.copy()\n    else:\n        xyz = xyz_n.clone()\n    x_min, x_max, y_min, y_max, z_min, z_max = extent\n    xyz[..., 0] = ((xyz_n[..., 0] + 1.0) * ((x_max - x_min) / 2.0)) + x_min\n    xyz[..., 1] = ((xyz_n[..., 1] + 1.0) * ((y_max - y_min) / 2.0)) + y_min\n    xyz[..., 2] = ((xyz_n[..., 2] + 1.0) * ((z_max - z_min) / 2.0)) + z_min\n    return xyz\n\n\ndef create_heatmap_gt(mu_xy, H, W, valid=None):\n    \"\"\"\n    Inputs:\n        mu_xy : torch.Tensor : shaped BxNx2 of pixel locations in range [0,H-1] and [0,W-1]\n        H : image height\n        W : image width:\n        valid : torch.Tensor : optional boolean mask shaped BxNx2 or whether to use this point or not\n    returns:\n        heat_gt : torch.Tensor : Bx1xHxW tensor of splatted 2D points\n    \"\"\"\n\n    B = mu_xy.shape[0]\n    inside = (\n        (mu_xy[..., 0] >= 0)\n        & (mu_xy[..., 0] <= (H - 1))\n        & (mu_xy[..., 1] >= 0)\n        & (mu_xy[..., 1] <= (W - 1))\n    )\n    if valid is not None:  # if we have additional valid signal, use it\n        inside = inside & valid\n    inds_xy = mu_xy.round().long()\n    inds_xy = inds_xy.reshape(B, -1, 2)\n    inds = (\n        inds_xy[:, :, 1] * W + inds_xy[:, :, 0]\n    )  # flatten matrix index into vector index\n    inds = torch.clip(inds, min=0, max=(H - 1) * (W - 1))\n    inside = inside.reshape(B, -1).to(inds)\n    heat_gt = torch.zeros((B, H * W)).to(inds)\n    heat_gt.scatter_(1, inds, inside)\n    heat_gt = heat_gt.reshape(B, H, W).float()\n    blur = torchvision.transforms.functional.gaussian_blur\n    kernel = 25\n    heat_gt = blur(heat_gt, kernel)\n    if heat_gt.sum() > 0:\n        # Normalize such that peak is ~1.\n        heat_gt = heat_gt * 100\n        heat_gt = torch.clip(heat_gt, min=0, max=1)\n    return heat_gt\n\n\ndef simple_nms(scores, nms_radius: int):\n    \"\"\"Approximate + Fast Non-maximum suppression to remove nearby points,\n    works by running max pool twice on GPU.\"\"\"\n    assert nms_radius >= 0\n\n    def max_pool(x):\n        return torch.nn.functional.max_pool2d(\n            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius\n        )\n\n    zeros = torch.zeros_like(scores)\n    max_mask = scores == max_pool(scores)\n    for _ in range(2):\n        supp_mask = max_pool(max_mask.float()) > 0\n        supp_scores = torch.where(supp_mask, zeros, scores)\n        new_max_mask = supp_scores == max_pool(supp_scores)\n        max_mask = max_mask | (new_max_mask & (~supp_mask))\n    return torch.where(max_mask, scores, zeros)\n\n\ndef simple_nms3d(scores, nms_radius: int):\n    \"\"\"Approximate + Fast Non-maximum suppression on 3D heatmap to remove nearby points,\n    works by running max pool twice on GPU.\"\"\"\n    assert nms_radius >= 0\n\n    def max_pool(x):\n        return torch.nn.functional.max_pool3d(\n            x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius\n        )\n\n    zeros = torch.zeros_like(scores)\n    max_mask = scores == max_pool(scores)\n    for _ in range(2):\n        supp_mask = max_pool(max_mask.float()) > 0\n        supp_scores = torch.where(supp_mask, zeros, scores)\n        new_max_mask = supp_scores == max_pool(supp_scores)\n        max_mask = max_mask | (new_max_mask & (~supp_mask))\n    return torch.where(max_mask, scores, zeros)\n\n\ndef heatmap2obb(scores, threshold=0.3, size=20, max_elts=1000):\n    \"\"\"Runs argmax on a 2D heatmaps to return (x,y) positions\n    in the heatmap in the ObbTW class, above a threshold. Creates\n    fake 2D bounding boxes of size 20x20 by default.\"\"\"\n    # Extract keypoints\n    hsize = int(round(size / 2))\n    obbs = []\n    dev = scores.device\n    for score in scores:\n        keypoint = torch.nonzero(score > threshold)\n        ymin = keypoint[:, 0] - hsize\n        ymax = keypoint[:, 0] + hsize\n        xmin = keypoint[:, 1] - hsize\n        xmax = keypoint[:, 1] + hsize\n        bb2 = torch.stack([xmin, xmax, ymin, ymax], dim=1).float()\n        obb = ObbTW().repeat(bb2.shape[0], 1).clone().to(dev)\n        # Set bb2_rgb\n        obb.set_bb2(cam_id=0, bb2d=bb2, use_mask=False)  # Set to RGB.\n        # Set probability\n        probs = score[tuple(keypoint.t())]\n        obb.set_prob(probs)\n        obbs.append(obb.add_padding(max_elts=max_elts))\n    return torch.stack(obbs, dim=0)\n\n\n# Centerness loss, binary cross entropy (evaluated densely per voxel position).\ndef compute_focal_loss(pred, gt, focal_gamma=2, focal_alpha=0.25):\n    \"\"\"focal loss for imbalanced classification\n    https://pytorch.org/vision/stable/_modules/torchvision/ops/focal_loss.html\n    Args:\n        pred (torch.tensor): predicted probabilities\n        gt (torch.tensor): GT probabilities\n    Returns:\n        nll_loss: negative log-likelihood loss\n    \"\"\"\n    assert pred.shape == gt.shape\n    gt = gt.double()\n    pred = pred.double()\n    eps = 1e-9\n    # Simple negative log-likelihood (aka binary cross-entropy). Assume sigmoid already applied.\n    nll = -(torch.log(pred + eps) * gt + torch.log((1.0 - pred) + eps) * (1.0 - gt))\n\n    if focal_gamma > 0:\n        p_t = pred * gt + (1 - pred) * (1 - gt)\n        nll = nll * ((1 - p_t) ** focal_gamma)\n\n    # class-wise balancing\n    if focal_alpha >= 0:\n        alpha_t = focal_alpha * gt + (1 - focal_alpha) * (1.0 - gt)\n        nll = alpha_t * nll\n\n    return nll.float()\n\n\ndef compute_chamfer_loss(vals, target):\n    B = vals.shape[0]\n    xx = vals.view(B, 8, 1, 3)\n    yy = target.view(B, 1, 8, 3)\n    l1_dist = (xx - yy).abs().sum(-1)\n\n    gt_to_pred = l1_dist.min(1).values.mean(-1)\n    pred_to_gt = l1_dist.min(2).values.mean(-1)\n    l1 = 0.1 * pred_to_gt + gt_to_pred\n    return l1\n\n\ndef obb2voxel(obb_v, vD, vH, vW, voxel_extent, num_class, splat_sigma=2):\n    \"\"\"\n    Inputs:\n        obb_v : ObbTW : shaped BxNx34 of obbs in voxel coordinates.\n        vD : voxel depth\n        vH : voxel height\n        vW : voxel width:\n        voxel_extent: size of voxel grid in meters, with order W, H, D\n        num_class: number of classes to detect\n        splat_sigma: how big to splat the Obbs\n    returns:\n        cent_gt : torch.Tensor : Bx1xDxHxW tensor of splatted 2D points\n        bbox_gt : torch.Tensor : Bx7xDxHxW tensor of bounding box params\n        clas_gt : torch.Tensor : Bxnum_classxDxHxW one hot tensor of class\n        valid_gt : torch.Tensor : Bx1xDxHxW bool tensor of where splatting is valid\n    \"\"\"\n    B = obb_v.shape[0]\n    device = obb_v.device\n    cent_gt = torch.zeros((B, 1, vD, vH, vW), device=device)\n    bbox_gt = torch.zeros((B, 7, vD, vH, vW), device=device)\n    clas_gt = torch.zeros((B, num_class, vD, vH, vW), device=device)\n    # Where to apply non-centerness losses.\n    valid_gt = torch.zeros((B, 1, vD, vH, vW), device=device, dtype=torch.bool)\n    # Gaussian kernel for splatting.\n    size = 2 * splat_sigma + 1\n    rng = torch.arange(0, size, 1).to(device)\n    xx, yy, zz = torch.meshgrid(rng, rng, rng, indexing=\"ij\")\n    x0 = y0 = z0 = size // 2\n    eps = 1e-6\n    gauss = torch.exp(\n        -((xx - x0) ** 2 + (yy - y0) ** 2 + (zz - z0) ** 2) / (2 * splat_sigma**2 + eps)\n    )\n    # Convert obb centers to voxel indices.\n    cent_v = obb_v.bb3_center_world\n    cent_vn = normalize_coord3d(cent_v, voxel_extent)\n    inds, inside = norm2ind(cent_vn, vD, vH, vW)\n    # Get index offsets for splatting.\n    if splat_sigma == 0:\n        dd = torch.tensor([0]).reshape(1, 1, 1).to(device)\n        hh = torch.tensor([0]).reshape(1, 1, 1).to(device)\n        ww = torch.tensor([0]).reshape(1, 1, 1).to(device)\n    elif splat_sigma > 0:\n        rng_d = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device)\n        rng_h = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device)\n        rng_w = torch.arange(start=-splat_sigma, end=splat_sigma + 1).to(device)\n        dd, hh, ww = torch.meshgrid(rng_d, rng_h, rng_w, indexing=\"ij\")\n    else:\n        raise ValueError(\"splat sigma most be non-negative\")\n    offsets_dhw = torch.stack((dd.reshape(-1), hh.reshape(-1), ww.reshape(-1)), dim=-1)\n    offsets_dhw = offsets_dhw.unsqueeze(0).repeat(B, 1, 1)\n    # Use broadcasting to apply the offset indices to the voxel indices.\n    O = offsets_dhw.shape[1]\n    N = inds.shape[1]\n    inds_dhw = inds.reshape(B, N, 1, 3) + offsets_dhw.reshape(B, 1, O, 3)\n    inds_dhw = inds_dhw.reshape(B, N * O, 3)\n    inside = inside.reshape(B, N, 1).repeat(1, 1, O).reshape(B, N * O).float()\n    # Avoid accessing OOB.\n    ones = torch.ones_like(inds_dhw[:, :, 0])\n    inds_dhw[:, :, 0] = torch.maximum(inds_dhw[:, :, 0], 0 * ones)\n    inds_dhw[:, :, 1] = torch.maximum(inds_dhw[:, :, 1], 0 * ones)\n    inds_dhw[:, :, 2] = torch.maximum(inds_dhw[:, :, 2], 0 * ones)\n    inds_dhw[:, :, 0] = torch.minimum(inds_dhw[:, :, 0], (vD - 1) * ones)\n    inds_dhw[:, :, 1] = torch.minimum(inds_dhw[:, :, 1], (vH - 1) * ones)\n    inds_dhw[:, :, 2] = torch.minimum(inds_dhw[:, :, 2], (vW - 1) * ones)\n\n    # keep the (d, h, w) indices before flattening\n    inds_dhw_3d = inds_dhw.clone()\n\n    # Convert D,H,W indices into flat array indices.\n    inds_d = inds_dhw[:, :, 0]\n    inds_h = inds_dhw[:, :, 1]\n    inds_w = inds_dhw[:, :, 2]\n    inds_dhw = inds_d * (vH * vW) + inds_h * vW + inds_w\n    b_inds = torch.arange(B).reshape(-1, 1).repeat(1, N * O)\n    # Set centerness GT.\n    cent_gt = cent_gt.reshape(B, -1)\n    gauss = gauss.reshape(1, 1, -1).repeat(B, N, 1).reshape(B, N * O)\n    cent_gt[b_inds, inds_dhw] = gauss * inside\n    cent_gt = cent_gt.reshape(B, 1, vD, vH, vW)\n    # Semantic class.\n    CL = num_class\n    sem_id = torch.clip(obb_v.sem_id, 0, CL - 1).long()\n    one_hot = torch.nn.functional.one_hot(sem_id, num_classes=CL)\n    one_hot = one_hot.reshape(B, N, CL).permute(0, 2, 1)\n    one_hot = one_hot.reshape(B, CL, N, 1).repeat(1, 1, 1, O)\n    one_hot = one_hot.reshape(B, CL, N * O)\n    val = one_hot * inside.reshape(B, 1, -1)\n    clas_gt = clas_gt.reshape(B, -1, vD * vH * vW)\n    b_inds_rep = b_inds.reshape(B, 1, -1).repeat(1, CL, 1)\n    cl_inds_rep = torch.arange(CL).reshape(1, CL, 1).repeat(B, 1, O * N)\n    inds_dhw_rep = inds_dhw.reshape(B, 1, -1).repeat(1, CL, 1)\n    clas_gt[b_inds_rep, cl_inds_rep, inds_dhw_rep] = val\n    clas_gt = clas_gt.reshape(B, -1, vD, vH, vW)\n    # Get gravity aligned rotation from obb\n    T_voxel_object = obb_v.T_world_object.clone()\n    # HACK to avoid gimbal lock for padded entries.\n    mask = obb_v.get_padding_mask()\n    T_voxel_object.R[mask] = PAD_VAL\n    rpy = T_voxel_object.to_euler()\n    yaw = rpy[:, :, 2]\n    # BBox size (in voxel coordinates.)\n    bb3 = obb_v.bb3_object\n    xsize = bb3[:, :, 1] - bb3[:, :, 0]\n    ysize = bb3[:, :, 3] - bb3[:, :, 2]\n    zsize = bb3[:, :, 5] - bb3[:, :, 4]\n    # Discretized centers.\n    centd_vn = ind2norm(inds_dhw_3d, vD, vH, vW)\n    centd_v = unnormalize_coord3d(centd_vn, voxel_extent)\n    # Compute offset between discretized centers and obb centers.\n    cent_v_rep = cent_v.reshape(B, -1, 1, 3).repeat(1, 1, O, 1).reshape(B, N * O, 3)\n    offsets = centd_v - cent_v_rep\n    xoff = offsets[:, :, 0]\n    yoff = offsets[:, :, 1]\n    zoff = offsets[:, :, 2]\n    # Splat via repeat.\n    xsize = xsize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O)\n    ysize = ysize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O)\n    zsize = zsize.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O)\n    yaw = yaw.reshape(B, -1, 1).repeat(1, 1, O).reshape(B, N * O)\n    # Assign bbox parameters into voxel GT.\n    bbox_gt = bbox_gt.reshape(B, 7, -1)\n    BB = bbox_gt.shape[1]\n    bb_inds = torch.arange(BB).reshape(-1, 1).repeat(1, N * O)\n    bbox_gt[b_inds, bb_inds[0, :], inds_dhw] = xsize * inside\n    bbox_gt[b_inds, bb_inds[1, :], inds_dhw] = ysize * inside\n    bbox_gt[b_inds, bb_inds[2, :], inds_dhw] = zsize * inside\n    bbox_gt[b_inds, bb_inds[3, :], inds_dhw] = xoff * inside\n    bbox_gt[b_inds, bb_inds[4, :], inds_dhw] = yoff * inside\n    bbox_gt[b_inds, bb_inds[5, :], inds_dhw] = zoff * inside\n    bbox_gt[b_inds, bb_inds[6, :], inds_dhw] = yaw * inside\n    bbox_gt = bbox_gt.reshape(B, 7, vD, vH, vW)\n    # Set valid mask.\n    valid_gt = valid_gt.reshape(B, -1)\n    valid_gt[b_inds, inds_dhw] = inside.bool()\n    valid_gt = valid_gt.reshape(B, 1, vD, vH, vW)\n    return cent_gt, bbox_gt, clas_gt, valid_gt\n\n\ndef voxel2obb(\n    cent_pr,\n    bbox_pr,\n    clas_pr,\n    voxel_extent,\n    top_k=None,\n    thresh=None,\n    return_full_prob=False,\n):\n    \"\"\"Convert 3D centerness, size, rotation voxel grids to ObbTW objects,\n    returning objects in the voxel coordinate frame. Can optionally threshold\n    based on a topK predictions.\n    \"\"\"\n    device = cent_pr.device\n    assert cent_pr.ndim == 5\n    B, _, vD, vH, vW = cent_pr.shape\n    device = cent_pr.device\n    # Get extent.\n    xhalf = bbox_pr[:, 0] / 2.0\n    yhalf = bbox_pr[:, 1] / 2.0\n    zhalf = bbox_pr[:, 2] / 2.0\n    bb3 = torch.stack(\n        [\n            -xhalf,\n            +xhalf,\n            -yhalf,\n            +yhalf,\n            -zhalf,\n            +zhalf,\n        ],\n        dim=-1,\n    )\n    # Get rotation to set T_world_object.\n    yaw = bbox_pr[:, 6]\n    zeros = torch.zeros_like(yaw)\n    e_angles = torch.stack([zeros, zeros, yaw], dim=-1)\n    R = rotation_from_euler(e_angles.reshape(-1, 3))\n    R = R.reshape(B, vD, vH, vW, 3, 3)\n    t_zero = torch.zeros(B, vD, vH, vW, 3).to(device)\n    T_voxel_object = PoseTW.from_Rt(R, t_zero)\n    rngd = torch.arange(vD).to(device)\n    rngh = torch.arange(vH).to(device)\n    rngw = torch.arange(vW).to(device)\n    xx, yy, zz = torch.meshgrid(rngd, rngh, rngw, indexing=\"ij\")\n    inds = torch.stack([xx.reshape(-1), yy.reshape(-1), zz.reshape(-1)], dim=-1)\n    norm_centers = ind2norm(inds, vD, vH, vW)\n    centers_v = unnormalize_coord3d(norm_centers, voxel_extent)\n    centers_v = centers_v.reshape(1, vD, vH, vW, 3).repeat(B, 1, 1, 1, 1)\n    # The center is defined as the voxel center + the offset.\n    xoff = bbox_pr[:, 3]\n    yoff = bbox_pr[:, 4]\n    zoff = bbox_pr[:, 5]\n    t_off = torch.stack([xoff, yoff, zoff], dim=-1)\n    T_voxel_object.t[:] = centers_v - t_off\n    # Get prob.\n    prob = cent_pr.reshape(B, vD, vH, vW, 1)\n    N = inds.shape[0]\n    # Get instance id, use voxel location for this.\n    inst_id = torch.arange(N).reshape(1, vD, vH, vW, 1).repeat(B, 1, 1, 1, 1)\n    # Get semantic id\n    sem_id = torch.argmax(clas_pr, dim=1).unsqueeze(-1)\n    # Construct ObbTW object.\n    obbs = ObbTW.from_lmc(\n        bb3_object=bb3,\n        T_world_object=T_voxel_object,\n        prob=prob,\n        inst_id=inst_id,\n        sem_id=sem_id,\n    )\n    # Optionally remove detections below threshold.\n    if thresh is not None:\n        below = (obbs.prob < thresh).squeeze(-1)\n        obbs._data[below, :] = PAD_VAL\n\n    # Optionally subselect top K.\n    if top_k is not None:\n        prob = obbs.prob.reshape(B, N)\n        s_vals, s_inds = torch.sort(prob, dim=1, descending=True)\n        n_inds = s_inds[:, :top_k].reshape(-1)\n        b_inds = torch.arange(B).reshape(B, 1).repeat(1, top_k).to(device).reshape(-1)\n        obbs = obbs.reshape(B, N, -1)\n        # B x K x 34\n        obbs = obbs[b_inds, n_inds].reshape(B, top_k, -1)\n        # B x K\n        prob = prob[b_inds, n_inds].reshape(B, top_k)\n        # B x K x C\n        clas_pr = clas_pr.reshape(B, -1, N)[b_inds, :, n_inds].reshape(B, top_k, -1)\n\n    if return_full_prob:\n        return obbs, prob, clas_pr\n    else:\n        return obbs\n"
  },
  {
    "path": "efm3d/utils/evl_loss.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_OBB_PADDED,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.aria.obb import obb_filter_outside_volume, obb_time_union, ObbTW, PAD_VAL\nfrom efm3d.thirdparty.mmdetection3d.iou3d import RotatedIoU3DLoss\nfrom efm3d.utils.detection_utils import (\n    compute_chamfer_loss,\n    compute_focal_loss,\n    obb2voxel,\n    voxel2obb,\n)\nfrom efm3d.utils.pointcloud import get_points_world\nfrom efm3d.utils.reconstruction import compute_occupancy_loss_subvoxel, compute_tv_loss\n\n\ndef get_gt_obbs(batch, voxel_extent, T_wv=None):\n    \"\"\"\n    Get the GT Obbs from the batch.\n\n    voxel_extent: used to filter GT Obbs outside of voxel grid.\n    T_wv: if not None, filter GT Obbs outside of voxel grid.\n    \"\"\"\n    if ARIA_OBB_PADDED not in batch:\n        B = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].shape[0]\n        return ObbTW().view(1, -1).repeat(B, 1)\n    obbs_gt = batch[ARIA_OBB_PADDED].clone()\n    # Optionally filter GT.\n    if batch[ARIA_OBB_PADDED].ndim == 4:\n        # Filter by time Union.\n        obbs_gt = obb_time_union(obbs_gt)\n    if T_wv is not None:\n        # Filter outside of voxel grid.\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1)\n        obbs_gt = obb_filter_outside_volume(\n            obbs_gt, T_ws, T_wv, voxel_extent=voxel_extent\n        )\n    return obbs_gt\n\n\ndef obbs_to_7d(obbs):\n    obbs_cent = obbs.bb3_center_world  # center in voxel coords\n    wlh = obbs.bb3_max_object - obbs.bb3_min_object\n\n    # Get gravity aligned rotation from obb\n    T_voxel_object = obbs.T_world_object.clone()\n    # HACK to avoid gimbal lock for padded entries.\n    mask = obbs.get_padding_mask()\n    T_voxel_object.R[mask] = PAD_VAL\n    rpy = T_voxel_object.to_euler()\n    yaw = rpy[..., 2].unsqueeze(-1)\n\n    obbs_7d = torch.concat([obbs_cent, wlh, yaw], dim=-1)\n    return obbs_7d\n\n\ndef iou_3d_loss(obbs_pr, obbs_gt, cent_pr, cent_gt, valid_gt):\n    \"\"\"\n    obbs_pr: N x 34\n    obbs_gt: N x 34\n    \"\"\"\n    assert obbs_pr.ndim == 2 and obbs_gt.ndim == 2, \"obbs dimension should be Nx34\"\n    obbs_pr_7d = obbs_to_7d(obbs_pr)\n    obbs_gt_7d = obbs_to_7d(obbs_gt)\n    iou_loss = RotatedIoU3DLoss(loss_weight=1.0)\n\n    # weighted by validness and GT centerness\n    obbs_weight = cent_gt.reshape(-1) * valid_gt.reshape(-1)\n    valid_idx = torch.nonzero(obbs_weight > 0).squeeze()\n    obbs_weight = obbs_weight[valid_idx]\n    obbs_pr_7d = obbs_pr_7d[valid_idx, :]\n    obbs_gt_7d = obbs_gt_7d[valid_idx, :]\n\n    loss = obbs_weight * iou_loss.forward(obbs_pr_7d, obbs_gt_7d)\n    loss = loss.mean()\n\n    return loss\n\n\ndef compute_obb_losses(\n    outputs,\n    batch,\n    voxel_extent,\n    num_class,\n    splat_sigma,\n    cent_weight,\n    clas_weight,\n    iou_weight,\n    bbox_weight,\n    cham_weight,\n):\n    B, _, vD, vH, vW = outputs[\"cent_pr\"].shape\n    ve = voxel_extent\n    N = vD * vH * vW\n    T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1)\n    T_wv = outputs[\"voxel/T_world_voxel\"]\n    obb_gt_s = get_gt_obbs(batch, voxel_extent, T_wv)\n\n    # Put GT in voxel coordinate frame.\n    T_vs = T_wv.inverse() @ T_ws\n    obb_gt_v = obb_gt_s.transform(T_vs.unsqueeze(1))\n\n    # Create 3D GT tensors.\n    cent_gt, bbox_gt, clas_gt, valid_gt = obb2voxel(\n        obb_gt_v, vD, vH, vW, ve, num_class, splat_sigma\n    )\n    outputs[\"cent_gt\"] = cent_gt\n    outputs[\"bbox_gt\"] = bbox_gt\n    outputs[\"clas_gt\"] = clas_gt\n    outputs[\"valid_gt\"] = valid_gt\n\n    # Get Obbs from densified predictions + GT.\n    cent_pr = outputs[\"cent_pr\"]\n    bbox_pr = outputs[\"bbox_pr\"]\n    clas_pr = outputs[\"clas_pr\"]\n    obbs_pr_dense = voxel2obb(cent_pr, bbox_pr, clas_pr, ve, top_k=None, thresh=None)\n    obbs_gt_dense = voxel2obb(cent_gt, bbox_gt, clas_gt, ve, top_k=None, thresh=None)\n    obbs_pr_dense = obbs_pr_dense.reshape(B * N, -1)\n    obbs_gt_dense = obbs_gt_dense.reshape(B * N, -1)\n    outputs[\"obbs_gt_dense\"] = obbs_gt_dense\n\n    losses = {\"rgb\": {}}\n    total_loss = 0.0\n\n    # Centerness loss.\n    if cent_weight > 0:\n        cent_pr = outputs[\"cent_pr\"]\n        cent_loss = compute_focal_loss(cent_pr, cent_gt)\n        cent_loss = cent_loss.reshape(B, -1)\n        cent_loss = cent_loss.mean()\n        cent_loss = cent_loss * cent_weight\n        losses[\"rgb\"][\"cent\"] = cent_loss\n        total_loss += cent_loss\n\n    # Classification loss.\n    if clas_weight > 0:\n        clas_pr = outputs[\"clas_pr\"]\n        clas_loss = compute_focal_loss(clas_pr, clas_gt)\n        clas_loss = clas_loss.sum(dim=1).reshape(-1)\n        clas_loss[~valid_gt.reshape(-1)] = 0.0\n        clas_loss = torch.sum(clas_loss) / (valid_gt.sum() + 1)\n        clas_loss = clas_loss * clas_weight\n        losses[\"rgb\"][\"clas\"] = clas_loss\n        total_loss += clas_loss\n\n    # 3D IoU loss (gravity aligned 7 DoF loss).\n    if iou_weight > 0:\n        iou_loss = iou_3d_loss(obbs_pr_dense, obbs_gt_dense, cent_pr, cent_gt, valid_gt)\n        iou_loss = iou_loss * iou_weight\n        losses[\"rgb\"][\"iou\"] = iou_loss\n        total_loss += iou_loss\n\n    # Supervise directly on D, H, W dimensions with L1 loss.\n    if bbox_weight > 0:\n        dhw_gt = obbs_gt_dense.bb3_diagonal\n        dhw_pr = obbs_pr_dense.bb3_diagonal\n        bbox_loss = torch.mean(torch.abs(dhw_pr - dhw_gt), dim=-1)\n        bbox_loss[~valid_gt.reshape(-1)] = 0.0\n        bbox_loss = torch.sum(bbox_loss) / (valid_gt.sum() + 1)\n        bbox_loss = bbox_loss * bbox_weight\n        losses[\"rgb\"][\"bbox\"] = bbox_loss\n        total_loss += bbox_loss\n\n    # Chamfer loss for rotation.\n    if cham_weight > 0:\n        corners_pr = obbs_pr_dense.bb3corners_world  # world is voxel\n        corners_gt = obbs_gt_dense.bb3corners_world  # world is voxel\n        cham_loss = compute_chamfer_loss(corners_pr, corners_gt)\n        cham_loss[~valid_gt.reshape(-1)] = 0.0\n        cham_loss = torch.sum(cham_loss) / (valid_gt.sum() + 1)\n        cham_loss = cham_loss * cham_weight\n        losses[\"rgb\"][\"cham\"] = cham_loss\n        total_loss += cham_loss\n\n    return losses, total_loss\n\n\ndef compute_occ_losses(\n    outputs,\n    batch,\n    voxel_extent,\n    occ_weight,\n    tv_weight,\n):\n    B, T, vD, vH, vW = outputs[\"occ_pr\"].shape\n    T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET].squeeze(1)\n    T_wv = outputs[\"voxel/T_world_voxel\"]\n\n    losses = {\"rgb\": {}}\n    total_loss = 0.0\n    p3s_w, dist_stds = get_points_world(batch)\n\n    # Occupancy loss.\n    cams = batch[ARIA_CALIB[0]]\n    Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]]\n    T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]\n    Ts_wr = T_ws @ Ts_sr\n    Ts_cw = cams.T_camera_rig @ Ts_wr.inverse()\n    Ts_wc = Ts_cw.inverse()\n\n    occ = outputs[\"occ_pr\"].squeeze(1)\n    voxel_counts = outputs[\"voxel/counts\"]\n\n    B, D, H, W = occ.shape\n    B, Df, Hf, Wf = voxel_counts.shape\n    if D != Df or H != Hf or W != Wf:\n        resize = torch.nn.Upsample(size=(D, H, W))\n        voxel_counts = resize(voxel_counts.unsqueeze(1).float()).squeeze(1)\n\n    visible = voxel_counts > 0\n\n    if occ_weight > 0:\n        occ_loss = compute_occupancy_loss_subvoxel(\n            occ,\n            visible,\n            p3s_w,\n            Ts_wc,\n            cams,\n            T_wv,\n            voxel_extent,\n            loss_type=\"l2\",\n        )\n        occ_loss = occ_loss * occ_weight\n        total_loss += occ_loss\n        losses[\"rgb\"][\"occ\"] = occ_loss.cpu().detach()\n\n    if tv_weight > 0.0:\n        tv_loss = compute_tv_loss(occ)\n        tv_loss = tv_loss * tv_weight\n        total_loss += tv_loss\n        losses[\"rgb\"][\"tv\"] = tv_loss.cpu().detach()\n\n    return losses, total_loss\n"
  },
  {
    "path": "efm3d/utils/file_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gzip\nimport json\nimport os\nimport pickle\nimport random\nfrom bisect import bisect_left\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport fsspec\nimport numpy as np\nimport pandas as pd\nimport pyvrs\nimport torch\nimport tqdm\nfrom efm3d.aria import CameraTW, PoseTW\nfrom efm3d.aria.aria_constants import ARIA_CAM_INFO, ARIA_OBB_BB2, ARIA_OBB_BB3\nfrom efm3d.utils.rescale import rescale_camera_tw, rescale_image\nfrom pyquaternion import Quaternion\nfrom pyvrs import SyncVRSReader\nfrom vrsbindings import ImageConversion, RecordType\n\n\ndef load_gt_calibration(\n    calib_path: Union[str, dict], load_torch=False, timestamps=None\n):\n    \"\"\"load ground truth calibration json from simulation\"\"\"\n\n    if isinstance(calib_path, str):\n        with fsspec.open(calib_path, \"r\") as f:\n            calib = json.load(f)\n    elif isinstance(calib_path, dict):\n        calib = calib_path\n    else:\n        raise IOError(\"calib_path must be str or dict\")\n    gt_calib = {}\n    gt_calib[\"T_rig_views\"] = {}\n    gt_calib[\"intr_type\"] = {}\n    gt_calib[\"intr_params\"] = {}\n\n    cam_names = ARIA_CAM_INFO[\"name\"]\n    # Maps names from the gt_calib.json file to the ARIA_CAM_INFO convention.\n    name_map = {\n        \"camera-rgb\": cam_names[0],\n        \"camera-slam-left\": cam_names[1],\n        \"camera-slam-right\": cam_names[2],\n    }\n    for camera in calib[\"CameraCalibrations\"]:\n        cn = camera[\"Label\"]\n        if cn not in name_map:  # Ignore other cameras like eye tracking.\n            continue\n        cam_name = name_map[cn]\n        [tx, ty, tz] = camera[\"T_Device_Camera\"][\"Translation\"]\n        [qw, [qx, qy, qz]] = camera[\"T_Device_Camera\"][\"UnitQuaternion\"]\n\n        rot_mat = Quaternion(qw, qx, qy, qz).rotation_matrix\n        translation = torch.tensor([tx, ty, tz]).view(3, 1)\n        T_rig_view = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n        T_rig_view = PoseTW.from_matrix3x4(T_rig_view)\n        T_rig_view = T_rig_view.fit_to_SO3()\n        if not load_torch:\n            T_rig_view = T_rig_view.numpy()\n        gt_calib[\"T_rig_views\"][cam_name] = T_rig_view\n\n        intr_type = camera[\"Projection\"][\"Name\"]\n        # This is the case for Fisheye62 which has 6+2+3=11 parameters, morphed as Fisheye624\n        # Add zeros to make it 15 params (same as Fisheye624)\n        if intr_type == \"Fisheye624\":\n            N = 15 - len(camera[\"Projection\"][\"Params\"])\n            if N > 0:\n                for _i in range(N):\n                    camera[\"Projection\"][\"Params\"].append(0)\n        intr_params = np.array(camera[\"Projection\"][\"Params\"])\n        if load_torch:\n            intr_params = torch.from_numpy(intr_params)\n        gt_calib[\"intr_type\"][cam_name] = intr_type\n        gt_calib[\"intr_params\"][cam_name] = intr_params\n\n    if timestamps is not None:\n        time2calib = {}\n        for timestamp in timestamps:\n            time2calib[timestamp] = gt_calib\n        return time2calib\n\n    return gt_calib\n\n\ndef get_image_info(image_reader: SyncVRSReader) -> Tuple[Dict, Dict]:\n    \"\"\"\n    Get image info such as sizes and frame rate. These fields are not\n    part of calibration so we have to query them through VRSReader.\n    \"\"\"\n    image_sizes = {}\n    fps = {}\n    image_config_reader = image_reader.filtered_by_fields(\n        record_types=[\"configuration\"]\n    )\n    for image_config in image_config_reader:\n        assert image_config.record_type == \"configuration\"\n        stream_id = image_config.stream_id\n        if stream_id not in ARIA_CAM_INFO[\"id_to_name\"]:\n            continue\n        name = ARIA_CAM_INFO[\"id_to_name\"][stream_id]\n        metadata = image_config.metadata_blocks[0]\n        image_sizes[name] = metadata[\"image_height\"], metadata[\"image_width\"]\n        fps[name] = metadata[\"nominal_rate\"]\n    return image_sizes, fps\n\n\ndef load_factory_calib(\n    reader: SyncVRSReader,\n    calib: Optional[str] = None,\n    map_radius_to_cam_height: bool = False,\n):\n    \"\"\"\n    Augment `load_gt_calibration` by adding `image_sizes`, `camera_tw`\n    (CameraTW objects), and `fps` for each camera. The reader has to be\n    an image VRSReader.\n    video_stream_name is needed for eye tracking images. Unlike slaml and slamr where their vrs ids are 1201-1 and 1201-2,\n    eye tracking vrs id is only 211-1 for both left and right eye images\n    \"\"\"\n    image_sizes, fps = get_image_info(reader)\n    if \"calib_json\" in reader.file_tags:\n        calib = json.loads(reader.file_tags[\"calib_json\"])\n    elif calib is None:\n        return None\n    cam_calib = load_gt_calibration(calib, load_torch=True, timestamps=None)\n    cam_calib[\"image_sizes\"] = image_sizes\n    cam_calib[\"fps\"] = fps\n    cam_calib[\"camera_tw\"] = {}\n\n    # Hack to override the camera model instead of using cam_calib[\"intr_type\"][cam_name] which is set to \"Fisheye62\"\n    for cam_name in image_sizes:\n        if map_radius_to_cam_height:\n            cam_calib[\"camera_tw\"][cam_name] = CameraTW.from_surreal(\n                height=image_sizes[cam_name][0],\n                width=image_sizes[cam_name][1],\n                type_str=cam_calib[\"intr_type\"][cam_name],\n                params=cam_calib[\"intr_params\"][cam_name],\n                T_camera_rig=cam_calib[\"T_rig_views\"][cam_name].inverse(),\n                valid_radius=image_sizes[cam_name][0],\n            )\n        else:\n            cam_calib[\"camera_tw\"][cam_name] = CameraTW.from_surreal(\n                height=image_sizes[cam_name][0],\n                width=image_sizes[cam_name][1],\n                type_str=cam_calib[\"intr_type\"][cam_name],\n                params=cam_calib[\"intr_params\"][cam_name],\n                T_camera_rig=cam_calib[\"T_rig_views\"][cam_name].inverse(),\n            )\n    return cam_calib\n\n\ndef load_2d_bounding_boxes(bb2d_path, time_in_secs=False):\n    bb2ds = {}\n\n    try:\n        with fsspec.open(bb2d_path).open() as f:\n            # genfromtxt handles missing values and lets us specify dtypes.\n            # #Object_UID, timestamp [nanoseconds], x_min [pixel], x_max [pixel], y_min [pixel], y_max [pixel]\n            lines = np.genfromtxt(\n                f,\n                dtype=[int] * 2 + [float] * 4,\n                names=True,\n                delimiter=\",\",\n                usecols=range(6),\n            )\n    except Exception:\n        try:\n            # sometimes the last row is bad for some reason so we just skip it\n            with fsspec.open(bb2d_path).open() as f:\n                # genfromtxt handles missing values and lets us specify dtypes.\n                # #Object_UID, timestamp [nanoseconds], x_min [pixel], x_max [pixel], y_min [pixel], y_max [pixel]\n                lines = np.genfromtxt(\n                    f,\n                    dtype=[int] * 2 + [float] * 4,\n                    names=True,\n                    delimiter=\",\",\n                    usecols=range(6),\n                    skip_footer=1,\n                )\n        except Exception as e:\n            print(f\"could not load {bb2d_path}; error {e}\")\n            return bb2ds\n\n    count = 0\n    for line in lines:\n        object_id = line[0]\n        timestamp_ns = line[1]\n        if time_in_secs:\n            timestamp = timestamp_ns / 1e9\n        else:\n            timestamp = timestamp_ns\n        x_min = max(0, line[2])\n        x_max = max(0, line[3])\n        y_min = max(0, line[4])\n        y_max = max(0, line[5])\n        # invalid entries will have nan as fill value; we skip them.\n        if any(x != x for x in [x_min, x_max, y_min, y_max]):\n            continue\n        if timestamp not in bb2ds:\n            bb2ds[timestamp] = [(object_id, x_min, x_max, y_min, y_max)]\n        else:\n            bb2ds[timestamp].append((object_id, x_min, x_max, y_min, y_max))\n        count += 1\n    print(f\"loaded {count} 2d bbs for {len(bb2ds)} timestamps from {bb2d_path}\")\n    return bb2ds\n\n\ndef load_2d_bounding_boxes_adt(bb2d_path):\n    bb2ds_rgb = {}\n    bb2ds_slaml = {}\n    bb2ds_slamr = {}\n\n    with fsspec.open(bb2d_path).open() as f:\n        lines = f.readlines()\n\n    # expected header:\n    # stream_id,object_uid,timestamp[ns],x_min[pixel],x_max[pixel],y_min[pixel],y_max[pixel],visibility_ratio[%]\\n'\n\n    count = 0\n    for ii, line in enumerate(lines):\n        if ii == 0:\n            continue  # skip header\n        line = line.decode(\"utf-8\").rstrip().split(\",\")\n        device_id = str(line[0])\n        object_id = int(line[1])\n        timestamp = int(line[2])  # ns\n        x_min = max(0, float(line[3]))\n        x_max = max(0, float(line[4]))\n        y_min = max(0, float(line[5]))\n        y_max = max(0, float(line[6]))\n        # invalid entries will have nan as fill value; we skip them.\n        if any(x != x for x in [x_min, x_max, y_min, y_max]):\n            continue\n\n        if device_id == \"214-1\":\n            if timestamp not in bb2ds_rgb:\n                bb2ds_rgb[timestamp] = [(object_id, x_min, x_max, y_min, y_max)]\n            else:\n                bb2ds_rgb[timestamp].append((object_id, x_min, x_max, y_min, y_max))\n\n        elif device_id == \"1201-1\":\n            if timestamp not in bb2ds_slaml:\n                bb2ds_slaml[timestamp] = [(object_id, x_min, x_max, y_min, y_max)]\n            else:\n                bb2ds_slaml[timestamp].append((object_id, x_min, x_max, y_min, y_max))\n\n        elif device_id == \"1201-2\":\n            if timestamp not in bb2ds_slamr:\n                bb2ds_slamr[timestamp] = [(object_id, x_min, x_max, y_min, y_max)]\n            else:\n                bb2ds_slamr[timestamp].append((object_id, x_min, x_max, y_min, y_max))\n        else:\n            raise IOError(\"unexpected device id {device_id} in 2d observations\")\n\n        count += 1\n    print(\n        f\"loaded {count} 2d bbs for {len(bb2ds_rgb)}[rgb] {len(bb2ds_slaml)}[slaml] {len(bb2ds_slamr)}[slamr] timestamps from {bb2d_path}\"\n    )\n    return bb2ds_rgb, bb2ds_slaml, bb2ds_slamr\n\n\ndef remove_invalid_2d_bbs(timed_bb2s, filter_bb2_area=-1):\n    \"\"\"\n    remove bbs with x, y <= 0. In some datasets (DlrSim) these 2d bbs indicate\n    object is not visible!\n    \"\"\"\n    bb2s_filtered = defaultdict(list)\n    for time, bb2s in timed_bb2s.items():\n        for bb2 in bb2s:\n            if not ((bb2[1] <= 0 and bb2[2] <= 0) or (bb2[3] <= 0 and bb2[4] <= 0)):\n                if filter_bb2_area > 0:\n                    bb2_area = (bb2[2] - bb2[1]) * (bb2[4] - bb2[3])\n                    if bb2_area >= filter_bb2_area:\n                        bb2s_filtered[time].append(bb2)\n                else:\n                    bb2s_filtered[time].append(bb2)\n    return bb2s_filtered\n\n\ndef load_instances(instances_path):\n    instance2proto = {}\n\n    assert os.path.exists(instances_path), (\n        f\"instances path {instances_path} does not exist\"\n    )\n    with open(instances_path, \"r\") as f:\n        lines = f.readlines()\n\n    for line in lines[1:]:  # skip first line\n        line = line.rstrip().split(\",\")\n        instance_uid = int(line[0])\n        prototype_uid = str(line[1]).strip()\n        instance2proto[instance_uid] = prototype_uid\n\n    return instance2proto\n\n\ndef load_instances_adt(instances_path):\n    instance2proto = {}\n    with fsspec.open(instances_path).open() as f:\n        content = json.load(f)\n    for inst_id in content:\n        instance2proto[int(inst_id)] = content[inst_id][\"category\"]\n\n    # lot of other info available, for example:\n    #  {'instance_id': 5691266090916432, 'instance_name': 'Hook_4',\n    #  'prototype_name': 'Hook', 'category': 'hook', 'category_uid': 643,\n    #  'motion_type': 'static', 'instance_type': 'object', 'rigidity': 'rigid',\n    #  'rotational_symmetry': {'is_annotated': False},\n    #  'canonical_pose': {'up_vector': [0, 1, 0], 'front_vector': [0, 0, 1]}}\n\n    return instance2proto\n\n\ndef load_3d_bounding_box_transforms(scene_path, time_in_secs=False, load_torch=False):\n    T_world_object = {}\n\n    with fsspec.open(scene_path).open() as f:\n        lines = np.genfromtxt(\n            f,\n            dtype=[int] * 2 + [float] * 7,\n            names=True,\n            delimiter=\",\",\n            usecols=range(9),\n        )\n        if lines.size == 1:\n            lines = lines[np.newaxis]\n\n    for line in lines:\n        object_id = line[0]\n        timestamp_ns = line[1]\n        if time_in_secs and timestamp_ns != -1:\n            timestamp = timestamp_ns / 1e9\n        else:\n            timestamp = timestamp_ns\n        tx = line[2]\n        ty = line[3]\n        tz = line[4]\n        qw = line[5]\n        qx = line[6]\n        qy = line[7]\n        qz = line[8]\n        # invalid entries will have nan as fill value; we skip them.\n        if any(x != x for x in [tx, ty, tz, qw, qx, qy, qz]):\n            continue\n\n        rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix\n        translation = torch.tensor([tx, ty, tz]).view(3, 1)\n        T_wo = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n        T_wo = PoseTW.from_matrix3x4(T_wo)\n        T_wo = T_wo.fit_to_SO3()\n        if not load_torch:\n            T_wo = T_wo.numpy()\n\n        if timestamp not in T_world_object:\n            T_world_object[timestamp] = {}\n        T_world_object[timestamp][object_id] = T_wo\n    return T_world_object\n\n\ndef load_3d_bounding_box_local_extents(bb3d_path, load_torch=False):\n    bb3ds_local = {}\n    with fsspec.open(bb3d_path).open() as f:\n        # Object UID, Timestamp ( ns ), p_local_obj.xmin, p_local_obj.xmax, p_local_obj.ymin, p_local_obj.ymax, p_local_obj.zmin, p_local_obj.zmax\n        lines = np.genfromtxt(\n            f,\n            dtype=[int] * 2 + [float] * 6,\n            names=True,\n            delimiter=\",\",\n            usecols=range(8),\n        )\n        if lines.size == 1:\n            lines = lines[np.newaxis]\n    for line in lines:\n        object_id = line[0]\n        xmin = line[2]\n        xmax = line[3]\n        ymin = line[4]\n        ymax = line[5]\n        zmin = line[6]\n        zmax = line[7]\n        # invalid entries will have nan as fill value; we skip them.\n        if any(x != x for x in [xmin, xmax, ymin, ymax, zmin, zmax]):\n            continue\n        local = np.array([xmin, xmax, ymin, ymax, zmin, zmax])\n        if load_torch:\n            local = torch.from_numpy(local)\n        bb3ds_local[object_id] = local\n    return bb3ds_local\n\n\ndef load_obbs_gt(\n    input_dir,\n    load_2d_bbs=True,\n    filter_outside_2d_bbs: bool = False,\n    rgb_only=False,\n    filter_bb2_area=-1,\n):\n    obs = {}\n    if load_2d_bbs:\n        # Load 2d bbs from CSV.\n        bb2s_path_rgb = exists_nonzero_path(\n            [\n                os.path.join(input_dir, \"2d_bounding_box.csv\"),\n                os.path.join(input_dir, \"2d_bounding_box_rgb.csv\"),\n                os.path.join(input_dir, \"sensor_0_2d_bounding_box.csv\"),\n            ]\n        )\n        bb2s_path_slaml, bb2s_path_slamr = False, False\n        if not rgb_only:\n            bb2s_path_slaml = exists_nonzero_path(\n                [\n                    os.path.join(input_dir, \"2d_bounding_box_2.csv\"),\n                    os.path.join(input_dir, \"2d_bounding_box_left_slam.csv\"),\n                    os.path.join(input_dir, \"sensor_1_2d_bounding_box.csv\"),\n                ]\n            )\n            bb2s_path_slamr = exists_nonzero_path(\n                [\n                    os.path.join(input_dir, \"2d_bounding_box_3.csv\"),\n                    os.path.join(input_dir, \"2d_bounding_box_right_slam.csv\"),\n                    os.path.join(input_dir, \"sensor_2_2d_bounding_box.csv\"),\n                ]\n            )\n\n        bb2_loaded = False\n        if bb2s_path_rgb:\n            # ADT dataset packs all three bb2 observations into one file\n            with fsspec.open(bb2s_path_rgb).open() as f:\n                header = f.readline()\n                header = str(header).split(\",\")\n                if len(header) == 8:\n                    (\n                        obs[ARIA_OBB_BB2[0]],\n                        obs[ARIA_OBB_BB2[1]],\n                        obs[ARIA_OBB_BB2[2]],\n                    ) = load_2d_bounding_boxes_adt(bb2s_path_rgb)\n                    bb2_loaded = True\n\n        if not bb2_loaded and bb2s_path_rgb and bb2s_path_slaml and bb2s_path_slamr:\n            # Load 2d bounding boxes separately for three cameras\n            obs[ARIA_OBB_BB2[0]] = load_2d_bounding_boxes(\n                bb2s_path_rgb, time_in_secs=False\n            )\n            obs[ARIA_OBB_BB2[1]] = load_2d_bounding_boxes(\n                bb2s_path_slaml, time_in_secs=False\n            )\n            obs[ARIA_OBB_BB2[2]] = load_2d_bounding_boxes(\n                bb2s_path_slamr, time_in_secs=False\n            )\n            bb2_loaded = True\n        elif not bb2_loaded and bb2s_path_rgb:\n            # sometimes we only have RGB 2d bounding boxes.\n            obs[ARIA_OBB_BB2[0]] = load_2d_bounding_boxes(\n                bb2s_path_rgb, time_in_secs=False\n            )\n            obs[ARIA_OBB_BB2[1]] = {}\n            obs[ARIA_OBB_BB2[2]] = {}\n            bb2_loaded = True\n        elif not bb2_loaded:\n            print(\"Warning: could not find 2d bbs\")\n            return {}\n    else:\n        obs[ARIA_OBB_BB2[0]] = {}\n        obs[ARIA_OBB_BB2[1]] = {}\n        obs[ARIA_OBB_BB2[2]] = {}\n        print(\"not loading 2d bb information\")\n\n    # most of the time bbs with x, y <= 0 indicate object is visible but we dont\n    # know where. In the DlrSim dataset it indicates object not observed!\n    if filter_outside_2d_bbs:\n        for bb2_key in ARIA_OBB_BB2:\n            obs[bb2_key] = remove_invalid_2d_bbs(obs[bb2_key], filter_bb2_area)\n\n    # Load bounding box local 3D extents.\n    bb3d_path = exists_nonzero_path(\n        [\n            os.path.join(input_dir, \"scene/3d_bounding_box.csv\"),\n            os.path.join(input_dir, \"3d_bounding_box.csv\"),\n        ]\n    )\n    if bb3d_path:\n        obs[ARIA_OBB_BB3] = load_3d_bounding_box_local_extents(bb3d_path)\n\n    # Load scene object centers + object_ids from scene_objects.csv\n    scene_path = exists_nonzero_path(\n        [\n            os.path.join(input_dir, \"scene/scene_objects.csv\"),\n            os.path.join(input_dir, \"scene_objects.csv\"),\n        ]\n    )\n    if scene_path:\n        obs[\"timedTs_world_object\"] = load_3d_bounding_box_transforms(\n            scene_path, time_in_secs=False, load_torch=True\n        )\n    # Load label mapping from instances to prototypes.\n    instance_path = exists_nonzero_path(\n        [\n            # fixed some wrong 'rug' labels\n            os.path.join(input_dir, \"scene/instances_fix.csv\"),\n            os.path.join(input_dir, \"scene/instances.csv\"),\n            os.path.join(input_dir, \"instances.json\"),\n        ]\n    )\n    if instance_path:\n        if instance_path.endswith(\".csv\"):\n            obs[\"inst2proto\"] = load_instances(instance_path)\n        elif instance_path.endswith(\".json\"):\n            obs[\"inst2proto\"] = load_instances_adt(instance_path)\n        else:\n            raise IOError(\"Unknown instances extension\")\n\n    return obs\n\n\ndef load_trajectory_adt(\n    traj_path,\n    subsample: Union[float, int] = 1,\n    load_first_n=99999999999,\n):\n    print(\"checking \" + traj_path)\n    fs = fsspec.get_mapper(traj_path).fs\n    if not fs.exists(traj_path):\n        return None\n    if not fs.isfile(traj_path):\n        traj_path = exists_nonzero_path(\n            [\n                os.path.join(traj_path, \"aria_trajectory.csv\"),  # ADT ground truth\n            ]\n        )\n    if traj_path is None:\n        return None\n    print(\"loading \" + traj_path)\n\n    T_world_rigs = {}\n    # check for number of columns first\n    with fsspec.open(traj_path, \"r\").open() as f:\n        header = f.readline()\n        num_cols = len(header.split(\",\"))\n        if num_cols not in [20]:\n            return None\n\n    # load data without header\n    with fsspec.open(traj_path, \"rb\").open() as f:\n        lines = f.readlines()\n\n    N = min(len(lines), load_first_n)\n    idxs = sample_from_range(0, N, subsample)\n    for ii in idxs:\n        if ii == 0:\n            continue  # skip header\n        line = lines[ii]\n        line = str(line).split(\",\")\n        timestamp_us = int(line[1])\n        timestamp_ns = timestamp_us * 1000\n        timestamp = timestamp_ns\n        sub_line = line[3:10]\n        tx, ty, tz, qx, qy, qz, qw = [float(e) for e in sub_line]\n        rot_mat = Quaternion(qw, qx, qy, qz).rotation_matrix\n        translation = torch.tensor([tx, ty, tz]).view(3, 1)\n        T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n        T_world_rig = PoseTW.from_matrix3x4(T_world_rig)\n        T_world_rig = T_world_rig.fit_to_SO3()\n        T_world_rigs[timestamp] = T_world_rig\n\n    return T_world_rigs\n\n\ndef load_trajectory_aeo(\n    csv_path: str,\n    load_torch: bool = False,\n    subsample: Union[float, int] = 1,\n    time_in_secs: bool = False,\n    load_first_n: int = 99999999999,\n):\n    assert not time_in_secs, \"Only support time in ns for now\"\n    vio_filenames = [\n        \"closed_loop_framerate_trajectory.csv\",\n        \"closed_loop_trajectory.csv\",\n        \"mps/slam/closed_loop_trajectory.csv\",\n    ]\n    lines = None\n    for vio_filename in vio_filenames:\n        traj_csv_path = os.path.join(csv_path, vio_filename)\n        print(\"checking \" + traj_csv_path)\n        if os.path.exists(traj_csv_path):\n            with open(traj_csv_path, \"r\") as f:\n                lines = f.readlines()\n            print(f\"loaded {len(lines)} from \" + traj_csv_path)\n            break\n\n    if lines is None:\n        print(f\"No file found in {csv_path}.\")\n        return None\n\n    T_world_rigs = {}\n    header = lines[0].strip().split(\",\")\n    if len(header) not in {20, 26, 28, 29}:\n        print(\n            f\"Invalid header, expected 20, 26, 28 or 29 columns, but got {len(header)}\"\n        )\n        print(header)\n        return None\n\n    start_index = 0\n    if len(header) in {20, 28}:  # no recording_source field in this version\n        start_index = -1\n\n    N = min(len(lines), load_first_n)\n    idxs = sample_from_range(1, N, subsample)\n    for ii in idxs:\n        line = lines[ii]\n        # Handle data error\n        line = line.strip()\n        if len(line) == 0:\n            continue\n\n        cols = line.split(\",\")\n        timestamp_ns = int(cols[start_index + 2]) * 1000\n        tx, ty, tz, qx, qy, qz, qw = [\n            float(num) for num in cols[start_index + 4 : start_index + 11]\n        ]\n        rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix\n        translation = torch.tensor([tx, ty, tz]).view(3, 1)\n        T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n        T_world_rig = PoseTW.from_matrix3x4(T_world_rig)\n        # T_world_rig = T_world_rig.fit_to_SO3()\n        if not load_torch:\n            T_world_rig = T_world_rig.numpy()\n\n        T_world_rigs[timestamp_ns] = T_world_rig\n\n    return T_world_rigs\n\n\ndef load_trajectory(\n    traj_path,\n    time_in_secs=False,\n    load_torch=False,\n    subsample: Union[float, int] = 1,\n    load_quaternion=False,\n    load_first_n=99999999999,\n):\n    print(\"checking \" + traj_path)\n    fs = fsspec.get_mapper(traj_path).fs\n    if not fs.exists(traj_path):\n        return None\n    if not fs.isfile(traj_path):\n        traj_path = exists_nonzero_path(\n            [\n                os.path.join(traj_path, \"trajectory.csv\"),  # default\n                os.path.join(traj_path, \"traj000.csv\"),  # ASE\n            ]\n        )\n    if traj_path is None:\n        return None\n    print(\"loading \" + traj_path)\n\n    T_world_rigs = {}\n    # check for number of columns first\n    with fsspec.open(traj_path, \"r\").open() as f:\n        header = f.readline()\n        num_cols = len(header.split(\",\"))\n        if num_cols not in [8, 14, 17]:\n            return None\n    # load data without header\n    with fsspec.open(traj_path, \"rb\").open() as f:\n        lines = np.loadtxt(f, delimiter=\",\", skiprows=1)\n\n    N = min(len(lines), load_first_n)\n    idxs = sample_from_range(0, N, subsample)\n    for ii in idxs:\n        line = lines[ii]\n        timestamp_ns = int(line[0])\n        if time_in_secs:\n            timestamp = timestamp_ns / 1e9\n        else:\n            timestamp = timestamp_ns\n        sub_line = line[1:8]\n        tx, ty, tz, qw, qx, qy, qz = [float(e) for e in sub_line]\n        rot_mat = Quaternion(w=qw, x=qx, y=qy, z=qz).rotation_matrix\n\n        if load_quaternion:\n            # allow \"raw\" loading\n            T_world_rig = np.array([tx, ty, tz, qw, qx, qy, qz])\n        else:\n            translation = torch.tensor([tx, ty, tz]).view(3, 1)\n            T_world_rig = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n            T_world_rig = PoseTW.from_matrix3x4(T_world_rig)\n            T_world_rig = T_world_rig.fit_to_SO3()\n            if not load_torch:\n                T_world_rig = T_world_rig.numpy()\n        T_world_rigs[timestamp] = T_world_rig\n\n    return T_world_rigs\n\n\ndef parse_global_name_to_id_csv(csv_path: str, verbose: bool = True) -> Dict[str, int]:\n    \"\"\"\n    Loads a csv with 2 columns: old_sem_name, sem_id and returns it as\n    a dictionary of {old_sem_name: sem_id}\n    \"\"\"\n    global_name_to_id = None\n    if len(csv_path) > 0:\n        if verbose:\n            print(f\"trying to load taxonomy from csv at {csv_path}\")\n        with fsspec.open(csv_path) as f:\n            global_name_to_id = dict(\n                np.loadtxt(\n                    f,\n                    delimiter=\",\",\n                    skiprows=1,\n                    dtype={\n                        \"names\": (\"Object Name\", \"Object cls ID\"),\n                        \"formats\": (\"U30\", int),\n                    },\n                    ndmin=1,\n                )\n            )\n        if verbose:\n            print(f\"loaded {len(global_name_to_id)} name-to-id mappings from csv.\")\n    return global_name_to_id\n\n\ndef exists_nonzero_path(path: Union[str, list]) -> Optional[str]:\n    \"\"\"Helper function, iterate through paths to make sure exists;\n\n    Input:\n        paths - can be str or list of str\n    Returns:\n        found - if not found, return False, if found, return the path\n    \"\"\"\n    if isinstance(path, str):\n        paths = [path]\n    else:\n        paths = path\n    # Iterate through each path, breaking if good file is found.\n    found = None\n    for path in paths:\n        try:\n            fs = fsspec.core.url_to_fs(path)[0]\n        except Exception as e:\n            print(f\"skipping {path}: {e}\")\n            continue\n        if fs.exists(path):\n            found = path\n            break\n    return found\n\n\ndef get_timestamp_list_ns(reader, stream_id=None):\n    if stream_id is None:\n        filtered_reader = reader\n    else:\n        filtered_reader = reader.filtered_by_fields(\n            stream_ids=[stream_id], record_types=[\"data\"]\n        )\n    time_list = filtered_reader.get_timestamp_list()\n    # go from vrs output of float times in seconds to long times in nanoseconds\n    timestamp_list = [int(t * 1e9) for t in time_list]\n    return timestamp_list\n\n\ndef sample_times(time_list: List, start_time: int, end_time: int) -> Tuple[int, int]:\n    \"\"\"\n    Sample timestamps within the interval [start_time,end_time] using binary\n    search, making sure that at least one sample is taken.\n\n    Inputs:\n        time_list: list of sorted times\n        start_time: float of start of range to sample\n        end_time: float of end of range to sample\n    Returns:\n        (idx_i,idx_j): tuple of indices into time_list of sampled range.\n\n    Suppose the anchor modality is IMU, Using `bisect_left` would give\n    a time order like this:\n    image1, image2, image3, image4,   image5\n       |      |       |       |          |\n            img_i    ...   img_(j-1)   img_j\n          |                       |\n         imu1                    imu2\n          |                       |\n         audio1                  audio2\n    \"\"\"\n    idx_i = bisect_left(time_list, start_time)\n    idx_j = bisect_left(time_list, end_time)\n    # Make sure sampled image data is in between the start and end time\n    if idx_j > idx_i:\n        assert (\n            start_time <= time_list[idx_i]\n            and time_list[idx_i] <= time_list[idx_j - 1]\n            and time_list[idx_j - 1] < end_time\n        ), (\n            f\"start {start_time} end {end_time}, time_list[idx_i], {time_list[idx_i]}, time_list[idx_j], {time_list[idx_j]}\"\n        )\n    else:\n        # make sure idx_j is greater than idx_i\n        idx_j = max(idx_j, idx_i + 1)\n    return idx_i, idx_j\n\n\ndef sample_from_range(\n    start: int,\n    end: int,\n    sample_rate: Union[float, int],\n    add_random: bool = True,\n) -> List[int]:\n    \"\"\"\n    sample from a range using defined sample_rate.\n    Args:\n        start (int): start of the range (inclusive).\n        end (int): end of the range (exclusive).\n        sample_rate (Union[float, int]): target sampling rate.\n        add_random (bool): whether to add randomness to the final samples.\n    Returns:\n        list: a list of integers sampled from the range, in increasing order.\n    Example:\n        1. sample_rate is integer. We just return range(start, end, sample_rate).\n        2. sample_rate is float.\n            We first round up the sample_rate and then add the missing numbers from the reminder of the entire list.\n            For example, we'd like to sample from [0, 1, 2, ..., 9] with sample rate 1.25, which will result in 8 samples.\n            We first round the sample rate to 2, then get the list of numbers by range(0, 10, 2) which is [0, 2, 4, 6, 8].\n            And then we randomly get 3 numbers from the reminder of the list [1, 3, 5, 7, 9] to add to the final sample list.\n    \"\"\"\n    assert end >= start, \"the end of the range must be greater than the start.\"\n    assert sample_rate > 0, \"sample rate must be positive.\"\n\n    if end == start:\n        print(f\"[Warn] end equals start ({start}, {end}), return emply list\")\n        return []\n\n    # if sample rate is an integer, we just return the sampling by using sample_rate as the step size.\n    if type(sample_rate) is int or sample_rate.is_integer():\n        return list(range(start, end, int(sample_rate)))\n\n    # Otherwise, we do sampling with non-integer sampling rate.\n    if (end - start) % sample_rate != 0:\n        print(\n            f\"[WARN] sample_rate not divisible by for the range: got sample_rate {sample_rate}, start {start}, end {end}. Can not achieve the desired sampling rate in the end.\"\n        )\n\n    step = int(np.ceil(sample_rate))  # round-up the sampling rate\n    num = int((end - start) / sample_rate)  # number of final samples\n\n    # Generate the evenly spaced integers\n    integers = list(range(start, end, step))\n\n    # If we don't have enough integers, sample the missing ones randomly\n    if len(integers) < num:\n        missing_num = num - len(integers)\n        # Create a list of potential candidates that excludes already selected integers\n        candidates = [i for i in range(start, end) if i not in integers]\n        # Add the missing integers\n        if add_random:\n            integers.extend(random.sample(candidates, missing_num))\n        else:\n            integers = list(\n                np.linspace(start, end, num, endpoint=False).round().astype(int)\n            )\n\n    return sorted(integers)\n\n\ndef read_image_from_vrs(\n    reader: pyvrs.filter.FilteredVRSReader,\n    cam_id: str,\n    image_ts_ns: int,\n    intr_type: str,\n    intr_params: Union[List, np.array],\n    T_rig_camera: PoseTW,\n    scale_down_images: int = 0,\n    valid_radius: Optional[torch.Tensor] = None,\n    wh_multiple_of: int = 16,\n):\n    \"\"\"\n    Expect all the input time is in vrs capture time domain.\n    \"\"\"\n    cam_name = ARIA_CAM_INFO[\"id_to_name\"][cam_id]\n\n    # Read image from time-associated VRS block.\n    ret_error = (None, None, None)\n    try:\n        # convert from nanoseconds to seconds for vrs reader\n        image_ts = image_ts_ns / 1e9\n        record = reader.read_record_by_time(\n            cam_id, image_ts, record_type=RecordType.DATA\n        )\n    except ValueError as e:\n        return ret_error\n    if record is None:\n        return ret_error\n    if len(record.image_blocks) < 1:  # Bad image block.\n        return ret_error\n    else:\n        image = record.image_blocks[0]\n    cam_hw_before = image.shape\n\n    exposure_s = record.metadata_blocks[0][\"exposure_duration_s\"]\n    gain = record.metadata_blocks[0][\"gain\"]\n    # note that currently capture_time_ns is equal to image_ts_ns but this might\n    # change (?) so we rely on this meta data instead and pass it back out.\n    capture_time_ns = record.metadata_blocks[0][\"capture_timestamp_ns\"]\n\n    cam = CameraTW.from_surreal(\n        height=cam_hw_before[0],\n        width=cam_hw_before[1],\n        type_str=intr_type,\n        params=intr_params,\n        T_camera_rig=T_rig_camera.inverse(),\n        exposure_s=exposure_s,\n        gain=gain,\n        valid_radius=valid_radius,\n    )\n\n    image = rescale_image(image, cam_name, scale_down_images, wh_multiple_of)\n    cam = rescale_camera_tw(\n        cam, cam_hw_before, cam_name, scale_down_images, wh_multiple_of\n    )\n\n    if image.ndim == 2:\n        image = np.expand_dims(image, axis=2)\n    image = image.transpose(2, 0, 1)  # HxWxC -> CxHxW\n    image = torch.tensor(image.astype(np.float32) / 255.0)\n\n    return image, cam, capture_time_ns\n\n\ndef read_image_snippet_from_vrs(\n    image_reader: SyncVRSReader,\n    cam_id: str,\n    start_time_ns: int,\n    end_time_ns: int,\n    cam_calib,\n    subsample: Union[float, int] = 1,\n    scale_down_images: int = 0,\n    valid_radius: Optional[torch.Tensor] = None,\n    wh_multiple_of: int = 16,\n):\n    \"\"\"\n    If time code mapping provided, assume the input time is the timecode time domain.\n    Need to convert it to capture time domain to read data.\n    Otherwise, the start_time_ns and end_time_ns need to be in the capture time domain.\n    Output time domain is always aligned with the input time domain.\n    \"\"\"\n    image_reader.set_image_conversion(conversion=ImageConversion.NORMALIZE)\n    filtered_reader = image_reader.filtered_by_fields(\n        stream_ids=[cam_id], record_types=[\"data\"]\n    )\n    capture_time_list_ns = get_timestamp_list_ns(filtered_reader)\n\n    img_i, img_j = sample_times(capture_time_list_ns, start_time_ns, end_time_ns)\n\n    images = []\n    times_ns = []\n    cam_tws = []\n    frame_ids = []\n    sample_range = sample_from_range(\n        img_i, img_j, sample_rate=subsample, add_random=False\n    )\n    for i in sample_range:\n        image, cam_tw, capture_image_time_ns = read_image_from_vrs(\n            reader=filtered_reader,\n            cam_id=cam_id,\n            image_ts_ns=capture_time_list_ns[i],\n            intr_type=cam_calib[\"intr_type\"],\n            intr_params=cam_calib[\"intr_params\"],\n            T_rig_camera=cam_calib[\"T_rig_views\"],\n            scale_down_images=scale_down_images,\n            valid_radius=valid_radius,\n            wh_multiple_of=wh_multiple_of,\n        )\n        if (\n            image is not None\n            and capture_image_time_ns is not None\n            and cam_tw is not None\n        ):\n            images.append(image)\n            times_ns.append(capture_image_time_ns)\n            cam_tws.append(cam_tw)\n            frame_ids.append(i)\n\n    images = torch.stack(images)\n    # Long to hold timestamp in ns to not lose accuracy\n    times_ns = torch.LongTensor(times_ns)\n    cam_tws = torch.stack(cam_tws)\n    frame_ids = torch.LongTensor(frame_ids)\n    return images, times_ns, cam_tws, frame_ids\n\n\ndef load_global_points_csv(\n    path: str,\n    max_inv_depth_std: float = 0.001,\n    min_observations: int = 5,\n):\n    print(f\"loading global points from {path}\")\n    uid_to_p3 = {}\n    uid_to_inv_dist_std = {}\n    uid_to_dist_std = {}\n    if path.split(\".\")[-1] == \"gz\" or \"maps/maps_v1\" in path:\n        compression = \"gzip\"\n    else:\n        compression = None\n\n    cache_path = path + \".pickle.gz\"\n    if not os.path.exists(cache_path):\n        with fsspec.open(path, \"rb\") as f:\n            csv = pd.read_csv(f, compression=compression)\n            # filter by inverse distance std\n            csv = csv[csv.inv_dist_std < max_inv_depth_std]\n            if \"num_observations\" in csv.columns:\n                csv = csv[csv.num_observations > min_observations]\n            print(csv.columns)\n            # select points and uids and return mapping\n            uid_pts = csv[\n                [\"uid\", \"inv_dist_std\", \"dist_std\", \"px_world\", \"py_world\", \"pz_world\"]\n            ]\n\n            for row in tqdm.tqdm(uid_pts.values):\n                uid = int(row[0])\n                inv_dist_std = float(row[1])\n                dist_std = float(row[2])\n                p3 = row[3:]\n                uid_to_p3[uid] = p3\n                uid_to_inv_dist_std[uid] = inv_dist_std\n                uid_to_dist_std[uid] = dist_std\n\n        try:\n            # cache points\n            with gzip.open(cache_path, \"wb\") as f:\n                pickle.dump(uid_to_p3, f, protocol=pickle.HIGHEST_PROTOCOL)\n                pickle.dump(uid_to_inv_dist_std, f, protocol=pickle.HIGHEST_PROTOCOL)\n                pickle.dump(uid_to_dist_std, f, protocol=pickle.HIGHEST_PROTOCOL)\n\n            print(f\"Cached global points to {cache_path}\")\n        except:\n            print(\"Failed to cache the semidense points, like a write permission issue\")\n    else:\n        # load from the cached file\n        with gzip.open(cache_path, \"rb\") as f:\n            uid_to_p3 = pickle.load(f)\n            uid_to_inv_dist_std = pickle.load(f)\n            uid_to_dist_std = pickle.load(f)\n        print(f\"Loaded global points from cached file {cache_path}\")\n\n    uid_to_p3 = {uid: torch.from_numpy(p3) for uid, p3 in uid_to_p3.items()}\n    return uid_to_p3, uid_to_inv_dist_std, uid_to_dist_std\n\n\ndef load_semidense_observations(path: str):\n    print(f\"loading semidense observations from {path}\")\n    time_to_uids = defaultdict(list)\n    uid_to_times = defaultdict(list)\n    if path.split(\".\")[-1] == \"gz\" or \"maps/maps_v1\" in path:\n        compression = \"gzip\"\n    else:\n        compression = None\n\n    cache_path = path + \".pickle.gz\"\n    if not os.path.exists(cache_path):\n        with fsspec.open(path, \"rb\") as f:\n            csv = pd.read_csv(f, compression=compression)\n            csv = csv[[\"uid\", \"frame_tracking_timestamp_us\"]]\n            for row in tqdm.tqdm(csv.values):\n                uid = int(row[0])\n                time_ns = int(row[1]) * 1000\n                time_to_uids[time_ns].append(uid)\n                uid_to_times[uid].append(time_ns)\n\n        try:\n            with gzip.open(cache_path, \"wb\") as f:\n                pickle.dump(time_to_uids, f, protocol=pickle.HIGHEST_PROTOCOL)\n                pickle.dump(uid_to_times, f, protocol=pickle.HIGHEST_PROTOCOL)\n            print(f\"Cached semidense observations to {cache_path}\")\n        except:\n            print(\n                \"Failed to cache the semidense observations, like a write permission issue\"\n            )\n    else:\n        with gzip.open(cache_path, \"rb\") as f:\n            time_to_uids = pickle.load(f)\n            uid_to_times = pickle.load(f)\n        print(f\"Loaded semidense observations from cached file {cache_path}\")\n\n    return time_to_uids, uid_to_times\n"
  },
  {
    "path": "efm3d/utils/gravity.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom efm3d.aria.pose import PoseTW, rotation_from_euler\n\n\nGRAVITY_DIRECTION_DLR = np.array([0.0, -1.0, 0.0], np.float32)\nGRAVITY_DIRECTION_VIO = np.array([0.0, 0.0, -1.0], np.float32)\n\n\ndef get_transform_to_vio_gravity_convention(gravity_direction: np.array):\n    \"\"\"\n    Get transformation to map gravity_direction to (0,0,-1) as per our (and\n    VIO/Temple) convention.\n    \"\"\"\n    # gravity_direction = (d1, d2, d3) (0,0,-1)^T; d1, d2, d3 column vectors of rotation matrix R_gravity_vio\n    # -d3 = gravity_direction\n    d3 = -gravity_direction.copy()\n    # now construct an orthonormal basis for the rotation matrix\n    # d1 is a vector thats orthogonal to gravity_direction by construction\n    d1 = np.array(\n        [\n            gravity_direction[2] - gravity_direction[1],\n            gravity_direction[0],\n            -gravity_direction[0],\n        ]\n    )\n    # get d2 via orthogonal direction vector to d3 and d1\n    d2 = np.cross(d3, d1)\n    # get rotation matrix\n    R_gravity_vio = np.concatenate(\n        [d1[:, np.newaxis], d2[:, np.newaxis], d3[:, np.newaxis]], 1\n    )\n    assert (np.linalg.det(R_gravity_vio) - 1.0) < 1e-5\n    assert (((R_gravity_vio @ R_gravity_vio.transpose()) - np.eye(3)) < 1e-5).all()\n    R_gravity_vio = torch.from_numpy(R_gravity_vio)\n    # normalize to unit length\n    R_gravity_vio = F.normalize(R_gravity_vio, p=2, dim=-2)\n    R_vio_gravity = R_gravity_vio.transpose(1, 0)\n    T_vio_gravity = PoseTW.from_Rt(R_vio_gravity, torch.zeros(3))\n    return T_vio_gravity\n\n\ndef correct_adt_mesh_gravity(mesh):\n    \"\"\"\n    Change gravity direction of ADT mesh\n    \"\"\"\n    gravity_direction = np.array([0.0, -1.0, 0.0], np.float32)\n    T_vio_gravity = get_transform_to_vio_gravity_convention(gravity_direction).double()\n    print(\"Changing ADT gravity convention to VIO convention.\")\n    mesh.apply_transform(T_vio_gravity.matrix.numpy())\n    return mesh\n\n\ndef reject_vector_a_from_b(a, b):\n    # https://en.wikipedia.org/wiki/Vector_projection\n    b_norm = torch.sqrt((b**2).sum(-1, keepdim=True))\n    b_unit = b / b_norm\n    # batched dot product for variable dimensions\n    a_proj = b_unit * (a * b_unit).sum(-1, keepdim=True)\n    a_rej = a - a_proj\n    return a_rej\n\n\ndef gravity_align_T_world_cam(\n    T_world_cam, gravity_w=GRAVITY_DIRECTION_VIO, z_grav=False\n):\n    \"\"\"\n    get T_world_gravity from T_world_cam such that the x axis of T_world_gravity is gravity.\n    \"\"\"\n    assert T_world_cam.dim() > 1, f\"{T_world_cam} has wrong dimension; expected >1\"\n    dim = T_world_cam.dim()\n    device = T_world_cam.device\n    R_wc = T_world_cam.R\n    dir_shape = [1] * (dim - 1) + [3]\n    g_w = torch.from_numpy(gravity_w.copy()).view(dir_shape).to(R_wc)\n    g_w = g_w.expand_as(R_wc[..., 1])\n    # forward vector (z) that is orthogonal to gravity direction\n    d3 = reject_vector_a_from_b(a=R_wc[..., 2], b=g_w)\n    # optionally add a tiny offset to avoid cross product two identical vectors.\n    d3_is_zeros = (d3 == 0.0).all(dim=-1).unsqueeze(-1).expand_as(d3)\n    d3_offset = torch.zeros(*d3.shape).to(T_world_cam._data.device)\n    d3_offset[..., 1] += 0.001\n    d3 = torch.where(d3_is_zeros, d3 + d3_offset, d3)\n    d2 = torch.linalg.cross(d3, g_w, dim=-1)\n    # camera down vector is x direction since Aria cameras are rotated by 90 degree CW\n    # hence the new x direction is gravity\n    R_wcg = torch.cat([g_w.unsqueeze(-1), d2.unsqueeze(-1), d3.unsqueeze(-1)], -1)\n    # normalize to unit length\n    R_world_cg = torch.nn.functional.normalize(R_wcg, p=2, dim=-2)\n    if z_grav:\n        # add extra rotation to make z gravity direction, not x.\n        R_cg_cgz = rotation_from_euler(\n            torch.tensor([[-np.pi / 2.0, 0.0, np.pi / 2.0]])\n        ).to(device)\n        R_world_cgz = R_world_cg @ R_cg_cgz.inverse()\n        T_world_cgz = PoseTW.from_Rt(R_world_cgz, T_world_cam.t)\n        return T_world_cgz\n    else:\n        R_world_cg = R_world_cg\n        T_world_cg = PoseTW.from_Rt(R_world_cg, T_world_cam.t)\n        return T_world_cg\n"
  },
  {
    "path": "efm3d/utils/image.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple, Union\n\nimport cv2\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\n# Some globals for opencv drawing functions.\nBLU = (255, 0, 0)\nGRN = (0, 255, 0)\nRED = (0, 0, 255)\nWHT = (255, 255, 255)\nBLK = (0, 0, 0)\nFONT = cv2.FONT_HERSHEY_DUPLEX\nFONT_PT = (5, 15)\nFONT_SZ = 0.5\nFONT_TH = 1.0\n\n\ndef string2color(string):\n    string = string.lower()\n    if string == \"white\":\n        return WHT\n    elif string == \"green\":\n        return GRN\n    elif string == \"red\":\n        return RED\n    elif string == \"black\":\n        return BLK\n    elif string == \"blue\":\n        return BLU\n    else:\n        raise ValueError(\"input color string %s not supported\" % string)\n\n\ndef normalize(img, robust=0.0, eps=1e-6):\n    if isinstance(img, torch.Tensor):\n        vals = img.view(-1).cpu().numpy()\n    elif isinstance(img, np.ndarray):\n        vals = img.flatten()\n\n    if robust > 0.0:\n        v_min = np.quantile(vals, robust)\n        v_max = np.quantile(vals, 1.0 - robust)\n    else:\n        v_min = vals.min()\n        v_max = vals.max()\n    # make sure we are not dividing by 0\n    dv = max(eps, v_max - v_min)\n    # normalize to 0-1\n    img = (img - v_min) / dv\n    if isinstance(img, torch.Tensor):\n        img = img.clamp(0, 1)\n    elif isinstance(img, np.ndarray):\n        img = img.clip(0, 1)\n    return img\n\n\ndef put_text(\n    img: np.ndarray,\n    text: str,\n    scale: float = 1.0,\n    line: int = 0,\n    color: Tuple[Tuple, str] = WHT,\n    font_pt: Optional[Tuple[int, int]] = None,\n    truncate: int = None,\n):\n    \"\"\"Writes text with a shadow in the back at various lines and autoscales it.\n\n    Args:\n        image: image HxWx3 or BxHxWx3, should be uint8 for anti-aliasing to work\n        text: text to write\n        scale: 0.5 for small, 1.0 for normal, 1.5 for big font\n        line: vertical line to write on (0: first, 1: second, -1: last, etc)\n        color: text color, tuple of BGR integers between 0-255, e.g. (0,0,255) is red,\n               can also be a few strings like \"white\", \"black\", \"green\", etc\n        truncate: if not None, only show the first N characters\n    Returns:\n        image with text drawn on it\n\n    \"\"\"\n    if isinstance(img, list) or len(img.shape) == 4:  # B x H x W x 3\n        for i in range(len(img)):\n            img[i] = put_text(img[i], text, scale, line, color, font_pt, truncate)\n    else:  # H x W x 3\n        if truncate and len(text) > truncate:\n            text = text[:truncate] + \"...\"  # Add \"...\" to denote truncation.\n        height = img.shape[0]\n        scale = scale * (height / 320.0)\n        wht_th = max(int(FONT_TH * scale), 1)\n        blk_th = 2 * wht_th\n        text_ht = 15 * scale\n        if not font_pt:\n            font_pt = int(FONT_PT[0] * scale), int(FONT_PT[1] * scale)\n            font_pt = font_pt[0], int(font_pt[1] + line * text_ht)\n        if line < 0:\n            font_pt = font_pt[0], int(font_pt[1] + (height - text_ht * 0.5))\n        cv2.putText(img, text, font_pt, FONT, FONT_SZ * scale, BLK, blk_th, lineType=16)\n\n        if isinstance(color, str):\n            color = string2color(color)\n\n        cv2.putText(\n            img, text, font_pt, FONT, FONT_SZ * scale, color, wht_th, lineType=16\n        )\n    return img\n\n\ndef rotate_image90(image: np.ndarray, k: int = 3):\n    \"\"\"Rotates an image and then re-allocates memory to avoid problems with opencv\n    Input:\n        image: numpy image, HxW or HxWxC\n        k: number of times to rotate by 90 degrees counter clockwise\n    Returns\n        rotated image: numpy image, HxW or HxWxC\n    \"\"\"\n    return np.ascontiguousarray(np.rot90(image, k=k))\n\n\ndef smart_resize(\n    image: np.ndarray, height: int = -1, width: int = -1, pad_image: bool = False\n):\n    \"\"\"Resize with opencv, auto-inferring height or width to maintain aspect ratio.\"\"\"\n    if image.ndim == 4:\n        return np.stack([smart_resize(im, height, width, pad_image) for im in image])\n    assert image.ndim == 3, \"only three channel image currently supported\"\n    if width == -1 and height == -1:\n        return image\n    hh, ww = image.shape[0], image.shape[1]\n    if width == -1:\n        width = int(round((float(ww) / float(hh)) * height))\n        width = int(width / 2) * 2  # enforce divisible by 2\n    if height == -1:\n        height = int(round((float(hh) / float(ww)) * width))\n        height = int(height / 2) * 2  # enforce divisible by 2\n    if pad_image:\n        ar_orig = ww / hh\n        ar_new = width / height\n\n        if ar_new > ar_orig:  # pad the sides.\n            h_scale = height / hh\n            new_w = h_scale * ww\n            pad = (width - new_w) / 2\n            pad_before = int(pad / h_scale)\n            dtype = image.dtype\n            pad_img = np.zeros((hh, pad_before, 3), dtype=dtype)\n            image = np.hstack([pad_img, image, pad_img])\n        elif ar_new < ar_orig:  # pad the top and bottom\n            w_scale = width / ww\n            new_h = w_scale * hh\n            pad = (height - new_h) / 2\n            pad_before = int(pad / w_scale)\n            dtype = image.dtype\n            pad_img = np.zeros((pad_before, ww, 3), dtype=dtype)\n            image = np.vstack([pad_img, image, pad_img])\n\n    return cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)\n\n\ndef torch2cv2(\n    img: Union[np.ndarray, torch.Tensor],\n    rotate: bool = False,\n    rgb2bgr: bool = True,\n    ensure_rgb: bool = False,\n    apply_colormap: Optional[str] = None,\n    robust_quant: float = 0.0,\n):\n    \"\"\"\n    Converts numpy/torch float32 image [0,1] CxHxW to numpy uint8 [0,255] HxWxC\n\n    Args:\n        img: image CxHxW float32 image\n        rotate: if True, rotate image 90 degrees\n        rgb2bgr: convert image to BGR\n        ensure_rgb: ensure RGB if True (i.e. replicate the single color channel 3 times)\n        apply_colormap: apply colormap if specified (matplotlib color map names\n            i.e. \"jet\") to a single channel image. Overwrites ensure_rgb. This\n            lets you display single channel images outside the 0-1 range. (image\n            is normalized to [0,1] before applying the colormap.)\n        robust_quant: quantile to robustly compute min and max for normalization of the image.\n    \"\"\"\n\n    if isinstance(img, torch.Tensor):\n        if img.dim() == 4:\n            if img.shape[0] == 1:\n                # pre-serve old way of just squeezing 0th dim\n                img = img[0]\n            else:\n                # run torch2cv2 on all frames of the video\n                return np.stack(\n                    [\n                        torch2cv2(\n                            im,\n                            rotate,\n                            rgb2bgr,\n                            ensure_rgb,\n                            apply_colormap,\n                            robust_quant,\n                        )\n                        for im in img\n                    ]\n                )\n        img = img.data.cpu().float().numpy()\n    if img.ndim == 2:\n        img = img[np.newaxis, :, :]\n\n    # CxHxW -> HxWxC\n    img = img.transpose(1, 2, 0)\n    if img.shape[2] == 1 and apply_colormap is not None:\n        # make sure to normalize so min is 0 and max is 1.\n        img = normalize(img, robust=robust_quant)\n        cm = plt.cm.get_cmap(apply_colormap)\n        img = cm(img[:, :, 0])[:, :, :3]\n    img_cv2 = (img * 255.0).astype(np.uint8)\n\n    if rgb2bgr:\n        img_cv2 = img_cv2[:, :, ::-1]\n    if rotate:\n        img_cv2 = rotate_image90(img_cv2)\n    else:\n        img_cv2 = np.ascontiguousarray(img_cv2)\n    if ensure_rgb and img_cv2.shape[2] == 1:\n        img_cv2 = img_cv2[:, :, 0]\n    if ensure_rgb and img_cv2.ndim == 2:\n        img_cv2 = np.stack([img_cv2, img_cv2, img_cv2], -1)\n    return img_cv2\n\n\ndef numpy2mp4(imgs, output_path, fps=10):\n    \"\"\"\n    Convert a numpy array to mp4.\n\n    imgs: T, H, W, 3\n    \"\"\"\n    T, H, W, C = imgs.shape\n    assert C == 3, \"input image should be 3-channel\"\n    fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n    out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))\n\n    for i in range(T):\n        out.write(imgs[i])\n    out.release()\n"
  },
  {
    "path": "efm3d/utils/image_sampling.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport einops\nimport torch\n\n\ndef compute_factor(size):\n    return 1.0 * size / 2\n\n\ndef convert_pixel_to_coordinates(coordinates, factor):\n    return (coordinates / factor) - 1.0\n\n\ndef normalize_keypoints(kpts, height, width):\n    # compute conversion factor\n    x_factor = compute_factor(width)\n    y_factor = compute_factor(height)\n    pts_dst = kpts\n    pts_dst[..., 0] = convert_pixel_to_coordinates(pts_dst[..., 0], x_factor)\n    pts_dst[..., 1] = convert_pixel_to_coordinates(pts_dst[..., 1], y_factor)\n    return pts_dst\n\n\ndef sample_images(\n    feat2d,\n    query_pts_cam,\n    cams,\n    n_by_c=True,\n    warn=True,\n    padding_mode: Literal[\"border\", \"zeros\", \"reflection\"] = \"border\",\n    interp_mode: Literal[\"bilinear\", \"nearest\", \"bicubic\"] = \"bilinear\",\n    single_channel_mask: bool = False,\n):\n    \"\"\"\n    Uses 3D points and calibrated cameras to sample features from 2D feature maps.\n\n    Inputs:\n        feat2d: torch.tensor - feature maps to sample from shaped B(xT)xCxHxW\n        query_pts_cam: torch.tensor - 3D points in camera coordinates shaped B(xT)xNx3\n        cams: CameraTW - calibrated camera objects shaped B(xT)x15\n        n_by_c: return shapes ending in NxC or CxN\n    Returns:\n        samp_feats: torch.tensor - sampled features from 2D feature maps shaped B(xT)xCxN\n        valid: torch.tensor - boolean of whether there was a valid sampling B(xT)xCxN\n    \"\"\"\n    assert query_pts_cam.dim() == feat2d.dim() - 1\n\n    T = None\n    if feat2d.dim() == 5:\n        B, T, C, H, W = feat2d.shape\n        feat2d = feat2d.view(-1, C, H, W)\n        query_pts_cam = query_pts_cam.view(B * T, -1, 3)\n        cams = cams.view(B * T, -1)\n    elif feat2d.dim() == 4:\n        B, C, H, W = feat2d.shape\n    else:\n        raise ValueError(f\"feat2d.dim must be 5 or 4 {feat2d.shape}\")\n\n    camH = cams[0].size[1]\n    featH = feat2d.shape[-2]\n    camW = cams[0].size[0]\n    featW = feat2d.shape[-1]\n\n    # Cams may need to be rescaled to match the feature map spatial dimensions.\n    if camH != featH or camW != featW:\n        cams_resize = cams.scale_to(feat2d)\n    else:\n        cams_resize = cams\n\n    assert round(cams_resize[0].size[0].item()) == featW, (\n        f\"height of cam and feature image do not match. {cams_resize[0].size[0]}!= {feat2d.shape}\"\n    )\n    assert round(cams_resize[0].size[1].item()) == featH, (\n        f\"width of cam and feature image do not match. {cams_resize[0].size[1]}!= {feat2d.shape}\"\n    )\n\n    samp_pts, valid = cams_resize.project(query_pts_cam)\n    if warn:\n        frac_valid = valid.count_nonzero() / valid.numel()\n        if frac_valid < 0.05:\n            print(\n                f\"[Warning] not many valids! {frac_valid} {valid.count_nonzero()} {valid.shape}\"\n            )\n    samp_pts[~valid] = 0.0\n    samp_pts = samp_pts.float()\n    # Sample into the 2D feature maps.\n    norm_samp_pts = normalize_keypoints(\n        samp_pts.clone(), height=cams_resize[0].size[1], width=cams_resize[0].size[0]\n    )\n    device = feat2d.device\n    padding_mode = \"zeros\" if \"mps\" in str(device) else padding_mode\n    samp_feats = torch.nn.functional.grid_sample(\n        feat2d,\n        norm_samp_pts.unsqueeze(-2),\n        align_corners=False,\n        padding_mode=padding_mode,\n        mode=interp_mode,  # bilinear allows differentiating.\n    )\n    # squeeze back down the dimension of 1 we unsqueezed for norm_samp_pts to comply with interface\n    samp_feats = samp_feats.squeeze(-1)\n\n    # Overwrite invalid projections with zeros.\n    BT = samp_feats.shape[0]\n    valid = valid.reshape(BT, 1, -1)\n    if single_channel_mask:\n        samp_feats[(~valid).expand_as(samp_feats)] = 0.0\n    else:\n        valid = valid.repeat(1, C, 1)\n        samp_feats[~valid] = 0.0\n    if T is None:\n        if n_by_c:\n            samp_feats = einops.rearrange(samp_feats, \"b c n -> b n c\", b=B)\n            valid = einops.rearrange(valid, \"b c n -> b n c\", b=B)[..., 0]\n    else:\n        if n_by_c:\n            samp_feats = einops.rearrange(samp_feats, \"(b t) c n -> b t n c\", t=T, b=B)\n            valid = einops.rearrange(valid, \"(b t) c n -> b t n c\", t=T, b=B)[..., 0]\n        else:\n            samp_feats = einops.rearrange(samp_feats, \"(b t) c n -> b t c n\", t=T, b=B)\n            valid = einops.rearrange(valid, \"(b t) c n -> b t c n\", t=T, b=B)[..., 0]\n    return samp_feats, valid\n"
  },
  {
    "path": "efm3d/utils/marching_cubes.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\n\nimport torch\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\ndef marching_cubes_scaled(values, isolevel, voxel_extent, voxel_mask):\n    \"\"\"\n    Runs marching cubes on a values tensor (D H W) at the specified isolevel.\n    Voxel_mask is used to tell marching cubes where to run in the voxel grid.\n    Uses scikit implementation which runs only on CPU.\n\n    Returns vertices, face ids, and normals in the voxel coordinate system\n    scaled to the given voxel_extent.\n    \"\"\"\n\n    from skimage.measure import marching_cubes as mc_scikit\n\n    device = values.device\n    values = values.cpu()  # CPU only\n    assert values.ndim == 3, f\"skicit can only do non-batched inputs, {values.shape}\"\n    isolevel = max(values.min(), min(isolevel, values.max()))\n    logging.info(f\"mc min {values.min()}, max {values.max()}, isolevel {isolevel}\")\n    voxel_mask = voxel_mask.cpu().numpy() if voxel_mask is not None else None\n    try:\n        if voxel_mask is not None:\n            verts, faces, normals, _ = mc_scikit(\n                values.contiguous().numpy(), isolevel, mask=voxel_mask\n            )\n        else:\n            verts, faces, normals, _ = mc_scikit(values.contiguous().numpy(), isolevel)\n        logging.info(f\"{verts.shape}, {faces.shape}\")\n    except RuntimeError as e:\n        logging.error(f\"{e} {values.shape}, {voxel_mask.shape}\")\n        return torch.tensor([]), torch.tensor([]), torch.tensor([])\n    except Exception as e:\n        logging.error(f\"{e} {values.shape}, {voxel_mask.shape}\")\n        return torch.tensor([]), torch.tensor([]), torch.tensor([])\n\n    # copy to get around negative stride\n    # go back to x, y, z ordering\n    verts, faces, normals = (\n        torch.from_numpy(verts.copy()),\n        torch.from_numpy(faces.copy()),\n        torch.from_numpy(normals.copy()),\n    )\n    verts = verts[:, [2, 1, 0]]\n    normals = normals[:, [2, 1, 0]]\n    verts, faces, normals = verts.to(device), faces.to(device), normals.to(device)\n\n    logging.info(f\"{verts.shape}, {faces.shape}, {normals.shape}\")\n\n    vD, vH, vW = values.shape\n    logging.info(f\"{vD}, {vH}, {vW}, {voxel_extent}\")\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n    dW = (x_max - x_min) / vW\n    dH = (y_max - y_min) / vH\n    dD = (z_max - z_min) / vD\n\n    dVox = torch.tensor([dW, dH, dD]).view(1, 3).to(device)\n    vox_min = torch.tensor([x_min, y_min, z_min]).view(1, 3).to(device)\n    logging.info(f\"{verts.shape}\")\n    verts = verts * dVox + vox_min + dVox * 0.5\n    return verts, faces, normals\n"
  },
  {
    "path": "efm3d/utils/mesh_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport random\nfrom typing import Union\n\nimport numpy as np\nimport torch\nimport trimesh\nfrom matplotlib import pyplot as plt\n\n\ndef point_to_closest_vertex_dist(pts, verts, tris):\n    # pts N 3 float\n    # verts M 3 float\n    # norms M 3 float\n    # tris O 3 int\n    assert verts.ndim == 2, f\"{verts.shape}\"\n    assert tris.ndim == 2, f\"{tris.shape}\"\n    assert pts.ndim == 2, f\"{pts.shape}\"\n    v0s = verts[None, tris[:, 0], :]\n    v1s = verts[None, tris[:, 1], :]\n    v2s = verts[None, tris[:, 2], :]\n    pts = pts.unsqueeze(1)\n    # compute distance to closest vertex\n    vs = torch.cat([v0s, v1s, v2s], 0)\n    dist_vs = torch.linalg.norm(vs.unsqueeze(1) - pts.unsqueeze(0), 2.0, -1)  # 3, N, M\n    dist_vs = torch.min(dist_vs, 0)[0]\n    dist_vs = torch.min(dist_vs, 1)[0]  # N\n    return dist_vs\n\n\ndef point_to_closest_tri_dist(pts, verts, tris):\n    \"\"\"\n    Compute the min distance of points to triangles. If a point doesn't intersect with any triangles\n    return a big number (1e6) for that point.\n    \"\"\"\n    assert verts.ndim == 2, f\"{verts.shape}\"\n    assert tris.ndim == 2, f\"{tris.shape}\"\n    assert pts.ndim == 2, f\"{pts.shape}\"\n\n    def dot(a, b):\n        return (a * b).sum(-1, keepdim=True)\n\n    # pts N 3 float\n    # verts M 3 float\n    # norms M 3 float\n    # tris O 3 int\n    v0s = verts[None, tris[:, 0], :]\n    v1s = verts[None, tris[:, 1], :]\n    v2s = verts[None, tris[:, 2], :]\n    pts = pts.unsqueeze(1)\n\n    # compute if point projects inside triangle\n    u = v1s - v0s\n    v = v2s - v0s\n    n = torch.cross(u, v)\n    w = pts - v0s\n    nSq = dot(n, n)\n    gamma = dot(torch.cross(u, w, -1), n) / nSq\n    beta = dot(torch.cross(w, v, -1), n) / nSq\n    alpha = 1.0 - gamma - beta\n    valid_alpha = torch.logical_and(0.0 <= alpha, alpha <= 1.0)\n    valid_beta = torch.logical_and(0.0 <= beta, beta <= 1.0)\n    valid_gamma = torch.logical_and(0.0 <= gamma, gamma <= 1.0)\n    projs_to_tri = torch.logical_and(valid_alpha, valid_beta)\n    projs_to_tri = torch.logical_and(projs_to_tri, valid_gamma)\n    num_proj = projs_to_tri.count_nonzero(1)\n    projs_to_tri = projs_to_tri.squeeze(-1)\n\n    # compute distance to triangle plane\n    n = n / torch.sqrt(nSq)\n    dist_tri = dot(n, w).squeeze(-1).abs()\n    # set distance to large for point-triangle combinations that do not project\n    dist_tri[~projs_to_tri] = 1e6\n\n    dist_tri = torch.min(dist_tri, 1)[0]  # N\n    num_proj = num_proj.squeeze(-1)\n\n    return dist_tri, num_proj\n\n\ndef compute_pts_to_mesh_dist(pts, faces, verts, step):\n    dev = pts.device\n    N = pts.shape[0]\n    err = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev)\n    dist_tri = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev)\n    dist_ver = torch.from_numpy(np.array(N, np.finfo(np.float32).max)).to(dev)\n    num_proj = torch.zeros(N).to(dev)\n    for i in range(0, faces.shape[0], step):\n        dist_tri_i, num_proj_i = point_to_closest_tri_dist(\n            pts, verts, faces[i : i + step]\n        )\n        dist_ver_i = point_to_closest_vertex_dist(pts, verts, faces[i : i + step])\n        dist_tri = torch.min(dist_tri_i, dist_tri)\n        dist_ver = torch.min(dist_ver_i, dist_ver)\n        num_proj = num_proj + num_proj_i\n\n        prog_perc = min((i + step) / faces.shape[0] * 100, 100)\n        print(f\"Compute pts to mesh progress: {prog_perc:.01f}%\", end=\"\\r\")\n    err = torch.where(num_proj == 0, dist_ver, dist_tri)\n    err = err.detach().cpu().numpy()\n    return err\n\n\ndef eval_mesh_to_mesh(\n    pred: Union[str, trimesh.Trimesh],\n    gt: Union[str, trimesh.Trimesh],\n    threshold=0.05,\n    sample_num=10000,\n    step=50000,\n    cut_height=None,\n):\n    \"\"\"\n    Eval point to faces distance using `point_to_closest_tri_dist`.\n    \"\"\"\n    rnd_seed = 0\n    random.seed(0)\n    np.random.seed(0)\n\n    if isinstance(gt, str):\n        print(f\"load gt mesh {gt}\")\n        gt_mesh = trimesh.load_mesh(gt)\n    else:\n        gt_mesh = gt\n    if isinstance(pred, str):\n        print(f\"load pred mesh {pred}\")\n        pred_mesh = trimesh.load_mesh(pred)\n    else:\n        pred_mesh = pred\n\n    if cut_height is not None:\n        cutting_plane = [[0, 0, -1], [0, 0, cut_height]]\n        gt_mesh = gt_mesh.slice_plane(\n            plane_origin=cutting_plane[1], plane_normal=cutting_plane[0]\n        )\n        pred_mesh = pred_mesh.slice_plane(\n            plane_origin=cutting_plane[1], plane_normal=cutting_plane[0]\n        )\n\n    if torch.cuda.is_available():\n        dev = \"cuda:0\"\n    elif torch.backends.mps.is_available():\n        dev = \"mps\"\n    else:\n        dev = \"cpu\"\n    print(f\"==> [eval_mesh_to_mesh] use device {dev}\")\n\n    pred_vertices = torch.from_numpy(pred_mesh.vertices.view(np.ndarray)).to(dev)\n    gt_vertices = torch.from_numpy(gt_mesh.vertices.view(np.ndarray)).to(dev)\n    pred_faces = torch.from_numpy(pred_mesh.faces.view(np.ndarray)).to(dev)\n    gt_faces = torch.from_numpy(gt_mesh.faces.view(np.ndarray)).to(dev)\n    print(f\"gt vertices and faces {gt_vertices.shape}, {gt_faces.shape}\")\n    print(f\"pred vertices and faces {pred_vertices.shape}, {pred_faces.shape}\")\n\n    # accuracy (from sampled point in pred to GT)\n    acc = torch.from_numpy(np.array(sample_num, np.finfo(np.float32).max)).to(dev)\n    pred_pts, _ = trimesh.sample.sample_surface(pred_mesh, sample_num, seed=rnd_seed)\n    pred_pts = torch.from_numpy(pred_pts.view(np.ndarray)).to(dev)\n    acc = compute_pts_to_mesh_dist(pred_pts, gt_faces, gt_vertices, step)\n\n    # completeness\n    gt_pts, _ = trimesh.sample.sample_surface(gt_mesh, sample_num, seed=rnd_seed)\n    gt_pts = torch.from_numpy(gt_pts.view(np.ndarray)).to(dev)\n    comp = compute_pts_to_mesh_dist(gt_pts, pred_faces, pred_vertices, step)\n\n    precision5 = np.mean((acc < 0.05).astype(\"float\"))\n    recal5 = np.mean((comp < 0.05).astype(\"float\"))\n    precision1 = np.mean((acc < 0.01).astype(\"float\"))\n    recal1 = np.mean((comp < 0.01).astype(\"float\"))\n    fscore5 = 2 * precision5 * recal5 / (precision5 + recal5)\n    fscore1 = 2 * precision1 * recal1 / (precision1 + recal1)\n    # sort to get percentile numbers.\n    acc_sorted = np.sort(acc)\n    comp_sorted = np.sort(comp)\n    metrics = {\n        \"acc_mean\": np.mean(acc),\n        \"comp_mean\": np.mean(comp),\n        \"prec@0.05\": precision5,\n        \"recal@0.05\": recal5,\n        \"fscore@0.05\": fscore5,\n    }\n\n    # Create some visualizations for debugging.\n    cmap = plt.cm.jet\n    # accuracy heatmap (as a pointcloud) on predicted mesh\n    norm = plt.Normalize(acc.min(), acc.max())\n    colors = cmap(norm(acc))\n    acc_pc = trimesh.points.PointCloud(pred_pts.detach().cpu().numpy())\n    acc_pc.colors = colors\n    # completeness heatmap (as a pointcloud) on gt mesh\n    norm = plt.Normalize(comp.min(), comp.max())\n    colors = cmap(norm(comp))\n    com_pc = trimesh.points.PointCloud(gt_pts.detach().cpu().numpy())\n    com_pc.colors = colors\n\n    viz = {\n        \"acc_pc\": acc_pc,\n        \"comp_pc\": com_pc,\n        \"gt_mesh\": gt_mesh,\n    }\n\n    for threshold in [0.01, 0.05]:\n        prec_inliers = acc < threshold\n        recall_inliers = comp < threshold\n\n        # create visualizations for precision and recall\n        prec_pc = copy.deepcopy(acc_pc)\n        recal_pc = copy.deepcopy(com_pc)\n        prec_pc.colors[prec_inliers] = [0, 255, 0, 255]  # green\n        prec_pc.colors[~prec_inliers] = [255, 0, 0, 255]  # red\n        recal_pc.colors[recall_inliers] = [0, 255, 0, 255]  # green\n        recal_pc.colors[~recall_inliers] = [255, 0, 0, 255]  # red\n        viz[f\"prec@{threshold:.2}_pc\"] = prec_pc\n        viz[f\"recal@{threshold:.2}_pc\"] = recal_pc\n\n    raw_data = {\"acc\": acc, \"comp\": comp}\n    return metrics, viz, raw_data\n"
  },
  {
    "path": "efm3d/utils/obb_csv_writer.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport csv\nfrom typing import Dict, Optional\n\nimport fsspec\nimport torch\nfrom efm3d.aria.obb import ObbTW\nfrom efm3d.aria.pose import PoseTW\nfrom pyquaternion import Quaternion\n\n\nclass ObbCsvReader:\n    def __init__(self, file_name):\n        self.file_name = file_name\n        self.file_reader = fsspec.open(self.file_name, \"r\").open()\n        self.csv_reader = csv.DictReader(self.file_reader)\n        try:\n            self.next_row = next(self.csv_reader)\n        except Exception:  # StopIteration\n            self.next_row = None\n        self.all_obbs = None\n        self.sem_ids_to_names = {}\n\n    def parse_row(self, row):\n        t_ns = int(row[\"time_ns\"])\n        tx_wo = float(row[\"tx_world_object\"])\n        ty_wo = float(row[\"ty_world_object\"])\n        tz_wo = float(row[\"tz_world_object\"])\n        qw_wo = float(row[\"qw_world_object\"])\n        qx_wo = float(row[\"qx_world_object\"])\n        qy_wo = float(row[\"qy_world_object\"])\n        qz_wo = float(row[\"qz_world_object\"])\n        sx = float(row[\"scale_x\"])\n        sy = float(row[\"scale_y\"])\n        sz = float(row[\"scale_z\"])\n        if \"instance\" in row:\n            inst_id = int(row[\"instance\"])\n        else:\n            inst_id = -1\n        sem_id = int(row[\"sem_id\"])\n        name = row[\"name\"]\n        if sem_id not in self.sem_ids_to_names:\n            self.sem_ids_to_names[sem_id] = name\n        else:\n            assert name == self.sem_ids_to_names[sem_id]\n        if \"prob\" in row:\n            prob = float(row[\"prob\"])\n        else:\n            # methods like ObjectMapper may not have probabilities\n            prob = -1.0\n\n        # create obbs\n        xmin = -sx / 2.0\n        xmax = sx / 2.0\n        ymin = -sy / 2.0\n        ymax = sy / 2.0\n        zmin = -sz / 2.0\n        zmax = sz / 2.0\n        bb3s = torch.tensor([xmin, xmax, ymin, ymax, zmin, zmax])\n\n        # create poses\n        rot_mat = Quaternion(w=qw_wo, x=qx_wo, y=qy_wo, z=qz_wo).rotation_matrix\n        translation = torch.tensor([tx_wo, ty_wo, tz_wo]).view(3, 1)\n        T_wo = torch.concat([torch.tensor(rot_mat), translation], dim=1)\n        T_wo = PoseTW.from_matrix3x4(T_wo)\n        T_wo = T_wo.fit_to_SO3()\n        T_world_object = T_wo._data\n\n        # sem ids\n        sem_ids = sem_id\n        inst_ids = inst_id\n        probs = prob\n        # moveable: assuming static for now.\n        # bb2s also decide the visibility of the 3D obbs in the corresponding camera.\n        # we now just assume the obbs are visible in all the cameras\n        bb2_rgbs = torch.ones(4)\n        bb2_slamls = torch.ones(4)\n        bb2_slamrs = torch.ones(4)\n        # assume everything is static for now.\n        moveables = torch.zeros(1)\n\n        obb_tw = ObbTW.from_lmc(\n            bb3_object=bb3s,\n            bb2_rgb=bb2_rgbs,\n            bb2_slaml=bb2_slamls,\n            bb2_slamr=bb2_slamrs,\n            T_world_object=T_world_object,\n            sem_id=sem_ids,\n            inst_id=inst_ids,\n            prob=probs,\n            moveable=moveables,\n        ).float()\n        return t_ns, obb_tw\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        \"\"\"\n        Get the next obbs set with the same timestamp.\n        \"\"\"\n        if self.next_row is None:\n            raise StopIteration\n\n        t0_ns, obb = self.parse_row(self.next_row)\n        obbs = [obb]\n        for row in self.csv_reader:\n            t_ns = int(row[\"time_ns\"])\n            if t_ns != t0_ns:\n                self.next_row = row\n                return t0_ns, torch.stack(obbs)\n            t_ns, obb = self.parse_row(row)\n            obbs.append(obb)\n\n        self.next_row = None\n        return t0_ns, torch.stack(obbs)\n\n    @property\n    def obbs(self):\n        if self.all_obbs is not None:\n            return self.all_obbs\n\n        all_obbs = {}\n        for t_ns, obbs in self:\n            all_obbs[t_ns] = obbs\n        self.all_obbs = all_obbs\n        return all_obbs\n\n\nclass ObbCsvWriter:\n    def __init__(self, file_name=\"\"):\n        if not file_name:\n            file_name = \"/tmp/obbs.csv\"\n\n        print(f\"starting obb writer to {file_name}\")\n        self.file_name = file_name\n        self.file_writer = fsspec.open(self.file_name, \"w\").open()\n        headers_str = \"time_ns,tx_world_object,ty_world_object,tz_world_object,qw_world_object,qx_world_object,qy_world_object,qz_world_object,scale_x,scale_y,scale_z,name,instance,sem_id,prob\"\n        headers = headers_str.split(\",\")\n        self.num_cols = len(headers)\n        header_row = \",\".join(headers)\n        self.file_writer.write(header_row + \"\\n\")\n        self.rows = []\n\n    def write_rows(self):\n        for row in self.rows:\n            self.file_writer.write(row + \"\\n\")\n        self.file_writer.flush()\n        self.rows = []\n\n    def write(\n        self,\n        obb_padded: ObbTW,\n        timestamps_ns: int = -1,\n        sem_id_to_name: Optional[Dict[int, str]] = None,\n        flush_at_end: bool = True,\n    ):\n        obb = obb_padded.remove_padding().clone().cpu()\n        time_ns = str(int(timestamps_ns))\n\n        N = obb.shape[0]\n        if N == 0:\n            # write all -1 to indicate the obbs for this timestamp is missing\n            # null_row = [time_ns] + [\"-1\" for _ in range(self.num_cols - 1)]\n            # self.file_writer.write(\",\".join(null_row) + \"\\n\")\n            return\n\n        obbs_poses = obb.T_world_object\n        obbs_dims = obb.bb3_diagonal.numpy()\n        obb_sems = obb.sem_id.squeeze(-1).numpy()\n        obb_inst = obb.inst_id.squeeze(-1).numpy()\n        obb_prob = obb.prob.squeeze(-1).numpy()\n        for i in range(N):\n            sem_id = obb_sems[i]\n            if sem_id_to_name and sem_id in sem_id_to_name:\n                name = sem_id_to_name[sem_id]\n            else:\n                name = str(int(sem_id))\n\n            qwxyz = obbs_poses[i].q  # torch.Tensor [4]\n            qwxyz = \",\".join(qwxyz.numpy().astype(str))\n\n            txyz = obbs_poses[i].t  # torch.Tensor [3]\n            txyz = \",\".join(txyz.numpy().astype(str))\n\n            sxyz = \",\".join(obbs_dims[i].astype(str))\n            self.file_writer.write(\n                f\"{time_ns},{txyz},{qwxyz},{sxyz},{name},{obb_inst[i]},{obb_sems[i]},{obb_prob[i]}\\n\"\n            )\n        if flush_at_end:\n            self.file_writer.flush()\n\n    def flush(self):\n        self.file_writer.flush()\n\n    def __del__(self):\n        if hasattr(self, \"file_writer\"):\n            self.file_writer.close()\n"
  },
  {
    "path": "efm3d/utils/obb_io.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\nfrom efm3d.aria.aria_constants import ARIA_OBB_BB2, ARIA_OBB_BB3\nfrom efm3d.aria.pose import closest_timed_poses, interpolate_timed_poses, PoseTW\nfrom efm3d.utils.common import find_nearest\n\n\ndef bb2extent(bb):\n    if bb.ndim == 1:\n        bb = bb.reshape(1, -1)\n    x_min = bb[:, 0].min()\n    x_max = bb[:, 0].max()\n    y_min = bb[:, 1].min()\n    y_max = bb[:, 1].max()\n    z_min = bb[:, 2].min()\n    z_max = bb[:, 2].max()\n    out = np.stack([x_min, x_max, y_min, y_max, z_min, z_max], axis=0)\n    return out\n\n\ndef extent2bb(extent):\n    if extent.ndim == 1:\n        extent = extent.reshape(1, -1)\n\n    x_min, x_max = extent[:, 0], extent[:, 1]\n    y_min, y_max = extent[:, 2], extent[:, 3]\n    z_min, z_max = extent[:, 4], extent[:, 5]\n    arr = (\n        [\n            x_min,\n            y_min,\n            z_min,\n            x_max,\n            y_min,\n            z_min,\n            x_max,\n            y_max,\n            z_min,\n            x_min,\n            y_max,\n            z_min,\n            x_min,\n            y_min,\n            z_max,\n            x_max,\n            y_min,\n            z_max,\n            x_max,\n            y_max,\n            z_max,\n            x_min,\n            y_max,\n            z_max,\n        ],\n    )\n    if torch.is_tensor(extent):\n        bb3d = torch.stack(arr, dim=-1).reshape(-1, 8, 3)\n    elif isinstance(extent, np.ndarray):\n        bb3d = np.stack(arr, axis=-1).reshape(-1, 8, 3)\n    else:\n        raise TypeError(\"Unknown type\")\n\n    return bb3d.squeeze()\n\n\ndef get_all_Ts_world_object_for_time(\n    obs,\n    time,\n    load_dynamic_objects=True,\n    interpolate_poses=True,\n    dt_threshold_ns: int = 10_000_000,\n):\n    # concat static obb poses and dynamic ones at the current time\n    static_Ts_world_object = obs[\"static_Ts_world_object\"]\n    have_dynamic_objects = len(obs[\"timedTs_world_object\"]) > 1\n    if load_dynamic_objects and have_dynamic_objects:\n        if time in obs[\"timedTs_world_object\"].keys():\n            dynamic_Ts_world_object = obs[\"timedTs_world_object\"][time]\n        else:\n            if interpolate_poses:\n                dynamic_Ts_world_object = interpolate_timed_poses(\n                    obs[\"timedTs_world_object\"], time\n                )\n                print(\n                    f\"Warning: did not find time {time} in dynamic objects pose map - so interpolated poses\"\n                )\n            else:\n                dynamic_Ts_world_object, dt = closest_timed_poses(\n                    obs[\"timedTs_world_object\"], time\n                )\n                if abs(dt) > dt_threshold_ns:\n                    dynamic_Ts_world_object = {}\n                else:\n                    print(\n                        f\"Warning: no time {time} in dynamic objects pose map - picked closest pose before in time {dt}\"\n                    )\n    else:\n        dynamic_Ts_world_object = {}\n    all_Ts_world_object = {}\n    all_Ts_world_object.update(static_Ts_world_object)\n    all_Ts_world_object.update(dynamic_Ts_world_object)\n    static_inst = set(static_Ts_world_object.keys())\n    dynamic_inst = set(dynamic_Ts_world_object.keys())\n    if len(static_inst.intersection(dynamic_inst)):\n        print(\n            \"Warning: static and dynamic instances overlap overwriting static poses with dynamic ones! \"\n        )\n\n    return all_Ts_world_object\n\n\ndef get_inst_id_in_camera(\n    bb2s_camera,\n    time: int,\n    camera_name: str,\n):\n    if bb2s_camera and time in bb2s_camera.keys():\n        inst_ids = [line[0] for line in bb2s_camera[time]]\n    else:\n        bb2_times = list(bb2s_camera.keys())\n        nearest_idx = find_nearest(bb2_times, float(time), return_index=True)\n        nearest_time = bb2_times[nearest_idx]\n        if abs(time - nearest_time) >= 1_000_000:\n            print(\n                f\"Error: {camera_name}: target time {time}ns has too large gap from the found nearest time {nearest_time}ns with gap {abs(time - nearest_time)}ns, skip this frame.\"\n            )\n            return []\n        print(\n            f\"{camera_name}:\",\n            time,\n            nearest_time,\n            time - nearest_time,\n        )\n        inst_ids = [line[0] for line in bb2s_camera[nearest_time]]\n    return inst_ids\n\n\ndef get_instance_id_in_frameset(\n    obs,\n    time: int,\n    load_dynamic_objects: bool,\n    interpolate_poses: bool = True,\n    dt_threshold_ns: int = 10_000_000,\n):\n    # Get 3D object transforms that are visible in this frameset.\n    bb2s_rgb = obs[ARIA_OBB_BB2[0]]\n    bb2s_slaml = obs[ARIA_OBB_BB2[1]]\n    bb2s_slamr = obs[ARIA_OBB_BB2[2]]\n    bb2_time_rgb = time\n\n    all_Ts_world_object = get_all_Ts_world_object_for_time(\n        obs,\n        bb2_time_rgb,\n        load_dynamic_objects,\n        interpolate_poses=interpolate_poses,\n        dt_threshold_ns=dt_threshold_ns,\n    )\n    instance2proto = obs[\"inst2proto\"]\n    local_extents = obs[ARIA_OBB_BB3]\n\n    inst_ids_rgb = get_inst_id_in_camera(bb2s_rgb, bb2_time_rgb, \"rgb\")\n\n    # Support having visibility for only RGB.\n    if len(bb2s_slaml) == 0:\n        inst_ids_slaml = []\n    else:\n        inst_ids_slaml = get_inst_id_in_camera(bb2s_slaml, bb2_time_rgb, \"slaml\")\n    if len(bb2s_slamr) == 0:\n        inst_ids_slamr = []\n    else:\n        inst_ids_slamr = get_inst_id_in_camera(bb2s_slamr, bb2_time_rgb, \"slamr\")\n\n    # Get union of all instance ids.\n    inst_ids = list(\n        set(inst_ids_rgb).union(set(inst_ids_slaml)).union(set(inst_ids_slamr))\n    )\n    # Make sure that all 2D BB instance ids have a 3D pose, prototype and local extent.\n    warning_ids = [\n        id\n        for id in inst_ids\n        if id not in all_Ts_world_object\n        or id not in instance2proto\n        or id not in local_extents\n    ]\n    if len(warning_ids) > 0:\n        [inst_ids.remove(warning_id) for warning_id in warning_ids]\n\n    inst_ids = np.unique(inst_ids)\n    return inst_ids\n\n\ndef get_bb2s_for_instances(obs, time, inst_ids, cam_names, cam_scales=None):\n    \"\"\"\n    Args:\n        obs (dict): observation dict from Hive table\n        time (int): nanoseconds timestamp of observation\n        inst_ids (list): list of instance ids to get 2D BBs for\n        cam_names (list): list of camera names\n        cam_scales (dict): dict of camera scale for each camera (via cam_name) {cam_name:[x_scal, y_scale]}\n    \"\"\"\n    # visible bounding boixes are >=0; invisible ones are < 0\n    no_bb2 = [-1, -1, -1, -1]\n    bb2_time_rgb = time\n    bb2s = {cam_name: [] for cam_name in cam_names}\n    for bb2_name, cam_name in zip(ARIA_OBB_BB2, cam_names):\n        if bb2_time_rgb not in obs[bb2_name].keys():\n            bb2_insts = [no_bb2] * len(inst_ids)\n        else:\n            bb2_obs_at_time = obs[bb2_name][bb2_time_rgb]\n            bb2_insts = bb2s[cam_name]\n            for iid in inst_ids:\n                bb2 = None\n                for line in bb2_obs_at_time:\n                    if line[0] == iid:\n                        bb2 = line[1:]\n                        break\n                if bb2:\n                    bb2_insts.append(bb2)\n                else:\n                    bb2_insts.append(no_bb2)\n        bb2_insts = torch.from_numpy(np.array(bb2_insts)).float()\n        if cam_scales:\n            bb2_insts[:2] = bb2_insts[:2] * cam_scales[cam_name][0]\n            bb2_insts[2:] = bb2_insts[2:] * cam_scales[cam_name][1]\n        bb2s[cam_name] = bb2_insts\n    return bb2s\n\n\ndef next_obb_observations(\n    obs,\n    time,\n    inst_ids,\n    cam_names,\n    cam_scales=None,\n    load_dynamic_objects: bool = True,\n    interpolate_poses: bool = True,\n    dt_threshold_ns: int = 10_000_000,\n):\n    \"\"\"\n    Args:\n        obs (dict): observation dict from Hive table\n        time (float): timestamp of observation\n        inst_ids (list): list of instance ids to get 2D BBs for\n        cam_names (list): list of camera names\n        cam_scales (dict): dict of camera scale for each camera (via cam_name) {cam_name:[x_scal, y_scale]}\n    \"\"\"\n    all_Ts_world_object = get_all_Ts_world_object_for_time(\n        obs,\n        time,\n        load_dynamic_objects=load_dynamic_objects,\n        interpolate_poses=interpolate_poses,\n        dt_threshold_ns=dt_threshold_ns,\n    )\n    # make sure we have a pose for all instances at this time.\n    inst_ids = list(set(inst_ids).intersection(set(all_Ts_world_object.keys())))\n    # make sure we have instances for all obb extends\n    inst_ids = list(set(inst_ids).intersection(set(obs[ARIA_OBB_BB3].keys())))\n\n    # get data\n    Ts_wo = [all_Ts_world_object[iid] for iid in inst_ids]\n    proto_names = [obs[\"inst2proto\"][iid] for iid in inst_ids]\n    proto_ids = [obs[\"proto2id\"][name] for name in proto_names]\n    exs = [obs[ARIA_OBB_BB3][iid] for iid in inst_ids]\n    bbs_object = np.array([extent2bb(ex) for ex in exs])\n    bbs_object = torch.tensor(bbs_object).float()\n    # handle no obbs case.\n    if Ts_wo:\n        Ts_world_object = torch.stack(Ts_wo).float()\n    else:\n        Ts_world_object = PoseTW(torch.zeros(0, 12))\n    inst_ids = torch.tensor(inst_ids)\n    sem_ids = torch.tensor(proto_ids)\n    bb3 = torch.from_numpy(np.array([bb2extent(bb) for bb in bbs_object]))\n    # get 2D BBs for this frame\n    bb2s = get_bb2s_for_instances(obs, time, inst_ids, cam_names, cam_scales)\n    return (\n        bb2s[\"rgb\"],\n        bb2s[\"slaml\"],\n        bb2s[\"slamr\"],\n        bb3,\n        Ts_world_object,\n        sem_ids,\n        inst_ids,\n    )\n"
  },
  {
    "path": "efm3d/utils/obb_matchers.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\n\nimport torch\nfrom efm3d.aria.obb import bb2_xxyy_to_xyxy, ObbTW\nfrom efm3d.utils.obb_utils import box3d_overlap_wrapper\nfrom scipy.optimize import linear_sum_assignment\nfrom torchvision.ops import generalized_box_iou\nfrom torchvision.ops.boxes import box_area\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass HungarianMatcher2d3d(torch.nn.Module):\n    \"\"\"This class computes an assignment between the targets and the predictions of the network\n    For efficiency reasons, the targets don't include the no_object. Because of this, in general,\n    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,\n    while the others are un-matched (and thus treated as non-objects).\n    \"\"\"\n\n    def __init__(\n        self,\n        cost_class: float = 1,\n        cost_bbox2: float = 1,\n        cost_giou2: float = 1,\n        cost_bbox3: float = 1,\n        cost_iou3: float = 1,\n    ):\n        \"\"\"Creates the matcher\n        Params:\n            cost_class: This is the relative weight of the classification error in the matching cost\n            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost\n            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost\n        \"\"\"\n        super().__init__()\n        self.cost_class = cost_class\n        self.cost_bbox2 = cost_bbox2\n        self.cost_bbox3 = cost_bbox3\n        self.cost_giou2 = cost_giou2\n        self.cost_iou3 = cost_iou3\n        assert (\n            cost_class != 0\n            or cost_bbox2 != 0\n            or cost_bbox3 != 0\n            or cost_giou2 != 0\n            or cost_iou3 != 0\n        ), \"all costs cant be 0\"\n\n    @torch.no_grad()\n    def forward_obbs(\n        self,\n        prd: ObbTW,\n        tgt: ObbTW,\n        prd_logits=None,\n        logits_is_prob: bool = False,\n    ):\n        if prd.ndim == 2:\n            return self.forward(\n                pred_logits=prd_logits.unsqueeze(0),\n                pred_bb2s=prd.bb2_rgb.unsqueeze(0),\n                pred_bb3s=prd.bb3corners_world.unsqueeze(0),\n                pred_center_world=prd.bb3_center_world.unsqueeze(0),\n                tgt_labels=[tgt.sem_id.squeeze(-1)],\n                tgt_bb2s=[tgt.bb2_rgb],\n                tgt_bb3s=[tgt.bb3corners_world],\n                tgt_center_world=[tgt.bb3_center_world],\n                logits_is_prob=logits_is_prob,\n            )[0]\n        elif prd.ndim == 3:\n            if isinstance(tgt, ObbTW):\n                tgt = tgt.remove_padding()\n            return self.forward(\n                pred_logits=prd_logits,\n                pred_bb2s=prd.bb2_rgb,\n                pred_bb3s=prd.bb3corners_world,\n                pred_center_world=prd.bb3_center_world,\n                tgt_labels=[tt.sem_id.squeeze(-1) for tt in tgt],\n                tgt_bb2s=[tt.bb2_rgb for tt in tgt],\n                tgt_bb3s=[tt.bb3corners_world for tt in tgt],\n                tgt_center_world=[tt.bb3_center_world for tt in tgt],\n                logits_is_prob=logits_is_prob,\n            )\n        else:\n            raise ValueError(f\"Unsupported shape {prd.shape}\")\n\n    @torch.no_grad()\n    def forward(\n        self,\n        pred_logits=None,\n        pred_bb2s=None,\n        pred_center_world=None,\n        pred_bb3s=None,\n        tgt_labels=None,\n        tgt_bb2s=None,\n        tgt_center_world=None,\n        tgt_bb3s=None,\n        logits_is_prob: bool = False,\n    ):\n        \"\"\"Performs the matching\n        Params:\n            outputs: This is a dict that contains at least these entries:\n                 \"pred_logits\": Tensor of dim [batch_size, snippet_frames, num_queries, num_semcls] with the classification logits\n                 \"pred_bb2s\": Tensor of dim [batch_size, snippet_frames, num_queries, 4] with the predicted 2d box coordinates\n            targets: This is a list of batch_size targets:\n                 \"tgt_labels\": Tensor of dim [snippet_frames, num_target_boxes] (where num_target_boxes is the number of ground-truth\n                           objects in the target) containing the class labels\n                 \"tgt_bb2s\": Tensor of dim [snippet_frames, num_target_boxes, 4] containing the target box coordinates\n        Returns:\n            A list of size batch_size, containing tuples of (index_i, index_j) where:\n                - index_i is the indices of the selected predictions (in order)\n                - index_j is the indices of the corresponding selected targets (in order)\n            For each batch element, it holds:\n                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)\n        \"\"\"\n        assert pred_bb2s.dim() == 3, f\"{pred_bb2s.shape}\"\n        assert pred_center_world.dim() == 3, f\"{pred_center_world.shape}\"\n        B, N = pred_bb2s.shape[:2]\n        assert len(tgt_bb2s) == B, \"number of targets should be equal to batch size\"\n        assert len(tgt_center_world) == B, (\n            \"number of targets should be equal to batch size\"\n        )\n\n        cost_class = None\n        if pred_logits is not None:\n            assert pred_logits.dim() == 3, f\"{pred_logits.shape}\"\n            assert len(tgt_labels) == B, (\n                \"number of targets should be equal to batch size\"\n            )\n            # We flatten to compute the cost matrices in a batch\n            # [batch_size * num_queries, num_semcls]\n            out_prob = pred_logits.flatten(0, 1)\n            if not logits_is_prob:\n                out_prob = out_prob.softmax(-1)\n            tgt_ids = torch.cat(tgt_labels)\n            assert tgt_ids.ndim == 1, f\"{tgt_ids.shape} is not right\"\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be omitted.\n            cost_class = -out_prob[:, tgt_ids]\n            if cost_class.isnan().any():\n                logger.warning(\n                    f\"have {cost_class.isnan().sum()} nan values in cost_class\"\n                )\n            cost_class = torch.nan_to_num(cost_class, nan=1e6)\n\n        # [batch_size * num_queries, 4]\n        pred_bb2s = pred_bb2s.flatten(0, 1)\n        pred_center_world = pred_center_world.flatten(0, 1)\n        # remember sizes for later\n        sizes = [len(v) for v in tgt_bb2s]\n        # Also concat the target boxes\n        tgt_bb2s = torch.cat(tgt_bb2s)\n        tgt_center_world = torch.cat(tgt_center_world)\n\n        # Compute the L1 cost between boxes\n        cost_bbox2 = torch.cdist(pred_bb2s, tgt_bb2s, p=1)\n        if cost_bbox2.isnan().any():\n            logger.warning(f\"have {cost_bbox2.isnan().sum()} nan values in cost_bbox\")\n        cost_bbox2 = torch.nan_to_num(cost_bbox2, nan=1e6)\n        # 3d bbs\n        cost_bbox3 = torch.cdist(pred_center_world, tgt_center_world, p=1)\n        if cost_bbox3.isnan().any():\n            logger.warning(f\"have {cost_bbox3.isnan().sum()} nan values in cost_bbox\")\n        cost_bbox3 = torch.nan_to_num(cost_bbox3, nan=1e6)\n\n        # 3d bbs iou\n        cost_iou3 = None\n        if pred_bb3s is not None and tgt_bb3s is not None and self.cost_iou3 > 0.0:\n            pred_bb3s = pred_bb3s.flatten(0, 1)\n            tgt_bb3s = torch.cat(tgt_bb3s)\n            cost_iou3 = -box3d_overlap_wrapper(pred_bb3s, tgt_bb3s).iou\n            if cost_iou3.isnan().any():\n                logger.warning(\n                    f\"have {cost_iou3.isnan().sum()} nan values in cost_iou3\"\n                )\n            cost_iou3 = torch.nan_to_num(cost_iou3, nan=1e6)\n\n        # Compute the giou cost between boxes\n        cost_giou = -generalized_box_iou(\n            bb2_xxyy_to_xyxy(pred_bb2s), bb2_xxyy_to_xyxy(tgt_bb2s)\n        )\n        # set invalid costs to high value so they are not chosen in linear assignment\n        # invalid predictions have size 0.0\n        pred_areas = box_area(bb2_xxyy_to_xyxy(pred_bb2s))\n        pred_invalid = pred_areas <= 0.0\n        cost_giou[pred_invalid, :] = 1e6\n        if cost_giou.isnan().any():\n            logger.warning(f\"have {cost_giou.isnan().sum()} nan values in cost_giou\")\n        cost_giou = torch.nan_to_num(cost_giou, nan=1e6)\n\n        # Final cost matrix\n        C = (\n            self.cost_bbox2 * cost_bbox2\n            + self.cost_bbox3 * cost_bbox3\n            + self.cost_giou2 * cost_giou\n        )\n        if cost_class is not None:\n            C = C + self.cost_class * cost_class\n        if cost_iou3 is not None:\n            C = C + self.cost_iou3 * cost_iou3\n        C = C.view(B, N, -1).cpu()\n\n        indices = [\n            linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))\n        ]\n        return [\n            (\n                torch.as_tensor(row_id, dtype=torch.int64),\n                torch.as_tensor(col_id, dtype=torch.int64),\n            )\n            for row_id, col_id in indices\n        ]\n"
  },
  {
    "path": "efm3d/utils/obb_metrics.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom time import time\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom efm3d.aria.camera import CameraTW\nfrom efm3d.aria.obb import ObbTW\nfrom efm3d.utils.file_utils import parse_global_name_to_id_csv\nfrom efm3d.utils.obb_utils import MeanAveragePrecision3D\n\nARIA_CAM_IDS = list(range(3))\nARIA_CAM_NAMES = [\"rgb\", \"slaml\", \"slamr\"]\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\nclass ObbMetrics(torch.nn.Module):\n    \"\"\"\n    Metrics that directly work with our ObbTW class\n    It is a torch.nn.Module to be able to behave like a torchmetrics object\n    \"\"\"\n\n    def __init__(\n        self,\n        cam_ids=ARIA_CAM_IDS,\n        cam_names=ARIA_CAM_NAMES,\n        class_metrics: bool = False,\n        volume_range_metrics: bool = False,\n        eval_2d: bool = True,\n        eval_3d: bool = False,\n        ignore_bb2d_visibility: bool = False,\n        global_name_to_id_file: Optional[str] = None,\n        global_name_to_id: Optional[Dict] = None,\n        ret_all_prec_rec: Optional[bool] = False,\n        max_detection_thresholds: Optional[List[float]] = None,\n    ) -> None:\n        \"\"\"\n        Args:\n            cam_ids (list): list of camera ids to evaluate\n            cam_names (list): list of camera names to evaluate\n            class_metrics (bool): if True, computes per-class metrics\n            volume_range_metrics (bool): if True, computes volume range metrics\n            eval_2d (bool): if True, evaluate 2d detections\n            eval_3d (bool): if True, evaluate 3d detections\n        \"\"\"\n        from torchmetrics.detection.mean_ap import MeanAveragePrecision\n\n        super().__init__()\n        assert eval_2d or eval_3d, (\n            \"At least eval_2d or eval_3d needs to be set to True.\"\n        )\n        self.eval_2d = eval_2d\n        self.eval_3d = eval_3d\n        self.ignore_bb2d_visibility = ignore_bb2d_visibility\n\n        self.metric_2d = torch.nn.ModuleDict(\n            {\n                cam_name: MeanAveragePrecision(class_metrics=class_metrics)\n                for cam_name in cam_names\n            }\n            if eval_2d\n            else {}\n        )\n        bbox_area_ranges = None\n        if volume_range_metrics:\n            # Using category statistics from SUN_3D dataset: D42985037.\n            bbox_area_ranges = {\n                \"all\": (0, 1e5),\n                \"small\": (0, 1e-2),  # pen, remote, toilet paper, etc.\n                \"medium\": (1e-2, 1),  # chair, bin, monitor, etc.\n                \"large\": (1, 1e5),  # bed, sofa, etc.\n            }\n        if max_detection_thresholds is None:\n            # max number of detections to evaluate - 220 is sufficient for ASE scenes\n            max_detection_thresholds = [220]\n\n        self.metric_3d = torch.nn.ModuleDict(\n            {\n                cam_name: MeanAveragePrecision3D(\n                    class_metrics=class_metrics,\n                    bbox_area_ranges=bbox_area_ranges,\n                    max_detection_thresholds=max_detection_thresholds,\n                    ret_all_prec_rec=ret_all_prec_rec,\n                )\n                for cam_name in cam_names\n            }\n            if eval_3d\n            else {}\n        )\n        self.cam_ids = cam_ids\n        self.cam_names = cam_names\n        self.cam_id_to_name = {id: name for id, name in zip(cam_ids, cam_names)}\n        self.sem_id_to_name = None\n\n        if global_name_to_id_file is not None:\n            global_name_to_id = parse_global_name_to_id_csv(global_name_to_id_file)\n        if global_name_to_id is not None:\n            self.sem_id_to_name = {\n                int(sem_id): name for name, sem_id in global_name_to_id.items()\n            }\n\n    def update(self, prediction: ObbTW, target: ObbTW, cam: Optional[CameraTW] = None):\n        \"\"\" \"\"\"\n        for cam_id in self.cam_ids:\n            if self.eval_2d:\n                self.update_2d(\n                    prediction.bb2(cam_id),\n                    prediction.prob.squeeze(),\n                    prediction.sem_id.squeeze(),\n                    target.bb2(cam_id),\n                    target.sem_id.squeeze(),\n                    cam_id,\n                )\n            if self.eval_3d:\n                visible_predictions_ind = prediction.visible_bb3_ind(cam_id)\n                visible_targets_ind = target.visible_bb3_ind(cam_id)\n                if self.ignore_bb2d_visibility:\n                    visible_predictions_ind[:] = True\n                    visible_targets_ind[:] = True\n                    if not visible_predictions_ind.any():\n                        print(\"WARNING: no predictions are visible\")\n                    if not visible_targets_ind.any():\n                        print(\"WARNING: no targets are visible\")\n\n                # Use visible boxes in the camera for evaluation\n                self.update_3d(\n                    prediction.bb3corners_world[visible_predictions_ind],\n                    prediction.prob[visible_predictions_ind].view(-1),\n                    prediction.sem_id[visible_predictions_ind].view(-1),\n                    target.bb3corners_world[visible_targets_ind],\n                    target.sem_id[visible_targets_ind].view(-1),\n                    cam_id,\n                )\n\n    def forward(self, prediction: ObbTW, target: ObbTW):\n        self.update(prediction, target)\n        return self.compute()\n\n    def update_3d(\n        self,\n        pred_bb3corners: torch.Tensor,\n        pred_scores: torch.Tensor,\n        pred_labels: torch.Tensor,\n        tgt_bb3corners: torch.Tensor,\n        tgt_labels: torch.Tensor,\n        cam_id: int = 0,\n    ):\n        assert pred_bb3corners.dim() == 3\n        assert tgt_bb3corners.dim() == 3\n        assert pred_scores.dim() == 1\n        assert pred_labels.dim() == 1\n        assert tgt_labels.dim() == 1\n        p = [\n            {\n                \"boxes\": pred_bb3corners,\n                \"scores\": pred_scores,\n                \"labels\": pred_labels,\n            }\n        ]\n        t = [\n            {\n                \"boxes\": tgt_bb3corners,\n                \"labels\": tgt_labels,\n            }\n        ]\n        self.metric_3d[self.cam_id_to_name[cam_id]].update(p, t)\n\n    def update_2d(\n        self,\n        pred_bb2: torch.Tensor,\n        pred_scores: torch.Tensor,\n        pred_labels: torch.Tensor,\n        tgt_bb2: torch.Tensor,\n        tgt_labels: torch.Tensor,\n        cam_id: int = 0,\n    ):\n        assert pred_scores.dim() == 1\n        assert pred_labels.dim() == 1\n        assert tgt_labels.dim() == 1\n        p = [\n            {\n                \"boxes\": pred_bb2,\n                \"scores\": pred_scores,\n                \"labels\": pred_labels,\n            }\n        ]\n        t = [\n            {\n                \"boxes\": tgt_bb2,\n                \"labels\": tgt_labels,\n            }\n        ]\n        self.metric_2d[self.cam_id_to_name[cam_id]].update(p, t)\n\n    def update_2d_instances(\n        self,\n        preds,  #: List[Instances],\n        tgts,  #: List[Instances],\n        cam_id: int = 0,\n    ):\n        for pred, tgt in zip(preds, tgts):\n            self.update_2d(\n                pred_bb2=pred.pred_boxes.tensor,\n                pred_scores=pred.scores,\n                pred_labels=pred.pred_classes,\n                tgt_bb2=tgt.gt_boxes.tensor,\n                tgt_labels=tgt.gt_classes,\n                cam_id=cam_id,\n            )\n\n    def compute(self):\n        metrics = {}\n        for cam_name in self.cam_names:\n            if self.eval_2d:\n                m2d = self.metric_2d[cam_name].compute()\n                for metric_name, val in m2d.items():\n                    if (\n                        \"small\" not in metric_name\n                        and \"medium\" not in metric_name\n                        and \"large\" not in metric_name\n                    ):\n                        metrics[f\"{cam_name}/{metric_name}_2D\"] = val\n            if self.eval_3d:\n                logger.info(f\"Computing metric {self.metric_3d[cam_name]}\")\n                t0 = time()\n                m3d = self.metric_3d[cam_name].compute(self.sem_id_to_name)\n                t1 = time()\n                logger.info(\n                    f\"DONE Computing metric {self.metric_3d[cam_name]} in {t1 - t0} seconds\"\n                )\n                for metric_name, val in m3d.items():\n                    metrics[f\"{cam_name}/{metric_name}_3D\"] = val\n        return metrics\n\n    def reset(self):\n        for metric in self.metric_2d.values():\n            metric.reset()\n        for metric in self.metric_3d.values():\n            metric.reset()\n"
  },
  {
    "path": "efm3d/utils/obb_trackers.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nfrom typing import Optional\n\nimport torch\nfrom efm3d.aria.camera import CameraTW\nfrom efm3d.aria.obb import bb2_xxyy_to_xyxy, bb3_xyzxyz_to_xxyyzz, ObbTW\nfrom efm3d.aria.pose import all_rot90, find_r90, PoseTW\nfrom efm3d.utils.obb_matchers import HungarianMatcher2d3d\nfrom efm3d.utils.obb_utils import box3d_overlap_wrapper, remove_invalid_box3d\nfrom torch.nn import functional as F\nfrom torchvision.ops.boxes import box_iou\n\nlogging.basicConfig()\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef nms_3d(\n    obbs,\n    nms_iou3_thr: float = 0.1,\n    nms_iou3_max_thr: float = 0.15,\n    verbose: bool = False,\n    mark_in_place: bool = False,\n):\n    \"\"\"\n    NMS based on 3D bbs. When a duplicate is found the obb with higher probability is retained.\n    \"\"\"\n    ids_keep, ids_bad = list(range(obbs.shape[0])), []\n    if mark_in_place:\n        if obbs is None or obbs.shape[0] == 0:\n            return (ids_keep, ids_bad)\n        remove_invalid_box3d(obbs, mark_in_place)\n        if obbs.shape[0] == 0:\n            return (ids_keep, ids_bad)\n    else:\n        if obbs is None or obbs.shape[0] == 0:\n            return obbs, None, (ids_keep, ids_bad)\n        obbs, _ = remove_invalid_box3d(obbs, mark_in_place)\n        if obbs.shape[0] == 0:\n            return obbs, None, (ids_keep, ids_bad)\n\n    bb3s = obbs.bb3corners_world\n    N = bb3s.shape[0]\n    # iou and make the diagonal negative (since we dont want to check overlap with self.)\n    iou3 = box3d_overlap_wrapper(bb3s, bb3s).iou - 2.0 * torch.eye(\n        N, device=bb3s.device\n    )\n    # we want bb3s to overlap and be in the same class\n    same_ids = obbs.sem_id == obbs.sem_id.view(1, -1)\n    overlap = iou3 > nms_iou3_thr\n    overlap_overwrite = iou3 > nms_iou3_max_thr\n    nms = torch.logical_or(torch.logical_and(overlap, same_ids), overlap_overwrite)\n\n    if verbose and overlap.count_nonzero() > 0:\n        logger.debug(\"overlap\", iou3[overlap])\n    if verbose and overlap_overwrite.count_nonzero() > 0:\n        logger.debug(\"overlap_overwrite\", iou3[overlap_overwrite])\n\n    # ids where we want to NMS\n    ids = torch.nonzero(nms, as_tuple=True)\n    # ids of the obbs we want to remove (lower probability)\n    ids_bad = torch.where(\n        (obbs[ids[0]].prob < obbs[ids[1]].prob).squeeze(-1), ids[0], ids[1]\n    )\n    ids_bad = set(ids_bad.tolist())\n    if len(ids_bad) > 0:\n        # have some obbs to remove; compute which to keep and return those.\n        if verbose:\n            logger.debug(\n                f\"NMS3d: found {len(ids_bad)} non-maxima to suppress. {ids_bad}/ {bb3s.shape[0]}\"\n            )\n        ids_keep = list(set(range(bb3s.shape[0])) - ids_bad)\n        ids_bad = list(ids_bad)\n        if mark_in_place:\n            obbs._mark_invalid_ids(torch.tensor(ids_bad, dtype=torch.long))\n            return (ids_keep, ids_bad)\n        return obbs[ids_keep], obbs[ids_bad], (ids_keep, ids_bad)\n    if mark_in_place:\n        return (ids_keep, ids_bad)\n    return obbs, None, (ids_keep, ids_bad)\n\n\ndef nms_2d(obbs, nms_iou2_thr: float, verbose: bool = False):\n    \"\"\"\n    NMS based on 2D bbs. When a duplicate is found the obb with higher probability is retained.\n    \"\"\"\n    if obbs is None or obbs.shape[0] == 0:\n        return obbs, None\n    bb2s = bb2_xxyy_to_xyxy(obbs.bb2_rgb)\n    N = bb2s.shape[0]\n    # iou and make the diagonal negative (since we dont want to check overlap with self.)\n    iou2 = box_iou(bb2s, bb2s) - 2.0 * torch.eye(N, device=bb2s.device)\n    # we want bb2s to overlap and be in the same class\n    same_ids = obbs.sem_id == obbs.sem_id.view(1, -1)\n    overlap = torch.logical_and(iou2 > nms_iou2_thr, same_ids)\n    # ids where we have overlap\n    ids = torch.nonzero(overlap, as_tuple=True)\n    # ids of the obbs we want to remove (lower probability)\n    ids_bad = torch.where(\n        (obbs[ids[0]].prob < obbs[ids[1]].prob).squeeze(-1), ids[0], ids[1]\n    )\n    ids_bad = set(ids_bad.tolist())\n    if len(ids_bad) > 0:\n        # have some obbs to remove; compute which to keep and return those.\n        if verbose:\n            logger.debug(\n                f\"NMS2d: found {len(ids_bad)} non-maxima to suppress. {ids_bad}/ {bb2s.shape[0]}\"\n            )\n        ids_keep = list(set(range(bb2s.shape[0])) - ids_bad)\n        return obbs[ids_keep], obbs[list(ids_bad)]\n    return obbs, None\n\n\nclass ObbTracker:\n    \"\"\"\n    A simple obb tracker that uses Hungarian matching to associate new detected\n    obbs with a set of \"world\"-state obbs it maintains incrementally.\n    \"\"\"\n\n    def __init__(\n        self,\n        track_best: bool = False,\n        track_running_average: bool = True,\n        max_assoc_dist: float = 0.1,\n        max_assoc_iou2: float = 0.2,\n        max_assoc_iou3: float = 0.2,\n        prob_inst_thr: float = 0.3,\n        prob_assoc_thr: float = 0.25,\n        nms_iou3_thr: float = 0.1,\n        nms_iou2_thr: float = 0.5,\n        w_max: int = 30,\n        w_min: int = 5,\n        dt_max_inst: float = 1.0,\n        dt_max_occ: float = 5.0,\n    ):\n        \"\"\"\n        Args:\n            track_best: choose the highest probability obb for obbs that have\n                been associated. This is the most basic fusion strategy.\n            track_running_average: maintain a running average of obbs that have\n                been associated. This allows denoising obb parameters but\n                struggles with detections that are not consistently in the same\n                canonical orientation.\n            max_assoc_dist: maximum distance to associate an obb with another\n                obb. Obbs that are further are assumed to be distinct and lead\n                to a new instantiation.\n            max_assoc_iou2: maximum 2D IoU to associate an obb with another obb;\n                beyond we instantiate a new obb.\n            max_assoc_iou3: maximum 3D IoU to associate an obb with another obb;\n                beyond we instantiate a new obb.\n            prob_inst_thr: minimum probability threshold for instantiating a new\n                world obb.\n            prob_assoc_thr: minimum probability threshold for associating a new\n                obb with existing world obbs.\n            nms_iou3_thr: 3D IoU threshold to consider an world obb to be a\n                duplicate and suppress it.\n            nms_iou2_thr: 2D IoU threshold to consider an world obb to be a\n                duplicate and suppress it.\n            w_max: maximum weight accumulated in the running average.\n            w_min: minimum weight needed to return the scene_obb.\n            dt_max_inst: how long it can take for an object to be instantiated; in seconds\n            dt_max_occ: how long it is okay for an instantiated object to be occluded; in seconds\n        \"\"\"\n        self.matcher = HungarianMatcher2d3d(\n            cost_class=8.0,\n            cost_bbox2=0.0,\n            cost_bbox3=1.0,\n            cost_giou2=4.0,\n            cost_iou3=0.0,\n            # cost_class=8.0,\n            # cost_bbox2=0.0,\n            # cost_bbox3=0.0,\n            # cost_giou2=4.0,\n            # cost_iou3=4.0,\n        )\n        # the set of scene obbs\n        self.scene_obbs_w = None\n        # w is the weight (count) of each of the scene obbs\n        self.w = None\n        self.scene_probs_full = None\n        self.num_semcls = 128\n\n        # when last got an observation associated\n        self.last_obs_time = None\n        # when last possible to observe (based on 2d bb in frame)\n        self.last_possible_obs_time = None\n        # time of tracker (pseudo time incremented by 1 each track() call)\n        self.time = 0\n        self.hz = 10.0\n        # how long it can take for an object to be instantiated; in seconds\n        self.dt_max_inst = dt_max_inst\n        # how long it is okay for an instantiated object to be occluded; in seconds\n        self.dt_max_occ = dt_max_occ\n\n        self.w_max = w_max\n        self.w_min = w_min\n        self.track_best = track_best\n        self.track_running_average = track_running_average\n        self.max_assoc_dist = max_assoc_dist\n        self.max_assoc_iou2 = max_assoc_iou2\n        self.max_assoc_iou3 = max_assoc_iou3\n        self.prob_inst_thr = prob_inst_thr\n        self.prob_assoc_thr = prob_assoc_thr\n        self.nms_iou3_thr = nms_iou3_thr\n        self.nms_iou3_max_thr = 0.15\n        self.nms_iou2_thr = nms_iou2_thr\n        self.R90s = all_rot90()\n        self.counts_as_prob = False\n        self.device = torch.device(\"cpu\")\n        self.num_instances_so_far = 0\n\n    def reset(self):\n        self.scene_obbs_w = None\n        self.w = None\n        self.scene_probs_full = None\n        self.last_obs_time = None\n        self.last_possible_obs_time = None\n        self.time = 0\n\n    def set_hz(self, hz: float):\n        # adjust obb framerate\n        self.hz = float(hz)\n\n    @property\n    def obbs_world(self):\n        \"\"\"\n        The main function to access the tracked obbs that pass a set of gates.\n        The returned objects are a subset of the full set of world obbs.\n        \"\"\"\n        if self.scene_obbs_w is None:\n            return ObbTW().to(self.device), ObbTW().to(self.device)\n        sem_ids = self.scene_probs_full.argmax(dim=1)\n        if (sem_ids != self.scene_obbs_w.sem_id.squeeze(-1)).any():\n            change = sem_ids != self.scene_obbs_w.sem_id.squeeze(-1)\n            logger.debug(\n                \"semantic id has changed because of probs_full averaging \",\n                sem_ids[change].tolist(),\n                self.scene_obbs_w.sem_id.squeeze(-1)[change].tolist(),\n            )\n        self.scene_obbs_w.set_sem_id(sem_ids)\n        # which obbs have we seen recently?\n        dt = self.last_possible_obs_time - self.last_obs_time\n        seen_uninst = dt < self.dt_max_inst\n        seen_occlusion = dt < self.dt_max_occ\n        # remove obbs that do not have enough observations\n        enough_observations = self.w > self.w_min\n\n        # categories of obbs\n        good_visible = torch.logical_and(enough_observations, seen_occlusion)\n        good_invisible = torch.logical_and(enough_observations, ~seen_occlusion)\n        uninst_visible = torch.logical_and(~enough_observations, seen_uninst)\n        uninst_delete = torch.logical_and(~enough_observations, ~seen_uninst)\n\n        # return all good ones\n        obbs_w = self.scene_obbs_w[good_visible]\n        # return the stale visible ones for debugging\n        obbs_invis_w = self.scene_obbs_w[\n            torch.logical_or(uninst_visible, good_invisible)\n        ]\n\n        # delete uninst obbs\n        if uninst_delete.count_nonzero() > 0:\n            logger.debug(\n                f\"removing un-instantiated obbs {uninst_delete.count_nonzero()}\"\n            )\n            self.scene_obbs_w = self.scene_obbs_w[~uninst_delete]\n            self.scene_probs_full = self.scene_probs_full[~uninst_delete]\n            self.last_obs_time = self.last_obs_time[~uninst_delete]\n            self.last_possible_obs_time = self.last_possible_obs_time[~uninst_delete]\n            self.w = self.w[~uninst_delete]\n\n        # NMS based on 3D IoU\n        if self.nms_iou3_thr > 0.0:\n            obbs_w, obbs_non_max_w = self.nms_3d(obbs_w)\n        # NMS based on 2D IoU\n        if self.nms_iou2_thr > 0.0:\n            obbs_w, obbs_non_max_w = self.nms_2d(obbs_w)\n        return obbs_w, obbs_invis_w\n\n    def track(\n        self,\n        obbs_w: ObbTW,\n        probs_full: Optional[torch.Tensor] = None,\n        cam: Optional[CameraTW] = None,\n        T_world_rig: Optional[PoseTW] = None,\n    ):\n        \"\"\"\n        Args:\n            obbs_w: new obb detections to track. shape: Nx34\n            probs_full: full probability distribution over the classes of each of the obb detections.\n        \"\"\"\n        self.device = obbs_w.device\n        assert obbs_w.ndim == 2, f\"{obbs_w.shape}\"\n        # if we dont have any good new obbs return\n        if obbs_w.shape[0] == 0:\n            return self.obbs_world\n        # set 2d bbs\n        obbs_w = self.set_2d_bbs(obbs_w, cam, T_world_rig)\n        # filter out obbs that are too low probability to be associated\n        assoc = obbs_w.prob.squeeze(-1) > self.prob_assoc_thr\n        # remove probs_full padding\n        if probs_full is not None:\n            probs_full = probs_full[: obbs_w.shape[0], :]\n        else:\n            # create one-hot probability encoding based on semantic id\n            probs_full = F.one_hot(\n                obbs_w.sem_id.squeeze(-1).long(), num_classes=self.num_semcls\n            ).float()\n\n        obbs_w = obbs_w[assoc]\n        probs_full = probs_full[assoc] if probs_full is not None else None\n        # if we dont have any good new obbs return\n        if obbs_w.shape[0] == 0:\n            return self.obbs_world\n        # if we dont have any scene obbs yet (at the beginning) initialize the\n        # tracker state and return it.\n        if self.scene_obbs_w is None:\n            self.add_new_obbs(obbs_w, probs_full)\n            return self.obbs_world\n        # find matches\n        indices = self.matcher.forward_obbs(\n            prd=obbs_w,\n            tgt=self.scene_obbs_w,\n            prd_logits=probs_full,\n            logits_is_prob=True,\n        )\n        # if we have not matches we return\n        if len(indices[0]) == 0:\n            return self.obbs_world\n        # get matched obbs\n        pids, tids = indices[0], indices[1]\n        pobbs, tobbs = obbs_w[pids], self.scene_obbs_w[tids]\n        pprobs_full = probs_full[pids] if probs_full is not None else None\n        # find good associations based on the 2d and 3d iou\n        dist = torch.linalg.norm(\n            pobbs.bb3_center_world - tobbs.bb3_center_world, 2, dim=-1\n        ).cpu()\n        if self.max_assoc_iou2 > 0:\n            iou2 = (\n                box_iou(\n                    bb2_xxyy_to_xyxy(pobbs.bb2_rgb), bb2_xxyy_to_xyxy(tobbs.bb2_rgb)\n                )\n                .cpu()\n                .diagonal()\n            )\n        else:\n            iou2 = None\n\n        # filter out invalid bboxes, if we can't compute iou3 we return\n        pobbs, valid_ind = remove_invalid_box3d(pobbs)\n        pprobs_full = probs_full[valid_ind] if pprobs_full is not None else None\n        if pobbs.shape[0] == 0:\n            return self.obbs_world\n\n        if iou2 is not None:\n            iou2 = iou2[valid_ind]\n        tobbs = tobbs[valid_ind]\n        dist = dist[valid_ind]\n        pids = pids[valid_ind]\n        tids = tids[valid_ind]\n\n        # this function could fail due to thin object (ValueError: Planes have zero areas).\n        # if we can't compute iou3 we return\n        try:\n            iou3 = (\n                box3d_overlap_wrapper(pobbs.bb3corners_world, tobbs.bb3corners_world)\n                .iou.cpu()\n                .diagonal()\n            )\n        except Exception as e:\n            print(e)\n            return self.obbs_world\n\n        # assoc = torch.logical_or(dist < self.max_assoc_dist, iou2 > self.max_assoc_iou)\n        assoc = iou3 > self.max_assoc_iou3\n        if self.max_assoc_iou2 > 0.0:\n            assoc = torch.logical_or(assoc, iou2 > self.max_assoc_iou2)\n\n        # new obbs\n        new_ids = list(set(range(obbs_w.shape[0])) - set(pids.tolist()))\n        if assoc.count_nonzero() > 0:\n            logger.debug(\n                f\"{assoc.count_nonzero()} associated\",\n                \"dist\",\n                dist[assoc],\n                \"iou2\",\n                iou2[assoc] if iou2 is not None else None,\n                \"iou3\",\n                iou3[assoc],\n            )\n        if (~assoc).count_nonzero() > 0:\n            logger.debug(\n                f\"{(~assoc).count_nonzero()} not associated\",\n                \"dist\",\n                dist[~assoc],\n                \"iou2\",\n                iou2[assoc] if iou2 is not None else None,\n                \"iou3\",\n                iou3[~assoc],\n            )\n        new_obbs = torch.cat([pobbs[~assoc].clone(), obbs_w[new_ids]])\n        new_insts = new_obbs.prob.squeeze(-1) > self.prob_inst_thr\n        new_obbs = new_obbs[new_insts]\n        if pprobs_full is not None:\n            new_probs_full = torch.cat(\n                [pprobs_full[~assoc].clone(), probs_full[new_ids]]\n            )\n            new_probs_full = new_probs_full[new_insts]\n        # associated obbs\n        pids, tids = pids[assoc], tids[assoc]\n        pobbs, tobbs = pobbs[assoc], tobbs[assoc]\n        pprobs_full = pprobs_full[assoc] if pprobs_full is not None else None\n        # deal with associations\n        if self.track_best and tids.shape[0] > 0:\n            better_pred = (pobbs.prob > tobbs.prob).squeeze(-1).cpu()\n            # update better obbs\n            better_tids = tids[better_pred]\n            self.scene_obbs_w._data[better_tids] = pobbs._data[better_pred]\n            # increment weights\n            self.w[tids] = self.w[tids] + 1.0\n            # update times\n            self.last_obs_time[tids] = self.time\n\n            # update counts as probabilities\n            if self.counts_as_prob:\n                scene_obbs = self.scene_obbs_w[tids].clone()\n                scene_obbs.set_prob(self.w[tids])\n                self.scene_obbs_w._data[tids] = scene_obbs._data\n\n        elif self.track_running_average and tids.shape[0] > 0:\n            wpp = (self.w[tids] + 1.0).unsqueeze(-1)\n            pdiag = pobbs.bb3_diagonal\n            # running average T_world_object\n            dT_tobj_pobj = tobbs.T_world_object.inverse() @ pobbs.T_world_object\n            xi_tobj_pobj = dT_tobj_pobj.log()\n            # check if any relative pose is further than 45 degree which\n            # indicates that there is a 90 deg rotation that is closer.\n            dr = xi_tobj_pobj[..., 3:]\n            dr_norm = torch.linalg.norm(dr, 2, dim=-1)\n            too_big = dr_norm > 3.14 * 0.25\n            if too_big.any():\n                # find closest 90 degree rotation\n                pT_wo, R90min = find_r90(\n                    tobbs[too_big].T_world_object,\n                    pobbs[too_big].T_world_object,\n                    self.R90s.to(tobbs.device),\n                )\n                # update xi with the 90 deg closest rotation\n                dT_tobj_pobj = tobbs[too_big].T_world_object.inverse() @ pT_wo\n                xi_tobj_pobj[too_big] = dT_tobj_pobj.log()\n                # also permute the diagonal according to the 90 deg rotation\n                pdiag[too_big] = (\n                    (R90min @ pdiag[too_big].unsqueeze(-1)).squeeze(-1).abs()\n                )\n            # apply updates\n            ppT_world_object = tobbs.T_world_object @ PoseTW.exp(xi_tobj_pobj / wpp)\n            # running average over scale / diagonal of obb\n            ppdiag = (tobbs.bb3_diagonal * self.w[tids].unsqueeze(-1) + pdiag) / wpp\n            ppbb3 = bb3_xyzxyz_to_xxyyzz(\n                torch.cat([-ppdiag * 0.5, ppdiag * 0.5], dim=-1)\n            )\n            # running average over prob\n            ppprob = (tobbs.prob * self.w[tids].unsqueeze(-1) + pobbs.prob) / wpp\n\n            if pprobs_full is not None:\n                # running average over the full probability distribution\n                pprobs_full = (\n                    self.scene_probs_full[tids] * self.w[tids].unsqueeze(-1)\n                    + pprobs_full\n                ) / wpp\n                self.scene_probs_full[tids] = pprobs_full\n\n            # update target parameters\n            tobbs.set_T_world_object(ppT_world_object)\n            tobbs.set_bb3_object(ppbb3)\n            tobbs.set_prob(ppprob.squeeze(-1))\n            self.scene_obbs_w._data[tids] = tobbs._data\n            # update weights\n            self.w[tids] = wpp.clamp(max=self.w_max).squeeze(-1)\n            # update times\n            self.last_obs_time[tids] = self.time\n\n            # update counts as probabilities\n            if self.counts_as_prob:\n                scene_obbs = self.scene_obbs_w[tids].clone()\n                scene_obbs.set_prob(self.w[tids])\n                self.scene_obbs_w._data[tids] = scene_obbs._data\n\n        # add new obbs\n        if new_obbs.shape[0] > 0:\n            self.add_new_obbs(new_obbs, new_probs_full)\n\n        # update last possible obs time based on 2d visibility\n        self.update_last_obs_time(cam, T_world_rig)\n\n        # update time\n        self.time += 1.0 / self.hz\n        return self.obbs_world\n\n    def update_last_obs_time(self, cam, T_world_rig):\n        if cam is None or T_world_rig is None:\n            # mark all visible\n            self.last_possible_obs_time[:] = self.time\n            return\n        # compute visibility of scene obbs:\n        # - at least 50% of object has to be in 2d bb\n        # - 2d bb has to be at least 100 pixel area and each side has to be at least 10 pixels\n        bb2s, _, frac = self.scene_obbs_w.get_pseudo_bb2(\n            cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True\n        )\n        bb2s, frac = bb2s.squeeze(0), frac.squeeze(0)\n        visible = frac > 0.5\n        # area = box_area(bb2_xxyy_to_xyxy(bb2s))\n        # visible = torch.logical_and(visible, area > 100)\n        visible = torch.logical_and(visible, bb2s[..., 1] - bb2s[..., 0] > 10)\n        visible = torch.logical_and(visible, bb2s[..., 3] - bb2s[..., 2] > 10)\n        # update last possible times\n        self.last_possible_obs_time[visible] = self.time\n\n    def add_new_obbs(self, new_obbs, new_probs_full):\n        new_w = torch.ones(new_obbs.shape[0], device=new_obbs.device)\n        new_obbs_time = self.time * torch.ones(\n            new_obbs.shape[0], device=new_obbs.device\n        )\n        # Set instance ids for new obbs\n        new_obbs.set_inst_id(\n            torch.arange(\n                self.num_instances_so_far,\n                self.num_instances_so_far + new_obbs.shape[0],\n                device=new_obbs.device,\n            )\n        )\n        # Increment number of instances we have seen so far\n        self.num_instances_so_far += new_obbs.shape[0]\n\n        if self.scene_obbs_w is None:\n            self.scene_obbs_w = new_obbs\n            self.scene_probs_full = new_probs_full\n            self.w = new_w\n            self.last_obs_time = new_obbs_time\n            self.last_possible_obs_time = new_obbs_time.clone()\n        else:\n            self.scene_obbs_w = torch.cat([self.scene_obbs_w, new_obbs], dim=0)\n            self.scene_probs_full = torch.cat(\n                [self.scene_probs_full, new_probs_full], dim=0\n            )\n            self.w = torch.cat([self.w, new_w])\n            self.last_obs_time = torch.cat([self.last_obs_time, new_obbs_time])\n            self.last_possible_obs_time = torch.cat(\n                [self.last_possible_obs_time, new_obbs_time]\n            )\n\n    def nms_3d(self, obbs):\n        obbs_keep, obbs_rm, _ = nms_3d(obbs, self.nms_iou3_thr, self.nms_iou3_max_thr)\n        return obbs_keep, obbs_rm\n\n    def nms_2d(self, obbs):\n        return nms_2d(obbs, self.nms_iou2_thr)\n\n    def set_2d_bbs(self, obbs_w: ObbTW, cam: CameraTW, T_world_rig: PoseTW):\n        if cam is None or T_world_rig is None:\n            return obbs_w\n        if obbs_w.shape[0] > 0:\n            bb2s, valids, frac = obbs_w.get_pseudo_bb2(\n                cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True\n            )\n            invisible = ~valids  # frac < 0.1\n            bb2s[invisible] = -1.0\n            obbs_w.set_bb2(0, bb2s.squeeze(0))\n        if self.scene_obbs_w is not None and self.scene_obbs_w.shape[0] > 0:\n            bb2s, valids, frac = self.scene_obbs_w.get_pseudo_bb2(\n                cam.unsqueeze(0), T_world_rig.unsqueeze(0), 10, return_frac_valids=True\n            )\n            invisible = ~valids  # frac < 0.1\n            bb2s[invisible] = -1.0\n            self.scene_obbs_w.set_bb2(0, bb2s.squeeze(0))\n        return obbs_w\n"
  },
  {
    "path": "efm3d/utils/obb_utils.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Sequence, Tuple\n\nimport torch\nfrom efm3d.aria.obb import ObbTW\nfrom pytorch3d.ops.iou_box3d import (\n    _box3d_overlap,\n    _box_planes,\n    _box_triangles,\n    _check_nonzero,\n)\nfrom torch import IntTensor, Tensor\nfrom torch.nn import functional as F\nfrom torchmetrics.detection.mean_ap import (\n    _fix_empty_tensors,\n    BaseMetricResults,\n    MARMetricResults,\n    MeanAveragePrecision,\n)\nfrom torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n# logger.setLevel(logging.DEBUG)\n\n\n@dataclass\nclass IouOutputs:\n    vol: torch.Tensor\n    iou: torch.Tensor\n\n\ndef input_validator_box3d(  # noqa\n    preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]]\n) -> None:\n    \"\"\"Ensure the correct input format of `preds` and `targets`\"\"\"\n    if not isinstance(preds, Sequence):\n        raise ValueError(\"Expected argument `preds` to be of type Sequence\")\n    if not isinstance(targets, Sequence):\n        raise ValueError(\"Expected argument `target` to be of type Sequence\")\n    if len(preds) != len(targets):\n        raise ValueError(\n            \"Expected argument `preds` and `target` to have the same length\"\n        )\n\n    for k in [\"boxes\", \"scores\", \"labels\"]:\n        if any(k not in p for p in preds):\n            raise ValueError(f\"Expected all dicts in `preds` to contain the `{k}` key\")\n\n    for k in [\"boxes\", \"labels\"]:\n        if any(k not in p for p in targets):\n            raise ValueError(f\"Expected all dicts in `target` to contain the `{k}` key\")\n\n    if any(type(pred[\"boxes\"]) is not Tensor for pred in preds):\n        raise ValueError(\"Expected all boxes in `preds` to be of type Tensor\")\n    if any(type(pred[\"scores\"]) is not Tensor for pred in preds):\n        raise ValueError(\"Expected all scores in `preds` to be of type Tensor\")\n    if any(type(pred[\"labels\"]) is not Tensor for pred in preds):\n        raise ValueError(\"Expected all labels in `preds` to be of type Tensor\")\n    if any(type(target[\"boxes\"]) is not Tensor for target in targets):\n        raise ValueError(\"Expected all boxes in `target` to be of type Tensor\")\n    if any(type(target[\"labels\"]) is not Tensor for target in targets):\n        raise ValueError(\"Expected all labels in `target` to be of type Tensor\")\n\n    for i, item in enumerate(targets):\n        if item[\"boxes\"].size(0) != item[\"labels\"].size(0):\n            raise ValueError(\n                f\"Input boxes and labels of sample {i} in targets have a\"\n                f\" different length (expected {item['boxes'].size(0)} labels, got {item['labels'].size(0)})\"\n            )\n        if item[\"boxes\"].shape[-2:] != (8, 3):\n            raise ValueError(\n                f\"Input boxes of sample {i} in targets have a\"\n                f\" wrong shape (expected (...,8, 3), got {item['boxes'].shape})\"\n            )\n    for i, item in enumerate(preds):\n        if not (\n            item[\"boxes\"].size(0) == item[\"labels\"].size(0) == item[\"scores\"].size(0)\n        ):\n            raise ValueError(\n                f\"Input boxes, labels and scores of sample {i} in predictions have a\"\n                f\" different length (expected {item['boxes'].size(0)} labels and scores,\"\n                f\" got {item['labels'].size(0)} labels and {item['scores'].size(0)})\"\n            )\n\n\nclass MAPMetricResults3D(BaseMetricResults):\n    \"\"\"Class to wrap the final mAP results.\"\"\"\n\n    __slots__ = (\n        \"map\",\n        \"map_25\",\n        \"map_50\",\n        \"map_small\",\n        \"map_medium\",\n        \"map_large\",\n    )\n\n\ndef box3d_volume(boxes: Tensor) -> Tensor:\n    \"\"\"\n    Computes the volume of a set of 3d bounding boxes.\n\n    Args:\n        boxes (Tensor[N, 8, 3]): 3d boxes for which the volume will be computed.\n\n    Returns:\n        Tensor[N]: the volume for each box\n    \"\"\"\n    if boxes.numel() == 0:\n        return torch.zeros(0).to(boxes)\n    # Triple product to calculate volume\n    a = boxes[:, 1, :] - boxes[:, 0, :]\n    b = boxes[:, 3, :] - boxes[:, 0, :]\n    c = boxes[:, 4, :] - boxes[:, 0, :]\n    vol = torch.abs(torch.cross(a, b, dim=-1) @ c.T)[0]\n    return vol\n\n\ndef box3d_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:\n    \"\"\"\n    Convert 3d box coordinate conventions.\n    \"\"\"\n    assert in_fmt == \"xyz8\"\n    assert out_fmt == \"xyz8\"\n    return boxes\n\n\nclass MeanAveragePrecision3D(MeanAveragePrecision):\n    def __init__(\n        self,\n        box_format: str = \"xyz8\",\n        bbox_area_ranges: Optional[Dict[str, Tuple[float, float]]] = None,\n        iou_thresholds: Optional[List[float]] = None,\n        rec_thresholds: Optional[List[float]] = None,\n        max_detection_thresholds: Optional[List[int]] = None,\n        class_metrics: bool = False,  # compute per class metrics\n        compute_on_step: bool = True,\n        dist_sync_on_step: bool = False,\n        process_group: Optional[Any] = None,\n        dist_sync_fn: Callable = None,\n        ret_all_prec_rec: bool = False,\n    ) -> None:  # type: ignore\n        # Use Omni3D iOU thresholds by default\n        iou_thresholds = (\n            iou_thresholds\n            or torch.linspace(\n                0.05, 0.5, round((0.5 - 0.05) / 0.05) + 1, dtype=torch.float64\n            ).tolist()\n        )\n        rec_thresholds = (\n            rec_thresholds\n            or torch.linspace(\n                0.0, 1.00, round(1.00 / 0.01) + 1, dtype=torch.float64\n            ).tolist()\n        )\n        super().__init__(\n            iou_thresholds=iou_thresholds,\n            rec_thresholds=rec_thresholds,\n            compute_on_step=compute_on_step,\n            dist_sync_on_step=dist_sync_on_step,\n            process_group=process_group,\n            dist_sync_fn=dist_sync_fn,\n        )\n\n        if not _TORCHVISION_GREATER_EQUAL_0_8:\n            raise ModuleNotFoundError(\n                \"`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed.\"\n                \" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`.\"\n            )\n        allowed_box_formats = [\"xyz8\"]\n        if box_format not in allowed_box_formats:\n            raise ValueError(\n                f\"Expected argument `box_format` to be one of {allowed_box_formats} but got {box_format}\"\n            )\n        self.box_format = box_format\n        max_det_thr, _ = torch.sort(IntTensor(max_detection_thresholds or [1, 10, 100]))\n        self.max_detection_thresholds = max_det_thr.tolist()\n\n        if not isinstance(class_metrics, bool):\n            raise ValueError(\"Expected argument `class_metrics` to be a boolean\")\n\n        self.class_metrics = class_metrics\n        # important to overwrite after the __init__() call since they are otherwise overwritten by super().__init__()\n        self.bbox_area_ranges = bbox_area_ranges\n        if bbox_area_ranges is None:\n            self.bbox_area_ranges = {\"all\": (0, 1e5)}\n\n        self.ret_all_prec_rec = ret_all_prec_rec\n        self.eval_imgs = [] if self.ret_all_prec_rec else None\n\n    def update(\n        self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]\n    ) -> None:  # type: ignore\n        \"\"\"Add detections and ground truth to the metric.\n\n        Args:\n            preds: A list consisting of dictionaries each containing the key-values\n            (each dictionary corresponds to a single image):\n            - ``boxes``: ``torch.FloatTensor`` of shape\n                [num_boxes, 8, 3] containing `num_boxes` detection boxes of the format\n                specified in the constructor. By default, this method expects\n                (4) +---------+. (5)\n                    | ` .     |  ` .\n                    | (0) +---+-----+ (1)\n                    |     |   |     |\n                (7) +-----+---+. (6)|\n                    ` .   |     ` . |\n                    (3) ` +---------+ (2)\n                box_corner_vertices = [\n                    [xmin, ymin, zmin],\n                    [xmax, ymin, zmin],\n                    [xmax, ymax, zmin],\n                    [xmin, ymax, zmin],\n                    [xmin, ymin, zmax],\n                    [xmax, ymin, zmax],\n                    [xmax, ymax, zmax],\n                    [xmin, ymax, zmax],\n                ]\n            - ``scores``: ``torch.FloatTensor`` of shape\n                [num_boxes] containing detection scores for the boxes.\n            - ``labels``: ``torch.IntTensor`` of shape\n                [num_boxes] containing 0-indexed detection classes for the boxes.\n\n            target: A list consisting of dictionaries each containing the key-values\n            (each dictionary corresponds to a single image):\n            - ``boxes``: ``torch.FloatTensor`` of shape\n                [num_boxes, 8, 3] containing `num_boxes` ground truth boxes of the format\n                specified in the constructor.\n            - ``labels``: ``torch.IntTensor`` of shape\n                [num_boxes] containing 1-indexed ground truth classes for the boxes.\n\n        Raises:\n            ValueError:\n                If ``preds`` is not of type List[Dict[str, Tensor]]\n            ValueError:\n                If ``target`` is not of type List[Dict[str, Tensor]]\n            ValueError:\n                If ``preds`` and ``target`` are not of the same length\n            ValueError:\n                If any of ``preds.boxes``, ``preds.scores``\n                and ``preds.labels`` are not of the same length\n            ValueError:\n                If any of ``target.boxes`` and ``target.labels`` are not of the same length\n            ValueError:\n                If any box is not type float and of length 4\n            ValueError:\n                If any class is not type int and of length 1\n            ValueError:\n                If any score is not type float and of length 1\n        \"\"\"\n        input_validator_box3d(preds, target)\n\n        for item in preds:\n            boxes = _fix_empty_tensors(item[\"boxes\"])\n            boxes = box3d_convert(boxes, in_fmt=self.box_format, out_fmt=\"xyz8\")\n            if hasattr(self, \"detection_boxes\"):\n                self.detection_boxes.append(boxes)\n            else:\n                self.detections.append(boxes)\n\n            self.detection_labels.append(item[\"labels\"])\n            self.detection_scores.append(item[\"scores\"])\n\n        for item in target:\n            boxes = _fix_empty_tensors(item[\"boxes\"])\n            boxes = box3d_convert(boxes, in_fmt=self.box_format, out_fmt=\"xyz8\")\n            if hasattr(self, \"groundtruth_boxes\"):\n                self.groundtruth_boxes.append(boxes)\n            else:\n                self.groundtruths.append(boxes)\n            self.groundtruth_labels.append(item[\"labels\"])\n\n    def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:\n        \"\"\"Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given\n        image and class.\n\n        Args:\n            id:\n                Image Id, equivalent to the index of supplied samples\n            class_id:\n                Class Id of the supplied ground truth and detection labels\n            max_det:\n                Maximum number of evaluated detection bounding boxes\n        \"\"\"\n        if hasattr(self, \"detection_boxes\"):\n            gt = self.groundtruth_boxes[id]\n            det = self.detection_boxes[id]\n        else:\n            gt = self.groundtruths[id]\n            det = self.detections[id]\n        gt_label_mask = self.groundtruth_labels[id] == class_id\n        det_label_mask = self.detection_labels[id] == class_id\n        if len(gt_label_mask) == 0 or len(det_label_mask) == 0:\n            return Tensor([])\n        gt = gt[gt_label_mask]\n        det = det[det_label_mask]\n        if len(gt) == 0 or len(det) == 0:\n            return Tensor([])\n\n        # Sort by scores and use only max detections\n        scores = self.detection_scores[id]\n        scores_filtered = scores[self.detection_labels[id] == class_id]\n        inds = torch.argsort(scores_filtered, descending=True)\n        det = det[inds]\n        if len(det) > max_det:\n            det = det[:max_det]\n\n        # generalized_box_iou\n        # both det and gt are List of \"boxes\"\n        ious = box3d_overlap_wrapper(det, gt).iou\n        return ious\n\n    def _evaluate_image(\n        self,\n        id: int,\n        class_id: int,\n        area_range: Tuple[int, int],\n        max_det: int,\n        ious: dict,\n    ) -> Optional[dict]:\n        \"\"\"Perform evaluation for single class and image.\n\n        Args:\n            id:\n                Image Id, equivalent to the index of supplied samples.\n            class_id:\n                Class Id of the supplied ground truth and detection labels.\n            area_range:\n                List of lower and upper bounding box area threshold.\n            max_det:\n                Maximum number of evaluated detection bounding boxes.\n            ious:\n                IoU results for image and class.\n        \"\"\"\n        if hasattr(self, \"detection_boxes\"):\n            gt = self.groundtruth_boxes[id]\n            det = self.detection_boxes[id]\n        else:\n            gt = self.groundtruths[id]\n            det = self.detections[id]\n\n        gt_label_mask = self.groundtruth_labels[id] == class_id\n        det_label_mask = self.detection_labels[id] == class_id\n        if len(gt_label_mask) == 0 or len(det_label_mask) == 0:\n            return None\n        gt = gt[gt_label_mask]\n        det = det[det_label_mask]\n        if len(gt) == 0 and len(det) == 0:\n            return None\n\n        areas = box3d_volume(gt)\n        ignore_area = (areas < area_range[0]) | (areas > area_range[1])\n\n        # sort dt highest score first, sort gt ignore last\n        ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8))\n        # Convert to uint8 temporarily and back to bool, because \"Sort currently does not support bool dtype on CUDA\"\n        ignore_area_sorted = ignore_area_sorted.to(torch.bool)\n        gt = gt[gtind]\n        scores = self.detection_scores[id]\n        scores_filtered = scores[det_label_mask]\n        scores_sorted, dtind = torch.sort(scores_filtered, descending=True)\n        det = det[dtind]\n        if len(det) > max_det:\n            det = det[:max_det]\n        # load computed ious\n        ious = (\n            ious[id, class_id][:, gtind]\n            if len(ious[id, class_id]) > 0\n            else ious[id, class_id]\n        )\n\n        nb_iou_thrs = len(self.iou_thresholds)\n        nb_gt = len(gt)\n        nb_det = len(det)\n        gt_matches = torch.zeros(\n            (nb_iou_thrs, nb_gt), dtype=torch.bool, device=det.device\n        )\n        det_matches = torch.zeros(\n            (nb_iou_thrs, nb_det), dtype=torch.bool, device=det.device\n        )\n        gt_ignore = ignore_area_sorted\n        det_ignore = torch.zeros(\n            (nb_iou_thrs, nb_det), dtype=torch.bool, device=det.device\n        )\n\n        if torch.numel(ious) > 0:\n            for idx_iou, t in enumerate(self.iou_thresholds):\n                for idx_det, _ in enumerate(det):\n                    m = MeanAveragePrecision._find_best_gt_match(\n                        t, gt_matches, idx_iou, gt_ignore, ious, idx_det\n                    )\n                    if m != -1:\n                        det_ignore[idx_iou, idx_det] = gt_ignore[m]\n                        det_matches[idx_iou, idx_det] = 1\n                        gt_matches[idx_iou, m] = 1\n\n        # set unmatched detections outside of area range to ignore\n        det_areas = box3d_volume(det)\n        det_ignore_area = (det_areas < area_range[0]) | (det_areas > area_range[1])\n        ar = det_ignore_area.reshape((1, nb_det))\n        det_ignore = torch.logical_or(\n            det_ignore,\n            torch.logical_and(\n                det_matches == 0, torch.repeat_interleave(ar, nb_iou_thrs, 0)\n            ),\n        )\n        det_matches = det_matches.cpu()\n        gt_matches = gt_matches.cpu()\n        scores_sorted = scores_sorted.cpu()\n        gt_ignore = gt_ignore.cpu()\n        det_ignore = det_ignore.cpu()\n\n        ret = {\n            \"dtMatches\": det_matches,\n            \"gtMatches\": gt_matches,\n            \"dtScores\": scores_sorted,\n            \"gtIgnore\": gt_ignore,\n            \"dtIgnore\": det_ignore,\n        }\n\n        if self.ret_all_prec_rec:\n            self.eval_imgs.append(ret)\n\n        return ret\n\n    def _summarize_results(\n        self, precisions: Tensor, recalls: Tensor\n    ) -> Tuple[MAPMetricResults3D, MARMetricResults]:\n        \"\"\"Summarizes the precision and recall values to calculate mAP/mAR.\n\n        Args:\n            precisions:\n                Precision values for different thresholds\n            recalls:\n                Recall values for different thresholds\n        \"\"\"\n        results = dict(precision=precisions, recall=recalls)\n        map_metrics = MAPMetricResults3D()\n        last_max_det_thr = self.max_detection_thresholds[-1]\n        map_metrics.map = self._summarize(results, True, max_dets=last_max_det_thr)\n        if 0.25 in self.iou_thresholds:\n            map_metrics.map_25 = self._summarize(\n                results, True, iou_threshold=0.25, max_dets=last_max_det_thr\n            )\n        if 0.5 in self.iou_thresholds:\n            map_metrics.map_50 = self._summarize(\n                results, True, iou_threshold=0.5, max_dets=last_max_det_thr\n            )\n\n        mar_metrics = MARMetricResults()\n        for max_det in self.max_detection_thresholds:\n            mar_metrics[f\"mar_{max_det}\"] = self._summarize(\n                results, False, max_dets=max_det\n            )\n\n        if \"small\" in self.bbox_area_ranges:\n            map_metrics.map_small = self._summarize(\n                results, True, area_range=\"small\", max_dets=last_max_det_thr\n            )\n            mar_metrics.mar_small = self._summarize(\n                results, False, area_range=\"small\", max_dets=last_max_det_thr\n            )\n        if \"medium\" in self.bbox_area_ranges:\n            map_metrics.map_medium = self._summarize(\n                results, True, area_range=\"medium\", max_dets=last_max_det_thr\n            )\n            mar_metrics.mar_medium = self._summarize(\n                results, False, area_range=\"medium\", max_dets=last_max_det_thr\n            )\n        if \"large\" in self.bbox_area_ranges:\n            map_metrics.map_large = self._summarize(\n                results, True, area_range=\"large\", max_dets=last_max_det_thr\n            )\n            mar_metrics.mar_large = self._summarize(\n                results, False, area_range=\"large\", max_dets=last_max_det_thr\n            )\n\n        return map_metrics, mar_metrics\n\n    def compute(self, sem_id_to_name_mapping: Optional[Dict[int, str]] = None) -> dict:\n        metrics = MeanAveragePrecision.compute(self)\n        final_results = {}\n\n        # resemble class-based results.\n        if self.class_metrics:\n            seen_classes = self._get_classes()\n            if sem_id_to_name_mapping is None:\n                logger.warning(\"No sem_id to name mapping. Falling back on id=name\")\n                sem_id_to_name_mapping = {\n                    sem_id: str(sem_id) for sem_id in seen_classes\n                }\n\n            for k, v in metrics.items():\n                # Deal with per-class metrics\n                if \"per_class\" in k:\n                    # populate per class numbers\n                    mapped, unmapped = set(), set()\n                    for idx, pcr in enumerate(v):\n                        if seen_classes[idx] not in sem_id_to_name_mapping:\n                            unmapped.add(seen_classes[idx])\n                        else:\n                            mapped.add(seen_classes[idx])\n                            final_results[\n                                f\"{k}@{sem_id_to_name_mapping[seen_classes[idx]]}\"\n                            ] = pcr\n                    if len(unmapped) > 0:\n                        logger.warning(\n                            f\"Mapped sem_ids {mapped} but DID NOT MAP sem_ids {unmapped}\"\n                        )\n                else:\n                    final_results[k] = v\n        else:\n            final_results = metrics\n        return final_results\n\n\ndef coplanar_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None:\n    faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)\n    verts = boxes.index_select(index=faces.view(-1), dim=1)\n    B = boxes.shape[0]\n    P, V = faces.shape\n    # (B, P, 4, 3) -> (B, P, 3)\n    v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)\n\n    # Compute the normal\n    e0 = F.normalize(v1 - v0, dim=-1)\n    e1 = F.normalize(v2 - v0, dim=-1)\n    normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)\n\n    # Check the fourth vertex is also on the same plane\n    mat1 = (v3 - v0).view(B, 1, -1)  # (B, 1, P*3)\n    mat2 = normal.view(B, -1, 1)  # (B, P*3, 1)\n\n    good = (mat1.bmm(mat2).abs() < eps).view(-1)\n    return good\n\n\ndef nonzero_area_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None:\n    \"\"\"\n    Checks that the sides of the box have a non zero area\n    \"\"\"\n    faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)\n    verts = boxes.index_select(index=faces.view(-1), dim=1)\n    B = boxes.shape[0]\n    T, V = faces.shape\n    # (B, T, 3, 3) -> (B, T, 3)\n    v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)\n\n    normals = torch.cross(v1 - v0, v2 - v0, dim=-1)  # (B, T, 3)\n    face_areas = normals.norm(dim=-1) / 2\n    return (face_areas > eps).all(-1)\n\n\ndef bb3_valid(boxes: torch.Tensor, eps: float = 1e-4) -> None:\n    \"\"\"\n    Checks that the box is valid\n    \"\"\"\n    # Check that the box is not degenerate\n    return nonzero_area_mask(boxes, eps) & coplanar_mask(boxes, eps)\n\n\ndef box3d_overlap_wrapper(\n    boxes1: torch.Tensor, boxes2: torch.Tensor, eps: float = 1e-3\n) -> IouOutputs:\n    \"\"\"\n    only compute ious and volumes for good boxes and recompose with 0s for all bad boxes.\n    its better because it can handle if a subset of boxes is bad. But it costs more compute.\n    \"\"\"\n    if not all((8, 3) == box.shape[1:] for box in [boxes1, boxes2]):\n        raise ValueError(\"Each box in the batch must be of shape (8, 3)\")\n    m1 = bb3_valid(boxes1, eps)\n    m2 = bb3_valid(boxes2, eps)\n    b1_good = boxes1[m1]\n    b2_good = boxes2[m2]\n    vol = torch.zeros(boxes1.shape[0], boxes2.shape[0], device=boxes1.device)\n    iou = torch.zeros_like(vol)\n    if b1_good.shape[0] == 0 or b2_good.shape[0] == 0:\n        logger.info(\"no valid bbs returning 0 volumes and ious\")\n    else:\n        try:\n            vol_good, iou_good = _box3d_overlap.apply(b1_good, b2_good)\n            m_good = m1.unsqueeze(-1) & m2.unsqueeze(0)\n            vol[m_good] = vol_good.view(-1)\n            iou[m_good] = iou_good.view(-1)\n        except Exception:\n            logger.exception(\"returning 0 volumes and ious because of an exception\")\n    return IouOutputs(vol=vol, iou=iou)\n\n\ndef remove_invalid_box3d(obbs: ObbTW, mark_in_place: bool = False) -> torch.Tensor:\n    boxes = obbs.bb3corners_world\n    assert boxes.dim() == 3\n    assert (8, 3) == boxes.shape[1:]\n    valid_ind, invalid_ind = [], []\n    for b in range(boxes.shape[0]):\n        try:\n            # no need for co planarity check since our obbs are good by construction.\n            # _check_coplanar(boxes[b : b + 1, :, :])\n            _check_nonzero(boxes[b : b + 1, :, :])\n            valid_ind.append(b)\n        except Exception:\n            invalid_ind.append(b)\n\n    if mark_in_place:\n        obbs._mark_invalid_ids(torch.tensor(invalid_ind, dtype=torch.long))\n        return valid_ind\n    return obbs[valid_ind], valid_ind\n\n\ndef prec_recall_bb3(\n    padded_pred: ObbTW,\n    padded_target: ObbTW,\n    iou_thres=0.2,\n    return_ious=False,\n    per_class=False,\n):\n    \"\"\"Compute precision and recall based on 3D IoU.\"\"\"\n    assert padded_pred.ndim == 2 and padded_target.ndim == 2, (\n        f\"input ObbTWs must be Nx34, but got {padded_pred.shape} and {padded_target.shape}\"\n    )\n\n    pred = padded_pred.remove_padding()\n    target = padded_target.remove_padding()\n    pred_shape = pred.shape\n    target_shape = target.shape\n\n    pred, _ = remove_invalid_box3d(pred)\n    target, _ = remove_invalid_box3d(target)\n    if pred.shape != pred_shape:\n        logging.warning(\n            f\"Warning: predicted obbs filtered from {pred_shape[0]} to {pred.shape[0]}\"\n        )\n    if target.shape != target_shape:\n        logging.warning(\n            f\"Warning: target obbs filtered from {target_shape[0]} to {target.shape[0]}\"\n        )\n\n    prec_recall = (-1.0, -1.0, None)\n    # deal with edge cases first\n    if pred.shape[0] == 0:\n        # invalid precision and 0 recall\n        prec_recall = (-1.0, 0.0, None)\n        return prec_recall\n    elif target.shape[0] == 0:\n        # invalid recall and 0 precision\n        prec_recall = (0.0, -1.0, None)\n        return prec_recall\n\n    pred_sems = pred.sem_id\n    target_sems = target.sem_id.squeeze(-1).unsqueeze(0)\n    # 1. Match classes\n    sem_id_match = pred_sems == target_sems\n    # 2. Match IoUs\n    ious = box3d_overlap_wrapper(pred.bb3corners_world, target.bb3corners_world).iou\n    iou_match = ious > iou_thres\n    # 3. Match both\n    sem_iou_match = torch.logical_and(sem_id_match, iou_match)\n    # make final matching matrix\n    final_sem_iou_match = torch.zeros_like(sem_iou_match).bool()\n    num_pred = sem_iou_match.shape[0]  # TP + FP\n    num_target = sem_iou_match.shape[1]  # TP + FN\n    # 4. Deal with the case where one prediction correspond to multiple GTs.\n    # In this case, only the GT with highest IoU is considered the match.\n    for pred_idx in range(int(num_pred)):\n        if sem_iou_match[pred_idx, :].sum() <= 1:\n            final_sem_iou_match[pred_idx, :] = sem_iou_match[pred_idx, :].clone()\n        else:\n            tgt_ious = ious[pred_idx, :].clone()\n            tgt_ious[~sem_iou_match[pred_idx, :]] = -1.0\n            sorted_ids = torch.argsort(tgt_ious, descending=True)\n            tp_id = sorted_ids[0]\n            # Set the pred with highest iou\n            final_sem_iou_match[pred_idx, :] = False\n            final_sem_iou_match[pred_idx, tp_id] = True\n\n    # 5. Deal with the case where one GT correspond to multiple predictions.\n    # In this case, if the predictions contain probabilities, we take the one with the highest score, otherwise, we take the one with the highest iou.\n    for gt_idx in range(int(num_target)):\n        if final_sem_iou_match[:, gt_idx].sum() <= 1:\n            continue\n        else:\n            pred_scores = pred.prob.squeeze(-1).clone()\n            if torch.all(pred_scores.eq(-1.0)):\n                # go with highest iou\n                pred_ious = ious[:, gt_idx].clone()\n                pred_ious[~final_sem_iou_match[:, gt_idx]] = -1.0\n                sorted_ids = torch.argsort(pred_ious, descending=True)\n                tp_id = sorted_ids[0]\n                # Set the pred with highest iou\n                final_sem_iou_match[:, gt_idx] = False\n                final_sem_iou_match[tp_id, gt_idx] = True\n            else:\n                # go with the highest score\n                pred_scores[~final_sem_iou_match[:, gt_idx]] = -1.0\n                sorted_ids = torch.argsort(pred_scores, descending=True)\n                tp_id = sorted_ids[0]\n                final_sem_iou_match[:, gt_idx] = False\n                final_sem_iou_match[tp_id, gt_idx] = True\n\n    TPs = final_sem_iou_match.any(-1)\n    # precision = TP / (TP + FP) = TP / #Preds\n    num_tp = TPs.sum().item()\n    prec = num_tp / num_pred\n    # recall = TP / (TP + FN) = TP / #GTs\n    rec = num_tp / num_target\n\n    ret = (prec, rec, final_sem_iou_match)\n    if return_ious:\n        ret = ret + (ious,)\n\n    if per_class:\n        # per class prec and recalls\n        per_class_results = {}\n        all_sems = torch.cat([pred_sems.squeeze(-1), target_sems.squeeze(0)], dim=0)\n        unique_classes = torch.unique(all_sems.squeeze(-1))\n        for sem_id in unique_classes:\n            pred_obbs_sem = pred_sems.squeeze(-1) == sem_id\n            TPs_sem = (TPs & pred_obbs_sem).sum().item()\n            num_pred_sem = pred_obbs_sem.sum().item()\n            gt_obbs_sem = target_sems.squeeze(0) == sem_id\n            num_gt_sem = gt_obbs_sem.sum().item()\n            prec_sem = TPs_sem / num_pred_sem if num_pred_sem > 0 else -1.0\n            rec_sem = TPs_sem / num_gt_sem if num_gt_sem > 0 else -1.0\n            per_class_results[sem_id] = {}\n            per_class_results[sem_id][\"num_true_positives\"] = TPs_sem\n            per_class_results[sem_id][\"num_dets\"] = num_pred_sem\n            per_class_results[sem_id][\"num_gts\"] = num_gt_sem\n            per_class_results[sem_id][\"precision\"] = prec_sem\n            per_class_results[sem_id][\"recall\"] = rec_sem\n        ret = ret + (per_class_results,)\n\n    return ret\n\n\ndef prec_recall_curve(\n    pred_gt_pairs: List[Tuple[ObbTW, ObbTW]], iou_thres=0.2, interp=True\n):\n    # get all probs\n    probs = torch.empty(0)\n    for pred, _ in pred_gt_pairs:\n        pred_no_padding = pred.cpu().remove_padding()\n        ps = pred_no_padding.prob.squeeze(-1)\n        probs = torch.concatenate([probs, ps])\n\n    # truncate\n    probs = (probs * 100).int() / 100.0\n    # combine too close probs\n    probs = torch.unique(probs)\n    probs = probs.tolist()\n    probs.sort(reverse=True)\n\n    precs = []\n    recalls = []\n\n    eps = 1e-6\n    for prob in probs:\n        tps = 0\n        dets = 0\n        gts = 0\n        for pred, gt in pred_gt_pairs:\n            pred_no_padding = pred.remove_padding()\n            gt_no_padding = gt.remove_padding()\n            # thresholding\n            pred_no_padding = pred_no_padding[pred_no_padding.prob.squeeze(-1) >= prob]\n            dets += pred_no_padding.shape[0]\n            gts += gt_no_padding.shape[0]\n            pred_no_padding = (\n                pred_no_padding.cuda() if torch.cuda.is_available() else pred_no_padding\n            )\n            gt_no_padding = (\n                gt_no_padding.cuda() if torch.cuda.is_available() else gt_no_padding\n            )\n            _, _, match_mat = prec_recall_bb3(\n                pred_no_padding, gt_no_padding, iou_thres=iou_thres\n            )\n            if match_mat is None:\n                continue\n            tps += match_mat.any(-1).sum().item()\n        prec = tps / (dets + eps)\n        rec = tps / (gts + eps)\n        precs.append(prec)\n        recalls.append(rec)\n\n    if interp:\n        precs = torch.Tensor(precs)\n        precs_interp = []\n        for idx, _ in enumerate(precs):\n            precs_interp.append(precs[idx:].max().item())\n        precs = precs_interp\n    return precs, recalls, probs\n\n\ndef draw_prec_recall_curve(\n    prec: List,\n    recall: List,\n    save_folder: str,\n    name: str = \"pr_curve.png\",\n    iou_thres: Optional[float] = None,\n):\n    import matplotlib.pyplot as plt\n\n    fig_title = \"Prec-Recall Curve\"\n    if iou_thres is not None:\n        fig_title += f\" @IoU={iou_thres:.2f}\"\n    figure_path = os.path.join(save_folder, name)\n    plt.figure(figsize=(4, 4))\n    plt.title(fig_title)\n    plt.xlim([0, 1.1])\n    plt.ylim([0, 1.1])\n    plt.xlabel(\"recall\")\n    plt.ylabel(\"precision\")\n    # append prec recall if the last recall is not 1\n    if recall[-1] != 1:\n        prec.append(0)\n        recall.append(recall[-1])\n\n    plt.plot(recall, prec)\n    plt.savefig(figure_path)\n    print(f\"Save precision recall curve to {figure_path}\")\n    return figure_path\n"
  },
  {
    "path": "efm3d/utils/pointcloud.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_DISTANCE_M,\n    ARIA_IMG,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_POINTS_DIST_STD,\n    ARIA_POINTS_WORLD,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.utils.depth import dist_im_to_point_cloud_im\nfrom efm3d.utils.ray import sample_depths_in_grid, transform_rays\nfrom efm3d.utils.voxel import tensor_wrap_voxel_extent\nfrom torch.nn import functional as F\n\n\ndef get_points_world(batch, batch_idx=None, dist_std0=0.04, prefer_points=False):\n    if ARIA_DISTANCE_M[0] in batch and not prefer_points:\n        dists = batch[ARIA_DISTANCE_M[0]].squeeze(2)\n        cams = batch[ARIA_CALIB[0]]\n        B, T = cams.shape[:2]\n        Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]]\n        T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]\n        Ts_wr = T_ws @ Ts_sr\n        Ts_cw = cams.T_camera_rig @ Ts_wr.inverse()\n        Ts_wc = Ts_cw.inverse()\n        pc_c, valids = dist_im_to_point_cloud_im(dists, cams)\n        B, T, H, W = pc_c.shape[:4]\n        pc_w = Ts_wc * pc_c.view(B, T, -1, 3)\n        pc_w = pc_w.view(B, T, H, W, 3)\n\n        pc_w[~valids] = float(\"nan\")  # nan\n        # remove all points that are invalid across all time and batches.\n        all_valid = ~(~valids).all(0).all(0)\n        all_valid = all_valid.view(1, 1, H, W).repeat(B, T, 1, 1)\n        pc_w = pc_w[all_valid].view(B, T, -1, 3)\n\n        dist_stds = torch.ones(pc_w.shape[:-1], device=pc_w.device) * dist_std0\n    elif ARIA_POINTS_WORLD in batch:\n        pc_w = batch[ARIA_POINTS_WORLD]\n\n        if ARIA_POINTS_DIST_STD in batch:\n            dist_stds = batch[ARIA_POINTS_DIST_STD]\n        else:\n            dist_stds = torch.ones(pc_w.shape[:-1], device=pc_w.device) * 0.01\n\n    else:\n        raise NotImplementedError(\n            f\"do need either points or depth image! {batch.keys()}\"\n        )\n\n    if batch_idx is not None:\n        return pc_w[batch_idx], dist_stds[batch_idx]\n    return pc_w, dist_stds\n\n\ndef get_freespace_world(\n    batch,\n    batch_idx,\n    T_wv,\n    vW,\n    vH,\n    vD,\n    voxel_extent,\n    S=1,\n    prefer_points=False,\n    dropout_points=False,\n    drop_points_rate_max=0.5,\n):\n    \"\"\"\n    Get points (semi-dense or GT points) of a snippet in the batch.\n    \"\"\"\n    cams = batch[ARIA_CALIB[0]][batch_idx]\n    T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET][\n        batch_idx\n    ]  # T_world_rig (one per snippet)\n    Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]][\n        batch_idx\n    ]  # Ts_snippet_rig (T per snippet)\n    Ts_wr = T_ws @ Ts_sr\n    Ts_wc = Ts_wr @ cams.T_camera_rig.inverse()  # Ts_world_cam\n\n    # compute rays and max depths\n    if ARIA_DISTANCE_M[0] in batch and not prefer_points:\n        # get gt distances into world points\n        gt_dist = batch[ARIA_DISTANCE_M[0]][batch_idx]\n        cams = batch[ARIA_CALIB[0]][batch_idx]\n        # invalid depth has values 0 or NaN (padding used by semidense stream).\n        valid_depths = gt_dist.squeeze(1) > 1e-4\n        p3cs, valids = dist_im_to_point_cloud_im(\n            gt_dist.squeeze(1),\n            cams,\n        )\n        valids = torch.logical_and(valids, valid_depths)\n        p3cs = p3cs.reshape(p3cs.shape[0], -1, 3)\n        T, N = p3cs.shape[:2]\n        ds = torch.norm(p3cs, 2.0, dim=-1)\n        dirs_c = F.normalize(p3cs, 2.0, dim=-1)\n        rays_c = torch.cat([torch.zeros_like(dirs_c), dirs_c], dim=-1)\n        T_vc = T_wv.inverse() @ Ts_wc\n        rays_v = transform_rays(rays_c, T_vc)\n        rays_v = rays_v.view(-1, 6)\n        ds = ds.view(-1)\n        valids = valids.reshape(-1)\n        rays_v = rays_v[valids]\n        ds = ds[valids]\n    else:\n        p_w = batch[ARIA_POINTS_WORLD][batch_idx]  # TxNx3\n        T, N = p_w.shape[:2]\n        p0_w = Ts_wc.t.unsqueeze(1)  # Tx1x3\n        diff_w = p_w - p0_w\n        ds = torch.norm(diff_w, 2.0, dim=-1)\n        dir_w = F.normalize(diff_w, 2.0, dim=-1)\n        # filter out nans\n        good = ~p_w.isnan().any(dim=-1)\n        p0_w = p0_w.repeat(1, N, 1)[good]\n        ds = ds[good]\n        dir_w = dir_w[good]\n        rays_w = torch.cat([p0_w, dir_w], dim=-1)\n        rays_v = transform_rays(rays_w, T_wv.inverse())\n\n    # dropout rays if desired\n    if dropout_points:\n        N = rays_v.shape[0]\n        p = drop_points_rate_max\n        Ndrop = int(N * (torch.rand(1).item() * p + (1.0 - p)))\n        print(f\"dropout {Ndrop}/{N} points\")\n        rnd = torch.randperm(N, device=p_w.device)[:Ndrop]\n        rays_v = rays_v[rnd, :]\n        ds = ds[rnd]\n\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n    dW = (x_max - x_min) / vW\n    dH = (y_max - y_min) / vH\n    dD = (z_max - z_min) / vD\n    diag = math.sqrt(dW**2 + dH**2 + dD**2)\n    # subtract diagonal of voxel size to not label the occupied voxel as free\n    ds = ds - diag\n    # sample depths that lie within the feature volume grid (same function as used for nerf3d!)\n    depths, _, _ = sample_depths_in_grid(\n        rays_v.view(1, 1, -1, 6),\n        ds.view(1, 1, -1),\n        voxel_extent,\n        vW,\n        vH,\n        vD,\n        S,\n    )\n    depths = depths.view(-1, S)\n    rays_v = rays_v.view(-1, 1, 6)\n    pts_v = rays_v[..., :3] + depths.unsqueeze(-1) * rays_v[..., 3:]\n    pts_v = pts_v.view(-1, 3)\n    return T_wv * pts_v\n\n\ndef collapse_pointcloud_time(pc_w):\n    pc_w = pc_w.reshape(-1, 3)\n    # filter out nans\n    bad = pc_w.isnan().any(dim=-1)\n    pc_w = pc_w[~bad]\n    # filter out duplicates from the collapsing of the time dimension\n    pc_w = torch.unique(pc_w, dim=0)\n    pc_w = pc_w.reshape(-1, 3)\n    return pc_w\n\n\ndef pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent):\n    \"\"\"\n    converts a point cloud in voxel grid coordinates into voxel ids.\n    \"\"\"\n    assert pc_v.ndim == 3, f\"{pc_v.shape}\"  # T N 3\n    assert isinstance(voxel_extent, torch.Tensor)\n    assert voxel_extent.ndim == 1, f\"{voxel_extent.shape}\"  # 6\n    device = pc_v.device\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent.tolist()\n    valid = pc_v[..., 0] > x_min\n    valid = torch.logical_and(pc_v[..., 0] < x_max, valid)\n    valid = torch.logical_and(pc_v[..., 1] > y_min, valid)\n    valid = torch.logical_and(pc_v[..., 1] < y_max, valid)\n    valid = torch.logical_and(pc_v[..., 2] > z_min, valid)\n    valid = torch.logical_and(pc_v[..., 2] < z_max, valid)\n    dW = (x_max - x_min) / vW\n    dH = (y_max - y_min) / vH\n    dD = (z_max - z_min) / vD\n    s = [1] * (pc_v.ndim - 1) + [3]\n    dVox = torch.tensor([dW, dH, dD]).view(s).to(device)\n    vox_min = torch.tensor([x_min, y_min, z_min]).view(s).to(device)\n    pc_id = ((pc_v - vox_min) / dVox).floor().long()\n    valid = torch.logical_and(pc_id[..., 0] >= 0, valid)\n    valid = torch.logical_and(pc_id[..., 0] < vW, valid)\n    valid = torch.logical_and(pc_id[..., 1] >= 0, valid)\n    valid = torch.logical_and(pc_id[..., 1] < vH, valid)\n    valid = torch.logical_and(pc_id[..., 2] >= 0, valid)\n    valid = torch.logical_and(pc_id[..., 2] < vD, valid)\n    # to match the D H W ordering of the voxel tensors\n    pc_id = pc_id[..., [2, 1, 0]]\n    return pc_id, valid\n\n\ndef pointcloud_to_occupancy_snippet(\n    pcs_w, Ts_wc, cams, T_wv, vW, vH, vD, voxel_extent, S=1\n):\n    \"\"\"\n    converts a pointcloud to an occupancy grid (and mask where there are\n    points).\n\n    All voxels which have a point in them are marked occupied\n    Along rays to the points of the cloud we sample S points and mark them as\n    not occupied.\n    \"\"\"\n    assert pcs_w.ndim == 3, f\"{pcs_w.shape}\"  # T N 3\n    assert Ts_wc.ndim == 2, f\"{Ts_wc.shape}\"  # T C\n    assert cams.ndim == 2, f\"{cams.shape}\"  # T C\n    assert T_wv.ndim in [1, 2], f\"{T_wv.shape}\"  # 1 C\n    voxel_extent = tensor_wrap_voxel_extent(voxel_extent)\n    device = pcs_w.device\n    occ = torch.zeros((vD, vH, vW), device=device)\n    mask = torch.zeros_like(occ)\n\n    # get invalid mask as the points that are nan and do not project into the\n    # camera.\n    Ts_vc = T_wv.inverse() @ Ts_wc\n    pc_c = Ts_wc.inverse() * pcs_w\n    invalid = pc_c.isnan().any(-1)  # T N\n    pc_im, valid = cams.project(pc_c)\n    invalid = torch.logical_or(invalid, ~valid)\n    depth = torch.sqrt((pc_c**2).sum(-1))\n    ray_c = pc_c / depth.unsqueeze(-1)\n\n    # camera origins are not occupied\n    rayP_c = torch.zeros_like(Ts_wc.t)\n    rayP_v = Ts_vc * rayP_c\n    pc_ids, valid = pointcloud_to_voxel_ids(rayP_v, vW, vH, vD, voxel_extent)\n    pc_ids = pc_ids[valid]\n    pc_ids = pc_ids.view(-1, 3)\n    if pc_ids.numel() > 0:\n        occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0\n        mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    # sample along the ray\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n    dW = (x_max - x_min) / vW\n    dH = (y_max - y_min) / vH\n    dD = (z_max - z_min) / vD\n    diag = math.sqrt(dW**2 + dH**2 + dD**2)\n    T, N = ray_c.shape[:2]\n    rayP_c = rayP_c.view(T, 1, 3).repeat(1, N, 1)\n    # sample depths conservatively up to the depth - diagonal of a voxel\n    ds = depth.unsqueeze(-1) - diag\n    ds = torch.rand((T, N, S), device=device) * ds\n    samples_c = rayP_c.unsqueeze(2) + ds.unsqueeze(3) * ray_c.unsqueeze(2)\n    samples_c = samples_c.view(T, -1, 3)\n    samples_v = Ts_vc * samples_c\n    pc_ids, valid = pointcloud_to_voxel_ids(samples_v, vW, vH, vD, voxel_extent)\n    invalid_ = invalid.unsqueeze(-1).repeat(1, 1, S).view(T, -1)\n    valid = torch.logical_and(valid, ~invalid_)\n    pc_ids = pc_ids[valid]\n    pc_ids = pc_ids.view(-1, 3)\n    if pc_ids.numel() > 0:\n        occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0\n        mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    # add points as occupied\n    pc_v = T_wv.inverse() * pcs_w\n    pc_ids, valid = pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent)\n    valid = torch.logical_and(valid, ~invalid)\n    pc_ids = pc_ids[valid]\n    pc_ids = pc_ids.view(-1, 3)\n    if pc_ids.numel() > 0:\n        occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n        mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    return occ, mask\n\n\ndef pointcloud_occupancy_samples(\n    p3s_w,\n    Ts_wc,\n    cams,\n    vW,\n    vH,\n    vD,\n    voxel_extent,\n    S=16,\n    sample_beyond=False,\n    vox_diag_scale=1.0,\n    T_wv=None,\n    sample_mode=\"random\",\n):\n    \"\"\"\n    compute occupied points and sample S freespace points along rays.\n    \"\"\"\n    assert p3s_w.ndim == 4, f\"{p3s_w.shape}\"  # B T N 3\n    assert Ts_wc.ndim == 3, f\"{Ts_wc.shape}\"  # B T C\n    assert not sample_beyond, \"not supported\"\n    B = p3s_w.shape[0]\n    # precompute things\n    pc_c = Ts_wc.inverse() * p3s_w\n    invalid = pc_c.isnan().any(-1)  # B T N\n    pc_im, valid = cams.project(pc_c)\n    invalid = torch.logical_or(invalid, ~valid)\n    depth = torch.sqrt((pc_c**2).sum(-1)).unsqueeze(-1)\n    rayD_c = pc_c / depth\n    B, T, N = rayD_c.shape[:3]\n    rayP_c = torch.zeros_like(Ts_wc.t)\n    rayP_c = rayP_c.view(B, T, 1, 3).repeat(1, 1, N, 1)\n    T_vc = T_wv.inverse().unsqueeze(-2) @ Ts_wc\n    voxel_extent = tensor_wrap_voxel_extent(voxel_extent, B, device=depth.device)\n    diag = voxel_extent[..., 1::2] - voxel_extent[..., 0::2]\n    diag = diag / torch.tensor([vW, vH, vD], device=voxel_extent.device)\n    diag = torch.sqrt((diag**2).sum(-1)) * vox_diag_scale\n    delta = diag.view(B, 1, 1, 1)\n    ds_free_max = depth - delta  # BxTxNx1\n    # sample depths conservatively up to the depth - diagonal of a voxel\n    rays_c = torch.cat([rayP_c, rayD_c], dim=-1)\n    rays_v = transform_rays(rays_c, T_vc)\n    ds_free, _, _ = sample_depths_in_grid(\n        rays_v,\n        ds_free_max.squeeze(-1),\n        voxel_extent,\n        vW,\n        vH,\n        vD,\n        S,\n        d_near=0.01,\n        d_far=10.0,\n        sample_mode=sample_mode,\n    )\n    free_c = rayP_c.unsqueeze(3) + ds_free.unsqueeze(4) * rayD_c.unsqueeze(3)\n    free_c = free_c.view(B, T, -1, 3)\n    free_w = Ts_wc * free_c\n\n    ds_occ = depth + delta\n    occ_c = rayP_c + ds_occ * rayD_c\n    occ_c = occ_c.view(B, T, -1, 3)\n    occ_w = Ts_wc * occ_c\n    # occupied, on surface, free space\n    return occ_w, p3s_w, free_w, ~invalid\n\n\ndef pointcloud_to_occupancy(\n    pc_w, T_wc, cam, T_wv, vW, vH, vD, voxel_extent, S=1, occ=None, mask=None\n):\n    device = pc_w.device\n    if occ is None:\n        occ = torch.zeros((vD, vH, vW), device=device)\n    if mask is None:\n        mask = torch.zeros_like(occ)\n\n    T_vc = T_wv.inverse() @ T_wc\n    pc_c = T_wc.inverse() * pc_w\n    invalid = pc_c.isnan().any(-1)\n    pc_c = pc_c[~invalid]\n    pc_im, valid = cam.unsqueeze(0).project(pc_c.unsqueeze(0))\n    pc_im, valid = pc_im.squeeze(0), valid.squeeze(0)\n    depth = torch.sqrt((pc_c**2).sum(-1))\n    ray_c = pc_c / depth.unsqueeze(-1)\n    ray_c = ray_c[valid]\n    depth = depth[valid]\n\n    # camera origins are not occupied\n    rayP_c = torch.zeros_like(T_wc.t)\n    rayP_v = T_vc * rayP_c\n    pc_ids, valid = pointcloud_to_voxel_ids(rayP_v, vW, vH, vD, voxel_extent)\n    pc_ids = pc_ids[valid]\n    occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0\n    mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    # sample along the ray\n    N = ray_c.shape[0]\n    rayP_c = rayP_c.view(1, 3).repeat(N, 1)\n    ds = torch.rand((N, S), device=device) * depth.unsqueeze(1)\n    samples_c = rayP_c.unsqueeze(1) + ds.unsqueeze(2) * ray_c.unsqueeze(1)\n    samples_c = samples_c.view(-1, 3)\n    samples_v = T_vc * samples_c\n    pc_ids, valid = pointcloud_to_voxel_ids(samples_v, vW, vH, vD, voxel_extent)\n    pc_ids = pc_ids[valid]\n    occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 0.0\n    mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    # add points as occupied\n    pc_v = T_wv.inverse() * pc_w\n    invalid = pc_v.isnan().any(-1)\n    pc_v = pc_v[~invalid]\n    pc_ids, valid = pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent)\n    pc_ids = pc_ids[valid]\n    occ[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n    mask[pc_ids[:, 0], pc_ids[:, 1], pc_ids[:, 2]] = 1.0\n\n    return occ, mask\n\n\ndef pointcloud_to_voxel_counts(points_v, voxel_extent, vW, vH, vD):\n    \"\"\"\n    Convert a pointcloud in the voxel coordinate to a voxel grid where each voxel value indicates the number of points falling into this voxel.\n    \"\"\"\n    assert points_v.ndim == 2, f\"{points_v.shape}\"\n    voxel_extent = tensor_wrap_voxel_extent(voxel_extent).to(points_v.device)\n    assert voxel_extent.ndim == 1, f\"{voxel_extent.shape}\"\n    if points_v.shape[0] == 0:\n        print(\"WARNING: No 3D points provided. \")\n        return torch.zeros((1, vD, vH, vW), device=points_v.device, dtype=torch.int64)\n    num_voxels_x, num_voxels_y, num_voxels_z = vW, vH, vD\n    bb_min, bb_max = voxel_extent[..., 0::2], voxel_extent[..., 1::2]\n    dim = torch.tensor([vW, vH, vD], device=points_v.device)\n    voxel_sizes = (bb_max - bb_min) / dim\n    voxel_min = bb_min\n    point_count = torch.zeros(\n        (num_voxels_x, num_voxels_y, num_voxels_z), device=points_v.device\n    )\n    voxel_indices = torch.floor((points_v - voxel_min) / voxel_sizes).to(torch.int64)\n    # Filter out points that fall outside the voxel grid\n    valid_indices = (voxel_indices >= 0) & (\n        voxel_indices\n        < torch.tensor([num_voxels_x, num_voxels_y, num_voxels_z]).to(voxel_indices)\n    )\n    valid_indices = valid_indices.all(dim=-1)\n    voxel_indices = voxel_indices[valid_indices]\n\n    # get flat index so we can use bincount to get counts\n    voxel_indices_flat = (\n        voxel_indices[..., 0]\n        + voxel_indices[..., 1] * vW\n        + voxel_indices[..., 2] * vW * vH\n    )\n    # get counts of how many points per voxel\n    point_count = torch.bincount(voxel_indices_flat, minlength=vW * vH * vD)\n    # reshape back to vD x vH x vW convention.\n    point_count = point_count.view(1, vD, vH, vW)\n    return point_count\n\n\ndef get_points_counts(\n    batch,\n    T_wv,\n    vW,\n    vH,\n    vD,\n    voxel_extent,\n    prefer_points=True,\n    MAX_NUM_POINTS_VOXEL=50,\n    return_mask=False,\n    dropout_points=False,\n    dropout_points_rate_max=0.0,\n):\n    \"\"\"\n    Get points as voxel grid where each voxel is assigned a count of how many points are inside it.\n    If return_mask is trued the function returns the binary occupancy instead of point counts.\n    \"\"\"\n    B, T, _, H, W = batch[ARIA_IMG[0]].shape\n    point_counts = []\n    for b in range(B):\n        p_w = get_points_world(batch, b, prefer_points=prefer_points)[0]\n        p_w = collapse_pointcloud_time(p_w)\n        if dropout_points:\n            print(\"drop points \", p_w.shape)\n            N = p_w.shape[0]\n            p = dropout_points_rate_max\n            Ndrop = int(N * (torch.rand(1).item() * p + (1.0 - p)))\n            print(f\"dropout {N - Ndrop}/{N} points\")\n            rnd = torch.randperm(N, device=p_w.device)[:Ndrop]\n            p_w = p_w[rnd, :]\n        # transform points into voxel coordinate.\n        p_v = T_wv[b].inverse() * p_w\n        if isinstance(voxel_extent, list):\n            ve_b = voxel_extent\n        else:\n            ve_b = voxel_extent[b].tolist()\n        point_count = pointcloud_to_voxel_counts(p_v, ve_b, vW, vH, vD)\n        point_counts.append(point_count)\n    point_counts = torch.stack(point_counts, dim=0)  # B x 1 x vD, vH, vW\n    # Normalize\n    point_counts = point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL\n    if return_mask:\n        # Only use as a mask. Comment out if want to use real point counts.\n        point_counts[point_counts > 1e-4] = 1.0\n\n    return point_counts\n\n\ndef get_freespace_counts(\n    batch,\n    T_wv,\n    vW,\n    vH,\n    vD,\n    voxel_extent,\n    num_free_samples=1,\n    prefer_points=True,\n    MAX_NUM_POINTS_VOXEL=50,\n    return_mask=False,\n    dropout_points=False,\n    dropout_points_rate_max=0.0,\n):\n    \"\"\"\n    Get points as voxel grid where each voxel is assigned a count of how many points are inside it.\n    If return_mask is trued the function returns the binary occupancy instead of point counts.\n    \"\"\"\n    B, T, _, H, W = batch[ARIA_IMG[0]].shape\n    point_counts = []\n    for b in range(B):\n        if isinstance(voxel_extent, list):\n            ve_b = voxel_extent\n        else:\n            ve_b = voxel_extent[b].tolist()\n        p_w = get_freespace_world(\n            batch,\n            b,\n            T_wv[b],\n            vW,\n            vH,\n            vD,\n            ve_b,\n            num_free_samples,\n            prefer_points,\n            dropout_points,\n            dropout_points_rate_max,\n        )\n        # transform points into voxel coordinate.\n        p_v = T_wv[b].inverse() * p_w\n        point_count = pointcloud_to_voxel_counts(p_v, ve_b, vW, vH, vD)\n        point_counts.append(point_count)\n    point_counts = torch.stack(point_counts, dim=0)  # B x 1 x vD, vH, vW\n    # Normalize\n    point_counts = point_counts.clamp(0, MAX_NUM_POINTS_VOXEL) / MAX_NUM_POINTS_VOXEL\n    if return_mask:\n        # Only use as a mask. Comment out if want to use real point counts.\n        point_counts[point_counts > 1e-4] = 1.0\n\n    return point_counts\n"
  },
  {
    "path": "efm3d/utils/ray.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport einops\nimport torch\nfrom efm3d.aria.camera import CameraTW, pixel_grid\nfrom efm3d.utils.voxel import tensor_wrap_voxel_extent\nfrom torch.nn import functional as F\n\n\ndef grid_ray(pixel_grid, camera):\n    \"\"\"\n    grid_ray:\n            Given a 2D grid size, this function creates a 2D grid and then unprojects the grid\n            into rays in their respective rig coordinate systems.\n\n    Args:\n        grid_width: self-explanatory\n        grid_height: self-explanatory\n        camera: Batch of Camera objects [B x object_params]\n\n    Returns:\n        Rays: [B x grid_height x grid_width x 6] rays in their respective rig coordinates\n        Each ray grid in a batch may have different rig coordinate systems.\n        Valid: Valid rays in the batch\n    \"\"\"\n    eps = 1e-6\n    grid_height, grid_width = pixel_grid.shape[0], pixel_grid.shape[1]\n    batch_size = camera.shape[0]\n    pixel_grid = pixel_grid.reshape(-1, 2)\n    pixel_grid = einops.repeat(pixel_grid, \"n c -> b n c\", b=batch_size)\n    rays, valid = camera.double().unproject(pixel_grid.double())\n    rays = rays.float()\n    assert not torch.isnan(rays).any(), (\n        f\"have {torch.isnan(rays).count_nonzero().item()} nans in rays. Camera params: {camera.params}\"\n    )\n    rays = F.normalize(rays, p=2, dim=-1, eps=eps)\n    rays = torch.where(valid.unsqueeze(-1), rays, torch.zeros_like(rays))\n    T_rig_camera = camera.T_camera_rig.inverse()\n    T_rig_camera = T_rig_camera.to(dtype=rays.dtype)\n    rays = T_rig_camera.rotate(rays)\n    ray_origins = einops.repeat(\n        T_rig_camera.t, \"b c -> b n c\", n=grid_width * grid_height\n    )\n\n    # set invalid rays to zeros\n    rays = F.normalize(rays, p=2, dim=-1, eps=eps)\n    rays = torch.where(valid.unsqueeze(-1), rays, torch.zeros_like(rays))\n    ray_origins = torch.where(\n        valid.unsqueeze(-1), ray_origins, torch.zeros_like(ray_origins)\n    )\n\n    rays = torch.cat([ray_origins, rays], dim=-1)\n    return rays.view([batch_size, grid_height, grid_width, -1]), valid.view(\n        [batch_size, grid_height, grid_width]\n    )\n\n\ndef ray_grid(cam: CameraTW):\n    \"\"\"\n    rays returned are in rig coordinate system\n    \"\"\"\n    if cam.ndim == 1:\n        px = pixel_grid(cam)\n        rays, valid = grid_ray(px, cam.unsqueeze(0))\n        return rays.squeeze(0), valid.squeeze(0)\n    elif cam.ndim == 2:\n        px = pixel_grid(cam[0])  # assuming camera sizes are all the same in a batch!\n        return grid_ray(px, cam)\n    else:\n        raise ValueError(f\"Camera must be 1 or 2 dimensional: {cam.shape}\")\n\n\ndef transform_rays(rays_old: torch.Tensor, T_new_old):\n    \"\"\"\n    Expects rays to be in old coordinate frame\n    \"\"\"\n    assert rays_old.shape[-1], (\n        \"Rays must be 6 dimensional in the following order: [ray_origins, ray_directions]\"\n    )\n    ray_origins = T_new_old.transform(rays_old[..., :3])\n    ray_directions = T_new_old.rotate(rays_old[..., 3:])\n    return torch.cat([ray_origins, ray_directions], dim=-1)\n\n\ndef ray_obb_intersection(\n    rays_v, voxel_extent, t_min=-1e9, t_max=1e9, return_points=False\n):\n    assert rays_v.ndim == 3, f\"{rays_v.shape}\"\n    assert rays_v.shape[-1] == 6, f\"{rays_v.shape}\"\n\n    device = rays_v.device\n    B, N = rays_v.shape[:2]\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n    raysP_v = rays_v[..., :3]\n    raysD_v = rays_v[..., 3:]  # assume normalized!\n\n    ns_bb = [\n        [1.0, 0.0, 0.0],\n        [-1.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0],\n        [0.0, -1.0, 0.0],\n        [0.0, 0.0, 1.0],\n        [0.0, 0.0, -1.0],\n    ]\n    ps_bb = [\n        [x_max, 0.0, 0.0],\n        [x_min, 0.0, 0.0],\n        [0.0, y_max, 0.0],\n        [0.0, y_min, 0.0],\n        [0.0, 0.0, z_max],\n        [0.0, 0.0, z_min],\n    ]\n\n    eps = 1e-3\n    minmaxs_bb = [\n        [x_max - eps, y_min, z_min, x_max + eps, y_max, z_max],\n        [x_min - eps, y_min, z_min, x_min + eps, y_max, z_max],\n        [x_min, y_max - eps, z_min, x_max, y_max + eps, z_max],\n        [x_min, y_min - eps, z_min, x_max, y_min + eps, z_max],\n        [x_min, y_min, z_max - eps, x_max, y_max, z_max + eps],\n        [x_min, y_min, z_min - eps, x_max, y_max, z_min + eps],\n    ]\n\n    t_upper = torch.ones((B, N), device=device) * t_max\n    t_lower = torch.ones((B, N), device=device) * t_min\n    ts = torch.stack([t_upper, t_lower], dim=-1)\n    for n_bb, p_bb, minmax_bb in zip(ns_bb, ps_bb, minmaxs_bb):\n        n_bb = torch.tensor(n_bb).view(1, 1, 3).to(device)\n        p_bb = torch.tensor(p_bb).view(1, 1, 3).to(device)\n        min_bb = torch.tensor(minmax_bb[:3]).view(1, 1, 3).to(device)\n        max_bb = torch.tensor(minmax_bb[3:]).view(1, 1, 3).to(device)\n        # dot product\n        denom = (raysD_v * n_bb).sum(-1)\n        valid = denom.abs() > 1e-6\n        dp = p_bb - raysP_v\n        t = (dp * n_bb).sum(-1) / denom\n        valid = torch.logical_and(valid, t > t_min)\n        valid = torch.logical_and(valid, t < t_max)\n        # points on surface\n        ps_v = raysP_v + raysD_v * t.unsqueeze(-1)\n        valid = torch.logical_and(valid, (ps_v > min_bb).all(-1))\n        valid = torch.logical_and(valid, (ps_v < max_bb).all(-1))\n\n        ts_min = torch.where(valid, t, t_upper)\n        ts_max = torch.where(valid, t, t_lower)\n\n        ts[..., 0] = torch.minimum(ts_min, ts[..., 0])\n        ts[..., 1] = torch.maximum(ts_max, ts[..., 1])\n\n    if return_points:\n        one_int = ts[..., 0] == ts[..., 1]\n        ts[..., 0] = torch.where(\n            one_int, t_min * torch.ones_like(ts[..., 0]) * t_min, ts[..., 0]\n        )\n        no_int = ts[..., 0] > ts[..., 1]\n        ts[no_int] = t_min\n\n        ps_min_v = raysP_v + raysD_v * ts[..., 0].unsqueeze(-1)\n        ps_max_v = raysP_v + raysD_v * ts[..., 1].unsqueeze(-1)\n        return ts, ps_min_v, ps_max_v\n    return ts\n\n\ndef sample_depths_in_grid(\n    rays_v,\n    ds_max,\n    voxel_extent,\n    W,\n    H,\n    D,\n    num_samples,\n    d_near=0.01,\n    d_far=10.0,\n    sample_mode: Literal[\"random\", \"uniform\"] = \"random\",\n    ds_min=None,\n):\n    assert rays_v.ndim == 4, f\"{rays_v.shape}\"  # BxTxNx6\n    assert ds_max.ndim == 3, f\"{ds_max.shape}\"  # BxTxN\n    B = rays_v.shape[0]\n    voxel_extent = tensor_wrap_voxel_extent(voxel_extent, B).to(rays_v.device)\n\n    def safe_extent(voxel_extent, W, H, D):\n        # compute a \"safe\" voxel extent that is shrunk by half a voxel in all\n        # directions\n        bb_min, bb_max = voxel_extent[::2], voxel_extent[1::2]\n        dim = torch.tensor([W, H, D], device=voxel_extent.device)\n        dd = 0.5 * (bb_max - bb_min) / dim\n        bb_min = bb_min + dd\n        bb_max = bb_max - dd\n        voxel_extent_safe = torch.zeros_like(voxel_extent)\n        voxel_extent_safe[::2] = bb_min\n        voxel_extent_safe[1::2] = bb_max\n        return voxel_extent_safe\n\n    B, T, N = rays_v.shape[:3]\n    ts = []\n    for b in range(B):\n        voxel_extent_safe = safe_extent(voxel_extent[b], W, H, D)\n        ts.append(\n            ray_obb_intersection(\n                rays_v[b].view(T, N, 6),\n                voxel_extent_safe,\n                t_min=d_near,\n                t_max=d_far,\n            )\n        )\n    ts = torch.stack(ts, 0)  # BxTxNx2\n\n    no_int = ts[..., 0] > ts[..., 1]\n    one_int = ts[..., 0] == ts[..., 1]\n    depths_min = torch.where(\n        one_int, torch.ones_like(ts[..., 0]) * d_near, ts[..., 0]\n    )  # BxTxN\n    depths_max = ts[..., 1]\n    depths_min[no_int] = torch.nan\n    depths_max[no_int] = torch.nan\n\n    if ds_max is not None:\n        depths_max = torch.minimum(ds_max, depths_max)\n    if ds_min is not None:\n        depths_min = torch.maximum(ds_min, depths_min)\n\n    ddepths = depths_max - depths_min\n    ddepths[ddepths < 1e-3] = torch.nan\n    # go to d_min to d_max per ray\n    depths = torch.linspace(0.0, 1.0, num_samples).to(rays_v.device)\n    depths = depths.view(1, 1, 1, num_samples).repeat(B, T, N, 1)\n    depths = depths_min.unsqueeze(-1) + ddepths.unsqueeze(-1) * depths\n    if sample_mode == \"uniform\":\n        return depths, depths_max, ~no_int.view(B, T, N)\n    elif sample_mode == \"random\":\n        # add noise\n        noise = torch.rand((B, T, N, num_samples), device=rays_v.device)\n        noise = noise * (ddepths.unsqueeze(-1) / num_samples)\n        if num_samples > 1:\n            noise[..., -1] = 0.0\n        depths = depths + noise\n        return depths, depths_max, ~no_int.view(B, T, N)\n    else:\n        raise ValueError(f\"Unknown sample mode {sample_mode}\")\n"
  },
  {
    "path": "efm3d/utils/reconstruction.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_DISTANCE_M,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.utils.depth import dist_im_to_point_cloud_im\nfrom efm3d.utils.detection_utils import compute_focal_loss\nfrom efm3d.utils.pointcloud import (\n    pointcloud_occupancy_samples,\n    pointcloud_to_occupancy_snippet,\n    pointcloud_to_voxel_ids,\n)\nfrom efm3d.utils.voxel_sampling import pc_to_vox, sample_voxels\nfrom einops import rearrange\nfrom torch.nn import functional as F\n\n\ndef build_gt_occupancy(occ, visible, p3s_w, Ts_wc, cams, T_wv, voxel_extent):\n    \"\"\"\n    build GT occupancy from GT point cloud, return batched occupancy with masks.\n    \"\"\"\n    B, vD, vH, vW = occ.shape\n    occ_gts, masks = [], []\n    for b in range(B):\n        occ_gt, mask = pointcloud_to_occupancy_snippet(\n            p3s_w[b],\n            Ts_wc[b],\n            cams[b],\n            T_wv[b],\n            vW,\n            vH,\n            vD,\n            voxel_extent,\n            S=1,\n        )\n        mask = torch.logical_and(mask.bool(), visible[b])\n        occ_gts.append(occ_gt)\n        masks.append(mask)\n    occ_gts = torch.stack(occ_gts)\n    masks = torch.stack(masks)\n    return occ_gts, masks\n\n\ndef get_fused_gt_feat(\n    visible,\n    p3s_w,\n    Ts_wc,\n    cams,\n    T_wv,\n    voxel_extent,\n    img_feat_gt,\n    feat_pred,\n    gt_dists,\n    vD,\n    vH,\n    vW,\n):\n    feat_dim = img_feat_gt.shape[2]\n    gt_feat_volume = torch.zeros_like(feat_pred).detach()  # BxCxDxHxW\n    gt_feat_volume = gt_feat_volume.permute(\n        0, 2, 3, 4, 1\n    )  # BxDxHxWxC for easier indexing\n    gt_feat_volume_counts = (\n        torch.zeros(*gt_feat_volume.shape[:4]).to(feat_pred).detach()\n    )  # BxDxHxW\n\n    dists = gt_dists.squeeze(2)\n    B, T = cams.shape[:2]\n    pc_c, valids = dist_im_to_point_cloud_im(dists, cams)\n    pc_c = pc_c.reshape(B, T, -1, 3)  # BxTxNx3\n    T_vc = T_wv.inverse() @ Ts_wc\n    pc_v = T_vc * pc_c\n    for b in range(B):\n        pc_ids, valid_v = pointcloud_to_voxel_ids(pc_v[b], vW, vH, vD, voxel_extent)\n        for t in range(T):\n            valid = torch.logical_and(valid_v[t], valids[b, t].reshape(-1))\n            pc_ids_t = pc_ids[t][valid]\n            feat_gt_t = img_feat_gt[b, t].permute(1, 2, 0).reshape(-1, feat_dim)\n            feat_gt_t = feat_gt_t[valid]\n            gt_feat_volume_counts[b][\n                pc_ids_t[:, 0], pc_ids_t[:, 1], pc_ids_t[:, 2]\n            ] += 1.0\n            gt_feat_volume[b][pc_ids_t[:, 0], pc_ids_t[:, 1], pc_ids_t[:, 2]] += (\n                feat_gt_t\n            )\n        gt_feat_volume[b][gt_feat_volume_counts[b] > 1e-4] /= gt_feat_volume_counts[b][\n            gt_feat_volume_counts[b] > 1e-4\n        ].unsqueeze(-1)\n    surface_mask = gt_feat_volume_counts > 1e-4  # BxDxHxW\n    return gt_feat_volume, surface_mask\n\n\ndef get_feats_world(batch, tgt_feats):\n    B = tgt_feats.shape[0]\n    tgt_H, tgt_W = tgt_feats.shape[-2], tgt_feats.shape[-1]\n    dists_ori = batch[ARIA_DISTANCE_M[0]]\n    cams_ori = batch[ARIA_CALIB[0]]\n    # rescale dist and camera to tgt feat size\n    dists = rearrange(dists_ori, \"b t c h w -> (b t) c h w\")\n    dists = F.interpolate(dists, [tgt_H, tgt_W], mode=\"nearest\")\n    dists = rearrange(dists, \"(b t) c h w -> b t c h w\", b=B).squeeze(2)\n    cams = cams_ori.scale_to_size((tgt_W, tgt_H))\n\n    B, T = cams.shape[:2]\n    Ts_sr = batch[ARIA_IMG_T_SNIPPET_RIG[0]]\n    T_ws = batch[ARIA_SNIPPET_T_WORLD_SNIPPET]\n    Ts_wr = T_ws @ Ts_sr\n    Ts_cw = cams.T_camera_rig @ Ts_wr.inverse()\n    Ts_wc = Ts_cw.inverse()\n    pc_c, valids = dist_im_to_point_cloud_im(dists, cams)\n    B, T, H, W = pc_c.shape[:4]\n    pc_w = Ts_wc * pc_c.view(B, T, -1, 3)\n    pc_w = pc_w.view(B, T, H, W, 3)\n    pc_w[~valids] = float(\"nan\")  # nan\n    # remove all points that are invalid across all time and batches.\n    all_valid = ~(~valids).all(0).all(0)\n    all_valid = all_valid.view(1, 1, H, W).repeat(B, T, 1, 1)\n    pc_w = pc_w[all_valid].view(B, T, -1, 3)\n    feat_dim = tgt_feats.shape[2]\n    tgt_feats = tgt_feats.permute(0, 1, 3, 4, 2)\n    tgt_feats = tgt_feats[all_valid].view(B, T, -1, feat_dim)\n\n    return pc_w, tgt_feats\n\n\ndef compute_tv_loss(occ):\n    # B 1 D H W\n    tv_d = (occ[:, 1:, :, :] - occ[:, :-1, :, :]).abs().mean()\n    tv_h = (occ[:, :, 1:, :] - occ[:, :, :-1, :]).abs().mean()\n    tv_w = (occ[:, :, :, 1:] - occ[:, :, :, :-1]).abs().mean()\n    tv_loss = tv_d + tv_h + tv_w\n    return tv_loss\n\n\ndef compute_occupancy_loss_subvoxel(\n    occ,\n    visible,\n    p3s_w_all,\n    Ts_wc,\n    cams,\n    T_wv,\n    voxel_extent,\n    S=1,\n    sample_beyond=False,\n    surf_val=0.5,\n    subsample=1,\n    free_surf_occ_weights=None,\n    loss_type: Literal[\"l2\", \"l1\", \"logl1\", \"ce\", \"focal\"] = \"focal\",\n):\n    \"\"\"\n    sample occupied, surface and freespace GT points\n    obtain predictions at those sample points by sampling into the occ voxel\n    grid via tri-linear interpolation.\n    \"\"\"\n    assert p3s_w_all.ndim == 4, f\"{p3s_w_all.shape}\"  # B T N 3\n    assert occ.ndim == 4, f\"{occ.shape}\"  # B D H W\n    assert visible.ndim == 4, f\"{visible.shape}\"  # B D H W\n    assert not sample_beyond, \"not supported\"\n    device = occ.device\n    B, vD, vH, vW = occ.shape\n\n    if subsample > 1:\n        # subsample\n        B, T, N = p3s_w_all.shape[:3]\n        ids = torch.randperm(N)[: N // subsample].to(device)\n        p3s_w = p3s_w_all[:, :, ids]\n        # print(\"subsample\", subsample, p3s_w.shape, p3s_w_all.shape)\n    else:\n        p3s_w = p3s_w_all\n    B, T, N = p3s_w.shape[:3]\n\n    p3s_occ_w, p3s_surf_w, p3s_free_w, valid = pointcloud_occupancy_samples(\n        p3s_w,\n        Ts_wc,\n        cams,\n        vD,\n        vH,\n        vW,\n        voxel_extent,\n        S=S,\n        sample_beyond=sample_beyond,\n        vox_diag_scale=1.0,\n        T_wv=T_wv,\n    )\n    Ts_vw = T_wv.inverse().view(B, 1, -1).repeat(1, T, 1)\n\n    p3s_occ_v = Ts_vw * p3s_occ_w\n    p3s_surf_v = Ts_vw * p3s_surf_w\n    p3s_free_v = Ts_vw * p3s_free_w\n\n    B, vD, vH, vW = occ.shape\n    # free points\n    p3s_free_vox, valid_free = pc_to_vox(p3s_free_v, vW, vH, vD, voxel_extent)\n    valid_free = torch.logical_and(valid_free, valid)\n    free_samples, valid_samples = sample_voxels(\n        occ.unsqueeze(1), p3s_free_vox.view(B, -1, 3)\n    )\n    free_samples, valid_samples = (\n        free_samples.view(B, T, -1),\n        valid_samples.view(B, T, -1),\n    )\n    valid_free = torch.logical_and(valid_samples, valid_free)\n    free_samples = free_samples[valid_free].clamp(0.0, 1.0)\n    free_gt = torch.zeros_like(free_samples)\n\n    # surface points\n    p3s_surf_vox, valid_surf = pc_to_vox(p3s_surf_v, vW, vH, vD, voxel_extent)\n    valid_surf = torch.logical_and(valid_surf, valid)\n    surf_samples, valid_samples = sample_voxels(\n        occ.unsqueeze(1), p3s_surf_vox.view(B, -1, 3)\n    )\n    surf_samples, valid_samples = (\n        surf_samples.view(B, T, -1),\n        valid_samples.view(B, T, -1),\n    )\n    valid_surf = torch.logical_and(valid_samples, valid_surf)\n    surf_samples = surf_samples[valid_surf].clamp(0.0, 1.0)\n    surf_gt = surf_val * torch.ones_like(surf_samples)\n\n    # occupied points\n    p3s_occ_vox, valid_occ = pc_to_vox(p3s_occ_v, vW, vH, vD, voxel_extent)\n    valid_occ = torch.logical_and(valid_occ, valid)\n    occ_samples, valid_samples = sample_voxels(\n        occ.unsqueeze(1), p3s_occ_vox.view(B, -1, 3)\n    )\n    occ_samples, valid_samples = (\n        occ_samples.view(B, T, -1),\n        valid_samples.view(B, T, -1),\n    )\n    valid_occ = torch.logical_and(valid_samples, valid_occ)\n    occ_samples = occ_samples[valid_occ].clamp(0.0, 1.0)\n    occ_gt = torch.ones_like(occ_samples)\n\n    if free_surf_occ_weights is None:\n        num = free_samples.numel() + surf_samples.numel() + occ_samples.numel()\n        if loss_type == \"l2\":\n            # L2 loss\n            pred = torch.cat([free_samples, surf_samples, occ_samples], -1)\n            gt = torch.cat([free_gt, surf_gt, occ_gt], -1)\n            loss = ((pred - gt) ** 2).sum()\n        elif loss_type == \"l1\":\n            # L1 loss\n            pred = torch.cat([free_samples, surf_samples, occ_samples], -1)\n            gt = torch.cat([free_gt, surf_gt, occ_gt], -1)\n            loss = (pred - gt).abs().sum()\n        elif loss_type == \"logl1\":\n            # logl1 loss\n            pred = torch.cat([free_samples, surf_samples, occ_samples], -1)\n            gt = torch.cat([free_gt, surf_gt, occ_gt], -1)\n            loss = (torch.log(pred + 1e-5) - torch.log(gt + 1e-5)).abs().sum()\n        elif loss_type == \"ce\":\n            # CE on free and occ and L1 on surf\n            pred = torch.cat([free_samples, occ_samples], -1)\n            gt = torch.cat([free_gt, occ_gt], -1)\n            loss = F.binary_cross_entropy(pred, gt, reduction=\"sum\")\n            loss = loss + (surf_samples - surf_gt).abs().sum()\n        elif loss_type == \"focal\":\n            # like used to using focal loss\n            pred = torch.cat([free_samples, surf_samples, occ_samples], -1)\n            gt = torch.cat([free_gt, surf_gt, occ_gt], -1)\n            loss = compute_focal_loss(pred, gt)\n        assert not loss.isnan().any(), (\n            f\"have nans in loss {loss.isnan().count_nonzero()}\"\n        )\n        # handle no samples case in mean\n        num = max(1.0, num)\n        loss = loss.sum() / num\n        return loss\n\n    assert loss_type == \"focal\", f\"{loss_type} not supported\"\n    loss_free = compute_focal_loss(free_samples, free_gt).sum()\n    loss_surf = compute_focal_loss(surf_samples, surf_gt).sum()\n    loss_occ = compute_focal_loss(occ_samples, occ_gt).sum()\n\n    loss_free = loss_free / max(1.0, loss_free.numel())\n    loss_surf = loss_surf / max(1.0, loss_surf.numel())\n    loss_occ = loss_occ / max(1.0, loss_occ.numel())\n\n    loss = loss_free * free_surf_occ_weights[0] + loss_occ * free_surf_occ_weights[2]\n    return loss\n"
  },
  {
    "path": "efm3d/utils/render.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport colorsys\nfrom typing import Dict, Literal\n\nimport cv2\nimport numpy as np\nimport torch\nfrom efm3d.aria import CameraTW, ObbTW, PoseTW\nfrom efm3d.utils.image import put_text, torch2cv2\n\n\nAXIS_COLORS_RGB = {\n    0: (255, 0, 0),  # red\n    3: (0, 255, 0),  # green\n    8: (0, 0, 255),  # blue\n}  # use RGB for xyz axes respectively\n\n\ndef get_colors(num_colors: int, scale_to_255: bool = False):\n    assert num_colors > 0, f\"Number of colors {num_colors} has to be positive.\"\n    colors = []\n    for i in range(num_colors):\n        hue = i / num_colors  # Spread out the colors in the hue space\n        saturation = 1.0  # Use maximum saturation for bright colors\n        value = 1.0  # Use maximum value for bright colors\n        rgb = colorsys.hsv_to_rgb(hue, saturation, value)\n        if scale_to_255:\n            colors.append((int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)))\n        else:\n            colors.append((rgb[0], rgb[1], rgb[2]))\n    return colors\n\n\n# RGB values in [0, 1] used in Static Structure Index\nSSI_SEM_COLORS = {\n    \"floor\": (1, 0.75, 0.75),\n    \"mirror\": (0.5, 0.5, 0.5),\n    \"ceiling\": (1, 1, 0.75),\n    \"chair\": (0.2, 0.6, 1),\n    \"bench\": (0.2, 0.6, 1),\n    \"ottoman\": (0.2, 0.6, 1),\n    \"table\": (1, 1, 0),\n    \"desk\": (1, 1, 0),\n    \"storage\": (0.7, 0.4, 0.05),\n    \"plant\": (0, 1, 0),\n    \"plant_or_flower_pot\": (0, 1, 0),\n    \"vase\": (0, 1, 0),\n    \"screen\": (1, 0, 0),\n    \"wallart\": (0.6, 0.3, 0.95),\n    \"picture_frame_or_painting\": (0.6, 0.3, 0.95),\n    \"bed\": (0.55, 0.9, 0),\n    # \"couch\": (0, 1, 1),   # SSI color\n    \"couch\": (0.1, 0.5, 0.1),  # dark green\n    # \"sofa\": (0, 1, 1),    # SSI color\n    \"sofa\": (0.1, 0.5, 0.1),  # dark green\n    \"wall\": (1, 1, 1),\n    \"lamp\": (1, 0.8, 0.25),\n    \"door\": (0.95, 0.25, 0.85),\n    \"window\": (0.5, 1, 1),\n    \"unknown\": (0.4, 0.4, 0.8),\n    \"other\": (0.6, 0.6, 0.6),\n    # hard code 'floor_mat' to dark red\n    \"floor_mat\": (0.8, 0.15, 0.15),  # dark red\n}\n\n\ndef get_colors_from_sem_map(\n    sem_ids_to_names: Dict[int, str],\n    scale_to_255: bool = True,\n    match_with_ssi: bool = True,\n):\n    \"\"\"\n    sem_ids_to_names: taxonomy map from semantic id to semantic name.\n    scale_to_255: whether to scale the colors to [0, 255].\n    match_with_ssi: whether to match the colors with the Static Structure Index taxonomy for\n    the overlapped classes.\n    \"\"\"\n\n    if len(sem_ids_to_names) == 0:\n        num_sem_ids = 1\n    else:\n        num_sem_ids = max(sem_ids_to_names.keys()) + 1\n    colors = get_colors(num_sem_ids, scale_to_255=scale_to_255)\n\n    if match_with_ssi:\n        for sem_id, sem_name in sem_ids_to_names.items():\n            sn = sem_name.lower()\n            if sn in SSI_SEM_COLORS:\n                clr = SSI_SEM_COLORS[sn]\n                if scale_to_255:\n                    clr2 = (\n                        int(round(clr[0] * 255)),\n                        int(round(clr[1] * 255)),\n                        int(round(clr[2] * 255)),\n                    )\n                else:\n                    clr2 = clr\n                colors[sem_id] = clr2\n\n    return colors\n\n\ndef draw_bb2s(\n    viz,\n    bb2s,\n    line_type=cv2.LINE_AA,\n    bb2s_center=None,\n    labels=None,\n    rotate_text=True,\n    color=None,\n    text_size=0.6,\n):\n    \"\"\"\n    Args:\n        viz: numpy array image\n        bb2s: a list of bounding boxes as numpy array Nx 4 where (x_min, x_max, y_min, y_max) per row\n        color: either a 3-tuple/list or a list 3-tuples, or an np.array shaped Nx3\n    \"\"\"\n    height = viz.shape[0]\n    if height > 320:\n        thickness = 2\n    else:\n        thickness = 1\n\n    if color is None:\n        color = (255, 100, 100)  # brighter red\n\n    if bb2s.shape[0] == 0:\n        return viz\n\n    def _draw_bb2_line(img, p1, p2, clr):\n        cv2.line(img, p1, p2, clr, thickness, lineType=line_type)\n\n    if isinstance(color[0], (list, tuple, np.ndarray)):\n        assert len(color) == len(bb2s), (\n            \"need either single color or same # of colors as bb2s\"\n        )\n        if isinstance(color[0], np.ndarray):\n            colors = [clr.tolist() for clr in color]\n        else:\n            colors = color\n    elif isinstance(color[0], (int, float)):\n        colors = [color for _ in range(len(bb2s))]\n    else:\n        raise TypeError(\"Unknown type for 'color' argument of draw_bb2s()\")\n\n    for i, (bb2, clr) in enumerate(zip(bb2s, colors)):\n        x_min, y_min = int(round(bb2[0].item())), int(round(bb2[2].item()))  # min pt\n        x_max, y_max = int(round(bb2[1].item())), int(round(bb2[3].item()))  # max pt\n        # if x_min < 0 or y_min < 0:\n        #    print(\"WARNING line point outside image\")\n        _draw_bb2_line(viz, (x_min, y_min), (x_min, y_max), clr)\n        _draw_bb2_line(viz, (x_min, y_max), (x_max, y_max), clr)\n        _draw_bb2_line(viz, (x_max, y_max), (x_max, y_min), clr)\n        _draw_bb2_line(viz, (x_max, y_min), (x_min, y_min), clr)\n        if bb2s_center is not None:\n            cx = int(round(float(bb2s_center[i, 0])))\n            cy = int(round(float(bb2s_center[i, 1])))\n            cv2.circle(viz, (cx, cy), 1, clr, 1, lineType=line_type)\n        if labels is not None:\n            text = labels[i]\n            x = int(round((x_min + x_max) / 2.0))\n            y = int(round((y_min + y_max) / 2.0))\n\n            if rotate_text:\n                viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE)\n                center_rot90 = (height - y, x)\n                x, y = center_rot90\n            ((txt_w, txt_h), _) = cv2.getTextSize(\n                text, cv2.FONT_HERSHEY_DUPLEX, text_size, 1\n            )\n            x = x - int(round(txt_w / 4))\n            y = y + int(round(txt_h / 4))\n            put_text(viz, text, scale=text_size, font_pt=(x, y))\n            if rotate_text:\n                viz = cv2.rotate(viz, cv2.ROTATE_90_COUNTERCLOCKWISE)\n\n    return viz\n\n\ndef draw_bb3_lines(\n    viz,\n    T_world_cam: PoseTW,\n    cam: CameraTW,\n    obbs: ObbTW,\n    draw_cosy: bool,\n    T: int,\n    line_type=cv2.LINE_AA,\n    colors=None,\n    thickness=1,\n):\n    bb3corners_world = obbs.T_world_object * obbs.bb3edge_pts_object(T)\n    bb3corners_cam = T_world_cam.inverse() * bb3corners_world\n    B = bb3corners_cam.shape[0]\n    pt3s_cam = bb3corners_cam.view(B, -1, 3)\n    pt2s, valids = cam.project(pt3s_cam)\n    sem_ids = obbs.sem_id.int()\n    # reshape to lines each composed of T segments\n    pt2s = pt2s.round().int().view(B * 12, T, 2)\n    valids = valids.view(B * 12, T)\n    for line in range(pt2s.shape[0]):\n        line_id = line % 12\n        obb_id = line // 12\n        sem_id = sem_ids[obb_id]\n        # if colors is not None and sem_id >= len(colors):\n        #     print(\"warning sem_id too big\", sem_id, len(colors))\n        if colors is None or sem_id >= len(colors):\n            color = (255, 255, 255)\n        else:\n            color = colors[sem_id]\n        for i in range(T - 1):\n            j = i + 1\n            if valids[line, i] and valids[line, j]:\n                # check if we should color this line in a special way\n                if draw_cosy and line_id in AXIS_COLORS_RGB:\n                    color = AXIS_COLORS_RGB[line_id]\n                pt1 = (\n                    int(round(float(pt2s[line, i, 0]))),\n                    int(round(float(pt2s[line, i, 1]))),\n                )\n                pt2 = (\n                    int(round(float(pt2s[line, j, 0]))),\n                    int(round(float(pt2s[line, j, 1]))),\n                )\n                cv2.line(\n                    viz,\n                    pt1,\n                    pt2,\n                    color,\n                    thickness,\n                    lineType=line_type,\n                )\n\n\ndef draw_bb3s(\n    viz,\n    T_world_rig: PoseTW,\n    cam: CameraTW,\n    obbs: ObbTW,\n    draw_bb3_center=False,\n    draw_bb3=True,\n    draw_label=False,\n    draw_cosy=True,\n    draw_score=True,\n    render_obb_corner_steps=10,\n    line_type=cv2.LINE_AA,\n    sem_id_to_name_mapping: Dict[int, str] = None,\n    rotate_label=True,\n    colors=None,\n    white_backing_line=True,\n    draw_inst_id=False,\n):\n    # Get pose of camera.\n    T_world_cam = T_world_rig.float() @ cam.T_camera_rig.inverse()\n    # Project the 3D BB center into the image.\n    if draw_bb3:\n        # auto set the thickness of the bb3 lines\n        thickness = 1\n\n        # draw white background lines\n        if white_backing_line:\n            draw_bb3_lines(\n                viz,\n                T_world_cam,\n                cam,\n                obbs,\n                draw_cosy=draw_cosy,\n                T=render_obb_corner_steps,\n                line_type=cv2.LINE_AA,\n                colors=None,\n                thickness=thickness + 1,\n            )\n        # draw semantic colors\n        draw_bb3_lines(\n            viz,\n            T_world_cam,\n            cam,\n            obbs,\n            draw_cosy=draw_cosy,\n            T=render_obb_corner_steps,\n            line_type=cv2.LINE_AA,\n            colors=colors,\n            thickness=thickness,\n        )\n\n    if draw_label or draw_bb3_center:\n        bb3center_cam = T_world_cam.inverse() * obbs.bb3_center_world\n        bb2center_im, valids = cam.unsqueeze(0).project(bb3center_cam.unsqueeze(0))\n        bb2center_im, valids = bb2center_im.squeeze(0), valids.squeeze(0)\n        for idx, (pt2, valid) in enumerate(zip(bb2center_im, valids)):\n            if valid:\n                center = (int(pt2[0]), int(pt2[1]))\n                if draw_bb3_center:\n                    cv2.circle(viz, center, 3, (255, 0, 0), 1, lineType=line_type)\n                if draw_label:\n                    height = viz.shape[0]\n                    sem_id = int(obbs.sem_id.squeeze(-1)[idx])\n                    if sem_id_to_name_mapping and sem_id in sem_id_to_name_mapping:\n                        text = sem_id_to_name_mapping[sem_id]\n                    else:\n                        # display sem_id if no mapping is provided.\n                        text = str(sem_id)\n                    if draw_inst_id:\n                        inst_id = int(obbs.inst_id.squeeze(-1)[idx])\n                        text = f\"{inst_id}: {text}\"\n                    # rot 90 degree before drawing the text\n                    if rotate_label:\n                        viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE)\n                        center_rot90 = (height - center[1], center[0])\n                        x, y = center_rot90\n                    else:\n                        x, y = center\n                    ((txt_w, txt_h), _) = cv2.getTextSize(\n                        text, cv2.FONT_HERSHEY_DUPLEX, 0.4, 1\n                    )\n                    x = x - txt_w // 4\n                    y = y + txt_h // 4\n\n                    # Show text on top of the 3d boxes\n                    bb2_ymin = obbs.bb2_rgb[idx][2]\n                    bb2_ymax = obbs.bb2_rgb[idx][3]\n                    up = int((bb2_ymax - bb2_ymin) / 2.0)\n                    if y - up > 0:\n                        put_text(viz, text, scale=0.8, font_pt=(x, y - up))\n                        if draw_score and obbs.prob is not None:\n                            score = float(obbs.prob.squeeze(-1)[idx])\n                            score_text = f\"{score:.2f}\"\n                            score_pos = (x, y + int(txt_h + 0.5) - up)\n                            put_text(\n                                viz,\n                                score_text,\n                                scale=0.5,\n                                font_pt=score_pos,\n                                color=(200, 200, 200),\n                            )\n\n                    if rotate_label:\n                        viz = cv2.rotate(viz, cv2.ROTATE_90_COUNTERCLOCKWISE)\n    return viz\n\n\ndef draw_obbs_image(\n    img: torch.Tensor,\n    obbs_padded: ObbTW,\n    T_world_rig: PoseTW = None,\n    cam: CameraTW = None,\n    aria_cam_id: Literal[0, 1, 2] = 0,\n    draw_bb2=False,\n    draw_bb3=True,\n    draw_bb3_center=False,\n    draw_label=False,\n    draw_cosy=True,\n    draw_score=False,\n    render_obb_corner_steps=10,\n    post_rotate_viz=True,  # whether to rotate the image 90 degrees before (pre) or after (post) rendering 3d bbs; only for debugging. resulting image should be the same!\n    rgb2bgr=True,\n    rotate_viz=True,\n    background_sem_id: int = None,\n    prob_threshold: float = 0.5,\n    sem_id_to_name_mapping: Dict[int, str] = None,\n    draw_label_2d: bool = False,  # Draw label on 2D viz also.\n    white_backing_line: bool = True,\n    draw_inst_id: bool = False,\n    draw_conic: bool = False,\n):\n    assert img.dim() == 3, f\"image input must be 3D tensor {img.shape}\"\n    assert obbs_padded.dim() == 2, (\n        f\"assuming one set of obbs per frame {obbs_padded.shape}\"\n    )\n\n    viz = torch2cv2(img, rotate=False, ensure_rgb=True, rgb2bgr=rgb2bgr)\n    if not post_rotate_viz and rotate_viz:\n        viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE)\n\n    # get valid obbs\n    obbs = obbs_padded.remove_padding()\n    if obbs.shape[0] == 0:  # Handle no valid OBBs.\n        if post_rotate_viz and rotate_viz:\n            viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE)\n        return viz\n\n    # filter out low probability obbs\n    good = obbs.prob >= prob_threshold\n    colors = None\n    if sem_id_to_name_mapping is not None:\n        colors = get_colors_from_sem_map(sem_id_to_name_mapping)\n    if obbs.shape[0] > 0 and good.any():\n        obbs = obbs[good.squeeze(-1), :]\n        # if we have background id given, then filter out background obbs\n        if background_sem_id is not None:\n            background = obbs.sem_id == background_sem_id\n            obbs = obbs[~background.squeeze(-1), :]\n        if obbs.shape[0] > 0:\n            # Draw 2D bounding box.\n            if not draw_label_2d or sem_id_to_name_mapping is None:\n                labels = None\n            else:\n                sem_id = obbs.sem_id.squeeze(-1)\n                labels = [sem_id_to_name_mapping[int(si)] for si in sem_id]\n                if draw_inst_id:\n                    inst_ids = obbs.inst_id.squeeze(-1)\n                    labels = [f\"{inst}:{n}\" for inst, n in zip(inst_ids, labels)]\n            if draw_bb2:\n                viz = draw_bb2s(\n                    viz,\n                    obbs.bb2(aria_cam_id),\n                    bb2s_center=obbs.get_bb2_centers(aria_cam_id),\n                    labels=labels,\n                )\n\n            if draw_conic and cam and T_world_rig:\n                pass\n\n            # Draw 3D bounding box (requires poses from VIO).\n            if cam and T_world_rig and (draw_bb3 or draw_bb3_center):\n                if not post_rotate_viz and rotate_viz:\n                    cam = cam.rotate_90_cw()\n                viz = draw_bb3s(\n                    viz,\n                    T_world_rig,\n                    cam,\n                    obbs,\n                    draw_bb3_center=draw_bb3_center,\n                    draw_bb3=draw_bb3,\n                    draw_label=draw_label,\n                    draw_cosy=draw_cosy,\n                    draw_score=draw_score,\n                    render_obb_corner_steps=render_obb_corner_steps,\n                    sem_id_to_name_mapping=sem_id_to_name_mapping,\n                    rotate_label=rotate_viz,\n                    colors=colors,\n                    white_backing_line=white_backing_line,\n                    draw_inst_id=draw_inst_id,\n                )\n\n    # Rotate everything before displaying.\n    if post_rotate_viz and rotate_viz:\n        viz = cv2.rotate(viz, cv2.ROTATE_90_CLOCKWISE)\n    return viz\n\n\ndef draw_obbs_snippet(\n    imgs: torch.Tensor,\n    obbs_padded: ObbTW,\n    Ts_world_rig: PoseTW = None,\n    cams: CameraTW = None,\n    aria_cam_id: Literal[0, 1, 2] = 0,\n    draw_bb2=True,\n    draw_bb3=True,\n    draw_bb3_center=False,\n    render_obb_corner_steps=10,\n    post_rotate_viz=True,  # whether to rotate the image 90 degrees before (pre) or after (post) rendering 3d bbs; only for debugging. resulting image should be the same!\n    rgb2bgr=True,\n    rotate_viz=True,\n    background_sem_id: int = None,\n    prob_threshold: float = 0.5,\n    sem_id_to_name_mapping: Dict[int, str] = None,\n    draw_label: bool = False,\n    draw_label_2d: bool = False,  # Draw label on 2D viz also.\n    white_backing_line: bool = True,\n    draw_cosy: bool = True,\n    draw_score: bool = False,\n    draw_inst_id: bool = False,\n    draw_conic: bool = False,\n):\n    assert imgs.dim() == 4, f\"snippet input must be 4D tensor {imgs.shape}\"\n    T = imgs.shape[0]\n    viz = []\n    for t in range(T):\n        if obbs_padded.dim() == 2:\n            cur_obbs_padded = obbs_padded\n        elif obbs_padded.dim() == 3:\n            cur_obbs_padded = obbs_padded[t]\n        else:\n            raise ValueError(\n                f\"obbs_padded must have 2 or 3 dimensions {obbs_padded.shape}\"\n            )\n\n        viz.append(\n            draw_obbs_image(\n                img=imgs[t],\n                obbs_padded=cur_obbs_padded,\n                T_world_rig=Ts_world_rig[t],\n                cam=cams[t],\n                aria_cam_id=aria_cam_id,\n                draw_bb2=draw_bb2,\n                draw_bb3=draw_bb3,\n                draw_bb3_center=draw_bb3_center,\n                render_obb_corner_steps=render_obb_corner_steps,\n                post_rotate_viz=post_rotate_viz,\n                rgb2bgr=rgb2bgr,\n                rotate_viz=rotate_viz,\n                background_sem_id=background_sem_id,\n                prob_threshold=prob_threshold,\n                sem_id_to_name_mapping=sem_id_to_name_mapping,\n                draw_label=draw_label,\n                draw_label_2d=draw_label_2d,\n                white_backing_line=white_backing_line,\n                draw_cosy=draw_cosy,\n                draw_score=draw_score,\n                draw_inst_id=draw_inst_id,\n                draw_conic=draw_conic,\n            )\n        )\n    return viz\n\n\ndef discretize_values(values: torch.Tensor, precision: int):\n    \"\"\"\n    Discretize the values of an input tensor with a certain precision. The lower the precision, the coarser the output.\n    The function is added to better rendering a dense pointcloud.\n    \"\"\"\n    d_values = (values * precision).int()\n    d_values = (torch.unique(d_values, dim=0) / precision).float()\n    return d_values\n"
  },
  {
    "path": "efm3d/utils/rescale.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport cv2\nimport numpy as np\nimport torch\nfrom efm3d.aria import CameraTW, ObbTW\nfrom efm3d.aria.aria_constants import RESOLUTION_MAP\n\n\ndef get_crops_scale(\n    W: int,\n    H: int,\n    cam_name: Literal[\"rgb\", \"slaml\", \"slamr\"],\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n):\n    # Pre-cropping is universal to all down_scaling.\n    # Handle RGB properly with binning\n    pre_crop = None\n    if cam_name == \"rgb\" and W == 2880 and H == 2880:\n        # crop image to 2816x2816\n        pre_crop = [32, 32]\n        H, W = H - 64, W - 64\n\n    if down_scale in [0, 1, 2, 4]:\n        factor = 1\n        if cam_name == \"rgb\":\n            if W == 2816 and H == 2816:\n                # downsample to 1408x1408\n                factor = 2\n            if down_scale > 0:\n                factor = 2 * down_scale * factor\n        else:\n            factor = down_scale\n        if factor <= 1:\n            factor = None\n        if factor:\n            # W, H after scaling down\n            W, H = W // factor, H // factor\n\n        # post-crop to reach size divisible by wh_multiple_of\n        w_crop = (W % wh_multiple_of) // 2\n        h_crop = (H % wh_multiple_of) // 2\n        post_crop = [w_crop, h_crop]\n        # set outputs none if they are not needed\n        if w_crop == 0 and h_crop == 0:\n            post_crop = None\n    elif down_scale in RESOLUTION_MAP:\n        if cam_name == \"rgb\":\n            target_h = RESOLUTION_MAP[down_scale][0]\n            target_w = RESOLUTION_MAP[down_scale][0]\n        elif cam_name in [\"slaml\", \"slamr\"]:\n            target_w = RESOLUTION_MAP[down_scale][1]\n            target_h = RESOLUTION_MAP[down_scale][2]\n        else:\n            raise ValueError(\"Specified cam_name of %s is not supported\" % down_scale)\n\n        if target_h % wh_multiple_of != 0 or target_w % wh_multiple_of != 0:\n            raise ValueError(\n                f\"only wh_multiple_of 16 is guaranteed when using scale_down == [5,6,7,8,9] {target_h} % {wh_multiple_of}\"\n            )\n\n        # This rescale factor can be non-integer.\n        factor_w = W / target_w\n        factor_h = H / target_h\n        assert factor_w == factor_h, (\n            \"rescale factor must maintain original aspect ratio\"\n        )\n        factor = factor_w\n        post_crop = None\n    else:\n        raise ValueError(\"Specified down_scale of %d is not supported\" % down_scale)\n\n    return pre_crop, factor, post_crop\n\n\ndef rescale_camera_tw(\n    cam: CameraTW,\n    cam_size_before,  # tuple of (height, width, ...)\n    cam_name: Literal[\"rgb\", \"slaml\", \"slamr\"],\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n):\n    \"\"\"\n    Rescale CameraTW tensors by passing the camera size, camera name, and a down scale factor.\n    cam shape should be [..., N] where N is the valid camera calibration dimension (25 or 33)\n    \"\"\"\n\n    H, W = cam_size_before[:2]\n\n    if (cam.c > 1000.0).any():\n        # it can happen that the calibration was stored with respect to the full\n        # 2880 x 2880 resolution although the rgb video stream is binned to 1408 x\n        # 1408. We catch this by looking at the principal point which should be\n        # about [704, 704] and fix the calibration.\n        H, W = 2880, 2880\n        if cam.valid_radius[0].item() < 1000.0:\n            # it is likely that the valid_radius was set on the wrong cam\n            # size (2x too small) so we fix it here.\n            cam.set_valid_radius(cam.valid_radius * 2.0)\n\n    pre_crop, factor, post_crop = get_crops_scale(\n        W, H, cam_name, down_scale, wh_multiple_of\n    )\n    if pre_crop:\n        # new width and height after center crop\n        W, H = W - 2 * pre_crop[0], H - 2 * pre_crop[1]\n        cam = cam.crop(pre_crop, (W, H))\n    if factor:\n        cam = cam.scale(1.0 / factor)\n        # after scaling\n        W, H = W // factor, H // factor\n    if post_crop:\n        # new width and height after center crop\n        W, H = W - 2 * post_crop[0], H - 2 * post_crop[1]\n        cam = cam.crop(post_crop, (W, H))\n\n    return cam\n\n\ndef rescale_calib(\n    calib,\n    cam_size_before,  # tuple of (height, width, ...)\n    cam_name: Literal[\"rgb\", \"slaml\", \"slamr\"],\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n):\n    \"\"\"\n    rescale raw camera parameters\n    \"\"\"\n    # fisheye264\n    assert calib.shape[-1] == 15\n\n    H, W = cam_size_before[:2]\n    # it can happen that the calibration was stored with respect to the full\n    # 2880 x 2880 resolution although the rgb video stream is binned to 1408 x\n    # 1408. We catch this by looking at the principal point which should be\n    # about [704, 704] and fix the calibration.\n    if (calib[1:3] > 1000.0).any():\n        H, W = 2880, 2880\n\n    pre_crop, factor, post_crop = get_crops_scale(\n        W, H, cam_name, down_scale, wh_multiple_of\n    )\n    if pre_crop:\n        calib[1:3] = calib[1:3] - np.array(pre_crop)\n    if factor:\n        calib[0] = calib[0] / factor\n        calib[1:3] = (calib[1:3] + 0.5) / factor - 0.5\n    if post_crop:\n        calib[1:3] = calib[1:3] - np.array(post_crop)\n\n    return calib\n\n\ndef rescale_image(\n    img,\n    cam_name: Literal[\"rgb\", \"slaml\", \"slamr\"],\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n):\n    H, W = img.shape[:2]\n    pre_crop, factor, post_crop = get_crops_scale(\n        W, H, cam_name, down_scale, wh_multiple_of\n    )\n    if pre_crop:\n        img = img[pre_crop[1] : H - pre_crop[1], pre_crop[0] : W - pre_crop[0], ...]\n\n    if factor:\n        # When factor is integer, then cv2.INTER_AREA behaves identically\n        # to skimage.downscale_local_mean, as described in the blog post:\n        # https://medium.com/@wenrudong/what-is-opencvs-inter-area-actually-doing-282a626a09b3\n        H, W = img.shape[:2]\n        target_wh = int(round(W / factor)), int(round(H / factor))\n        orig_ndim = img.ndim\n        img = cv2.resize(img, target_wh, interpolation=cv2.INTER_AREA)\n        if orig_ndim == 3 and img.ndim == 2:\n            img = np.expand_dims(img, axis=2)  # Preserve HxWx1 vs HxW to match input.\n    if post_crop:\n        H, W = img.shape[:2]\n        img = img[post_crop[1] : H - post_crop[1], post_crop[0] : W - post_crop[0], ...]\n    return img\n\n\ndef rescale_image_tensor(\n    img,\n    cam_name: Literal[\"rgb\", \"slaml\", \"slamr\"],\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n    interpolate_mode: str = \"bilinear\",\n):\n    \"\"\"\n    Rescale the Aria image tensor. `img` is a torch Tensor, which is expected to have\n    [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n    `down_scale` specifies the degree of down-sampling.\n    \"\"\"\n    from torchvision.transforms.functional import InterpolationMode, resize\n\n    str2torchvision_mapping = {\n        \"nearest\": InterpolationMode.NEAREST,\n        \"nearest-exact\": InterpolationMode.NEAREST_EXACT,\n        \"bilinear\": InterpolationMode.BILINEAR,\n        \"bicubic\": InterpolationMode.BICUBIC,\n        \"box\": InterpolationMode.BOX,\n        \"hamming\": InterpolationMode.HAMMING,\n        \"lanczos\": InterpolationMode.LANCZOS,\n    }\n\n    H, W = img.shape[-2:]\n    pre_crop, factor, post_crop = get_crops_scale(\n        W, H, cam_name, down_scale, wh_multiple_of\n    )\n    if pre_crop:\n        img = img[..., pre_crop[1] : H - pre_crop[1], pre_crop[0] : W - pre_crop[0]]\n\n    if factor:\n        H, W = img.shape[-2:]\n        target_hw = int(round(H / factor)), int(round(W / factor))\n        img = resize(\n            img,\n            target_hw,\n            interpolation=str2torchvision_mapping[interpolate_mode],\n            antialias=True,\n        )\n    if post_crop:\n        H, W = img.shape[-2:]\n        img = img[..., post_crop[1] : H - post_crop[1], post_crop[0] : W - post_crop[0]]\n    return img\n\n\ndef rescale_depth_img(\n    depth_img, scale_down, filter_boundary=True, valid=None, wh_multiple_of: int = 16\n):\n    # Use torch to re-scale since opencv doesn't re-scale.\n    # And make sure it's 1xHxW\n    depth_img = torch.tensor(depth_img).squeeze().unsqueeze(0)\n    depth_img_rescale = rescale_image_tensor(\n        depth_img,\n        \"rgb\",\n        scale_down,\n        wh_multiple_of=wh_multiple_of,\n        interpolate_mode=\"nearest\",\n    )\n    if not filter_boundary:\n        return depth_img_rescale\n\n    # Change the mask to float to capture the boundaries of invalid area.\n    if valid is None:\n        d_mask = (depth_img > 0).float()\n    else:\n        d_mask = torch.tensor(valid).float().unsqueeze(0)\n\n    d_mask_rescale = rescale_image_tensor(\n        d_mask,\n        \"rgb\",\n        scale_down,\n        wh_multiple_of=wh_multiple_of,\n        interpolate_mode=\"nearest\",\n    )\n    # only the mask pixels which are close to 1.0 is the valid ones.\n    depth_img_rescale[abs(d_mask_rescale - 1.0) > 1e-5] = 0.0\n    return depth_img_rescale\n\n\ndef rescale_obb_tw(\n    obbs: ObbTW,\n    cam_size_before_rgb,\n    cam_size_before_slam,\n    down_scale: Literal[0, 1, 2, 4, 5, 6, 7, 8, 9] = 0,\n    wh_multiple_of: int = 16,\n):\n    \"\"\"\n    Rescale ObbTW 2d bb tensors by passing the camera size, camera name, and a down scale factor.\n    \"\"\"\n    H_rgb, W_rgb = cam_size_before_rgb[:2]\n    H_slam, W_slam = cam_size_before_slam[:2]\n    pre_crop_rgb, factor_rgb, post_crop_rgb = get_crops_scale(\n        W_rgb, H_rgb, \"rgb\", down_scale, wh_multiple_of\n    )\n    pre_crop_slam, factor_slam, post_crop_slam = get_crops_scale(\n        W_slam, H_slam, \"slaml\", down_scale, wh_multiple_of\n    )\n    if pre_crop_rgb or pre_crop_slam:\n        if not pre_crop_rgb:\n            pre_crop_rgb = [0, 0]\n        if not pre_crop_slam:\n            pre_crop_slam = [0, 0]\n        obbs = obbs.crop_bb2(left_top_rgb=pre_crop_rgb, left_top_slam=pre_crop_slam)\n    if factor_rgb or factor_slam:\n        if not factor_slam:\n            factor_slam = 1.0\n        if not factor_rgb:\n            factor_slam = 1.0\n        obbs = obbs.scale_bb2(scale_rgb=1.0 / factor_rgb, scale_slam=1.0 / factor_slam)\n    if post_crop_rgb or post_crop_slam:\n        if not post_crop_rgb:\n            post_crop_rgb = [0, 0]\n        if not post_crop_slam:\n            post_crop_slam = [0, 0]\n        obbs = obbs.crop_bb2(left_top_rgb=post_crop_rgb, left_top_slam=post_crop_slam)\n\n    return obbs\n"
  },
  {
    "path": "efm3d/utils/viz.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport platform\nfrom typing import Optional, Tuple, Union\n\nimport moderngl\nimport numpy as np\nimport torch\nfrom efm3d.aria.aria_constants import (\n    ARIA_CALIB,\n    ARIA_CALIB_TIME_NS,\n    ARIA_DISTANCE_M,\n    ARIA_DISTANCE_M_PRED,\n    ARIA_IMG_T_SNIPPET_RIG,\n    ARIA_MESH_FACES,\n    ARIA_MESH_VERT_NORMS_W,\n    ARIA_MESH_VERTS_W,\n    ARIA_OBB_PADDED,\n    ARIA_OBB_PRED_VIZ,\n    ARIA_OBB_TRACKED,\n    ARIA_OBB_UNINST,\n    ARIA_POINTS_WORLD,\n    ARIA_POSE_T_SNIPPET_RIG,\n    ARIA_POSE_TIME_NS,\n    ARIA_SNIPPET_T_WORLD_SNIPPET,\n)\nfrom efm3d.aria.camera import CameraTW\nfrom efm3d.aria.obb import BB3D_LINE_ORDERS, OBB_LINE_INDS, OBB_MESH_TRI_INDS, ObbTW\nfrom efm3d.aria.pose import PoseTW\nfrom efm3d.utils.common import sample_nearest\nfrom efm3d.utils.depth import dist_im_to_point_cloud_im\nfrom efm3d.utils.gravity import gravity_align_T_world_cam, GRAVITY_DIRECTION_VIO\nfrom efm3d.utils.render import discretize_values, get_colors_from_sem_map\nfrom PIL import Image\nfrom torch.nn import functional as F\n\n# mapping from edge ids to colors for visualizing the xyz axes\nAXIS_COLORS_GL = {\n    0: (1.0, 0.0, 0.0, 1.0),  # red\n    3: (0.0, 1.0, 0.0, 1.0),  # green\n    8: (0.0, 0.0, 1.0, 1.0),  # blue\n}  # use RGB for xyz axes respectively\n\n\ndef render_points(pts, rgba, prog=None, ctx=None, point_size=1.0, scene=None):\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if pts.shape[0] == 0:\n        return\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n    prog[\"global_color\"].value = rgba\n    prog[\"point_size\"].value = point_size\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    vao = ctx.vertex_array(prog, [(vbo, \"3f\", \"in_vert\")])\n    vao.render(moderngl.POINTS)\n\n    vao.release()\n    vbo.release()\n\n\ndef render_cubes(centers, bb3_halfdiag, prog, ctx, rgb=None):\n    cs = centers.reshape(-1, 3)\n    offs = [\n        torch.tensor([-1.0, -1.0, -1.0], device=cs.device),\n        torch.tensor([1.0, -1.0, -1.0], device=cs.device),\n        torch.tensor([1.0, 1.0, -1.0], device=cs.device),\n        torch.tensor([-1.0, 1.0, -1.0], device=cs.device),\n        torch.tensor([-1.0, -1.0, 1.0], device=cs.device),\n        torch.tensor([1.0, -1.0, 1.0], device=cs.device),\n        torch.tensor([1.0, 1.0, 1.0], device=cs.device),\n        torch.tensor([-1.0, 1.0, 1.0], device=cs.device),\n    ]\n\n    offs = torch.stack(offs, dim=0)\n    corners = (\n        cs.unsqueeze(1) + (offs * bb3_halfdiag.unsqueeze(0)).unsqueeze(0)\n    ).clone()\n    tris = (\n        torch.tensor(OBB_MESH_TRI_INDS, dtype=torch.int32, device=cs.device)\n        .transpose(1, 0)\n        .unsqueeze(0)\n    )\n    tris_offset = 8 * torch.arange(\n        0, corners.shape[0], dtype=torch.int32, device=cs.device\n    ).view(-1, 1, 1)\n    tris = (tris + tris_offset).clone()\n    normals = F.normalize((offs * bb3_halfdiag.unsqueeze(0)), 2.0, -1)\n    normals = normals.unsqueeze(0).repeat(corners.shape[0], 1, 1).clone()\n\n    # render_rgb_points(corners, normals, prog, ctx)\n    if rgb is not None:\n        render_rgb_tri_mesh(corners, normals, tris, rgb, prog, ctx)\n    else:\n        render_tri_mesh(corners, normals, tris, prog, ctx)\n\n\ndef render_tri_mesh(pts, normals, tris, prog, ctx):\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if isinstance(tris, torch.Tensor):\n        tris = tris.detach().cpu().numpy()\n    if isinstance(normals, torch.Tensor):\n        normals = normals.detach().cpu().float().numpy()\n    if pts.shape[0] == 0:\n        return\n    prog[\"point_size\"].value = 1.0\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    nbo = ctx.buffer(normals.astype(\"float32\").tobytes())\n    ibo = ctx.buffer(tris.astype(\"int32\").tobytes())\n    vao = ctx.vertex_array(\n        prog, [(vbo, \"3f\", \"in_vert\"), (nbo, \"3f\", \"in_normal\")], ibo\n    )\n    vao.render(moderngl.TRIANGLES)\n\n    vao.release()\n    ibo.release()\n    nbo.release()\n    vbo.release()\n\n\ndef render_rgb_tri_mesh(pts, normals, tris, rgb, prog, ctx):\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if isinstance(tris, torch.Tensor):\n        tris = tris.detach().cpu().numpy()\n    if isinstance(normals, torch.Tensor):\n        normals = normals.detach().cpu().float().numpy()\n    if isinstance(rgb, torch.Tensor):\n        rgb = rgb.detach().cpu().float().numpy()\n    if pts.shape[0] == 0:\n        return\n    prog[\"point_size\"].value = 1.0\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    nbo = ctx.buffer(normals.astype(\"float32\").tobytes())\n    cbo = ctx.buffer(rgb.astype(\"float32\").tobytes())\n    ibo = ctx.buffer(tris.astype(\"int32\").tobytes())\n    vao = ctx.vertex_array(\n        prog,\n        [(vbo, \"3f\", \"in_vert\"), (nbo, \"3f\", \"in_normal\"), (cbo, \"3f\", \"in_rgb\")],\n        ibo,\n    )\n    vao.render(moderngl.TRIANGLES)\n\n    vao.release()\n    ibo.release()\n    cbo.release()\n    nbo.release()\n    vbo.release()\n\n\ndef render_scalar_field_points(\n    pts,\n    values,\n    prog,\n    ctx,\n    val_min=0.0,\n    val_max=1.0,\n    point_size=1.0,\n    alphas=None,\n):\n    assert pts.shape[-1] == 3, f\"only support 3d points {pts.shape}\"\n    assert pts.numel() == 3 * values.numel(), (\n        f\"pts and values must have same numel {pts.numel()} {values.numel()}, {pts.shape} and {values.shape}\"\n    )\n\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if isinstance(values, torch.Tensor):\n        values = values.detach().cpu().float().numpy()\n    if pts.shape[0] == 0:\n        return\n    if alphas is None:\n        alphas = np.ones_like(values)\n    else:\n        if isinstance(alphas, torch.Tensor):\n            alphas = alphas.detach().cpu().float().numpy()\n        if isinstance(alphas, torch.Tensor):\n            alphas = alphas.detach().cpu().float().numpy()\n    prog[\"max_value\"].value = val_max\n    prog[\"min_value\"].value = val_min\n    prog[\"point_size\"].value = point_size\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    vbv = ctx.buffer(values.astype(\"float32\").tobytes())\n    vba = ctx.buffer(alphas.astype(\"float32\").tobytes())\n    vao = ctx.vertex_array(\n        prog, [(vbo, \"3f\", \"in_vert\"), (vbv, \"1f\", \"in_value\"), (vba, \"1f\", \"in_alpha\")]\n    )\n    vao.render(moderngl.POINTS)\n\n    vao.release()\n    vba.release()\n    vbv.release()\n    vbo.release()\n\n\ndef render_rgb_points(\n    pts,\n    rgb,\n    prog,\n    ctx,\n    point_size=1.0,\n):\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if isinstance(rgb, torch.Tensor):\n        rgb = rgb.detach().cpu().float().numpy()\n    if pts.shape[0] == 0:\n        return\n    prog[\"point_size\"].value = point_size\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    cbo = ctx.buffer(rgb.astype(\"float32\").tobytes())\n    vao = ctx.vertex_array(prog, [(vbo, \"3f\", \"in_vert\"), (cbo, \"3f\", \"in_rgb\")])\n    vao.render(moderngl.POINTS)\n\n    vao.release()\n    cbo.release()\n    vbo.release()\n\n\ndef render_linestrip(pts, rgba, prog=None, ctx=None, scene=None):\n    if isinstance(pts, torch.Tensor):\n        pts = pts.detach().cpu().float().numpy()\n    if rgba is None:\n        rgba = (0.0, 0.0, 0.0, 1.0)\n    if pts.shape[0] == 0:\n        return\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n    prog[\"global_color\"].value = rgba\n    vbo = ctx.buffer(pts.astype(\"float32\").tobytes())\n    vao = ctx.vertex_array(prog, vbo, \"in_vert\")\n    vao.render(moderngl.LINE_STRIP)\n\n    vao.release()\n    vbo.release()\n\n\ndef render_line(p0, p1, rgba, prog=None, ctx=None, scene=None):\n    if isinstance(p0, list):\n        p0 = np.array(p0)\n    if isinstance(p1, list):\n        p1 = np.array(p1)\n    if isinstance(p0, torch.Tensor):\n        p0 = p0.detach().cpu().numpy()\n    if isinstance(p1, torch.Tensor):\n        p1 = p1.detach().cpu().numpy()\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n    pts = np.stack([p0, p1])\n    render_linestrip(pts, rgba=rgba, prog=prog, ctx=ctx)\n\n\ndef render_cosy(\n    T: Optional[PoseTW] = None, prog=None, ctx=None, scale: float = 0.1, scene=None\n):\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n    if T is None:\n        T = PoseTW.from_Rt(torch.eye(3), torch.zeros(3))\n    T = T.cpu().detach()\n    ex = (T * torch.tensor([scale, 0.0, 0.0])).squeeze(0)\n    ey = (T * torch.tensor([0.0, scale, 0.0])).squeeze(0)\n    ez = (T * torch.tensor([0.0, 0.0, scale])).squeeze(0)\n    render_line(T.t, ex, rgba=(1.0, 0.0, 0.0, 1.0), prog=prog, ctx=ctx)\n    render_line(T.t, ey, rgba=(0.0, 1.0, 0.0, 1.0), prog=prog, ctx=ctx)\n    render_line(T.t, ez, rgba=(0.0, 0.0, 1.0, 1.0), prog=prog, ctx=ctx)\n\n\ndef render_frustum(\n    T_wr: PoseTW,\n    cam: CameraTW,\n    prog=None,\n    ctx=None,\n    rgba=(0, 0, 0, 1.0),\n    scale=0.2,\n    scene=None,\n):\n    \"\"\"\n    Draw the camera frustum of the given camera cam at the rig pose T_wr.\n    \"\"\"\n    assert T_wr.dim() == 1\n    assert cam.dim() == 1\n    cam = cam.cpu().detach()\n    T_wr = T_wr.cpu().detach()\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n\n    def scaled_unproject(cam, pt2, scale):\n        pt3 = cam.unproject(pt2)[0]\n        pt3 = pt3 / torch.linalg.norm(pt3, dim=-1, keepdim=True)\n        return pt3 * scale\n\n    T_wc = T_wr @ cam.T_camera_rig.inverse()\n    T_wc = T_wc.detach().cpu()\n    c = cam.c\n    rs = cam.valid_radius * 0.7071  # multiply by sqrt(0.5) to get the diagonal\n    # valid get image corners\n    tl = (c + torch.FloatTensor([-rs[0], -rs[1]])).view(1, 1, -1)\n    tr = (c + torch.FloatTensor([-rs[0], rs[1]])).view(1, 1, -1)\n    br = (c + torch.FloatTensor([rs[0], rs[1]])).view(1, 1, -1)\n    bl = (c + torch.FloatTensor([rs[0], -rs[1]])).view(1, 1, -1)\n    # unproject to 3d\n    tl_w = (T_wc * scaled_unproject(cam, tl, scale)).squeeze()\n    tr_w = (T_wc * scaled_unproject(cam, tr, scale)).squeeze()\n    br_w = (T_wc * scaled_unproject(cam, br, scale)).squeeze()\n    bl_w = (T_wc * scaled_unproject(cam, bl, scale)).squeeze()\n    c_w = T_wc.t\n    # get line_strip\n    p3_w = torch.stack(\n        [tl_w, tr_w, br_w, bl_w, tl_w, c_w, tr_w, c_w, br_w, c_w, bl_w, c_w, tl_w], 0\n    )\n    return render_linestrip(p3_w.numpy(), rgba=rgba, prog=prog, ctx=ctx)\n\n\ndef render_obbs_line(\n    obbs: ObbTW,\n    prog=None,\n    ctx=None,\n    rgba=(0.0, 0.0, 0.0, 1.0),\n    colors=None,\n    color_alpha=1.0,\n    line_width=3.0,\n    draw_cosy=False,\n    scene=None,\n):\n    \"\"\"\n    Draw multiple oriented bounding boxes (obbs) each as a set of lines. obbs should be of shape N x C.\n    \"\"\"\n    assert obbs.dim() == 2, f\"{obbs.shape}\"\n    if scene is not None:\n        prog, ctx = scene.prog, scene.ctx\n    old_line_width = ctx.line_width\n    ctx.line_width = line_width\n    for obb in obbs:\n        sem_id = int(obb.sem_id)\n        if colors is not None and sem_id < len(colors):\n            rgb = colors[sem_id]\n            rgba = (rgb[0], rgb[1], rgb[2], color_alpha)\n        if obb.sem_id.item() >= 0:\n            render_obb_line(\n                obb,\n                prog,\n                ctx,\n                rgba=rgba,\n                draw_cosy=draw_cosy,\n            )\n    ctx.line_width = old_line_width\n\n\ndef get_color_from_id(sem_id, max_sem_id, rgba=None):\n    if sem_id:\n        rgba = (0.0, 0.0, 0.0, 1.0)\n    return rgba\n\n\ndef render_obb_line(obb: ObbTW, prog, ctx, rgba=None, draw_cosy=False):\n    \"\"\"\n    Draw line-based oriented bounding box (obb) for a single obb.\n    \"\"\"\n    assert obb.dim() == 1\n    p3_w = obb.bb3corners_world\n    if not draw_cosy:\n        # Draw with linestrip.\n        p3_w_strip = p3_w[OBB_LINE_INDS, :]\n        render_linestrip(p3_w_strip, rgba=rgba, prog=prog, ctx=ctx)\n    else:\n        # Draw lines one by one.\n        p3_w_all = p3_w[BB3D_LINE_ORDERS, :]\n        for i, p3 in enumerate(p3_w_all):\n            if i in AXIS_COLORS_GL:\n                cur_rgba = AXIS_COLORS_GL[i]\n            else:\n                cur_rgba = rgba\n            render_linestrip(p3, rgba=cur_rgba, prog=prog, ctx=ctx)\n\n\nclass SceneView:\n    \"\"\"\n    SceneView is a simple 3D scene renderer using OpenGL.\n    Simply follow the pattern:\n\n    # init the scene\n    sceneView = SceneView(...)\n\n    while something:\n        # clear render buffer\n        sceneView.clear()\n\n        # set view to camera pose\n        sceneView.set_follow_view(T_world_camera)\n        # OR set view to model view matrix (any matrix you want to)\n        sceneView.set_view(MV)\n\n        # do call any rendering functions using scene.ctx and scene.prog\n        ...\n\n        # finish the rendering and obtain the rendered image\n        img = sceneView.finish()\n        # display or save image\n\n    Here is a simple example to render a coordinate system at the origin:\n\n    ```\n    scene = SceneView(width=320, height=320)\n    scene.clear()\n    T_wc = PoseTW()\n    scene.set_default_view(PoseTW(), zoom_factor=6)\n    render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n    img = np.array(scene.finish())\n    ```\n\n    \"\"\"\n\n    def __init__(\n        self,\n        width: int,\n        height: int,\n        z_near: float = 0.1,\n        z_far: float = 1000.0,\n        bg_color: Tuple[float, float, float] = (1.0, 1.0, 1.0),\n    ):\n        \"\"\"\n        Args:\n            width (int): width of rendered image.\n            height (int): height of rendered image.\n            z_near (float): near clipping plane.\n            z_far (float): far clipping plane.\n            bg_color (Tuple[float, float, float]): background color (0-1 range)\n        \"\"\"\n        self.width = width\n        self.height = height\n        self.z_near = z_near\n        self.z_far = z_far\n        self.bg_color = bg_color\n        self.ctx = init_egl_context()\n        if self.ctx is not None:\n            self.prog = simple_shader_program(self.ctx)\n            self.prog_scalar_field = scalar_field_shader_program(self.ctx)\n            self.prog_rgb_point_cloud = rgb_point_cloud_shader_program(self.ctx)\n            self.prog_mesh = mesh_normal_shader_program(self.ctx)\n            self.prog_mesh_rgb = mesh_rgb_shader_program(self.ctx)\n            # attach frame and depth buffer. Depth buffer is important to be able to\n            # do z-buffering!\n            self.fbo1 = self.ctx.framebuffer(\n                self.ctx.renderbuffer((width, height), samples=4),\n                self.ctx.depth_renderbuffer((width, height), samples=4),\n            )\n            self.fbo2 = self.ctx.framebuffer(\n                self.ctx.renderbuffer((width, height)),\n                self.ctx.depth_renderbuffer((width, height)),\n            )\n\n        # setup camera projection for rendering\n        fu, fv = self.width / 0.5, self.height / 0.5\n        self.f = min(fu, fv)\n        self.P = projection_matrix_rdf_top_left(\n            self.width,\n            self.height,\n            self.f,\n            self.f,\n            (self.width - 1.0) / 2,\n            (self.height - 1.0) / 2,\n            self.z_near,\n            self.z_far,\n        )\n\n    def valid(self):\n        return self.ctx is not None\n\n    def clear(self, bg_color: Optional[Tuple[float, float, float]] = None):\n        \"\"\"\n        clear the scene rendering buffer.\n        (call before any rendering!)\n\n        if bg_color is specified then this is used over the one specified during\n        construction.\n        \"\"\"\n        self.fbo1.use()\n        if bg_color is not None:\n            self.ctx.clear(\n                red=bg_color[0], green=bg_color[1], blue=bg_color[2], depth=1e4\n            )\n        else:\n            self.ctx.clear(\n                red=self.bg_color[0],\n                green=self.bg_color[1],\n                blue=self.bg_color[2],\n                depth=1e4,\n            )\n        # enable depth test, point size, blending, and cull backfacing mesh triangles\n        self.ctx.enable(\n            moderngl.DEPTH_TEST\n            | moderngl.PROGRAM_POINT_SIZE\n            | moderngl.BLEND\n            | moderngl.CULL_FACE\n        )\n\n    def set_default_view(self, T_world_camera: PoseTW, zoom_factor: float = 4.0):\n        \"\"\"\n        set view to follow given T_world_camera behind and to the right of the T_wc.\n        \"\"\"\n        mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-1, -1, -2])\n        self.set_view(mv)\n\n    def set_follow_view(self, T_world_camera: PoseTW, zoom_factor: float = 4.0):\n        \"\"\"\n        set view to follow given T_world_camera.\n        \"\"\"\n        mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-1, 0, -2])\n        self.set_view(mv)\n\n    def set_birds_eye_view(self, T_world_camera: PoseTW, zoom_factor: float = 6.0):\n        \"\"\"\n        set view to a birds eye view given T_world_camera.\n        \"\"\"\n        mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-2, 0, -0.0001])\n        T_ahead = PoseTW.from_Rt(torch.eye(3), torch.tensor([0, -2, 0]))\n        mv = T_ahead.matrix.numpy() @ mv\n        self.set_view(mv)\n\n    def set_side_view(self, T_world_camera: PoseTW, zoom_factor: float = 6.0):\n        \"\"\"\n        set view to the left side of T_world_camera\n        \"\"\"\n        mv = get_mv(T_world_camera, zoom_factor=zoom_factor, position=[-0, 2, -0.0001])\n        T_ahead = PoseTW.from_Rt(torch.eye(3), torch.tensor([-2.5, 0, 0]))\n        mv = T_ahead.matrix.numpy() @ mv\n        self.set_view(mv)\n\n    def set_birds_eye_view_from_bb(\n        self, bb_scene_xyzxyz: torch.Tensor, zoom_factor: float = 6.0\n    ):\n        \"\"\"\n        set view to a birds eye view given bounding volume of scene\n        assumes gravity aligned coordinate system with z=up\n        \"\"\"\n        bb_min = bb_scene_xyzxyz[:3]\n        bb_max = bb_scene_xyzxyz[3:]\n        bb_diag = bb_max - bb_min\n        bb_center = (bb_max + bb_min) * 0.5\n        up = torch.tensor([0, 0, 1])\n        dz = bb_diag[0] * self.f / self.width\n        dz = max(dz, bb_diag[1] * self.f / self.height)\n        dz += bb_diag[2] * 0.5\n        dir_max = bb_diag / F.normalize(bb_diag, p=2, dim=0)\n        eye = bb_center + up * zoom_factor * dz\n        eye = bb_center + dir_max * zoom_factor * dz\n        eye = bb_max + up * zoom_factor * dz\n        mv = model_view_look_at_rdf(eye.numpy(), bb_center.numpy(), -up.numpy())\n        self.set_view(mv)\n\n    def set_view(self, mv: Union[PoseTW, np.array]):\n        \"\"\"\n        set view to model view matrix.\n        \"\"\"\n        if isinstance(mv, PoseTW):\n            mv = mv.matrix.numpy()\n        MVP = self.P @ mv\n        # important to transpose MVP since opengl is column-major!\n        self.prog[\"mvp\"].write(MVP.transpose().astype(\"float32\").tobytes())\n        self.prog_scalar_field[\"mvp\"].write(MVP.transpose().astype(\"float32\").tobytes())\n        self.prog_rgb_point_cloud[\"mvp\"].write(\n            MVP.transpose().astype(\"float32\").tobytes()\n        )\n        self.prog_mesh[\"mvp\"].write(MVP.transpose().astype(\"float32\").tobytes())\n        self.prog_mesh[\"mv\"].write(mv.transpose().astype(\"float32\").tobytes())\n        self.prog_mesh_rgb[\"mvp\"].write(MVP.transpose().astype(\"float32\").tobytes())\n        self.prog_mesh_rgb[\"mv\"].write(mv.transpose().astype(\"float32\").tobytes())\n\n    def finish(self):\n        \"\"\"\n        finish the scene rendering and return the rendered image as a PIL image.\n        (call after all rendering!)\n        \"\"\"\n        self.ctx.copy_framebuffer(self.fbo2, self.fbo1)\n        data = self.fbo2.read(components=3, alignment=1)\n        img = Image.frombytes(\"RGB\", self.fbo2.size, data)\n        img = img.transpose(Image.FLIP_LEFT_RIGHT)\n        return img\n\n\ndef draw_obb_scene_3d(\n    tgt,\n    T_ws,\n    Ts_sr,\n    cams,\n    frame_id=0,\n    tgt_removed=None,\n    sem_ids_to_names=None,\n    prd=None,\n    width=512,\n    height=512,\n    draw_origin=False,\n    draw_trajectory=True,\n    draw_frustum=True,\n    matcher=None,\n    prd_logits=None,\n    p3s_world=None,\n    p3s_pred_world=None,\n    depth_pred=None,\n    # optional scene object - if you want constant GPU memory allocate them once\n    # outside and pass them in to reuse.\n    scene: Optional[SceneView] = None,\n    z_height_clip=None,\n    render_raw_pred=True,\n    render_removed_pred=True,\n    cams_slaml=None,\n    cams_slamr=None,\n    zoom_factor=4.0,\n    bird_eye_view=False,\n    scene_mesh_v=None,\n    scene_mesh_f=None,\n    scene_mesh_n=None,\n    scene_mesh_T_wv=None,\n):\n    \"\"\"\n    Draw a 3D scene of obbs, camera trajectory and camera frustum. The scene is\n    selected via the frame_id which indexes into the snippet variables Ts_wr,\n    cams which are TxC.\n\n    Args:\n      tgt: target obbs whose bounding boxes are to be drawn\n      Ts_wr: camera trajectory\n      cams: camera calibrations\n      frame_id: frame index to select from Ts_wr and cams\n      sem_ids_to_names: a dict mapping sem ids to names\n      prd:  predicted obbs (if any) who are to be drawn. These are optional and\n            meant to allow comparing two sets of bounding boxes.\n      width: width of figure (only needed if scene is not provided)\n      height: height of figure (only needed if scene is not provided)\n      draw_origin: if True, draw the origin of the scene\n      draw_trajectory: if True, draw the camera trajectory\n      draw_frustum: if True, draw the camera frustum\n      matcher: a function that takes tgt ObbTW, prd ObbTW and prd_logits and returns a list of matching ids to draw (HungarianMatcher)\n      prd_logits: a list of logits matching the prd ObbTWs for the matcher\n      bg_color: background color for the rendered scene\n      scene: optional scene to draw into (if not provided instantiated internally)\n      z_height_clip: z clip to limit the points of the scene to below this height (remove ceilings for better viz)\n      cams_slaml: camera calibrations of slam left camera\n      cams_slamr: camera calibrations of slam right camera\n    Returns:\n      fig: plotly figure with all the drawings\n    \"\"\"\n\n    if scene is None:\n        scene = SceneView(width=width, height=height)\n\n    cam = cams[frame_id].cpu()\n    cam_slaml = cams_slaml[frame_id].cpu() if cams_slaml is not None else None\n    cam_slamr = cams_slamr[frame_id].cpu() if cams_slamr is not None else None\n    if p3s_world is not None:\n        p3s_world = p3s_world[frame_id].cpu() if p3s_world.ndim == 3 else p3s_world\n    if p3s_pred_world is not None:\n        p3s_pred_world = (\n            p3s_pred_world[frame_id].cpu()\n            if isinstance(p3s_pred_world, list)\n            else p3s_pred_world\n        )\n    if depth_pred is not None:\n        depth_pred = (\n            depth_pred[frame_id].cpu() if isinstance(depth_pred, list) else depth_pred\n        )\n    pose_id = frame_id\n    Ts_wr = T_ws @ Ts_sr\n    Ts_wr = Ts_wr.cpu()\n    if cams.shape[0] != Ts_wr.shape[0]:\n        pose_id = round(Ts_wr.shape[0] * (float(frame_id) / float(cams.shape[0])))\n    Ts_wc = Ts_wr @ cam.T_camera_rig.inverse()\n    T_wr = Ts_wr[pose_id]\n    T_wc = Ts_wc[pose_id]\n\n    if tgt is not None and tgt.ndim == 3:\n        tgt = tgt[frame_id]\n    tgt = tgt.cpu() if tgt is not None else None\n\n    colors = None\n    if sem_ids_to_names is not None:\n        # needs to color to be in scale [0,1]\n        colors = get_colors_from_sem_map(sem_ids_to_names, scale_to_255=False)\n\n    # setup framebuffer for rendering\n    scene.clear()\n    if not bird_eye_view:\n        scene.set_follow_view(T_wc, zoom_factor=zoom_factor)\n    else:\n        scene.set_birds_eye_view(T_wc, zoom_factor=zoom_factor)\n\n    # draw target obbs\n    if tgt is not None and tgt.shape[0] > 0:\n        render_obbs_line(\n            tgt,\n            scene.prog,\n            scene.ctx,\n            rgba=(1.0, 0.0, 0.0, 1.0),\n            colors=colors,\n        )\n    if render_removed_pred and tgt_removed is not None:\n        if tgt_removed.ndim == 3:\n            tgt_removed = tgt_removed[frame_id]\n        render_obbs_line(\n            tgt_removed.cpu(),\n            scene.prog,\n            scene.ctx,\n            rgba=(0.75, 0.75, 0.75, 0.3),\n        )\n\n    if render_raw_pred and prd is not None:\n        if prd.ndim == 3:\n            prd = prd[frame_id]\n        # change the alpha value of the predictions when we have target obbs.\n        if tgt is not None and tgt.shape[0] > 0:\n            color_alpha = 0.3\n        else:\n            color_alpha = 1.0\n        # draw predicted obbs\n        render_obbs_line(\n            prd.cpu(),\n            scene.prog,\n            scene.ctx,\n            colors=colors,\n            color_alpha=color_alpha,\n        )\n\n    if draw_trajectory:\n        # draw rig trajectory\n        render_linestrip(\n            Ts_wr.t, rgba=(0.0, 0.0, 0.0, 1.0), prog=scene.prog, ctx=scene.ctx\n        )\n        # draw the current rig pose\n        render_cosy(T_wr, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n        # draw the snippet origin\n        # n the case of frames coming from different snippets, e.g. T_ws is [10, 12]\n        # take the first T_ws as the snippet origin.\n        if T_ws.shape[0] > 0:\n            T_ws = T_ws[0:1]\n        render_cosy(T_ws, ctx=scene.ctx, prog=scene.prog, scale=0.3)\n\n    if p3s_world is not None:\n        if z_height_clip is not None:\n            keep = p3s_world[:, 2] < z_height_clip\n            p3s_world = p3s_world[keep]\n        # draw 3d points\n        render_points(\n            p3s_world,\n            (0.1, 0.1, 0.1, 1.0),\n            prog=scene.prog,\n            ctx=scene.ctx,\n            point_size=1.2,\n        )\n\n    if scene_mesh_v is not None:\n        verts_w = scene_mesh_T_wv * scene_mesh_v.to(scene_mesh_T_wv.device)\n        normals_w = scene_mesh_T_wv.rotate(scene_mesh_n.to(scene_mesh_T_wv.device))\n        render_tri_mesh(\n            verts_w,\n            normals_w,\n            scene_mesh_f,\n            prog=scene.prog_mesh,\n            ctx=scene.ctx,\n        )\n\n    if p3s_pred_world is not None:\n        if depth_pred is None:\n            # draw 3d points\n            render_points(\n                p3s_pred_world,\n                (0.0, 1.0, 0.0, 1.0),\n                prog=scene.prog,\n                ctx=scene.ctx,\n                point_size=2.0,\n            )\n        else:\n            # draw 3d points colored by depth\n            render_scalar_field_points(\n                p3s_pred_world,\n                depth_pred,\n                prog=scene.prog_scalar_field,\n                ctx=scene.ctx,\n                val_min=0.0,\n                val_max=3.0,\n                point_size=2.0,\n            )\n\n    if draw_frustum:\n        # draw the current frustum\n        render_frustum(\n            T_wr, cam, prog=scene.prog, ctx=scene.ctx, rgba=(0.0, 0.0, 0.0, 1.0)\n        )\n        render_line(\n            T_wr.t, T_wc.t, rgba=(0.0, 0.0, 1.0, 1.0), prog=scene.prog, ctx=scene.ctx\n        )\n\n        if draw_trajectory:\n            # Show smaller frustums along trajectory.\n            for twr in Ts_wr:\n                render_frustum(\n                    twr,\n                    cam,\n                    prog=scene.prog,\n                    ctx=scene.ctx,\n                    rgba=(0.0, 0.0, 0.0, 1.0),\n                    scale=0.08,\n                )\n\n        if cam_slaml is not None:\n            # draw the current frustum\n            render_frustum(\n                T_wr,\n                cam_slaml,\n                prog=scene.prog,\n                ctx=scene.ctx,\n                rgba=(0.0, 0.0, 0.0, 1.0),\n            )\n            T_wcsl = T_wr @ cam_slaml.T_camera_rig.inverse()\n            render_line(\n                T_wr.t,\n                T_wcsl.t,\n                rgba=(0.0, 0.0, 1.0, 1.0),\n                prog=scene.prog,\n                ctx=scene.ctx,\n            )\n        if cam_slamr is not None:\n            # draw the current frustum\n            render_frustum(\n                T_wr,\n                cam_slamr,\n                prog=scene.prog,\n                ctx=scene.ctx,\n                rgba=(0.0, 0.0, 0.0, 1.0),\n            )\n            T_wcsr = T_wr @ cam_slamr.T_camera_rig.inverse()\n            render_line(\n                T_wr.t,\n                T_wcsr.t,\n                rgba=(0.0, 0.0, 1.0, 1.0),\n                prog=scene.prog,\n                ctx=scene.ctx,\n            )\n\n    if draw_origin:\n        # draw the origin cosy\n        render_cosy(PoseTW(), ctx=scene.ctx, prog=scene.prog, scale=1.0)\n\n    if matcher is not None and prd is not None and prd_logits is not None:\n        # draw matches under matcher\n        tgt_sem_id = [tgt.sem_id.squeeze(-1)]\n        indices = matcher(\n            prd_logits.unsqueeze(0),\n            prd.bb3_center_world.unsqueeze(0),\n            tgt_sem_id,\n            [tgt.bb3_center_world],\n        )\n        for p, t in zip(indices[0][0], indices[0][1]):\n            pt0 = prd.bb3_center_world[p]\n            pt1 = tgt.bb3_center_world[t]\n            render_line(pt0, pt1, (0.0, 0.0, 0.0, 1.0), scene.ctx, scene.prog)\n\n    # finish and obtain image\n    img = scene.finish()\n    return img\n\n\ndef draw_snippet_scene_3d(\n    snippet,\n    sem_ids_to_names=None,\n    width=512,\n    height=512,\n    draw_origin=False,\n    frame_id: Optional[int] = None,\n    batch_id: int = 0,\n    # optional scene object - if you want constant GPU memory allocate them once\n    # outside and pass them in to reuse.\n    scene: Optional[SceneView] = None,\n    clean_viz: bool = False,\n    viz_gt_points: bool = True,\n):\n    \"\"\"\n    Draw a 3D scene of obbs and camera trajectory.\n\n    Args:\n      snippet: a snippet dict containing all relevant information for drawing\n      sem_ids_to_names: a dict mapping sem ids to names\n      width: width of figure (only needed if scene is not provided)\n      height: height of figure (only needed if scene is not provided)\n      draw_origin: if True, draw the origin of the scene\n      draw_center: if True, draw the center of the scene\n      return_plotly: if True, return the plotly figures, otherwise return the rendered images.\n      frame_id: if set, only return the image/plotly plot for this frame.\n      batch_id: if we are passing batched inputs, select the batch with this id for rendering.\n      scene: optional scene to draw into (if not provided instantiated internally)\n      viz_gt_points: if there is ground truth depth in the batch, visualize the GT depth instead of the semi-dense points.\n    Returns:\n      fig: plotly figure with all the drawings\n    \"\"\"\n\n    if scene is None:\n        scene = SceneView(width=width, height=height)\n\n    has_slaml = ARIA_CALIB[1] in snippet\n    has_slamr = ARIA_CALIB[2] in snippet\n\n    cams = snippet[ARIA_CALIB[0]].cpu()\n    if has_slaml:\n        cams_slaml = snippet[ARIA_CALIB[1]].cpu()\n    else:\n        cams_slaml = None\n    if has_slamr:\n        cams_slamr = snippet[ARIA_CALIB[2]].cpu()\n    else:\n        cams_slamr = None\n    T_ws = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET].cpu()\n    if ARIA_IMG_T_SNIPPET_RIG[0] in snippet:\n        Ts_sr = snippet[ARIA_IMG_T_SNIPPET_RIG[0]].cpu()\n    elif ARIA_POSE_T_SNIPPET_RIG in snippet:\n        Ts_sr = snippet[ARIA_POSE_T_SNIPPET_RIG].cpu()\n        if Ts_sr.shape[0] != cams.shape[0]:\n            cam_times_ns = snippet[ARIA_CALIB_TIME_NS[0]].tolist()\n            pose_times_ns = snippet[ARIA_POSE_TIME_NS].tolist()\n            Ts_sr = sample_nearest(cam_times_ns, pose_times_ns, Ts_sr)\n    if T_ws.ndim == 3:\n        T_ws = T_ws[batch_id]\n        Ts_sr = Ts_sr[batch_id]\n        cams = cams[batch_id]\n        if has_slaml:\n            cams_slaml = cams_slaml[batch_id]\n        if has_slamr:\n            cams_slamr = cams_slamr[batch_id]\n\n    obbs, prd, uninst = None, None, None\n    if ARIA_OBB_PADDED in snippet:\n        obbs = snippet[ARIA_OBB_PADDED].cpu()\n        obbs = obbs[batch_id] if obbs.ndim == 4 else obbs\n    have_tracked = ARIA_OBB_TRACKED in snippet\n    if have_tracked:\n        obbs = snippet[ARIA_OBB_TRACKED].cpu()\n        obbs = obbs[batch_id] if obbs.ndim == 4 else obbs\n    if ARIA_OBB_PRED_VIZ in snippet:\n        prd = snippet[ARIA_OBB_PRED_VIZ].cpu()\n        prd = prd[batch_id] if prd.ndim == 4 else prd\n    if ARIA_OBB_UNINST in snippet:\n        uninst = snippet[ARIA_OBB_UNINST].cpu()\n        uninst = uninst[batch_id] if uninst.ndim == 4 else uninst\n    p3s_world = None\n\n    # If GT depth exists, visualize GT depth pointcloud instead of semi-dense points.\n    if viz_gt_points and ARIA_DISTANCE_M[0] in snippet:\n        # Note: we only visualize GT depth map of RGB images now.\n        valid_depths = snippet[ARIA_DISTANCE_M[0]].squeeze(1) > 1e-4\n        p3cs, valids = dist_im_to_point_cloud_im(\n            snippet[ARIA_DISTANCE_M[0]].squeeze(1),\n            snippet[ARIA_CALIB[0]],\n        )\n        valids = torch.logical_and(valids, valid_depths)\n        p3cs = p3cs.reshape(p3cs.shape[0], -1, 3)\n        T_s_c = (\n            snippet[ARIA_IMG_T_SNIPPET_RIG[0]]\n            @ snippet[ARIA_CALIB[0]].T_camera_rig.inverse()\n        )\n        T_w_c = snippet[ARIA_SNIPPET_T_WORLD_SNIPPET] @ T_s_c\n        p3ws = T_w_c * p3cs\n        p3ws = p3ws.reshape(-1, 3)\n        valids = valids.reshape(-1)\n        p3ws = p3ws[valids]\n        p3s_world = discretize_values(p3ws, precision=70)\n\n    if ARIA_POINTS_WORLD in snippet:\n        p3s_world = snippet[ARIA_POINTS_WORLD]\n        p3s_world = p3s_world[batch_id] if p3s_world.ndim == 4 else p3s_world\n\n    p3s_pred_world, depth_pred = None, None\n    if ARIA_DISTANCE_M_PRED[0] in snippet:\n        dist_m = snippet[ARIA_DISTANCE_M_PRED[0]].cpu()\n        dist_m = dist_m[batch_id] if dist_m.ndim == 4 else dist_m\n        # scale camera to fit the depth image (in case depth image is at a lower res)\n        cams_depth = cams.scale_to(dist_m)\n        Ts_wc = T_ws @ Ts_sr @ cams.T_camera_rig.inverse()\n        p3s_pred_world, depth_pred = [], []\n        for t in range(dist_m.shape[0]):\n            p3s_c, valids = dist_im_to_point_cloud_im(dist_m[t], cams_depth[t])\n            p3s_pred_world.append(Ts_wc[t] * p3s_c[valids])\n            depth_pred.append(dist_m[t][valids])\n\n    Ts_wr = T_ws @ Ts_sr\n    obbs = obbs.remove_padding() if obbs is not None else None\n    prd = prd.remove_padding() if prd is not None else None\n    uninst = uninst.remove_padding() if uninst is not None else None\n\n    # clip the point cloud 1m above the rig coordinates\n    z_height_clip = Ts_wr.t[..., 2].max() + 1.0\n\n    assert Ts_wr.shape[0] == cams.shape[0], (\n        f\"poses and cameras must have the same length but got {Ts_wr.shape[0]} and {cams.shape[0]}\"\n    )\n    if obbs is not None:\n        assert Ts_wr.shape[0] == len(obbs), (\n            f\"poses and obbs must have the same length {len(obbs)} but got {Ts_wr.shape}\"\n        )\n\n    if frame_id:\n        assert frame_id >= 0 and frame_id < Ts_wr.shape[0]\n    frame_ids = [frame_id] if frame_id else range(Ts_wr.shape[0])\n\n    scene_mesh_v = None\n    scene_mesh_f = None\n    scene_mesh_n = None\n    scene_mesh_T_wv = None\n\n    if ARIA_MESH_VERTS_W in snippet:\n        scene_mesh_v = snippet[ARIA_MESH_VERTS_W].squeeze().cpu().detach().float()\n        scene_mesh_f = snippet[ARIA_MESH_FACES].squeeze().cpu().detach().float()\n        # flip normals to visualize better.\n        scene_mesh_n = -snippet[ARIA_MESH_VERT_NORMS_W].squeeze().cpu().detach().float()\n        scene_mesh_T_wv = PoseTW()\n\n    imgs = []\n    for t in frame_ids:\n        # transform obbs into world coordinates too\n        tgt_w = obbs[t].transform(T_ws) if obbs is not None else None\n        prd_w = prd[t].transform(T_ws) if prd is not None else None\n        uninst_w = uninst[t].transform(T_ws) if uninst is not None else None\n        img = draw_obb_scene_3d(\n            tgt=tgt_w,\n            prd=prd_w,\n            tgt_removed=uninst_w,\n            T_ws=T_ws,\n            Ts_sr=Ts_sr,\n            cams=cams,\n            cams_slaml=cams_slaml,\n            cams_slamr=cams_slamr,\n            frame_id=t,\n            p3s_world=p3s_world,\n            p3s_pred_world=p3s_pred_world,\n            depth_pred=depth_pred,\n            sem_ids_to_names=sem_ids_to_names,\n            width=width,\n            height=height,\n            draw_origin=draw_origin,\n            scene=scene,\n            z_height_clip=z_height_clip,\n            render_raw_pred=(not clean_viz) or (not have_tracked and clean_viz),\n            render_removed_pred=not clean_viz,\n            scene_mesh_v=scene_mesh_v,\n            scene_mesh_f=scene_mesh_f,\n            scene_mesh_n=scene_mesh_n,\n            scene_mesh_T_wv=scene_mesh_T_wv,\n        )\n        imgs.append(np.array(img))\n    return imgs\n\n\ndef normalize(x):\n    return x / (np.linalg.norm(x, 2) + 1e-6)\n\n\n# https://github.com/stevenlovegrove/Pangolin/blob/7776a567f5c7b074668b8abb2316aba3f4b8b568/components/pango_opengl/src/opengl_render_state.cpp#L621\n# e=eye is the eye location in world coordinates (camera position)\n# l=look_at is the look at direction (projects to image center)\n# u=up is the up direction\ndef model_view_look_at_rdf(e, look_at, u):\n    z = normalize(look_at - e)\n    if np.allclose(u - z, np.zeros(3), atol=1e-5):\n        # Add some tiny offset so that cross product is non-zero.\n        z[1] = z[1] + 0.001\n    x = normalize(np.cross(z, u))\n    y = normalize(np.cross(z, x))\n\n    M = np.zeros((4, 4))\n    M[0, 0] = x[0]\n    M[0, 1] = x[1]\n    M[0, 2] = x[2]\n    M[1, 0] = y[0]\n    M[1, 1] = y[1]\n    M[1, 2] = y[2]\n    M[2, 0] = z[0]\n    M[2, 1] = z[1]\n    M[2, 2] = z[2]\n    M[3, 0] = 0.0\n    M[3, 1] = 0.0\n    M[3, 2] = 0.0\n    M[0, 3] = -(M[0, 0] * e[0] + M[0, 1] * e[1] + M[0, 2] * e[2])\n    M[1, 3] = -(M[1, 0] * e[0] + M[1, 1] * e[1] + M[1, 2] * e[2])\n    M[2, 3] = -(M[2, 0] * e[0] + M[2, 1] * e[1] + M[2, 2] * e[2])\n    M[3, 3] = 1.0\n    return M\n\n\ndef get_mv(T_world_cam: PoseTW, zoom_factor: float = 3.0, position=[-1, 0, -2]):\n    \"\"\"\n    T_world_cam is the camera pose in world coordinates that the ModelView Matrix will \"follow\".\n    zoom_factor is the zoom factor for the ModelView Matrix. I.e. from how far\n    above and behind the camera pose we will render the scene. 1.0 is very\n    close, 3.0 is medium (good default) and 6.0 is farther away.\n    \"\"\"\n    # gravity align the camera pose to make rendering videos smoother.\n    T_world_cam = gravity_align_T_world_cam(\n        T_world_cam.clone().unsqueeze(0), gravity_w=GRAVITY_DIRECTION_VIO\n    ).squeeze(0)\n    T_world_cam = T_world_cam.detach().cpu()\n    # center is where \"look at\" position; will project to center of rendering\n    center = T_world_cam.t\n    # eye is the position of the camera center (translation)\n    eye = T_world_cam * (torch.FloatTensor(position) * zoom_factor)\n    # eye = T_world_cam * (torch.FloatTensor([-1,0,-1]) * zoom_factor)\n    eye = eye.squeeze(0)\n    # up is the up direction for the rendering camera. We choose it to be the\n    # negative x-axis of the camera pose. Which works for our 90-deg rotated\n    # cameras on Aria.\n    up = T_world_cam.R[:, 0]\n    # model view matrix\n    mv = model_view_look_at_rdf(eye.numpy(), center.numpy(), up.numpy())\n    return mv\n\n\n# https://github.com/stevenlovegrove/Pangolin/blob/7776a567f5c7b074668b8abb2316aba3f4b8b568/components/pango_opengl/src/opengl_render_state.cpp#L462\n# Camera Axis:\n#   X - Right, Y - Down, Z - Forward\n# Image Origin:\n#   Top Left\n# Pricipal point specified with image origin (0,0) at top left of top-left pixel (not center)\ndef projection_matrix_rdf_top_left(w, h, fu, fv, u0, v0, zNear, zFar):\n    # http://www.songho.ca/opengl/gl_projectionmatrix.html\n    L = -(u0) * zNear / fu\n    R = +(w - u0) * zNear / fu\n    T = -(v0) * zNear / fv\n    B = +(h - v0) * zNear / fv\n\n    P = np.zeros((4, 4))\n    P[0, 0] = 2 * zNear / (R - L)\n    P[1, 1] = 2 * zNear / (T - B)\n    P[0, 2] = (R + L) / (L - R)\n    P[1, 2] = (T + B) / (B - T)\n    P[2, 2] = (zFar + zNear) / (zFar - zNear)\n    P[3, 2] = 1.0\n    P[2, 3] = (2 * zFar * zNear) / (zNear - zFar)\n    return P\n\n\ndef init_egl_context():\n    try:\n        if platform.system() == \"Darwin\":\n            ctx = moderngl.create_context(standalone=True)\n        else:\n            ctx = moderngl.create_context(standalone=True, backend=\"egl\")\n    except Exception as e:\n        print(f\"{e}\")\n        return None\n    return ctx\n\n\ndef simple_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n    uniform float point_size;\n    in vec3 in_vert;\n\n    void main() {\n        gl_Position = mvp * vec4(in_vert, 1.0);\n        gl_PointSize = point_size;\n    }\"\"\"\n    fragment_shader_source = \"\"\"#version 330\n    uniform vec4 global_color;\n    out vec4 f_color;\n\n    void main() {\n        f_color = global_color;\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n\n\ndef mesh_normal_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n    uniform mat4 mv;\n    uniform float point_size;\n    in vec3 in_vert;\n    in vec3 in_normal;\n\n    out vec3 n_c;\n\n    void main() {\n        gl_Position = mvp * vec4(in_vert, 1.0);\n        gl_PointSize = point_size;\n        n_c = transpose(inverse(mat3(mv))) * in_normal;\n    }\"\"\"\n    fragment_shader_source = \"\"\"#version 330\n    in vec3 n_c;\n    out vec4 f_color;\n\n    void main() {\n        f_color = vec4((normalize(n_c) + vec3(1.0, 1.0, 1.0)) / 2.0, 1.0f);\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n\n\ndef mesh_rgb_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n    uniform mat4 mv;\n    uniform float point_size;\n    in vec3 in_vert;\n    in vec3 in_normal;\n    in vec3 in_rgb;\n\n    out vec3 n_c;\n    out vec3 rgb;\n\n    void main() {\n        gl_Position = mvp * vec4(in_vert, 1.0);\n        gl_PointSize = point_size;\n        n_c = transpose(inverse(mat3(mv))) * in_normal;\n        rgb = in_rgb;\n    }\"\"\"\n    fragment_shader_source = \"\"\"#version 330\n    in vec3 n_c;\n    in vec3 rgb;\n    out vec4 f_color;\n\n    void main() {\n        f_color = vec4(rgb * max(n_c.z, 0.0), 1.0);\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n\n\ndef rgb_point_cloud_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n    uniform mat4 mv;\n    uniform float point_size;\n    in vec3 in_vert;\n    in vec3 in_rgb;\n\n    out vec3 rgb;\n\n    void main() {\n        gl_Position = mvp * vec4(in_vert, 1.0);\n        gl_PointSize = point_size;\n        rgb = in_rgb;\n    }\"\"\"\n    fragment_shader_source = \"\"\"#version 330\n    in vec3 rgb;\n    out vec4 f_color;\n\n    void main() {\n        f_color = vec4(rgb, 1.0f);\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n\n\ndef scalar_field_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n    uniform float point_size;\n    uniform float max_value;\n    uniform float min_value;\n    in vec3 in_vert;\n    in float in_value;\n    in float in_alpha;\n    out vec3 frag_rgb;\n    out float frag_a;\n\n    // https://thebookofshaders.com/06/\n    vec3 hsb2rgb( in vec3 c ){\n        vec3 rgb = clamp(abs(mod(c.x*6.0+vec3(0.0,4.0,2.0),6.0)-3.0)-1.0, 0.0, 1.0);\n        rgb = rgb*rgb*(3.0-2.0*rgb);\n        return c.z * mix( vec3(1.0), rgb, c.y);\n    }\n    vec3 hsv(float v) {\n        return hsb2rgb(vec3(v, 1.0, 1.0));\n    }\n\n    // https://github.com/kbinani/colormap-shaders/tree/master\n    // The MIT License (MIT)\n    // Copyright (c) 2015 kbinani\n    // Permission is hereby granted, free of charge, to any person obtaining a copy\n    // of this software and associated documentation files (the \"Software\"), to deal\n    // in the Software without restriction, including without limitation the rights\n    // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n    // copies of the Software, and to permit persons to whom the Software is\n    // furnished to do so, subject to the following conditions:\n    // The above copyright notice and this permission notice shall be included in all\n    // copies or substantial portions of the Software.\n    // THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n    // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n    // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n    // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n    // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n    // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n    // SOFTWARE.\n\n    float jet_red(float x) {\n        if (x < 0.7) {\n            return 4.0 * x - 1.5;\n        } else {\n            return -4.0 * x + 4.5;\n        }\n    }\n    float jet_green(float x) {\n        if (x < 0.5) {\n            return 4.0 * x - 0.5;\n        } else {\n            return -4.0 * x + 3.5;\n        }\n    }\n    float jet_blue(float x) {\n        if (x < 0.3) {\n           return 4.0 * x + 0.5;\n        } else {\n           return -4.0 * x + 2.5;\n        }\n    }\n    vec3 jet(float x) {\n        float r = clamp(jet_red(x), 0.0, 1.0);\n        float g = clamp(jet_green(x), 0.0, 1.0);\n        float b = clamp(jet_blue(x), 0.0, 1.0);\n        return vec3(r, g, b);\n    }\n\n    void main() {\n        float f_value = (in_value - min_value) / (max_value - min_value);\n        f_value = clamp(f_value, 0.0, 1.0);\n        frag_rgb = jet(f_value);\n        frag_a = in_alpha;\n        gl_Position = mvp * vec4(in_vert, 1.0);\n        gl_PointSize = point_size;\n    }\n    \"\"\"\n    fragment_shader_source = \"\"\"#version 330\n    in vec3 frag_rgb;\n    in float frag_a;\n    out vec4 f_color;\n\n    void main() {\n        f_color = vec4(frag_rgb, frag_a);\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n\n\ndef semantic_color_shader_program(ctx):\n    vertex_shader_source = \"\"\"#version 330\n    uniform mat4 mvp;\n\n    in int in_sem_id;\n    in vec3 in_vert;\n    out int sem_id;\n    out vec3 v_vert;\n\n    void main() {\n        v_vert = in_vert;\n        gl_Position = mvp * vec4(v_vert, 1.0);\n        sem_id = in_sem_id;\n    }\"\"\"\n\n    fragment_shader_source = \"\"\"#version 330\n    uniform int sem_max;\n\n    in int sem_id;\n    in vec3 v_vert;\n    out vec4 f_color;\n\n    // https://thebookofshaders.com/06/\n    vec3 hsb2rgb( in vec3 c ){\n        vec3 rgb = clamp(abs(mod(c.x*6.0+vec3(0.0,4.0,2.0),6.0)-3.0)-1.0, 0.0, 1.0);\n        rgb = rgb*rgb*(3.0-2.0*rgb);\n        return c.z * mix( vec3(1.0), rgb, c.y);\n    }\n\n    void main() {\n        sem_hue = sem_id / sem_max;\n        f_color = vec4(hsb2rgb(vec3(sem_hue, 1.0, 1.0)), 1.0);\n    }\n    \"\"\"\n    prog = ctx.program(\n        vertex_shader=vertex_shader_source, fragment_shader=fragment_shader_source\n    )\n    return prog\n"
  },
  {
    "path": "efm3d/utils/voxel.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\n\ndef tensor_wrap_voxel_extent(voxel_extent, B=None, device=\"cpu\"):\n    if isinstance(voxel_extent, torch.Tensor):\n        if B is not None:\n            assert voxel_extent.shape[0] == B\n        return voxel_extent\n    elif isinstance(voxel_extent, list):\n        if B is None:\n            return torch.tensor(voxel_extent, device=device)\n        else:\n            return torch.tensor(voxel_extent, device=device).view(1, 6).repeat(B, 1)\n    else:\n        raise NotImplementedError(f\"type {voxel_extent} not supported\")\n\n\ndef create_voxel_grid(vW, vH, vD, voxel_extent, device=\"cpu\"):\n    \"\"\"\n    Given a bounding box range [x_min, x_max, y_min, y_max, z_min, z_max], and the\n    number of voxels in each dimension [vW, vH, vD], return a voxel center positions.\n    Note that the min and max coordinates are not [x_min, y_min, z_min] and [x_max, y_max, z_max],\n    since they are the bounding range but not the center positions.\n\n    vW: the number of voxels for x-dim\n    vH: the number of voxels for y-dim\n    vD: the number of voxels for z-dim\n    voxel_extent: the bounding box range in [x_min, x_max, y_min, y_max, z_min, z_max]\n    \"\"\"\n    x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n    dW = (x_max - x_min) / vW\n    dH = (y_max - y_min) / vH\n    dD = (z_max - z_min) / vD\n    # take the center position of each voxel\n    rng_x = torch.linspace(x_min + dW / 2, x_max - dW / 2, steps=vW, device=device)\n    rng_y = torch.linspace(y_min + dH / 2, y_max - dH / 2, steps=vH, device=device)\n    rng_z = torch.linspace(z_min + dD / 2, z_max - dD / 2, steps=vD, device=device)\n    xx, yy, zz = torch.meshgrid(rng_x, rng_y, rng_z, indexing=\"ij\")\n    vox_v = torch.stack([xx, yy, zz], axis=-1)\n    return vox_v\n\n\ndef erode_voxel_mask(mask):\n    \"\"\"\n    Erode a given mask by one voxel i.e.\n    0 0 0 0 0    0 0 0 0 0\n    0 1 1 1 0    0 0 0 0 0\n    0 1 1 1 0 -> 0 0 1 0 0\n    0 1 1 1 0    0 0 0 0 0\n    0 0 0 0 0    0 0 0 0 0\n    \"\"\"\n    # B T D H W\n    assert mask.ndim in [4, 5], f\"mask dim needs to be 3 or 4 got {mask.shape}\"\n    kernel = torch.ones((1, 1, 3, 3, 3), device=mask.device)\n    mask = (\n        1.0\n        - torch.clamp(\n            torch.nn.functional.conv3d(1.0 - mask.float(), kernel, padding=\"same\"), 0, 1\n        )\n    ).bool()\n    return mask\n"
  },
  {
    "path": "efm3d/utils/voxel_sampling.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport torch\n\n\ndef pc_to_vox(pc_v, vW, vH, vD, voxel_extent):\n    device = pc_v.device\n    if isinstance(voxel_extent, list):\n        x_min, x_max, y_min, y_max, z_min, z_max = voxel_extent\n        valid = pc_v[..., 0] > x_min\n        valid = torch.logical_and(pc_v[..., 0] < x_max, valid)\n        valid = torch.logical_and(pc_v[..., 1] > y_min, valid)\n        valid = torch.logical_and(pc_v[..., 1] < y_max, valid)\n        valid = torch.logical_and(pc_v[..., 2] > z_min, valid)\n        valid = torch.logical_and(pc_v[..., 2] < z_max, valid)\n        dW = (x_max - x_min) / vW\n        dH = (y_max - y_min) / vH\n        dD = (z_max - z_min) / vD\n        dVox = torch.tensor([dW, dH, dD]).view(1, 3).to(device)\n        vox_min = torch.tensor([x_min, y_min, z_min]).view(1, 3).to(device)\n        pc_id = (pc_v - vox_min) / dVox\n    else:\n        s = pc_v.shape[:-1]\n        B = s[0]\n        vox_min = voxel_extent[..., 0::2].view(B, 1, 3)\n        vox_max = voxel_extent[..., 1::2].view(B, 1, 3)\n        dim = (\n            torch.tensor([vW, vH, vD], device=voxel_extent.device)\n            .view(1, 1, 3)\n            .repeat(B, 1, 1)\n        )\n        dVox = (vox_max - vox_min) / dim\n        pc_v = pc_v.view(B, -1, 3)\n        valid = torch.logical_not(pc_v.isnan().any(-1))\n        valid = torch.logical_and(valid, (pc_v > vox_min).all(-1))\n        valid = torch.logical_and(valid, (pc_v < vox_max).all(-1))\n        pc_id = (pc_v - vox_min) / dVox\n        valid = valid.view(s)\n        pc_id = pc_id.view(list(s) + [3])\n\n    return pc_id, valid\n\n\ndef compute_factor(size):\n    return 1.0 * size / 2\n\n\ndef convert_coordinates_to_voxel(coordinates, factor):\n    return factor * (coordinates + 1.0)\n\n\ndef convert_voxel_to_coordinates(coordinates, factor):\n    return (coordinates / factor) - 1.0\n\n\ndef normalize_keypoints(kpts, depth, height, width):\n    # compute conversion factor\n    x_factor = compute_factor(width)\n    y_factor = compute_factor(height)\n    z_factor = compute_factor(depth)\n    factors = torch.tensor([x_factor, y_factor, z_factor], device=kpts.device).view(\n        [1] * (kpts.ndim - 1) + [3]\n    )\n    pts_dst = convert_voxel_to_coordinates(kpts, factors)\n    return pts_dst\n\n\ndef denormalize_keypoints(kpts, depth, height, width):\n    # compute conversion factor\n    x_factor = compute_factor(width)\n    y_factor = compute_factor(height)\n    z_factor = compute_factor(depth)\n    if isinstance(kpts, torch.Tensor):\n        pts_dst = kpts.clone()\n    elif isinstance(kpts, np.ndarray):\n        pts_dst = kpts.copy()\n    else:\n        raise TypeError(\"must be torch or numpy\")\n    factors = torch.tensor([x_factor, y_factor, z_factor], device=kpts.device).view(\n        [1] * (kpts.ndim - 1) + [3]\n    )\n    pts_dst = convert_coordinates_to_voxel(kpts, factors)\n    return pts_dst\n\n\ndef in_grid(pt_vox, depth, height, width):\n    valid = pt_vox[..., 0] >= 0.5\n    valid = torch.logical_and(pt_vox[..., 0] <= width - 0.5, valid)\n    valid = torch.logical_and(pt_vox[..., 1] >= 0.5, valid)\n    valid = torch.logical_and(pt_vox[..., 1] <= height - 0.5, valid)\n    valid = torch.logical_and(pt_vox[..., 2] >= 0.5, valid)\n    valid = torch.logical_and(pt_vox[..., 2] <= depth - 0.5, valid)\n    return valid\n\n\ndef sample_voxels(feat3d, pts_v, differentiable=False, interp_mode=\"bilinear\"):\n    \"\"\"\n    Sample voxel grid of features at pts_v locations.\n    Args:\n        feat3d: feature volume batches B C D H W\n        pts_v: 3d points in -1 to 1 range in shape compatible with B N 3\n        differentiable: we need this to be differentiable wrt to the pts_v\n    Returns:\n        voxel grid samples in shape B C N\n    \"\"\"\n    assert feat3d.ndim == 5, f\"{feat3d.shape}\"\n    assert pts_v.ndim == 3, f\"{pts_v.shape}\"\n    B, C, D, H, W = feat3d.shape\n    valid = in_grid(pts_v, height=H, width=W, depth=D)\n    # Sample into the 3D feature maps.\n    norm_samp_pts = normalize_keypoints(pts_v.clone(), height=H, width=W, depth=D)\n    if differentiable:\n        # use differentiable implementation of 3d trilinear interpolation.\n        samp_feats = diff_grid_sample(\n            feat3d,\n            norm_samp_pts.view(B, 1, 1, -1, 3),\n            align_corners=False,  # B 1 1 N 3\n        )\n    else:\n        # if we dont need differentiability wrt to sample points then we can use\n        # the default implementation.\n        samp_feats = torch.nn.functional.grid_sample(\n            feat3d,\n            norm_samp_pts.view(B, 1, 1, -1, 3),  # B 1 1 N 3\n            align_corners=False,\n            padding_mode=\"border\",\n            mode=interp_mode,  # important to be differentiable\n        )\n    # squeeze back down the dimension of 1 we unsqueezed for norm_samp_pts to comply with interface\n    samp_feats = samp_feats.view(B, C, -1)\n    return samp_feats, valid\n\n\ndef diff_grid_sample(feature_3d, pts_norm, align_corners=False):\n    N, C, iD, iH, iW = feature_3d.shape\n    _, D, H, W, _ = pts_norm.shape\n    assert not pts_norm.isnan().any(), \"have nan values in pts_norm! not supported\"\n\n    ix = pts_norm[..., 0]\n    iy = pts_norm[..., 1]\n    iz = pts_norm[..., 2]\n\n    if align_corners:\n        ix = ((ix + 1.0) * 0.5) * (iW - 1)\n        iy = ((iy + 1.0) * 0.5) * (iH - 1)\n        iz = ((iz + 1.0) * 0.5) * (iD - 1)\n    else:\n        ix = ((ix + 1.0) * 0.5) * iW - 0.5\n        iy = ((iy + 1.0) * 0.5) * iH - 0.5\n        iz = ((iz + 1.0) * 0.5) * iD - 0.5\n\n    with torch.no_grad():\n        ix_tnw = torch.floor(ix)\n        iy_tnw = torch.floor(iy)\n        iz_tnw = torch.floor(iz)\n\n        ix_tne = ix_tnw + 1\n        iy_tne = iy_tnw\n        iz_tne = iz_tnw\n\n        ix_tsw = ix_tnw\n        iy_tsw = iy_tnw + 1\n        iz_tsw = iz_tnw\n\n        ix_tse = ix_tnw + 1\n        iy_tse = iy_tnw + 1\n        iz_tse = iz_tnw\n\n        ix_bnw = ix_tnw\n        iy_bnw = iy_tnw\n        iz_bnw = iz_tnw + 1\n\n        ix_bne = ix_tnw + 1\n        iy_bne = iy_tnw\n        iz_bne = iz_tnw + 1\n\n        ix_bsw = ix_tnw\n        iy_bsw = iy_tnw + 1\n        iz_bsw = iz_tnw + 1\n\n        ix_bse = ix_tnw + 1\n        iy_bse = iy_tnw + 1\n        iz_bse = iz_tnw + 1\n\n    bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse)\n    bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw)\n    bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne)\n    bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw)\n\n    tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz)\n    tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz)\n    tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz)\n    tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz)\n\n    with torch.no_grad():\n        torch.clamp(ix_bnw, 0, iW - 1, out=ix_bnw)\n        torch.clamp(iy_bnw, 0, iH - 1, out=iy_bnw)\n        torch.clamp(iz_bnw, 0, iD - 1, out=iz_bnw)\n\n        torch.clamp(ix_bne, 0, iW - 1, out=ix_bne)\n        torch.clamp(iy_bne, 0, iH - 1, out=iy_bne)\n        torch.clamp(iz_bne, 0, iD - 1, out=iz_bne)\n\n        torch.clamp(ix_bsw, 0, iW - 1, out=ix_bsw)\n        torch.clamp(iy_bsw, 0, iH - 1, out=iy_bsw)\n        torch.clamp(iz_bsw, 0, iD - 1, out=iz_bsw)\n\n        torch.clamp(ix_bse, 0, iW - 1, out=ix_bse)\n        torch.clamp(iy_bse, 0, iH - 1, out=iy_bse)\n        torch.clamp(iz_bse, 0, iD - 1, out=iz_bse)\n\n        torch.clamp(ix_tnw, 0, iW - 1, out=ix_tnw)\n        torch.clamp(iy_tnw, 0, iH - 1, out=iy_tnw)\n        torch.clamp(iz_tnw, 0, iD - 1, out=iz_tnw)\n\n        torch.clamp(ix_tne, 0, iW - 1, out=ix_tne)\n        torch.clamp(iy_tne, 0, iH - 1, out=iy_tne)\n        torch.clamp(iz_tne, 0, iD - 1, out=iz_tne)\n\n        torch.clamp(ix_tsw, 0, iW - 1, out=ix_tsw)\n        torch.clamp(iy_tsw, 0, iH - 1, out=iy_tsw)\n        torch.clamp(iz_tsw, 0, iD - 1, out=iz_tsw)\n\n        torch.clamp(ix_tse, 0, iW - 1, out=ix_tse)\n        torch.clamp(iy_tse, 0, iH - 1, out=iy_tse)\n        torch.clamp(iz_tse, 0, iD - 1, out=iz_tse)\n\n    feature_3d = feature_3d.reshape(N, C, -1)\n\n    # D H W, z y x\n    bnw_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_bnw * iH * iW + iy_bnw * iW + ix_bnw)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    bne_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_bne * iH * iW + iy_bne * iW + ix_bne)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    bsw_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_bsw * iH * iW + iy_bsw * iW + ix_bsw)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    bse_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_bse * iH * iW + iy_bse * iW + ix_bse)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n\n    tnw_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_tnw * iH * iW + iy_tnw * iW + ix_tnw)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    tne_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_tne * iH * iW + iy_tne * iW + ix_tne)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    tsw_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_tsw * iH * iW + iy_tsw * iW + ix_tsw)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n    tse_val = torch.gather(\n        feature_3d,\n        2,\n        (iz_tse * iH * iW + iy_tse * iW + ix_tse)\n        .long()\n        .view(N, 1, D * H * W)\n        .repeat(1, C, 1),\n    )\n\n    out_val = (\n        bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W)\n        + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W)\n        + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W)\n        + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)\n        + tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W)\n        + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W)\n        + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W)\n        + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W)\n    )\n\n    return out_val\n"
  },
  {
    "path": "environment-mac.yml",
    "content": "name: efm3d\nchannels:\n  - defaults\n  - conda-forge\n  - pytorch\ndependencies:\n  - python=3.9\n  - pytorch=2.3.0\n  - torchvision=0.18.0\n  - pip\n  - pip:\n    - omegaconf==2.3.0\n    - hydra-core==1.3.2\n    - webdataset==0.2.86\n    - vrs==1.2.1\n    - fsspec==2024.6.0\n    - einops==0.8.0\n    - pandas==2.2.2\n    - pyquaternion==0.9.9\n    - opencv-python==4.10.0.84\n    - tqdm==4.66.4\n    - matplotlib==3.9.0\n    - numpy==1.26.4\n    - moderngl==5.8.2\n    - trimesh==4.4.9\n    - scikit-image==0.24.0\n    - projectaria_tools\n"
  },
  {
    "path": "environment.yml",
    "content": "name: efm3d\nchannels:\n  - nvidia/label/cuda-12.1.1\n  - pytorch\n  - nvidia\n  - conda-forge\n  - defaults\ndependencies:\n  - ninja\n  - python=3.9\n  - pip\n  - cuda\n  - anaconda::cudnn\n  - gcc=12.1\n  - gxx=12.1\n  - numpy=1.26.4\n  - pytorch-cuda=12.1\n  - pytorch=2.3.0\n  - torchvision=0.18.0\n  - torchaudio=2.3.0\n  - pip:\n    - omegaconf==2.3.0\n    - hydra-core==1.3.2\n    - webdataset==0.2.86\n    - vrs==1.2.1\n    - fsspec==2024.6.0\n    - einops==0.8.0\n    - pandas==2.2.2\n    - pyquaternion==0.9.9\n    - opencv-python==4.10.0.84\n    - tqdm==4.66.4\n    - matplotlib==3.9.0\n    - moderngl==5.8.2\n    - trimesh==4.4.9\n    - scikit-image==0.24.0\n    - projectaria_tools==1.5.5\n    - projectaria-atek==1.0.0\n    - tensorboard==2.14.0\n    - torchmetrics==0.10.1\n    - git+https://github.com/facebookresearch/pytorch3d.git@V0.7.8\n"
  },
  {
    "path": "eval.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport json\nimport os\n\nfrom efm3d.inference.eval import obb_eval_dataset\nfrom efm3d.inference.pipeline import compute_avg_metrics, run_one\n\n\nASE_DATA_PATH = \"./data/ase_eval\"\nADT_DATA_PATH = \"./data/adt\"\nAEO_DATA_PATH = \"./data/aeo\"\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run EFM3D evaluation benchmark\")\n    parser.add_argument(\n        \"--num_seqs\",\n        type=int,\n        default=9999,\n        help=\"number of sequences to evaluate, by default evaluate all sequences\",\n    )\n    parser.add_argument(\n        \"--num_snips\",\n        type=int,\n        default=9999,\n        help=\"number of snippets per sequence, by default evaluate the full sequence\",\n    )\n    parser.add_argument(\n        \"--snip_stride\",\n        type=float,\n        default=0.1,\n        help=\"overlap between snippets in second, default to 0.1 (recommend to set it between 0.1-0.5), set it larger will make performance worse but run faster\",\n    )\n    parser.add_argument(\n        \"--voxel_res\",\n        type=float,\n        default=0.04,\n        help=\"voxel resolution in meter for volumetric fusion\",\n    )\n    parser.add_argument(\n        \"--model_ckpt\",\n        type=str,\n        default=\"./ckpt/model_release.pth\",\n        help=\"model checkpoint path\",\n    )\n    parser.add_argument(\n        \"--model_cfg\",\n        type=str,\n        default=\"./efm3d/config/evl_inf.yaml\",\n        help=\"model config file\",\n    )\n    parser.add_argument(\"--output_dir\", type=str, default=\"./output\", help=\"output dir\")\n    parser.add_argument(\n        \"--ase\", action=\"store_true\", help=\"Evaluate the model on ASE dataset\"\n    )\n    parser.add_argument(\n        \"--adt\", action=\"store_true\", help=\"Evaluate the model on ADT dataset\"\n    )\n    parser.add_argument(\n        \"--aeo\", action=\"store_true\", help=\"Evaluate the model on AEO dataset\"\n    )\n    args = parser.parse_args()\n\n    input_paths = []\n    if args.ase:\n        with open(\"./data/ase_splits.json\", \"r\") as f:\n            seq_ids = json.load(f)[\"test_sequences\"]\n            seq_ids = [seq.strip() for seq in seq_ids]\n        input_paths = [\n            os.path.join(ASE_DATA_PATH, seq.strip()) for seq in seq_ids[: args.num_seqs]\n        ]\n    elif args.adt:\n        with open(\"./data/adt_sequences.txt\", \"r\") as f:\n            seq_ids = f.readlines()\n            seq_ids = [seq.strip() for seq in seq_ids]\n        input_paths = [\n            os.path.join(ADT_DATA_PATH, seq.strip(), \"video.vrs\")\n            for seq in seq_ids[: args.num_seqs]\n        ]\n    elif args.aeo:\n        with open(\"./data/aeo_sequences.txt\", \"r\") as f:\n            seq_ids = f.readlines()\n            seq_ids = [seq.strip() for seq in seq_ids]\n        input_paths = [\n            os.path.join(AEO_DATA_PATH, seq.strip(), \"main.vrs\")\n            for seq in seq_ids[: args.num_seqs]\n        ]\n    else:\n        assert args.ase or args.adt or args.aeo, (\n            \"Specify eval dataset, for example, --ase\"\n        )\n\n    for input_path in input_paths:\n        run_one(\n            input_path,\n            args.model_ckpt,\n            model_cfg=args.model_cfg,\n            max_snip=args.num_snips,\n            snip_stride=args.snip_stride,\n            voxel_res=args.voxel_res,\n            output_dir=args.output_dir,\n        )\n\n    # aggregate results\n    if len(seq_ids) > 1:\n        dirs = []\n        model_name = os.path.splitext(os.path.basename(args.model_ckpt))[0]\n        output_dir = os.path.join(args.output_dir, model_name)\n        for seq_id in seq_ids:\n            seq_output_dir = os.path.join(output_dir, seq_id)\n            dirs.append(seq_output_dir)\n\n        metrics_paths = [os.path.join(folder, \"metrics.json\") for folder in dirs]\n        metrics_paths = [p for p in metrics_paths if os.path.exists(p)]\n        if len(metrics_paths) > 0:\n            avg_ret = compute_avg_metrics(metrics_paths)\n            print(\"==> mean results\")\n            print(json.dumps(avg_ret, indent=2, sort_keys=True))\n            with open(os.path.join(output_dir, \"mean_metrics.json\"), \"w\") as f:\n                json.dump(avg_ret, f, indent=2, sort_keys=True)\n\n            # aggregate mAP for 3D object detection\n            if args.ase or args.aeo:\n                joint_map = obb_eval_dataset(output_dir)\n                print(\"==> joint mAP\")\n                print(json.dumps(joint_map, indent=2, sort_keys=True))\n                with open(\n                    os.path.join(args.output_dir, \"joint_metrics.json\"), \"w\"\n                ) as f:\n                    json.dump(joint_map, f, indent=2, sort_keys=True)\n"
  },
  {
    "path": "infer.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport argparse\n\nfrom efm3d.inference.pipeline import run_one\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Run EVL model inference on Aria sequences\"\n    )\n    parser.add_argument(\"--input\", type=str, required=True, help=\"input data\")\n    parser.add_argument(\n        \"--model_ckpt\",\n        type=str,\n        default=\"./ckpt/model_release.pth\",\n        help=\"model checkpoint path\",\n    )\n    parser.add_argument(\n        \"--model_cfg\",\n        type=str,\n        default=\"./efm3d/config/evl_inf.yaml\",\n        help=\"model config file\",\n    )\n    parser.add_argument(\"--output_dir\", type=str, default=\"./output\", help=\"output dir\")\n    parser.add_argument(\n        \"--num_seqs\",\n        type=int,\n        default=9999,\n        help=\"number of sequences to evaluate, by default evaluate all sequences\",\n    )\n    parser.add_argument(\n        \"--num_snips\",\n        type=int,\n        default=9999,\n        help=\"number of snippets per sequence, by default evaluate the full sequence\",\n    )\n    parser.add_argument(\n        \"--snip_stride\",\n        type=float,\n        default=0.1,\n        help=\"overlap between snippets in second, default to 0.1 (recommend to set it between 0.1-0.5), set it larger will make performance worse but run faster\",\n    )\n    parser.add_argument(\n        \"--voxel_res\",\n        type=float,\n        default=0.04,\n        help=\"voxel resolution in meter for volumetric fusion\",\n    )\n    parser.add_argument(\n        \"--obb_only\",\n        action=\"store_true\",\n        help=\"only run OBB prediction, skip occupancy prediction and volume fusion for faster inference on long sequences\",\n    )\n    parser.add_argument(\n        \"--skip_video\",\n        action=\"store_true\",\n        help=\"skip video generation\",\n    )\n    parser.add_argument(\n        \"--skip_snips\",\n        type=int,\n        default=0,\n        help=\"skip the first N snippets\",\n    )\n    args = parser.parse_args()\n\n    run_one(\n        args.input,\n        args.model_ckpt,\n        model_cfg=args.model_cfg,\n        max_snip=args.num_snips,\n        snip_stride=args.snip_stride,\n        voxel_res=args.voxel_res,\n        output_dir=args.output_dir,\n        obb_only=args.obb_only,\n        skip_video=args.skip_video,\n        skip_snips=args.skip_snips,\n    )\n"
  },
  {
    "path": "prepare_inference.sh",
    "content": "#!/usr/bin/env bash\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nset -e\n\nif ! ls infer.py | grep -q \"infer.py\"; then\n  echo \"Error: Can't find infer.py under the current directory. Make sure to run this script under <EFM3D_DIR>\"\n  exit 1\nfi\n\n# download DinoV2 weights\nwget -O ckpt/dinov2_vitb14_reg4_pretrain.pth https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth\n\nif [ ! -f \"ckpt/evl_model_ckpt.zip\" ]; then\n  echo \"Error: File evl_model_ckpt.zip does not exist. Make sure it's put under EFM3D_DIR/ckpt\"\n  exit 1\nfi\n\n# model\ncd ckpt\nunzip evl_model_ckpt.zip\nmv evl_model_ckpt/*.pth .\nmv evl_model_ckpt/seq136_sample.zip ../data\nrmdir evl_model_ckpt\n\n# data\ncd ../data\nunzip seq136_sample.zip\nrm seq136_sample.zip\n\necho \"Done preparing for inference\"\n"
  },
  {
    "path": "requirements-extra.txt",
    "content": "projectaria-atek\ngit+https://github.com/facebookresearch/pytorch3d.git@V0.7.8\ntensorboard==2.14.0\ntorchmetrics==0.10.1\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==2.3.0\ntorchvision==0.18.0\nomegaconf==2.3.0\nhydra-core==1.3.2\nwebdataset==0.2.86\nvrs==1.2.1\nfsspec==2024.6.0\neinops==0.8.0\npandas==2.2.2\npyquaternion==0.9.9\nopencv-python==4.10.0.84\ntqdm==4.66.4\nmatplotlib==3.9.0\nnumpy==1.26.4\nmoderngl==5.8.2\ntrimesh==4.4.9\nscikit-image==0.24.0\nprojectaria_tools\n"
  },
  {
    "path": "sbatch_run.sh",
    "content": "#!/bin/bash\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#SBATCH --job-name=efm3d_multinode\n#SBATCH --nodes=2\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:8\n#SBATCH --cpus-per-task=96\nexport NCCL_DEBUG=INFO\nexport PYTHONFAULTHANDLER=1\n\nnodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )\nnodes_array=($nodes)\nhead_node=${nodes_array[0]}\nhead_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\necho Node IP: $head_node_ip\nexport LOGLEVEL=INFO\n\nsrun torchrun \\\n--nnodes 2 \\\n--nproc_per_node 8 \\\n--rdzv_id $RANDOM \\\n--rdzv_backend c10d \\\n--rdzv_endpoint $head_node_ip:29600 \\\ntrain.py\n"
  },
  {
    "path": "train.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\n# train with a single gpu\npython train.py\n\n# train with 8 gpus\ntorchrun --standalone --nproc_per_node=8 train.py\n\n# train with multi-node multi-gpu, run\nsbatch sbatch_run.sh\n\"\"\"\n\nimport math\nimport os\nimport random\nimport shutil\nimport time\nfrom datetime import datetime\n\nimport hydra\nimport omegaconf\nimport torch\nimport torch.distributed as dist\nimport tqdm\nimport webdataset as wds\nimport yaml\nfrom efm3d.aria.tensor_wrapper import custom_collate_fn\nfrom efm3d.dataset.augmentation import ColorJitter, PointDropSimple, PointJitter\nfrom efm3d.dataset.efm_model_adaptor import load_atek_wds_dataset_as_efm_train\nfrom efm3d.dataset.vrs_dataset import preprocess\nfrom efm3d.dataset.wds_dataset import get_tar_sample_num\nfrom torch.distributed import destroy_process_group, init_process_group\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.tensorboard import SummaryWriter\n\n\nDATA_PATH = \"./data/ase_train\"\nMAX_LR = 2e-4\nMIN_LR = MAX_LR * 0.1\nBATCH_SIZE = 2\nMAX_EPOCHS = 40\nMAX_SAMPLES_PER_EPOCH = 100000\nSAVE_EVERY_EPOCHS = 5  # save the model every\nLOG_STEP = 5  # print error every\n\n\ndef get_lr(it, warmup_its, max_its, max_lr, min_lr):\n    \"\"\"\n    cosine learning rate scheduler, `it` can be either step or epoch.\n    \"\"\"\n    # learning rate scheduler\n    # linear warmup for warmup_epochs\n    if it < warmup_its:\n        return max_lr * (it + 1) / warmup_its\n\n    # return min_lr if epoch > max_epochs\n    if it > max_its:\n        return min_lr\n\n    # cosine annealing\n    decay_ratio = (it - warmup_its) / (max_its - warmup_its)\n    assert 0 <= decay_ratio <= 1\n    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # 1.0 -> 0.0\n    return min_lr + coeff * (max_lr - min_lr)\n\n\ndef get_dataloader(\n    data_path,\n    batch_size,\n    world_size,\n    max_samples_per_epoch,\n    epoch_sample_ratio=1.0,\n    tar_yaml=\"train_tars.yaml\",\n):\n    assert epoch_sample_ratio > 0 and epoch_sample_ratio <= 1.0, (\n        f\"{epoch_sample_ratio} is the ratio ([0, 1]) of samples used in each epoch\"\n    )\n\n    tar_yaml = os.path.join(data_path, tar_yaml)\n    with open(tar_yaml, \"r\") as f:\n        tar_list = yaml.safe_load(f)[\"tars\"]\n    tar_list = [os.path.join(data_path, tar_name) for tar_name in tar_list]\n\n    # check existence\n    for tar in tar_list:\n        assert os.path.exists(tar), f\"{tar} not exists\"\n    random.shuffle(tar_list)\n    dataset = load_atek_wds_dataset_as_efm_train(\n        urls=tar_list,\n        atek_to_efm_taxonomy_mapping_file=f\"{os.path.dirname(__file__)}/efm3d/config/taxonomy/atek_to_efm.csv\",\n        batch_size=batch_size,\n        collation_fn=custom_collate_fn,\n    )\n\n    samples_per_tar = get_tar_sample_num(tar_list[0])\n    dataset_size = len(tar_list) * samples_per_tar\n    dataset_size = min(dataset_size, max_samples_per_epoch)\n    dataset_size = int(dataset_size * epoch_sample_ratio)\n\n    batches_per_epoch = int(dataset_size // (batch_size * world_size))\n    dataloader = wds.WebLoader(\n        dataset,\n        num_workers=batch_size,\n        pin_memory=True,\n        prefetch_factor=2,\n        batch_size=None,\n        shuffle=False,\n    )\n    dataloader = dataloader.with_epoch(batches_per_epoch)\n    dataloader = dataloader.with_length(batches_per_epoch)\n\n    return dataloader\n\n\nddp = int(os.environ.get(\"RANK\", -1)) != -1\nif ddp:\n    assert torch.cuda.is_available()\n    init_process_group(\"nccl\")\n    DDP_RANK = int(os.environ[\"RANK\"])\n    DDP_LOCAL_RANK = int(os.environ[\"LOCAL_RANK\"])\n    DDP_WORLD_SIZE = int(os.environ[\"WORLD_SIZE\"])\n    device = f\"cuda:{DDP_LOCAL_RANK}\"\n    print(f\"==> setting device to {device}\")\n    torch.cuda.set_device(device)\n    master_process = DDP_RANK == 0\nelse:\n    DDP_RANK = 0\n    DDP_LOCAL_RANK = 0\n    DDP_WORLD_SIZE = 1\n    master_process = True\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\ntorch.manual_seed(42)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed(42)\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nmodel_config = omegaconf.OmegaConf.load(\"efm3d/config/evl_train.yaml\")\nmodel = hydra.utils.instantiate(model_config)\nmodel = model\nmodel.to(device)\nif ddp:\n    model = DDP(model, device_ids=[DDP_LOCAL_RANK])\nraw_model = model.module if ddp else model\n\ntrain_dataloader = get_dataloader(\n    DATA_PATH,\n    BATCH_SIZE,\n    DDP_WORLD_SIZE,\n    max_samples_per_epoch=MAX_SAMPLES_PER_EPOCH,\n    tar_yaml=\"train_tars.yaml\",\n)\nval_dataloader = get_dataloader(\n    DATA_PATH,\n    BATCH_SIZE,\n    DDP_WORLD_SIZE,\n    max_samples_per_epoch=MAX_SAMPLES_PER_EPOCH,\n    tar_yaml=\"val_tars.yaml\",\n)\noptimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR)\n\nif master_process:\n    exp_name = f\"efm3d_train_b{BATCH_SIZE}g{DDP_WORLD_SIZE}e{MAX_EPOCHS}lr{str(MAX_LR)}_{datetime.fromtimestamp(time.time()).strftime('%y-%m-%d-%H-%M-%S')}\"\n    log_dir = os.path.join(\"tb_logs\", exp_name)\n    writer = SummaryWriter(log_dir=log_dir)\n\ncolor_jitter = ColorJitter(\n    brightness=0.5,\n    contrast=0.3,\n    saturation=0.3,\n    hue=0.05,\n    sharpness=2.0,\n    snippet_jitter=True,\n)\npoint_drop = PointDropSimple(max_dropout_rate=0.8)\npoint_jitter = PointJitter(depth_std_scale_min=1.0, depth_std_scale_max=6.0)\naugmentations = [color_jitter, point_drop, point_jitter]\n\nstep = 0\nval_step = 0\n# main loop\nfor epoch in range(MAX_EPOCHS):\n    # train\n    model.train()\n    for batch in tqdm.tqdm(train_dataloader):\n        start = time.time()\n        optimizer.zero_grad()\n\n        batch = preprocess(batch, device, aug_funcs=augmentations)\n        output = model(batch)\n        losses, total_loss = raw_model.compute_losses(output, batch)\n\n        total_loss.backward()\n\n        # epoch-based lr scheduler\n        lr = get_lr(\n            epoch, warmup_its=5, max_its=MAX_EPOCHS, max_lr=MAX_LR, min_lr=MIN_LR\n        )\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n        if ddp:\n            dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)\n        max_norm = 1.0\n        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)\n\n        optimizer.step()\n        time_per_it = time.time() - start\n\n        if master_process and step % LOG_STEP == 0:\n            print(\n                f\"E:s-{epoch}:{step} | loss {total_loss.item():.03f} | lr {lr:.06f} | norm {norm} | time {time_per_it:.02f}s/it\"\n            )\n\n            # log training\n            writer.add_scalar(\"train/loss\", total_loss.item(), step)\n            for stream in losses:\n                for loss_name in losses[stream]:\n                    writer.add_scalar(\n                        f\"train/loss/{stream}/{loss_name}\",\n                        losses[stream][loss_name].item(),\n                        step,\n                    )\n            writer.add_scalar(\"train/lr\", lr, step)\n            writer.add_scalar(\"train/iter_sec\", time_per_it, step)\n\n            # log images (log every `10xlog_step` since writing video is slow)\n            if step % (10 * LOG_STEP) == 0:\n                imgs = raw_model.log_single(batch, output, batch_idx=0)\n                for k, v in imgs.items():\n                    vid = torch.tensor(v.transpose((0, 3, 1, 2))).unsqueeze(0)\n                    writer.add_video(f\"train/{k}\", vid, global_step=step, fps=10)\n        step += 1\n\n    # val\n    model.eval()\n    for batch in tqdm.tqdm(val_dataloader):\n        with torch.no_grad():\n            start = time.time()\n            batch = preprocess(batch, device, aug_funcs=augmentations)\n            output = model(batch)\n            losses, total_loss = raw_model.compute_losses(output, batch)\n            if ddp:\n                dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)\n            time_per_it = time.time() - start\n\n        if master_process and val_step % LOG_STEP == 0:\n            print(\n                f\"E:s-{epoch}:{val_step} | loss {total_loss.item():.03f} | time {time_per_it:.02f}s/it\"\n            )\n\n            # log val\n            if val_step % LOG_STEP == 0:\n                writer.add_scalar(\"val/loss\", total_loss.item(), val_step)\n                for stream in losses:\n                    for loss_name in losses[stream]:\n                        writer.add_scalar(\n                            f\"val/loss/{stream}/{loss_name}\",\n                            losses[stream][loss_name].item(),\n                            val_step,\n                        )\n                writer.add_scalar(\"val/iter_sec\", time_per_it, val_step)\n\n            # log images\n            if val_step % (10 * LOG_STEP) == 0:\n                imgs = raw_model.log_single(batch, output, batch_idx=0)\n                for k, v in imgs.items():\n                    vid = torch.tensor(v.transpose((0, 3, 1, 2))).unsqueeze(0)\n                    writer.add_video(f\"val/{k}\", vid, global_step=val_step, fps=10)\n        val_step += 1\n\n    # save model\n    if master_process and (epoch + 1) % SAVE_EVERY_EPOCHS == 0:\n        ckpt_path = os.path.join(\n            log_dir, f\"model_e{epoch}s{step}_l{total_loss.item():.02f}.pth\"\n        )\n        last_ckpt_path = os.path.join(log_dir, \"last.pth\")\n        torch.save(\n            {\"state_dict\": raw_model.state_dict(), \"optimizer\": optimizer.state_dict()},\n            ckpt_path,\n        )\n        shutil.copy(ckpt_path, last_ckpt_path)\n\nif master_process:\n    writer.close()\nif ddp:\n    destroy_process_group()\n"
  }
]