Full Code of facebookresearch/efm3d for AI

main 07950d73d147 cached
92 files
848.9 KB
239.5k tokens
785 symbols
1 requests
Download .txt
Showing preview only (885K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/efm3d
Branch: main
Commit: 07950d73d147
Files: 92
Total size: 848.9 KB

Directory structure:
gitextract_3a354cf0/

├── .github/
│   ├── CODE_OF_CONDUCT.md
│   ├── CONTRIBUTING.md
│   └── workflows/
│       └── conda-env.yaml
├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── benchmark.md
├── efm3d/
│   ├── __init__.py
│   ├── aria/
│   │   ├── __init__.py
│   │   ├── aria_constants.py
│   │   ├── camera.py
│   │   ├── obb.py
│   │   ├── pose.py
│   │   ├── projection_utils.py
│   │   └── tensor_wrapper.py
│   ├── config/
│   │   ├── efm_preprocessing_conf.yaml
│   │   ├── evl_inf.yaml
│   │   ├── evl_inf_desktop.yaml
│   │   ├── evl_train.yaml
│   │   └── taxonomy/
│   │       ├── aeo_to_efm.csv
│   │       ├── ase_sem_name_to_id.csv
│   │       └── atek_to_efm.csv
│   ├── dataset/
│   │   ├── atek_vrs_dataset.py
│   │   ├── atek_wds_dataset.py
│   │   ├── augmentation.py
│   │   ├── efm_model_adaptor.py
│   │   ├── vrs_dataset.py
│   │   └── wds_dataset.py
│   ├── inference/
│   │   ├── __init__.py
│   │   ├── eval.py
│   │   ├── fuse.py
│   │   ├── model.py
│   │   ├── pipeline.py
│   │   ├── track.py
│   │   └── viz.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── cnn.py
│   │   ├── dinov2_utils.py
│   │   ├── dpt.py
│   │   ├── evl.py
│   │   ├── evl_train.py
│   │   ├── image_tokenizer.py
│   │   ├── lifter.py
│   │   └── video_backbone.py
│   ├── thirdparty/
│   │   ├── __init__.py
│   │   └── mmdetection3d/
│   │       ├── LICENSE
│   │       ├── __init__.py
│   │       ├── cuda/
│   │       │   ├── cuda_utils.h
│   │       │   ├── iou3d.cpp
│   │       │   ├── iou3d.h
│   │       │   ├── iou3d_kernel.cu
│   │       │   ├── setup.py
│   │       │   ├── sort_vert.cpp
│   │       │   ├── sort_vert.h
│   │       │   ├── sort_vert_kernel.cu
│   │       │   └── utils.h
│   │       └── iou3d.py
│   └── utils/
│       ├── __init__.py
│       ├── common.py
│       ├── depth.py
│       ├── detection_utils.py
│       ├── evl_loss.py
│       ├── file_utils.py
│       ├── gravity.py
│       ├── image.py
│       ├── image_sampling.py
│       ├── marching_cubes.py
│       ├── mesh_utils.py
│       ├── obb_csv_writer.py
│       ├── obb_io.py
│       ├── obb_matchers.py
│       ├── obb_metrics.py
│       ├── obb_trackers.py
│       ├── obb_utils.py
│       ├── pointcloud.py
│       ├── ray.py
│       ├── reconstruction.py
│       ├── render.py
│       ├── rescale.py
│       ├── viz.py
│       ├── voxel.py
│       └── voxel_sampling.py
├── environment-mac.yml
├── environment.yml
├── eval.py
├── infer.py
├── prepare_inference.sh
├── requirements-extra.txt
├── requirements.txt
├── sbatch_run.sh
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
  advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
  address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: .github/CONTRIBUTING.md
================================================
# Contributing to "efm3d"

We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests

We welcome pull requests.

1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation in the code.
4. Ensure the test suite passes.
5. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")

In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues

We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License

By contributing to "efm3d", you agree that your contributions will be licensed under
the [LICENSE](../LICENSE) file in the root directory of this source tree.


================================================
FILE: .github/workflows/conda-env.yaml
================================================
name: Conda Environment CI

on:
  push:
    branches:
      - main
  pull_request:
    branches:
      - main

jobs:
  test:
    name: Test conda env
    runs-on: "ubuntu-latest"
    defaults:
      run:
        shell: bash -el {0}
    steps:
      - uses: actions/checkout@v4
      - uses: conda-incubator/setup-miniconda@v3
        with:
          activate-environment: efm3d
          environment-file: environment-mac.yml
          python-version: 3.9
          auto-activate-base: false
      - run: |
          conda info
          conda list
          conda activate efm3d
          pip install -r requirements.txt


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
dist/
build/
eggs/
.eggs/
*.egg-info/
lib/
lib64/

# PyTorch specific
*.pt
*.pth
*.ckpt
*.tfevents
.ipynb_checkpoints/

# Environment
.env
venv/
ENV/

# IDEs
.vscode/
.idea/

# Miscellaneous
.DS_Store
Thumbs.db

# artifacts
*.mp4

# data
*.ply
data/
tb_logs/

# model weights
ckpt/

# output dir
*.out
output/
# tensoboard output
runs/


================================================
FILE: INSTALL.md
================================================
# Installation

We provide two ways to install the dependencies of EFM3D. We recommend using miniconda to manage the dependencies, which
also provide a easy setup to for all the additional dependencies listed in `requirements.txt` and `requirements-extra.txt`.

## Install using conda (recommended)

First install [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install),
then run the following commands under the `<EFM3D_DIR>` root directory

```
conda env create --file=environment.yml
conda activate efm3d

cd efm3d/thirdparty/mmdetection3d/cuda/
python setup.py install
```

The commands will first create a conda environment named `efm3d`, and then build the
third-party CUDA kernel required for training.

## Install via pip

Make sure you have
Python>=3.9, then install the dependencies using `pip`.
The packages in `requirements.txt` are needed for the basic functionalities of
EFM3D, such as running the example model inference to see 3D object detection
and surface reconstruction on a [vrs](https://facebookresearch.github.io/vrs/)
sequence.

```
pip install -r requirements.txt
```

Additional dependencies in `requirements-extra.txt` are needed for training and eval.

```
pip install -r requirements-extra.txt
```

**Important**: For training, we also need to built a CUDA kernel from
[mmdetection3d](https://github.com/open-mmlab/mmdetection3d). Compile the CUDA
kernel of the IoU3d loss by running the following commands, which requires the
installation of
[CUDA dev toolkit](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/).

```
cd efm3d/thirdparty/mmdetection3d/cuda/
python setup.py install
```


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models

[[paper](https://arxiv.org/abs/2406.10224)]
[[website](https://www.projectaria.com/research/efm3D/)]

## Intro

This is the official release for the paper EFM3D: A Benchmark for Measuring
Progress Towards 3D Egocentric Foundation Models
(https://arxiv.org/abs/2406.10224). To measure progress on what we term
Egocentric Foundation Models (EFMs) we establish EFM3D, a benchmark with two
core 3D egocentric perception tasks. EFM3D is the first benchmark for 3D object
detection and surface regression on high quality annotated egocentric data of
[Project Aria](https://www.projectaria.com/). We also propose Egocentric Voxel
Lifting (EVL), a baseline for 3D EFMs.

<img src="assets/efm3d.png">

We provide the following code and assets

- The pretrained EVL model weights for surface reconstruction and 3D object
  detection on Aria sequences
- The datasets included in the EFM3D benchmark, including the training and
  evaluation data for Aria Synthetic Datasets (ASE), Aria Everyday Objects (AEO)
  for 3D object detection, and the eval mesh models for surface reconstruction
  evaluation.
- Distributed training code to train EVL.
- Native integration with
  [Aria Training and Evaluation Kit (ATEK)](https://github.com/facebookresearch/atek).

The following serves as a minimal example to run the model inference, including
installation guide, data downloading instructions and how to run the inference
code.

## Installation

**Option 1**: First navigate to the root folder. The core library is written in
PyTorch, with additional dependencies listed in `requirements.txt`. This needs
Python>=3.9

```
pip install -r requirements.txt
```

**Option 2**: You can choose to use conda to manage the dependencies.
We recommend using [miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install) for its fast dependency solver.
The runtime dependencies can be installed by running (replace `environment.yaml` with `environment-mac.yml` if run on macOS)

```
conda env create --file=environment.yml
conda activate efm3d
```

This should be sufficient to initiate the use of the EVL model inference with
the pretrained model weights. please refer to [INSTALL.md](INSTALL.md) for a
full installation, which is required for training and eval.

## Inference

### Pretrained models

Download the pretrained model weights and a sample data on the
[EFM3D](https://www.projectaria.com/research/efm3D/#download-dataset) page
(email required). We provide two model checkpoints, one for server-side GPU
(>20GB GPU memory) and one for desktop GPU. There is a sample sequence attached
to the model weights to facilitate using the model. Check out the
[README.md](ckpt/README.md) for detailed instructions on how to download the
model weights.

### Run on the sample data

After downloading the model weights `evl_model_ckpt.zip`, put it under
`${EFM3D_DIR}/ckpt/`, then run the command under `${EFM3D_DIR}`

```
sh prepare_inference.sh
```

This will unzip the file, make sure the model weights and sample data are put
under the right paths. To run inference on the sample sequence

```
python infer.py --input ./data/seq136_sample/video.vrs
```

**Note**: the pretrained model requires ~20GB GPU memory. Use the following
command to run the model on a desktop GPU with ~10GB memory (tested on
RTX-3080). The performance is downgraded a bit.

```
python infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08
```

### Run on macOS

The inference demo works on macOS too. Use the following command (tested on
Apple M1 MAX 64GB memory)

```
PYTORCH_ENABLE_MPS_FALLBACK=1 python infer.py --input ./data/seq136_sample/video.vrs --model_ckpt ./ckpt/model_lite.pth --model_cfg ./efm3d/config/evl_inf_desktop.yaml --voxel_res 0.08
```

This wraps up the basic usage of EVL model. To train the model from scratch and
use the EFM3D benchmark, have a full installation following
[INSTALL.md](INSTALL.md) then read below

### Inference with ATEK

The inference also supports taking
[ATEK-format](https://github.com/facebookresearch/atek) WDS sequences. First
download a test ASE sequence following the `ASE eval data` section in
[README.md](data/README.md), then run

```
python infer.py --input ./data/ase_eval/81022
```

## Datasets

See [README.md](data/README.md) for instructions to work with all datasets
included in the EFM3D benchmark. There are three datasets in the EFM3D benchmark

- [Aria Synthetic Environments (ASE)](https://www.projectaria.com/datasets/ase/):
  for training and eval on 3D object detection and surface reconstruction
- [Aria Digital Twin (ADT)](https://www.projectaria.com/datasets/adt/): for eval
  on surface reconstruction
- [Aria Everyday Objects (AEO)](https://www.projectaria.com/datasets/aeo/): for
  eval on 3D object detection.

## Train EVL

First make sure you have a full installation (see [INSTALL.md](INSTALL.md)).
Train the EVL model from scratch requires downloading the full ASE training data
You can download a small subset of ASE sequences (>10 sequences) to test the
training script. Check out the `ASE training data` section in
[data/README.md](data/README.md). After following the instructions to prepare
the data, run the following command.

- train the EVL model from scratch on a single GPU

```
python train.py
```

- train with 8 GPUs

```
torchrun --standalone --nproc_per_node=8 train.py
```

We also provide a script to train on multi-node multi-gpu environment via
[slurm](https://slurm.schedmd.com/documentation.html). The pretrained model is
trained on 2 nodes with 8xH100.

- train with multi-node multi-gpu using slurm

```
sbatch sbatch_run.sh
```

By default the tensorboard log is saved to `${EFM3D_DIR}/tb_logs`.

## EFM3D benchmark

Please see [benchmark.md](benchmark.md) for details.

## Citing EFM3D

If you find EFM3D useful, please consider citing

```
@article{straub2024efm3d,
  title={EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models},
  author={Straub, Julian and DeTone, Daniel and Shen, Tianwei and Yang, Nan and Sweeney, Chris and Newcombe, Richard},
  journal={arXiv preprint arXiv:2406.10224},
  year={2024}
}
```

If you use Aria Digital Twin (ADT) dataset in the EFM3D benchmark, please
consider citing

```
@inproceedings{pan2023aria,
  title={Aria digital twin: A new benchmark dataset for egocentric 3d machine perception},
  author={Pan, Xiaqing and Charron, Nicholas and Yang, Yongqian and Peters, Scott and Whelan, Thomas and Kong, Chen and Parkhi, Omkar and Newcombe, Richard and Ren, Yuheng Carl},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={20133--20143},
  year={2023}
}
```

If you use the Aria Synthetic Environments (ASE) dataset in the EFM3D benchmark,
please consider citing

```
@article{avetisyan2024scenescript,
  title={SceneScript: Reconstructing Scenes With An Autoregressive Structured Language Model},
  author={Avetisyan, Armen and Xie, Christopher and Howard-Jenkins, Henry and Yang, Tsun-Yi and Aroudj, Samir and Patra, Suvam and Zhang, Fuyang and Frost, Duncan and Holland, Luke and Orme, Campbell and others},
  journal={arXiv preprint arXiv:2403.13064},
  year={2024}
}
```

## How to Contribute

We welcome contributions! Go to [CONTRIBUTING](./.github/CONTRIBUTING.md) and
our [CODE OF CONDUCT](./.github/CODE_OF_CONDUCT.md) for how to get started.

## License

EFM3D is released by Meta under the [Apache 2.0 license](LICENSE).


================================================
FILE: benchmark.md
================================================
## EFM3D Benchmark

We provide three evaluation datasets for the EFM3D benchmarks. For more details on the benchmark see the [EFM3D](https://arxiv.org/abs/2406.10224) paper.

### ASE - 3D object detection and mesh reconstruction
Aria Synthetic Environments (ASE) is a synthetic dataset, created from procedurally-generated interior layouts filled with 3D objects, simulated with the sensor characteristics of Aria glasses. We use ASE for both surface reconstruction and 3D object detection evaluation.

First follow instructions in the dataset [README.md](data/README.md) to download the ASE eval set and eval meshes, then run the following

```
python eval.py --ase
```

Running the full evaluation on 100 eval sequences takes a long time (>10 hrs on a single GPU). To see the eval results on a reduced set, use `--num_seqs` and `--num_snips` to specify number of sequences and number of snippets per sequence to speed up evaluation, for example

```
# run eval on the first 10 sequences of ASE, each running for 100 snippets (10s)
python eval.py --ase --num_seqs 10 --num_snips 100
```

### ADT - mesh reconstruction
ADT is the benchmark data for surface reconstruction, containing 6 sequences.
Download the ADT data and mesh files following the data instruction. Then run

```
python eval.py --adt
```

The provided script provides an end-to-end solution to run EVL model with the default checkpoint,
finding the right GT mesh path for ASE and ADT dataset, then run the evaluation metrics for mesh-to-mesh distance.
If you have your own model that generates a ply file, check [eval_mesh_to_mesh](efm3d/utils/mesh_utils.py) for how to evaluate against surface GT directly.

### AEO - 3D object detection
AEO is the benchmark data for 3D object detection, with 25 sequences.
Download the AEO dataset following the data instruction. Then run

```
python eval.py --aeo
```

This will run the EVL model inference using the default model checkpoint path.
If you have your own model for inference, check [eval.py](efm3d/inference/eval.py) for how to evaluate against 3D object GT directly.


================================================
FILE: efm3d/__init__.py
================================================


================================================
FILE: efm3d/aria/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .camera import CameraTW, DEFAULT_CAM_DATA_SIZE
from .obb import ObbTW, transform_obbs
from .pose import PoseTW
from .tensor_wrapper import smart_cat, smart_stack, TensorWrapper


================================================
FILE: efm3d/aria/aria_constants.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# High level organization of the constants:
# - */time_ns is timestamp with respect to the aria clock in nanoseconds stored as torch.long()
# - */snippet_time_s is the timestamp with respect to the start of the snippet in seconds stored as torch.float32()
# - */t_A_B is a pose transformation from coordinate system B to A
# - path-like key strings designate hierarchical relationships of data. I.e.
#   rgb/img/... is all data relating to the rgb image information. rgb/calib/...
#   is all about the calibration data. And all rgb/... is data relating to the
#   rgb video stream.

# ---------------------------------------------------------------------
# sequence level information
# ---------------------------------------------------------------------
ARIA_SEQ_ID = "sequence/id"
# start of the sequence in ns relative to global Aria timestamp
ARIA_SEQ_TIME_NS = "sequence/time_ns"

# ---------------------------------------------------------------------
# snippet level information
# ---------------------------------------------------------------------
ARIA_SNIPPET_ID = "snippet/id_in_sequence"
ARIA_SNIPPET_LENGTH_S = "snippet/length_s"
# start of sequence in ns relative to global Aria timestamp (sometimes unix 0)
ARIA_SNIPPET_TIME_NS = "snippet/time_ns"
# offset of snippet coordinate system to sequence coordinate system
ARIA_SNIPPET_T_WORLD_SNIPPET = "snippet/t_world_snippet"
# Ratio of where in the snippet is the origin of cosy relative to the
# snippet length. E.g. 0.5 for a 10 sec snippet would mean that 5 sec is origin,
# was previously known as "frame_selection" in LocalCosyPreprocessor.
ARIA_SNIPPET_ORIGIN_RATIO = "snippet/origin_ratio"

# ---------------------------------------------------------------------
# streamer playback time information
# ---------------------------------------------------------------------
ARIA_PLAY_TIME_NS = "play/time_ns"
ARIA_PLAY_SEQUENCE_TIME_S = "play/sequence_time_s"
ARIA_PLAY_SNIPPET_TIME_S = "play/snippet_time_s"
ARIA_PLAY_FREQUENCY_HZ = "play/hz"

# ---------------------------------------------------------------------
# aria video stream information
# ---------------------------------------------------------------------
# frame id in the sequence
ARIA_FRAME_ID = [
    "rgb/frame_id_in_sequence",
    "slaml/frame_id_in_sequence",
    "slamr/frame_id_in_sequence",
]
# timestamp within snippet
ARIA_IMG_SNIPPET_TIME_S = [
    "rgb/img/snippet_time_s",
    "slaml/img/snippet_time_s",
    "slamr/img/snippet_time_s",
]
# timestamp within sequence
ARIA_IMG_TIME_NS = [
    "rgb/img/time_ns",
    "slaml/img/time_ns",
    "slamr/img/time_ns",
]
# poses of the rig at the time of the respective frame capture
# T x 12
ARIA_IMG_T_SNIPPET_RIG = [
    "rgb/t_snippet_rig",
    "slaml/t_snippet_rig",
    "slamr/t_snippet_rig",
]
# image tensors
ARIA_IMG = ["rgb/img", "slaml/img", "slamr/img"]
ARIA_IMG_FREQUENCY_HZ = [
    "rgb/img/hz",
    "slaml/img/hz",
    "slamr/img/hz",
]

# ---------------------------------------------------------------------
# calibration information
# ---------------------------------------------------------------------
ARIA_CALIB = [
    "rgb/calib",
    "slaml/calib",
    "slamr/calib",
]
# timestamp within the snippet
ARIA_CALIB_SNIPPET_TIME_S = [
    "rgb/calib/snippet_time_s",
    "slaml/calib/snippet_time_s",
    "slamr/calib/snippet_time_s",
]
# timestamp within the sequence
ARIA_CALIB_TIME_NS = [
    "rgb/calib/time_ns",
    "slaml/calib/time_ns",
    "slamr/calib/time_ns",
]

# ---------------------------------------------------------------------
# pose information
# ---------------------------------------------------------------------
# pose timestamp within snippet
ARIA_POSE_SNIPPET_TIME_S = "pose/snippet_time_s"
# pose timestamp within sequence
ARIA_POSE_TIME_NS = "pose/time_ns"
# transformation from rig to snippet coordinate system
ARIA_POSE_T_SNIPPET_RIG = "pose/t_snippet_rig"
# transformation from rig to world coordinate system
ARIA_POSE_T_WORLD_RIG = "pose/t_world_rig"
# frequency of poses
ARIA_POSE_FREQUENCY_HZ = "pose/hz"

# ---------------------------------------------------------------------
# semidense points information
# ---------------------------------------------------------------------
ARIA_POINTS_WORLD = "points/p3s_world"
ARIA_POINTS_TIME_NS = "points/time_ns"
ARIA_POINTS_SNIPPET_TIME_S = "points/snippet_time_s"
ARIA_POINTS_FREQUENCY_HZ = "points/hz"
ARIA_POINTS_INV_DIST_STD = "points/inv_dist_std"
ARIA_POINTS_DIST_STD = "points/dist_std"

# ---------------------------------------------------------------------
# imu information
# ---------------------------------------------------------------------
ARIA_IMU = ["imur", "imul"]
ARIA_IMU_CHANNELS = [
    ["imur/lin_acc_ms2", "imur/rot_vel_rads"],
    ["imul/lin_acc_ms2", "imul/rot_vel_rads"],
]
ARIA_IMU_SNIPPET_TIME_S = ["imur/snippet_time_s", "imul/snippet_time_s"]
ARIA_IMU_TIME_NS = ["imur/time_ns", "imul/time_ns"]
ARIA_IMU_FACTORY_CALIB = ["imur/factory_calib", "imul/factory_calib"]
ARIA_IMU_FREQUENCY_HZ = ["imur/hz", "imul/hz"]

# ---------------------------------------------------------------------
# audio data
# ---------------------------------------------------------------------
ARIA_AUDIO = "audio"
# snippet time within snippet of audio sample
ARIA_AUDIO_SNIPPET_TIME_S = "audio/snippet_time_s"
# timestamp of audio sample in sequence
ARIA_AUDIO_TIME_NS = "audio/time_ns"
# frequency of audio sample
ARIA_AUDIO_FREQUENCY_HZ = "audio/hz"

# ---------------------------------------------------------------------
# OBB
# ---------------------------------------------------------------------
# padded ObbTW tensor for oriented object bounding boxes given in *snippet coordinate system*
ARIA_OBB_PADDED = "obbs/padded_snippet"
# mapping of semantic id of the obb to a string name
ARIA_OBB_SEM_ID_TO_NAME = "obbs/sem_id_to_name"
# snippet time within the sequence
ARIA_OBB_SNIPPET_TIME_S = "obbs/snippet_time_s"
# timestamp within the sequence
ARIA_OBB_TIME_NS = "obbs/time_ns"
# frequency of object detection information
ARIA_OBB_FREQUENCY_HZ = "obbs/hz"

# predicted ObbTW tensor for oriented object bounding boxes
ARIA_OBB_PRED = "obbs/pred"  # raw predictions from the networks.
ARIA_OBB_PRED_VIZ = "obbs/pred_viz"  # predictions for visualization (e.g. raw predictions filtered by some criteria.)
ARIA_OBB_PRED_SEM_ID_TO_NAME = "obbs/pred/sem_id_to_name"
ARIA_OBB_PRED_PROBS_FULL = "obbs/pred/probs_full"
ARIA_OBB_PRED_PROBS_FULL_VIZ = "obbs/pred/probs_ful_viz"
# tracked ObbTW tensor for oriented object bounding boxes
ARIA_OBB_TRACKED = "obbs/tracked"
ARIA_OBB_TRACKED_PROBS_FULL = "obbs/tracked/probs_full"
# tracked but not instantiated ObbTW tensor for oriented object bounding boxes
ARIA_OBB_UNINST = "obbs/uninst"

ARIA_OBB_BB2 = ["bb2s_rgb", "bb2s_slaml", "bb2s_slamr"]
ARIA_OBB_BB3 = "bb3s_object"

# ---------------------------------------------------------------------
# depth information
# ---------------------------------------------------------------------
# for depth images (z-depth) in meters
ARIA_DEPTH_M = ["rgb/depth_m", "slaml/depth_m", "slamr/depth_m"]
# for distance images (distance along ray) in meters
ARIA_DISTANCE_M = ["rgb/distance_m", "slaml/distance_m", "slamr/distance_m"]
ARIA_DEPTH_TIME_NS = [
    "rgb/depth/time_ns",
    "slaml/depth/time_ns",
    "slamr/depth/time_ns",
]
ARIA_DEPTH_SNIPPET_TIME_S = [
    "rgb/depth/snippet_time_s",
    "slaml/depth/snippet_time_s",
    "slamr/depth/snippet_time_s",
]

ARIA_DEPTH_M_PRED = ["rgb/pred/depth_m", "slaml/pred/depth_m", "slamr/pred/depth_m"]
# for distance images (distance along ray) in meters
ARIA_DISTANCE_M_PRED = [
    "rgb/pred/distance_m",
    "slaml/pred/distance_m",
    "slamr/pred/distance_m",
]

# ---------------------------------------------------------------------
# SDF information
# ---------------------------------------------------------------------
ARIA_SDF = "snippet/sdf/sdf"
ARIA_SDF_EXT = "snippet/sdf/extent"
ARIA_SDF_COSY_TIME_NS = "snippet/sdf/cosy_time_ns"
ARIA_SDF_MASK = "snippet/sdf/mask"
ARIA_SDF_T_WORLD_VOXEL = "snippet/sdf/T_world_voxel"

# ---------------------------------------------------------------------
# GT Mesh information
# ---------------------------------------------------------------------
ARIA_MESH_VERTS_W = "snippet/mesh/verts_w"
ARIA_MESH_FACES = "snippet/mesh/faces"
ARIA_MESH_VERT_NORMS_W = "snippet/mesh/v_norms_w"
ARIA_SCENE_MESH_VERTS_W = "scene/mesh/verts_w"
ARIA_SCENE_MESH_FACES = "scene/mesh/faces"
ARIA_SCENE_MESH_VERT_NORMS_W = "scene/mesh/v_norms_w"

# ---------------------------------------------------------------------
# Scene volume information (can be acquired from mesh or semidense points)
# --------------------------------------------------------------------
ARIA_MESH_VOL_MIN = "scene/mesh/vol_min"
ARIA_MESH_VOL_MAX = "scene/mesh/vol_max"
ARIA_POINTS_VOL_MIN = "scene/points/vol_min"
ARIA_POINTS_VOL_MAX = "scene/points/vol_max"

# ---------------------------------------------------------------------
# additional image constants
# ---------------------------------------------------------------------

# Fixed mapping of resolutions, tuple has three numbers: (RGB_HW, SLAM_W, SLAM_H)
RESOLUTION_MAP = {
    0: (1408, 640, 480),
    1: (704, 640, 480),
    2: (352, 320, 240),
    # 3: there is none
    4: (176, 160, 112),  # there is some cropping in SLAM image height
    5: (480, 640, 480),
    6: (336, 448, 336),  # match typical internet image FOV (assume 70 deg)
    7: (240, 320, 240),  # match typical internet pixels e.g. ImageNet
    8: (192, 256, 192),
    9: (144, 192, 144),
    # divisible by 14 for ViTs that use patch size 14
    10: (
        1400,
        560,
        420,
    ),  # similar to 0  560x420 instead of 616x462 so that we can also get half the resolution for equivalent to 7
    11: (700, 560, 420),  # similar to 1
    12: (420, 560, 420),  # similar to 5
    13: (210, 280, 210),  # similar to 7
}
# Fixed mapping of corresponding wh_multiple_of, for each resolution
WH_MULTIPLE_OF_MAP = {
    0: 16,
    1: 16,
    2: 16,
    # 3: there is none
    4: 16,
    5: 16,
    6: 16,
    7: 16,
    8: 16,
    9: 16,
    10: 14,
    11: 14,
    12: 14,
    13: 14,
}

# Helper constants for managing valid radius of the fisheye images, valid radius
# defines a circle from the center of projection where project/unproject is valid
RGB_RADIUS_FACTOR = 760.0 / 1408.0
SLAM_RADIUS_FACTOR = 320.0 / 640.0

ARIA_RGB_WIDTH_TO_RADIUS = {
    RESOLUTION_MAP[key][0]: RESOLUTION_MAP[key][0] * RGB_RADIUS_FACTOR
    for key in RESOLUTION_MAP
}
ARIA_SLAM_WIDTH_TO_RADIUS = {
    RESOLUTION_MAP[key][1]: RESOLUTION_MAP[key][1] * SLAM_RADIUS_FACTOR
    for key in RESOLUTION_MAP
}

ARIA_RGB_SCALE_TO_WH = {
    key: [RESOLUTION_MAP[key][0], RESOLUTION_MAP[key][0]] for key in RESOLUTION_MAP
}
ARIA_SLAM_SCALE_TO_WH = {
    key: [RESOLUTION_MAP[key][1], RESOLUTION_MAP[key][2]] for key in RESOLUTION_MAP
}

ARIA_IMG_MIN_LUX = 30.0
ARIA_IMG_MAX_LUX = 150000.0
ARIA_IMG_MAX_PERC_OVEREXPOSED = 0.02
ARIA_IMG_MAX_PERC_UNDEREXPOSED = 0.0001

# ---------------------------------------------------------------------
# EFM Constants
# ---------------------------------------------------------------------
ARIA_EFM_OUTPUT = "efm/output"

ARIA_CAM_INFO = {
    "name": ["rgb", "slaml", "slamr"],
    "stream_id": [0, 1, 2],
    "name_to_stream_id": {
        "rgb": 0,
        "slaml": 1,
        "slamr": 2,
    },
    "width_height": {
        "rgb": (1408, 1408),
        "slaml": (640, 480),
        "slamr": (640, 480),
    },
    # vrs id
    "id": ["214-1", "1201-1", "1201-2"],
    "id_to_name": {
        "214-1": "rgb",
        "1201-1": "slaml",
        "1201-2": "slamr",
    },
    # display names
    "display": [
        "RGB",
        "SLAM Left",
        "SLAM Right",
    ],
    # Physical position on glasses from left to right.
    "spatial_order": [1, 0, 2],
}


================================================
FILE: efm3d/aria/camera.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Tuple, Union

import numpy as np
import torch

from .pose import get_T_rot_z, IdentityPose, PoseTW
from .projection_utils import (
    fisheye624_project,
    fisheye624_unproject,
    pinhole_project,
    pinhole_unproject,
)
from .tensor_wrapper import autocast, autoinit, smart_cat, TensorWrapper

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class DefaultCameraTWData(TensorWrapper):
    """Allows multiple input sizes."""

    def __init__(self):
        self._data = -1 * torch.ones(33)

    @property
    def shape(self):
        return (torch.Size([34]), torch.Size([26]), torch.Size([22]))


class DefaultCameraTWParam(TensorWrapper):
    """Allows multiple input sizes."""

    def __init__(self):
        self._data = -1 * torch.ones(15)

    @property
    def shape(self):
        return (torch.Size([16]), torch.Size([15]), torch.Size([8]), torch.Size([4]))


class DefaultCameraTWDistParam(TensorWrapper):
    """Allows multiple input sizes."""

    def __init__(self):
        self._data = -1 * torch.ones(12)

    @property
    def shape(self):
        return (torch.Size([12]), torch.Size([4]), torch.Size([0]))


DEFAULT_CAM_DATA = DefaultCameraTWData()
DEFAULT_CAM_PARAM = DefaultCameraTWParam()
DEFAULT_CAM_DIST_PARAM = DefaultCameraTWDistParam()
DEFAULT_CAM_DATA_SIZE = 34

RGB_PARAMS = np.float32(
    [2 * 600.0, 2 * 352.0, 2 * 352.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
)
SLAM_PARAMS = np.float32([500.0, 320.0, 240.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

FISHEYE624_TYPE_STR = (
    "FisheyeRadTanThinPrism:f,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4"
)
FISHEYE624_DF_TYPE_STR = (
    "FisheyeRadTanThinPrism:fu,fv,u0,v0,k0,k1,k2,k3,k5,k5,p1,p2,s1,s2,s3,s4"
)
PINHOLE_TYPE_STR = "Pinhole"


def is_fisheye624(inp):
    names = [
        "Fisheye624",
        "f624",
        FISHEYE624_TYPE_STR,
        FISHEYE624_DF_TYPE_STR,
        "FisheyeRadTanThinPrism",
        "CameraModelType.FISHEYE624",
    ]
    names += [name.lower() for name in names]
    return inp in names


def is_kb3(inp):
    names = ["KB:fu,fv,u0,v0,k0,k1,k2,k3", "KannalaBrandtK3", "KB3"]
    names += [name.lower() for name in names]
    return inp in names


def is_pinhole(inp):
    names = ["Pinhole", "Linear", "CameraModelType.LINEAR"]
    names += [name.lower() for name in names]
    return inp in names


def get_aria_camera(params=SLAM_PARAMS, width=640, height=480, valid_radius=None, B=1):
    type_str = FISHEYE624_TYPE_STR if params.shape[-1] == 15 else FISHEYE624_DF_TYPE_STR
    if valid_radius is None:
        cam = CameraTW.from_surreal(width, height, type_str, params)
    else:
        cam = CameraTW.from_surreal(
            width,
            height,
            type_str,
            params,
            valid_radius=valid_radius,
        )
    if B > 1:
        cam = cam.unsqueeze(0).repeat(B, 1)
    return cam


def get_pinhole_camera(params, width=640, height=480, valid_radius=None, B=1):
    type_str = PINHOLE_TYPE_STR
    if valid_radius is None:
        cam = CameraTW.from_surreal(width, height, type_str, params)
    else:
        cam = CameraTW.from_surreal(
            width,
            height,
            type_str,
            params,
            valid_radius=valid_radius,
        )
    if B > 1:
        cam = cam.unsqueeze(0).repeat(B, 1)
    return cam


def get_base_aria_rgb_camera_full_res():
    params = RGB_PARAMS * 2
    params[1:3] += 32
    return get_aria_camera(params, 2880, 2880)


def get_base_aria_rgb_camera():
    return get_aria_camera(RGB_PARAMS, 1408, 1408)


def get_base_aria_slam_camera():
    return get_aria_camera(SLAM_PARAMS, 640, 480)


class CameraTW(TensorWrapper):
    """
    Class to represent a batch of camera calibrations of the same camera type.
    """

    SIZE_IND = slice(0, 2)
    F_IND = slice(2, 4)
    C_IND = slice(4, 6)
    GAIN_IND = 6
    EXPOSURE_S_IND = 7
    VALID_RADIUS_IND = slice(8, 10)
    T_CAM_RIG_IND = slice(10, 22)
    DIST_IND = slice(22, None)

    @autocast
    @autoinit
    def __init__(
        self, data: Union[torch.Tensor, DefaultCameraTWData] = DEFAULT_CAM_DATA
    ):
        assert isinstance(data, torch.Tensor)
        assert data.shape[-1] in {22, 26, 34}
        super().__init__(data)

    @classmethod
    @autoinit
    def from_parameters(
        cls,
        width: torch.Tensor = -1 * torch.ones(1),
        height: torch.Tensor = -1 * torch.ones(1),
        fx: torch.Tensor = -1 * torch.ones(1),
        fy: torch.Tensor = -1 * torch.ones(1),
        cx: torch.Tensor = -1 * torch.ones(1),
        cy: torch.Tensor = -1 * torch.ones(1),
        gain: torch.Tensor = -1 * torch.ones(1),
        exposure_s: torch.Tensor = 1e-3 * torch.ones(1),
        valid_radiusx: torch.Tensor = 99999.0 * torch.ones(1),
        valid_radiusy: torch.Tensor = 99999.0 * torch.ones(1),
        T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.
        dist_params: Union[
            torch.Tensor, DefaultCameraTWDistParam
        ] = DEFAULT_CAM_DIST_PARAM,
    ):
        # Concatenate into one big data tensor, handles TensorWrapper objects.
        data = smart_cat(
            [
                width,
                height,
                fx,
                fy,
                cx,
                cy,
                gain,
                exposure_s,
                valid_radiusx,
                valid_radiusy,
                T_camera_rig,
                dist_params,
            ],
            dim=-1,
        )
        return cls(data)

    @classmethod
    @autoinit
    def from_surreal(
        cls,
        width: torch.Tensor = -1 * torch.ones(1),
        height: torch.Tensor = -1 * torch.ones(1),
        type_str: str = "Fisheye624",
        params: Union[torch.Tensor, DefaultCameraTWParam] = DEFAULT_CAM_PARAM,
        gain: torch.Tensor = 1 * torch.ones(1),
        exposure_s: torch.Tensor = 1e-3 * torch.ones(1),
        valid_radius: torch.Tensor = 99999.0 * torch.ones(1),
        T_camera_rig: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.
    ):
        # Try to auto-determine the camera model.
        if (
            is_fisheye624(type_str) and params.shape[-1] == 16
        ):  # Fisheye624 double focals
            fx = params[..., 0].unsqueeze(-1)
            fy = params[..., 1].unsqueeze(-1)
            cx = params[..., 2].unsqueeze(-1)
            cy = params[..., 3].unsqueeze(-1)
            dist_params = params[..., 4:]
        elif (
            is_fisheye624(type_str) and params.shape[-1] == 15
        ):  # Fisheye624 single focal
            f = params[..., 0].unsqueeze(-1)
            cx = params[..., 1].unsqueeze(-1)
            cy = params[..., 2].unsqueeze(-1)
            dist_params = params[..., 3:]
            fx = fy = f
        elif is_kb3(type_str) and params.shape[-1] == 8:  # KB3.
            fx = params[..., 0].unsqueeze(-1)
            fy = params[..., 1].unsqueeze(-1)
            cx = params[..., 2].unsqueeze(-1)
            cy = params[..., 3].unsqueeze(-1)
            dist_params = params[..., 4:]
        elif is_pinhole(type_str) and params.shape[-1] == 4:  # Pinhole.
            fx = params[..., 0].unsqueeze(-1)
            fy = params[..., 1].unsqueeze(-1)
            cx = params[..., 2].unsqueeze(-1)
            cy = params[..., 3].unsqueeze(-1)
            dist_params = params[..., 4:]
        else:
            raise NotImplementedError(
                "Unknown number of params entered for camera model"
            )

        if torch.any(torch.logical_or(valid_radius > height, valid_radius > width)):
            if not is_pinhole(type_str):
                # Try to auto-determine the valid radius for fisheye cameras.
                default_radius = 99999.0
                hw_ratio = height / width
                eyevideo_camera_hw_ratio = torch.tensor(240.0 / 640.0).to(hw_ratio)
                slam_camera_hw_ratio = torch.tensor(480.0 / 640.0).to(hw_ratio)
                rgb_camera_hw_ratio = torch.tensor(2880.0 / 2880.0).to(hw_ratio)
                guess_rgb = hw_ratio == rgb_camera_hw_ratio
                guess_slam = hw_ratio == slam_camera_hw_ratio
                guess_eyevideo = hw_ratio == eyevideo_camera_hw_ratio
                valid_radius = default_radius * torch.ones_like(hw_ratio)
                valid_radius = torch.where(
                    guess_rgb, 1415 * (height / 2880), valid_radius
                )
                valid_radius = torch.where(
                    guess_slam, 330 * (height / 480), valid_radius
                )
                # This is for Eye Video Camera
                valid_radius = torch.where(
                    guess_eyevideo, 330 * (height / 480), valid_radius
                )
                if torch.any(valid_radius == default_radius):
                    raise ValueError(
                        f"Failed to auto-determine valid radius based on aspect ratios (valid_radius {valid_radius}, width {width}, height {height})"
                    )
            else:
                # Note that the valid_radius for pinhole camera is not well-defined.
                # We heuristically set the valid radius to be the half of the image diagonal.
                # Add one pixel to be sure that all pixels in the image are valid.
                valid_radius = (
                    torch.sqrt((width / 2.0) ** 2 + (height / 2.0) ** 2) + 1.0
                )

        return cls.from_parameters(
            width=width,
            height=height,
            fx=fx,
            fy=fy,
            cx=cx,
            cy=cy,
            gain=gain,
            exposure_s=exposure_s,
            valid_radiusx=valid_radius,
            valid_radiusy=valid_radius,
            T_camera_rig=T_camera_rig,
            dist_params=dist_params,
        )

    @property
    def size(self) -> torch.Tensor:
        """Size (width height) of the images, with shape (..., 2)."""
        return self._data[..., self.SIZE_IND]

    @property
    def f(self) -> torch.Tensor:
        """Focal lengths (fx, fy) with shape (..., 2)."""
        return self._data[..., self.F_IND]

    @property
    def c(self) -> torch.Tensor:
        """Principal points (cx, cy) with shape (..., 2)."""
        return self._data[..., self.C_IND]

    @property
    def K(self) -> torch.Tensor:
        """Intrinsic matrix with shape (..., 3, 3)"""
        K = torch.eye(3, device=self.device, dtype=self.dtype)
        # Make proper size of K to take care of B and T dims.
        K_view = [1] * (self.f.ndim - 1) + [3, 3]
        K_repeat = list(self.f.shape[:-1]) + [1, 1]
        K = K.view(K_view)
        K = K.repeat(K_repeat)
        K[..., 0, 0] = self.f[..., 0]
        K[..., 1, 1] = self.f[..., 1]
        K[..., 0, 2] = self.c[..., 0]
        K[..., 1, 2] = self.c[..., 1]
        return K

    @property
    def K44(self) -> torch.Tensor:
        """Intrinsic matrix with shape (..., 4, 4)"""
        K = torch.eye(4, device=self.device, dtype=self.dtype)
        # Make proper size of K to take care of B and T dims.
        K_view = [1] * (self.f.ndim - 1) + [4, 4]
        K_repeat = list(self.f.shape[:-1]) + [1, 1]
        K = K.view(K_view)
        K = K.repeat(K_repeat)
        K[..., 0, 0] = self.f[..., 0]
        K[..., 1, 1] = self.f[..., 1]
        K[..., 0, 2] = self.c[..., 0]
        K[..., 1, 2] = self.c[..., 1]
        return K

    @property
    def gain(self) -> torch.Tensor:
        """Gain of the camera, with shape (..., 1)."""
        return self._data[..., self.GAIN_IND].unsqueeze(-1)

    @property
    def exposure_s(self) -> torch.Tensor:
        """Exposure of the camera in seconds, with shape (..., 1)."""
        return self._data[..., self.EXPOSURE_S_IND].unsqueeze(-1)

    @property
    def valid_radius(self) -> torch.Tensor:
        """Radius from camera center for valid projections, with shape (..., 1)."""
        return self._data[..., self.VALID_RADIUS_IND]

    @property
    def T_camera_rig(self) -> torch.Tensor:
        """Pose of camera, shape (..., 12)."""
        return PoseTW(self._data[..., self.T_CAM_RIG_IND])

    @property
    def dist(self) -> torch.Tensor:
        """Distortion parameters, with shape (..., {0, D}), where D is number of distortion params."""
        return self._data[..., self.DIST_IND]

    @property
    def params(self) -> torch.Tensor:
        """Get the camera "params", which are defined as fx,fy,cx,cy,dist"""
        return torch.cat([self.f, self.c, self.dist], dim=-1)

    @property
    def is_fisheye624(self):
        return self.dist.shape[-1] == 12

    @property
    def is_kb3(self):
        return self.dist.shape[-1] == 4

    @property
    def is_linear(self):
        return self.dist.shape[-1] == 0

    def set_valid_radius(self, valid_radius: torch.Tensor):
        self._data[..., self.VALID_RADIUS_IND] = valid_radius

    def set_T_camera_rig(self, T_camera_rig: PoseTW):
        self._data[..., self.T_CAM_RIG_IND] = T_camera_rig._data.clone()

    def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
        """Update the camera parameters after resizing an image."""
        if isinstance(scales, (int, float)):
            scales = (scales, scales)
        s = self._data.new_tensor(scales)
        data = torch.cat(
            [
                self.size * s,
                self.f * s,
                (self.c + 0.5) * s - 0.5,
                self.gain,
                self.exposure_s,
                self.valid_radius * s,
                self.T_camera_rig._data,
                self.dist,
            ],
            dim=-1,
        )
        return self.__class__(data)

    def scale_to_size(self, size_wh: Union[int, Tuple[int]]):
        """Scale the camera parameters to a given image size"""
        if torch.unique(self.size).numel() > 2:
            raise ValueError(f"cannot handle multiple sizes {self.size}")
        if isinstance(size_wh, int):
            size_wh = (size_wh, size_wh)
        i0w = tuple([0] * self.ndim)
        i0h = tuple([0] * (self.ndim - 1) + [1])
        scale = (
            float(size_wh[0]) / float(self.size[i0w]),
            float(size_wh[1]) / float(self.size[i0h]),
        )
        return self.scale(scale)

    def scale_to(self, im: torch.Tensor):
        """
        Scale the camera parameters to match the size of the given image assumes
        ...xHxW image tensor convention of pytorch
        """
        H, W = im.shape[-2:]
        return self.scale_to_size((W, H))

    def crop(self, left_top: Tuple[float], size: Tuple[int]):
        """Update the camera parameters after cropping an image."""
        left_top = self._data.new_tensor(left_top)
        size = self._data.new_tensor(size)

        # Expand the dimension if self._data is a tensor of CameraTW
        if len(self._data.shape) > 1:
            expand_dim = list(self._data.shape[:-1]) + [1]
            size = size.repeat(expand_dim)
            left_top = left_top.repeat(expand_dim)

        data = torch.cat(
            [
                size,
                self.f,
                self.c - left_top,
                self.gain,
                self.exposure_s,
                self.valid_radius,
                self.T_camera_rig._data,
                self.dist,
            ],
            dim=-1,
        )
        return self.__class__(data)

    @autocast
    def in_image(self, p2d: torch.Tensor):
        """Check if 2D points are within the image boundaries."""
        assert p2d.shape[-1] == 2, f"p2d shape needs to be 2d {p2d.shape}"
        # assert p2d.shape[:-2] == self.shape  # allow broadcasting
        size = self.size.unsqueeze(-2)
        valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), dim=-1)
        return valid

    @autocast
    def in_radius(self, p2d: torch.Tensor):
        """Check if 2D points are within the valid fisheye radius region."""
        assert p2d.shape[-1] == 2, f"p2d shape needs to be 2d {p2d.shape}"
        dists = torch.linalg.norm(
            (p2d - self.c.unsqueeze(-2)) / self.valid_radius.unsqueeze(-2),
            dim=-1,
            ord=2,
        )
        valid = dists < 1.0
        return valid

    @autocast
    def in_radius_mask(self):
        """
        Return a mask that is True where 2D points are within the valid fisheye
        radius region.  Returned mask is of shape ... x 1 x H x W, where ... is
        the shape of the camera (BxT or B for example).
        """
        s = self.shape[:-1]
        C = self.shape[-1]
        px = pixel_grid(self.view(-1, C)[0])
        H, W, _ = px.shape
        valids = self.in_radius(px.view(-1, 2))
        s = s + (1, H, W)
        valids = valids.view(s)
        return valids

    @autocast
    def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
        """Transform 3D points into 2D pixel coordinates."""

        # Explicitly promote the data types.
        promoted_type = torch.promote_types(self._data.dtype, p3d.dtype)
        self._data = self._data.to(promoted_type)
        p3d = p3d.to(promoted_type)

        # Try to auto-determine the camera model.
        if self.is_fisheye624:  # Fisheye624.
            params = torch.cat([self.f, self.c, self.dist], dim=-1)
            if params.ndim == 1:
                B = p3d.shape[0]
                params = params.unsqueeze(0).repeat(B, 1)
            p2d = fisheye624_project(p3d, params)
        elif self.is_linear:  # Pinhole.
            params = self.params
            if params.ndim == 1:
                B = p3d.shape[0]
                params = params.unsqueeze(0).repeat(B, 1)
            p2d = pinhole_project(p3d, params)
        else:
            raise ValueError(
                "only fisheye624 and pinhole implemented, kb3 not yet implemented"
            )

        in_image = self.in_image(p2d)
        in_radius = self.in_radius(p2d)
        in_front = p3d[..., -1] > 0
        valid = in_image & in_radius & in_front
        return p2d, valid

    @autocast
    def unproject(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
        """Transform 2D points into 3D rays."""

        # Explicitly promote the data types.
        promoted_type = torch.promote_types(self._data.dtype, p2d.dtype)
        self._data = self._data.to(promoted_type)
        p2d = p2d.to(promoted_type)

        # Try to auto-determine the camera model.
        if self.is_fisheye624:  # Fisheye624.
            params = torch.cat([self.f, self.c, self.dist], dim=-1)
            if params.ndim == 1:
                B = p2d.shape[0]
                params = params.unsqueeze(0).repeat(B, 1)
            rays = fisheye624_unproject(p2d, params)
        elif self.is_linear:  # Pinhole.
            params = self.params
            if params.ndim == 1:
                B = p2d.shape[0]
                params = params.unsqueeze(0).repeat(B, 1)
            rays = pinhole_unproject(p2d, params)
        else:
            raise ValueError(
                "only fisheye624 and pinhole implemented, kb3 not yet implemented"
            )

        in_image = self.in_image(p2d)
        in_radius = self.in_radius(p2d)
        valid = in_image & in_radius
        return rays, valid

    def rotate_90_cw(self):
        return self.rotate_90(clock_wise=True)

    def rotate_90_ccw(self):
        return self.rotate_90(clock_wise=False)

    def rotate_90(self, clock_wise: bool):
        dist_params = self.dist.clone()
        if self.is_fisheye624:
            # swap thin prism and tangential distortion parameters
            # {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3} to
            # {k_0 ... k_5} {p_1 p_0} {s_2 s_3 s_0 s_1}
            dist_p = self.dist[..., 6:8]
            dist_s = self.dist[..., 8:12]
            dist_params[..., 6] = dist_p[..., 1]
            dist_params[..., 7] = dist_p[..., 0]
            dist_params[..., 8:10] = dist_s[..., 2:]
            dist_params[..., 10:12] = dist_s[..., :2]
        elif self.is_linear:
            # no need to rotate distortion parameters since there are none
            pass
        elif self.is_kb3:
            raise NotImplementedError("kb3 model rotation not implemented yet")
        else:
            raise NotImplementedError(f"camera model not recognized {self}")

        # clock-wise or counter clock-wise
        DIR = 1 if clock_wise else -1
        # rotate camera extrinsics by 90 degree CW
        T_rot_z = PoseTW.from_matrix3x4(get_T_rot_z(DIR * np.pi * 0.5)).to(self.device)
        if clock_wise:
            # rotate x, y of principal point
            # x_rotated = height - 1 - y_before
            # y_rotated = x_before
            rot_cx = self.size[..., 1] - self.c[..., 1] - 1
            rot_cy = self.c[..., 0].clone()
        else:
            rot_cx = self.c[..., 1].clone()
            rot_cy = self.size[..., 0] - self.c[..., 0] - 1

        return CameraTW.from_parameters(
            # swap width and height
            self.size[..., 1].clone().unsqueeze(-1),
            self.size[..., 0].clone().unsqueeze(-1),
            # swap x, y of focal lengths
            self.f[..., 1].clone().unsqueeze(-1),
            self.f[..., 0].clone().unsqueeze(-1),
            rot_cx.unsqueeze(-1),
            rot_cy.unsqueeze(-1),
            self.gain.clone(),
            self.exposure_s.clone(),
            # swap valid radius x, y
            self.valid_radius[..., 1].clone().unsqueeze(-1),
            self.valid_radius[..., 0].clone().unsqueeze(-1),
            # rotate camera extrinsics
            T_rot_z @ self.T_camera_rig,
            dist_params,
        )

    def __repr__(self):
        return f"CameraTW {self.shape} {self.dtype} {self.device}"


def grid_2d(
    width: int,
    height: int,
    output_range=(-1.0, 1.0, -1.0, 1.0),
    device="cpu",
    dtype=torch.float32,
):
    x = torch.linspace(
        output_range[0], output_range[1], width + 1, device=device, dtype=dtype
    )[:-1]
    y = torch.linspace(
        output_range[2], output_range[3], height + 1, device=device, dtype=dtype
    )[:-1]
    xx, yy = torch.meshgrid(x, y, indexing="xy")
    grid = torch.stack([xx, yy], dim=-1)
    return grid


def pixel_grid(cam: CameraTW):
    assert cam.ndim == 1, f"Camera must be 1 dimensional {cam.shape}"
    W, H = int(cam.size[0]), int(cam.size[1])
    return grid_2d(W, H, output_range=[0, W, 0, H], device=cam.device, dtype=cam.dtype)


def scale_image_to_cam(cams: CameraTW, ims: torch.Tensor) -> torch.Tensor:
    """Scale an image to a camera."""

    from torchvision.transforms import InterpolationMode, Resize

    T = None
    if ims.ndim == 5:
        B, T, C, H, W = ims.shape
        ims = ims.view(-1, C, H, W)
        Wo, Ho = cams[0, 0].size.int().tolist()
    elif ims.ndim == 4:
        B, C, H, W = ims.shape
        Wo, Ho = cams[0].size.int().tolist()
    else:
        raise ValueError(f"unusable image shape {ims.shape}, {cams.shape}")
    ims = Resize((Ho, Wo), interpolation=InterpolationMode.BILINEAR, antialias=True)(
        ims
    )
    if T is not None:
        return ims.view(B, T, C, Ho, Wo)
    return ims.view(B, C, Ho, Wo)


================================================
FILE: efm3d/aria/obb.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Tuple, Union

import torch
import torch.nn.functional as F

from .camera import CameraTW
from .pose import IdentityPose, PAD_VAL, PoseTW, rotation_from_euler
from .tensor_wrapper import autocast, autoinit, smart_cat, smart_stack, TensorWrapper

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# logger.setLevel(logging.DEBUG)


# OBB corner numbering diagram for this implementation (the same as pytorch3d
# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/ops/iou_box3d.py#L111)
#
# (4) +---------+. (5)
#     | ` .     |  ` .
#     | (0) +---+-----+ (1)
#     |     |   |     |
# (7) +-----+---+. (6)|
#     ` .   |     ` . |
#     (3) ` +---------+ (2)
#
# NOTE: Throughout this implementation, we assume that boxes
# are defined by their 8 corners exactly in the order specified in the
# diagram above for the function to give correct results. In addition
# the vertices on each plane must be coplanar.
# As an alternative to the diagram, this is a unit bounding
# box which has the correct vertex ordering:
# box_corner_vertices = [
#     [0, 0, 0],  #   (0)
#     [1, 0, 0],  #   (1)
#     [1, 1, 0],  #   (2)
#     [0, 1, 0],  #   (3)
#     [0, 0, 1],  #   (4)
#     [1, 0, 1],  #   (5)
#     [1, 1, 1],  #   (6)
#     [0, 1, 1],  #   (7)
# ]

# triangle indices to draw an OBB mesh from bb3corners_*
OBB_MESH_TRI_INDS = [
    [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
    [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
    [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
]

# line indices to draw an OBB line strip frame from bb3corners_*
OBB_LINE_INDS = [0, 1, 2, 3, 0, 3, 7, 4, 0, 1, 5, 6, 5, 4, 7, 6, 2, 1, 5]

# corner indices to construct all edge lines
BB3D_LINE_ORDERS = [
    [0, 1],
    [1, 2],
    [2, 3],
    [3, 0],
    [4, 5],
    [5, 6],
    [6, 7],
    [7, 4],
    [0, 4],
    [1, 5],
    [2, 6],
    [3, 7],
]

_box_planes = [
    [0, 1, 2, 3],
    [3, 2, 6, 7],
    [0, 1, 5, 4],
    [0, 3, 7, 4],
    [1, 2, 6, 5],
    [4, 5, 6, 7],
]

DOT_EPS = 1e-3
AREA_EPS = 1e-4


class ObbTW(TensorWrapper):
    """
    Oriented 3D Bounding Box observation in world coordinates (via
    T_world_object) for Aria headsets.
    """

    @autocast
    @autoinit
    def __init__(self, data: torch.Tensor = PAD_VAL * torch.ones((1, 34))):
        assert isinstance(data, torch.Tensor)
        assert data.shape[-1] == 34
        super().__init__(data)

    @classmethod
    @autoinit
    def from_lmc(
        cls,
        bb3_object: torch.Tensor = PAD_VAL * torch.ones(6),
        bb2_rgb: torch.Tensor = PAD_VAL * torch.ones(4),
        bb2_slaml: torch.Tensor = PAD_VAL * torch.ones(4),
        bb2_slamr: torch.Tensor = PAD_VAL * torch.ones(4),
        T_world_object: Union[torch.Tensor, PoseTW] = IdentityPose,  # 1x12.
        sem_id: torch.Tensor = PAD_VAL * torch.ones(1),
        inst_id: torch.Tensor = PAD_VAL * torch.ones(1),
        prob: torch.Tensor = 1 * torch.ones(1),
        moveable: torch.Tensor = 0 * torch.ones(1),
    ):
        # Concatenate into one big data tensor, handles TensorWrapper objects.
        # make sure that its on the same device (fails if IdentityPose is used)
        device = bb3_object.device
        data = smart_cat(
            [
                bb3_object,
                bb2_rgb.to(device),
                bb2_slaml.to(device),
                bb2_slamr.to(device),
                T_world_object.to(device),
                sem_id.to(device),
                inst_id.to(device),
                prob.to(device),
                moveable.to(device),
            ],
            dim=-1,
        )
        return cls(data)

    @property
    def bb3_object(self) -> torch.Tensor:
        """3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6)."""
        return self._data[..., :6]

    @property
    def bb3_min_object(self) -> torch.Tensor:
        """3D bounding box minimum corner [xmin,ymin,zmin] in object coord frame, with shape (..., 3)."""
        return self._data[..., 0:6:2]

    @property
    def bb3_max_object(self) -> torch.Tensor:
        """3D bounding box maximum corner [xmax,ymax,zmax] in object coord frame, with shape (..., 3)."""
        return self._data[..., 1:6:2]

    @property
    def bb3_center_object(self) -> torch.Tensor:
        """3D bounding box center in object coord frame, with shape (..., 3)."""
        return 0.5 * (self.bb3_min_object + self.bb3_max_object)

    @property
    def bb3_center_world(self) -> torch.Tensor:
        """3D bounding box center in world coord frame, with shape (..., 3)."""
        s = self.bb3_center_object.shape
        _bb3_center_world = self.T_world_object.view(-1, 12).batch_transform(
            self.bb3_center_object.view(-1, 3)
        )
        return _bb3_center_world.view(s)

    @property
    def bb3_diagonal(self) -> torch.Tensor:
        """3D bounding box diagonal, with shape (..., 3)."""
        return self.bb3_max_object - self.bb3_min_object

    @property
    def bb3_volumes(self) -> torch.Tensor:
        """3D bounding box volumes, with shape (..., 1)."""
        diags = self.bb3_diagonal
        return diags.prod(dim=-1, keepdim=True)

    @property
    def bb2_rgb(self) -> torch.Tensor:
        """2D bounding box [xmin,xmax,ymin,ymax] as visible in RGB image, -1's if not visible, with shape (..., 4)."""
        return self._data[..., 6:10]

    def visible_bb3_ind(self, cam_id) -> torch.Tensor:
        """Indices of visible 3D bounding boxes in camera cam_id"""
        bb2_cam = self.bb2(cam_id)
        vis_ind = torch.all(bb2_cam > 0, dim=-1)
        return vis_ind

    @property
    def bb2_slaml(self) -> torch.Tensor:
        """2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Left image, -1's if not visible, with shape (..., 4)."""
        return self._data[..., 10:14]

    @property
    def bb2_slamr(self) -> torch.Tensor:
        """2D bounding box [xmin,xmax,ymin,ymax] as visible in SLAM Right image, -1's if not visible, with shape (..., 4)."""
        return self._data[..., 14:18]

    def bb2(self, cam_id) -> torch.Tensor:
        """
        2D bounding box [xmin,xmax,ymin,ymax] as visible in camera with given
        cam_id, -1's if not visible, with shape (..., 4).
        cam_id == 0 for rgb
        cam_id == 1 for slam left
        cam_id == 2 for slam right
        """
        return self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4]

    def set_bb2(self, cam_id, bb2d, use_mask=True):
        """
        Set 2D bounding box [xmin,xmax,ymin,ymax] in camera with given
        cam_id == 0 for rgb
        cam_id == 1 for slam left
        cam_id == 2 for slam right
        """
        padding_mask = self.get_padding_mask()
        self._data[..., 6 + cam_id * 4 : 10 + cam_id * 4] = bb2d
        if use_mask:
            self._data[padding_mask] = PAD_VAL

    def set_bb3_object(self, bb3_object, use_mask=True) -> torch.Tensor:
        """set 3D bounding box [xmin,xmax,ymin,ymax,zmin,zmax] in object coord frame, with shape (..., 6)."""
        padding_mask = self.get_padding_mask()
        self._data[..., :6] = bb3_object
        if use_mask:
            self._data[padding_mask] = PAD_VAL

    def set_prob(self, prob, use_mask=True):
        """Set probability score"""
        padding_mask = self.get_padding_mask()
        self._data[..., 32] = prob
        if use_mask:
            self._data[padding_mask] = PAD_VAL

    @property
    def T_world_object(self) -> torch.Tensor:
        """3D SE3 transform from object to world coords, with shape (..., 12)."""
        return PoseTW(self._data[..., 18:30])

    def get_padding_mask(self) -> torch.Tensor:
        """get boolean mask indicating which Obbs are valid/non-padded."""
        return (self._data == PAD_VAL).all(dim=-1, keepdim=False)

    def set_T_world_object(self, T_world_object: PoseTW):
        """set 3D SE3 transform from object to world coords."""
        invalid_mask = self.get_padding_mask()
        self._data[..., 18:30] = T_world_object._data
        self._data[invalid_mask] = PAD_VAL

    @property
    def sem_id(self) -> torch.Tensor:
        """semantic id, with shape (..., 1)."""
        return self._data[..., 30].unsqueeze(-1).int()

    def set_sem_id(self, sem_id: torch.Tensor):
        """set semantic id to sem_id"""
        self._data[..., 30] = sem_id.squeeze()

    @property
    def inst_id(self) -> torch.Tensor:
        """instance id, with shape (..., 1)."""
        return self._data[..., 31].unsqueeze(-1).int()

    def set_inst_id(self, inst_id: torch.Tensor):
        """set instance id to inst_id"""
        self._data[..., 31] = inst_id.squeeze()

    @property
    def prob(self) -> torch.Tensor:
        """probability of detection, with shape (..., 1)."""
        return self._data[..., 32].unsqueeze(-1)

    @property
    def moveable(self) -> torch.Tensor:
        """boolean if moveable, with shape (..., 1)."""
        return self._data[..., 33].unsqueeze(-1)

    @property
    def bb3corners_world(self) -> torch.Tensor:
        return self.T_world_object * self.bb3corners_object

    @property
    def bb3corners_object(self) -> torch.Tensor:
        """return the 8 corners of the 3D BB in object coord frame (..., 8, 3)."""
        ids = [0, 2, 4, 1, 2, 4, 1, 3, 4, 0, 3, 4, 0, 2, 5, 1, 2, 5, 1, 3, 5, 0, 3, 5]
        b3o = self.bb3_object
        c3o = b3o[..., ids]
        c3o = c3o.reshape(*c3o.shape[:-1], 8, 3)
        return c3o

    def bb3edge_pts_object(self, num_samples_per_edge: int = 10) -> torch.Tensor:
        """
        return the num_samples_per_edge points per 3D BB edge in object coord
        frame (..., num_samples_per_edge * 12, 3).

        num_samples_per_edge == 1 will result in a list of corners (with some duplicates)
        num_samples_per_edge == 2 will result in a list of corners (with some more duplicates)
        num_samples_per_edge == 3 will result in a list of corners and edge midpoints
        ...
        """
        bb3corners = self.bb3corners_object
        shape = bb3corners.shape
        alphas = torch.linspace(0, 1, num_samples_per_edge, device=bb3corners.device)
        alphas = alphas.view([1] * len(shape[:-2]) + [num_samples_per_edge, 1])
        alphas = alphas.repeat(list(shape[:-2]) + [1, 3])
        betas = torch.ones_like(alphas) - alphas
        bb3edge_pts = []
        for edge_ids in BB3D_LINE_ORDERS:
            bb3edge_pts.append(
                bb3corners[..., edge_ids[0], :].unsqueeze(-2) * betas
                + bb3corners[..., edge_ids[1], :].unsqueeze(-2) * alphas
            )
        return torch.cat(bb3edge_pts, dim=-2)

    def center(self):
        """
        Returns a ObbTW object where the 3D OBBs are centered in their local coordinate system.
        I.e. bb3_min_object == - bb3_max_object.
        """

        T_wo = self.T_world_object
        center_o = self.bb3_center_object
        # compute centered bb3_object and obb pose T_world_object
        centered_T_wo = PoseTW.from_Rt(T_wo.R, T_wo.batch_transform(center_o))
        centered_bb3_min_o = self.bb3_min_object - center_o
        centered_bb3_max_o = self.bb3_max_object - center_o
        centered_bb3_o = torch.stack(
            [
                centered_bb3_min_o[..., 0],
                centered_bb3_max_o[..., 0],
                centered_bb3_min_o[..., 1],
                centered_bb3_max_o[..., 1],
                centered_bb3_min_o[..., 2],
                centered_bb3_max_o[..., 2],
            ],
            dim=-1,
        )
        return ObbTW.from_lmc(
            bb3_object=centered_bb3_o,
            bb2_rgb=self.bb2_rgb,
            bb2_slaml=self.bb2_slaml,
            bb2_slamr=self.bb2_slamr,
            T_world_object=centered_T_wo,
            sem_id=self.sem_id,
            inst_id=self.inst_id,
            prob=self.prob,
            moveable=self.moveable,
        )

    def add_padding(self, max_elts: int = 1000) -> "ObbTW":
        """
        Adds padding to Obbs, useful for returning batches with a varying number
        of Obbs. E.g. if in one batch we have 4 Obbs and another one we have 2,
        setting max_elts=4 will add 2 pads (consisting of all -1s) to the second
        element in the batch.
        """
        assert self._data.ndim <= 2, "higher than order 2 add_padding not supported yet"
        elts = self._data
        num_to_pad = max_elts - len(elts)
        # All -1's denotes a pad element.
        pad_elt = PAD_VAL * self._data.new_ones(self._data.shape[-1])
        if num_to_pad > 0:
            rep_elts = torch.stack([pad_elt for _ in range(num_to_pad)], dim=0)
            elts = torch.cat([elts, rep_elts], dim=0)
        elif num_to_pad < 0:
            elts = elts[:max_elts]
            logger.warning(
                f"Warning: some obbs have been clipped (actual/max {len(elts)}/{max_elts}) in ObbTW.add_padding()"
            )
        return self.__class__(elts)

    def remove_padding(self) -> List["ObbTW"]:
        """
        Removes any padding by finding Obbs with all -1s. Returns a list.
        """
        assert self.ndim <= 4, "higher than order 4 remove_padding not supported yet"

        if self.ndim == 1:
            return self  # Nothing to be done in this case.

        # All -1's denotes a pad element.
        pad_elt = (PAD_VAL * self._data.new_ones(self._data.shape[-1])).unsqueeze(-2)
        is_not_pad = ~torch.all(self._data == pad_elt, dim=-1)

        if self.ndim == 2:
            new_data = self.__class__(self._data[is_not_pad])
        elif self.ndim == 3:
            B = self._data.shape[0]
            new_data = []
            for b in range(B):
                new_data.append(self.__class__(self._data[b][is_not_pad[b]]))
        else:  # self.ndim == 4:
            B, T = self._data.shape[:2]
            new_data = []
            for b in range(B):
                new_data.append([])
                for t in range(T):
                    new_data[-1].append(
                        self.__class__(self._data[b, t][is_not_pad[b, t]])
                    )
        return new_data

    def _mark_invalid(self, invalid_mask: torch.Tensor) -> "ObbTW":
        """
        in place mark obbs in this ObbTW as invalid via mask
        """
        assert invalid_mask.ndim == self.ndim - 1, "invalid_mask must match ObbTW"
        assert invalid_mask.shape[:-1] == self.shape[:-1], (
            "invalid_mask must match ObbTW"
        )
        self._data[invalid_mask] = PAD_VAL

    def _mark_invalid_ids(self, invalid_ids: torch.Tensor) -> "ObbTW":
        """
        in place mark obbs in this ObbTW as invalid via mask
        """
        assert self.ndim == 2, "invalid_ids only supported for 2d ObbTW"
        assert invalid_ids.ndim == 1, "invalid_ids must be 1d"
        assert invalid_ids.dtype == torch.int64, "invalid_ids must be int64"
        self._data[invalid_ids] = PAD_VAL

    def num_valid(self) -> int:
        """
        Returns the number of valid Obbs in this collection.
        """
        if self.ndim == 1:
            is_pad = torch.all(self._data == PAD_VAL, dim=-1)
            return 0 if is_pad.item() else 1
        elif self.ndim == 2:
            is_pad = torch.all(self._data == PAD_VAL, dim=-1)
            return self.shape[0] - is_pad.sum()
        elif self.ndim == 3:
            is_pad = torch.all(self._data == PAD_VAL, dim=-1)
            return self.shape[0] * self.shape[1] - is_pad.sum()
        elif self.ndim == 4:
            is_pad = torch.all(self._data == PAD_VAL, dim=-1)
            return self.shape[0] * self.shape[1] * self.shape[2] - is_pad.sum()
        else:
            raise NotImplementedError(f"{self.shape}")

    def scale_bb2(self, scale_rgb: float, scale_slam: float):
        """Update the 2d bb parameters after resizing the underlying images.
        All 2d bbs are scaled by the same scale specified for the frame of the
        2d bb (RGB vs SLAM)."""

        # Check for padded values and leave those unchanged.
        pad_rgb = (
            torch.all(self.bb2_rgb == PAD_VAL, dim=-1)
            .unsqueeze(-1)
            .expand(*self.bb2_rgb.shape)
        )
        pad_slamr = (
            torch.all(self.bb2_slamr == PAD_VAL, dim=-1)
            .unsqueeze(-1)
            .expand(*self.bb2_slamr.shape)
        )
        pad_slaml = (
            torch.all(self.bb2_slaml == PAD_VAL, dim=-1)
            .unsqueeze(-1)
            .expand(*self.bb2_slaml.shape)
        )
        sc_rgb = scale_rgb * torch.ones_like(self.bb2_rgb)
        sc_slamr = scale_slam * torch.ones_like(self.bb2_slamr)
        sc_slaml = scale_slam * torch.ones_like(self.bb2_slaml)
        # If False, multiply by scale, if True multiply by 1.
        sc_rgb = torch.where(pad_rgb, torch.ones_like(sc_rgb), sc_rgb)
        sc_slamr = torch.where(pad_slamr, torch.ones_like(sc_slamr), sc_slamr)
        sc_slaml = torch.where(pad_slaml, torch.ones_like(sc_slaml), sc_slaml)

        data = smart_cat(
            [
                self.bb3_object,
                self.bb2_rgb * sc_rgb,
                self.bb2_slaml * sc_slaml,
                self.bb2_slamr * sc_slamr,
                self.T_world_object,
                self.sem_id,
                self.inst_id,
                self.prob,
                self.moveable,
            ],
            dim=-1,
        )
        return self.__class__(data)

    def crop_bb2(self, left_top_rgb: Tuple[float], left_top_slam: Tuple[float]):
        """Update the 2d bb parameters after cropping the underlying images.
        All 2d bbs are cropped by the same crop specified for the frame of the
        2d bb (RGB vs SLAM).
        left_top_* is assumed to be a 2D tuple of the left top corner of te crop.
        """
        # accumulate 2d bb formatting of (xmin, xmax, ymin, ymax)
        left_top_rgb = self._data.new_tensor(
            (left_top_rgb[0], left_top_rgb[0], left_top_rgb[1], left_top_rgb[1])
        )
        left_top_slam = self._data.new_tensor(
            (left_top_slam[0], left_top_slam[0], left_top_slam[1], left_top_slam[1])
        )

        # Expand the dimension if self._data is a tensor of CameraTW
        if len(self._data.shape) > 1:
            expand_dim = list(self._data.shape[:-1]) + [1]
            left_top_rgb = left_top_rgb.repeat(expand_dim)
            left_top_slam = left_top_slam.repeat(expand_dim)

        data = smart_cat(
            [
                self.bb3_object,
                self.bb2_rgb - left_top_rgb,
                self.bb2_slaml - left_top_slam,
                self.bb2_slamr - left_top_slam,
                self.T_world_object,
                self.sem_id,
                self.inst_id,
                self.prob,
                self.moveable,
            ],
            dim=-1,
        )
        return self.__class__(data)

    def rotate_bb2_cw(self, image_sizes: List[Tuple[int]]):
        """Update the 2d bb parameters after rotating the underlying images.
        Args:
          image_sizes: List of original image sizes before the rotation.
                       The order of the images sizes should be [(w_rgb, h_rgb), (w_slaml, h_slaml), (w_slamr, h_slamr)].
        """
        ## Early check the input input sizes
        assert len(image_sizes) == 3, (
            f"the image sizes of 3 video stream should be given, but only got {len(image_sizes)}"
        )
        for s in image_sizes:
            assert len(s) == 2

        # rotate the obbs stream by stream
        bb2_rgb_cw = rot_obb2_cw(self.bb2_rgb.clone(), image_sizes[0])
        bb2_slaml_cw = rot_obb2_cw(self.bb2_slaml.clone(), image_sizes[1])
        bb2_slamr_cw = rot_obb2_cw(self.bb2_slamr.clone(), image_sizes[2])

        data = smart_cat(
            [
                self.bb3_object,
                bb2_rgb_cw,
                bb2_slaml_cw,
                bb2_slamr_cw,
                self.T_world_object,
                self.sem_id,
                self.inst_id,
                self.prob,
                self.moveable,
            ],
            dim=-1,
        )
        return self.__class__(data)

    def rectify_obb2(self, fisheye_cams: List[CameraTW], pinhole_cams: List[CameraTW]):
        rect_bb2s = []
        for idx, (fisheye_cam, pinhole_cam) in enumerate(
            zip(fisheye_cams, pinhole_cams)
        ):
            if idx == 0:  # rgb
                bb2 = self.bb2_rgb
            elif idx == 1:  # slaml
                bb2 = self.bb2_slaml
            else:  # slamr
                bb2 = self.bb2_slamr

            tl_points = bb2[..., [0, 2]].clone()  # top-left
            bl_points = bb2[..., [0, 3]].clone()  # bottom-left
            br_points = bb2[..., [1, 3]].clone()  # bottom-right
            tr_points = bb2[..., [1, 2]].clone()  # top-right
            visible_points = self.visible_bb3_ind(idx)

            tl_rays, _ = fisheye_cam.unproject(tl_points)
            br_rays, _ = fisheye_cam.unproject(br_points)
            bl_rays, _ = fisheye_cam.unproject(bl_points)
            tr_rays, _ = fisheye_cam.unproject(tr_points)

            rect_tl_pts, valid = pinhole_cam.project(tl_rays)
            rect_br_pts, valid = pinhole_cam.project(br_rays)
            rect_tl_pts, valid = pinhole_cam.project(bl_rays)
            rect_tr_pts, valid = pinhole_cam.project(tr_rays)
            rect_concat = torch.cat(
                [rect_tl_pts, rect_br_pts, rect_tl_pts, rect_tr_pts], dim=-1
            )
            xmin, _ = torch.min(rect_concat[..., 0::2], dim=-1, keepdim=True)
            xmax, _ = torch.max(rect_concat[..., 0::2], dim=-1, keepdim=True)
            ymin, _ = torch.min(rect_concat[..., 1::2], dim=-1, keepdim=True)
            ymax, _ = torch.max(rect_concat[..., 1::2], dim=-1, keepdim=True)

            # trim
            width = pinhole_cam.size.reshape(-1, 2)[0][0]
            height = pinhole_cam.size.reshape(-1, 2)[0][1]
            xmin = torch.clamp(xmin, min=0, max=width - 1)
            xmax = torch.clamp(xmax, min=0, max=width - 1)
            ymin = torch.clamp(ymin, min=0, max=height - 1)
            ymax = torch.clamp(ymax, min=0, max=height - 1)

            rect_bb2 = torch.cat([xmin, xmax, ymin, ymax], dim=-1)

            # remove the ones without any area
            areas = (rect_bb2[..., 1] - rect_bb2[..., 0]) * (
                rect_bb2[..., 3] - rect_bb2[..., 2]
            )
            areas = areas.unsqueeze(-1)
            areas = areas.repeat(*([1] * (areas.ndim - 1)), 4)
            rect_bb2[areas <= 0] = PAD_VAL
            rect_bb2[~visible_points] = PAD_VAL
            rect_bb2s.append(rect_bb2)

        data = smart_cat(
            [
                self.bb3_object,
                rect_bb2s[0],
                rect_bb2s[1],
                rect_bb2s[2],
                self.T_world_object,
                self.sem_id,
                self.inst_id,
                self.prob,
                self.moveable,
            ],
            dim=-1,
        )
        return self.__class__(data)

    def get_pseudo_bb2(
        self,
        cam: CameraTW,
        T_world_rig: PoseTW,
        num_samples_per_edge: int = 1,
        return_frac_valids: bool = False,
    ):
        """
        get the 2d bbs of the projection of the 3d bbs into all given camera view points.
        This is done by sampling points on the 3d bb edges (see
        bb3edge_pts_object), projecting them and then computing the 2d bbs from
        the valid projected points. The caller has to make sure the ObbTW has valid
        3d bbs data

        num_samples_per_edge == 1 and num_samples_per_edge == 2 are equivalent
        (in both cases we project the obb corners into the frames to compute 2d bbs)
        """
        assert self._data.shape[-2] > 0, "No valid 3d bbs data found!"
        return bb2d_from_project_bb3d(
            self, cam, T_world_rig, num_samples_per_edge, return_frac_valids
        )

    def get_bb2_heights(self, cam_id):
        bb2s = self.bb2(cam_id)
        valid_bb2s = self.visible_bb3_ind(cam_id)
        heights = bb2s[..., 3] - bb2s[..., 2]
        heights[~valid_bb2s] = -1
        return heights

    def get_bb2_widths(self, cam_id):
        bb2s = self.bb2(cam_id)
        valid_bb2s = self.visible_bb3_ind(cam_id)
        widths = bb2s[..., 1] - bb2s[..., 0]
        widths[~valid_bb2s] = -1
        return widths

    def get_bb2_areas(self, cam_id):
        bb2s = self.bb2(cam_id)
        valid_bb2s = self.visible_bb3_ind(cam_id)
        areas = (bb2s[..., 1] - bb2s[..., 0]) * (bb2s[..., 3] - bb2s[..., 2])
        areas[~valid_bb2s] = -1
        return areas

    def get_bb2_centers(self, cam_id):
        bb2s = self.bb2(cam_id)
        valid_bb2s = self.visible_bb3_ind(cam_id)
        center_x = (bb2s[..., 0:1] + bb2s[..., 1:2]) / 2.0
        center_y = (bb2s[..., 2:3] + bb2s[..., 3:4]) / 2.0
        center_2d = torch.cat([center_x, center_y], -1)
        center_2d[~valid_bb2s] = -1
        return center_2d

    def batch_points_inside_bb3(self, pts_world: torch.Tensor) -> torch.Tensor:
        """
        checks if a set of points is inside the 3d bounding box
        expected input shape is N x 3 where N is the number of points and the
        number of obbs in self.
        """
        assert pts_world.shape == self.T_world_object.t.shape
        pts_object = self.T_world_object.inverse().batch_transform(pts_world)
        inside_min = (pts_object > self.bb3_min_object).all(-1)
        inside_max = (pts_object < self.bb3_max_object).all(-1)
        return torch.logical_and(inside_min, inside_max)

    def points_inside_bb3(
        self, pts_world: torch.Tensor, scale_obb: float = 1.0
    ) -> torch.Tensor:
        """
        checks if a set of points is inside the 3d bounding box
        """
        assert self.ndim == 1 and pts_world.ndim == 2
        pts_object = self.T_world_object.inverse().transform(pts_world)
        inside_min = (pts_object > self.bb3_min_object * scale_obb).all(-1)
        inside_max = (pts_object < self.bb3_max_object * scale_obb).all(-1)
        return torch.logical_and(inside_min, inside_max)

    def _transform(self, T_new_world):
        """
        in place transform T_world_object as T_new_object = T_new_world @ T_world_object
        """
        T_world_object = self.T_world_object
        T_new_object = T_new_world @ T_world_object
        self.set_T_world_object(T_new_object)

    def transform(self, T_new_world):
        """
        transform T_world_object as T_new_object = T_new_world @ T_world_object
        """
        obb_new = self.clone()
        obb_new._transform(T_new_world)
        return obb_new

    def _transform_object(self, T_object_new):
        """
        in place transform T_world_object as T_world_new = T_world_object @ T_object_new
        """
        T_world_object = self.T_world_object
        T_world_new = T_world_object @ T_object_new
        self.set_T_world_object(T_world_new)

    def filter_by_sem_id(self, keep_sem_ids):
        valid = self._data.new_zeros(self.shape[:-1]).bool()
        for si in keep_sem_ids:
            valid = valid | (self.sem_id == si)[..., 0]
        self._data[~valid] = PAD_VAL
        return self

    def filter_by_prob(self, prob_thr: float):
        # since PAD_VAL is -1 this will work fine with padded entries
        invalid = self.prob.squeeze(-1) < prob_thr
        self._data[invalid] = PAD_VAL
        return self

    def filter_bb2_center_by_radius(self, calib, cam_id):
        """
        Inputs
            calib: CameraTW : shaped ... x 34, matching leading dims with self
            cam_id : int : integer corresponding to which bb2ds to use (0: rgb, 1: slaml, 2: slamr)
        """
        # Remove detections centers outside of valid_radius.
        centers = self.get_bb2_centers(cam_id)
        inside = calib.in_radius(centers)
        self._data[~inside, :] = PAD_VAL
        return self

    def voxel_grid(self, vD: int, vH: int, vW: int):
        """
        Input: Works on obbs shaped (B) x 34
        Output: world points sampled uniformly in a voxel grid (B) x vW*vH*vD x 3
        """
        x_min, x_max, y_min, y_max, z_min, z_max = self.bb3_object.unbind(-1)
        dW = (x_max - x_min) / vW
        dH = (y_max - y_min) / vH
        dD = (z_max - z_min) / vD
        # take the center position of each voxel
        rng_x = tensor_linspace(
            x_min + dW / 2, x_max - dW / 2, steps=vW, device=self.device
        )
        rng_y = tensor_linspace(
            y_min + dH / 2, y_max - dH / 2, steps=vH, device=self.device
        )
        rng_z = tensor_linspace(
            z_min + dD / 2, z_max - dD / 2, steps=vD, device=self.device
        )
        if self.ndim > 1:
            if self.ndim > 2:
                raise NotImplementedError
            B = self.shape[0]
            xs, ys, zs = [], [], []
            for b in range(B):
                xx, yy, zz = torch.meshgrid(rng_x[b], rng_y[b], rng_z[b], indexing="ij")
                xs.append(xx)
                ys.append(yy)
                zs.append(zz)
            xx = torch.stack(xs)
            yy = torch.stack(ys)
            zz = torch.stack(zs)
        else:
            xx, yy, zz = torch.meshgrid(rng_x, rng_y, rng_z, indexing="ij")
        vox_v = torch.stack([xx, yy, zz], axis=-1)
        vox_v = vox_v.reshape(B, -1, 3)
        # vox_v = vox_v.unsqueeze(0).repeat(B, 1, 1)
        T_wv = self.T_world_object
        vox_w = T_wv * vox_v
        return vox_w

    def __repr__(self):
        return f"ObbTW {self.shape} {self.dtype} {self.device}"


def _single_transform_obbs(obbs_padded, Ts_other_world):
    assert obbs_padded.ndim == 3  # T x N x C
    assert Ts_other_world.ndim == 2 and Ts_other_world.shape[0] == 1  # 1 x C
    T, N, C = obbs_padded.shape
    if T == 0:
        # Directly return the input since T=0 and there are no obbs to transform.
        return obbs_padded
    obbs_transformed = []
    for t in range(T):
        # clone so that we get a new transformed obbs object.
        obbs = obbs_padded[t, ...].remove_padding().clone()
        obbs._transform(Ts_other_world)
        obbs_transformed.append(obbs.add_padding(N))
    obbs_transformed = ObbTW(smart_stack(obbs_transformed))
    return obbs_transformed


def _batched_transform_obbs(obbs_padded, Ts_other_world):
    assert obbs_padded.ndim == 4  # B x T x N x C
    assert Ts_other_world.ndim == 3  # T x 1 x C
    B, T, N, C = obbs_padded.shape
    obbs_transformed = []
    for b in range(B):
        obbs_transformed.append(
            _single_transform_obbs(obbs_padded[b], Ts_other_world[b])
        )
    obbs_transformed = ObbTW(smart_stack(obbs_transformed))
    return obbs_transformed


def transform_obbs(obbs_padded, Ts_other_world):
    """
    transform padded obbs from the world coordinate system to a "other"
    coordinate system.
    """
    if obbs_padded.ndim == 4:
        return _batched_transform_obbs(obbs_padded, Ts_other_world)
    return _single_transform_obbs(obbs_padded, Ts_other_world)


def rot_obb2_cw(bb2: torch.Tensor, size: Tuple[int]):
    bb2_ori = bb2.clone()
    # exchange (xmin, xmax, ymin, ymax) -> (ymax, ymin, xmin, xmax)
    bb2 = bb2[..., [3, 2, 0, 1]]
    # x_new = height - x_new
    bb2[..., 0:2] = size[1] - bb2[..., 0:2] - 1
    # bring back the invalid entries.
    bb2[bb2_ori < 0] = bb2_ori[bb2_ori < 0]
    return bb2


def project_bb3d_onto_image(
    obbs: ObbTW, cam: CameraTW, T_world_rig: PoseTW, num_samples_per_edge: int = 1
):
    """
    project 3d bb edge points into snippet images defined by T_world_rig and
    camera cam. The assumption is that obbs are in the "world" coordinate system
    of T_world_rig.
    Supports batched operation.

    Args:
        obbs (ObbTW): obbs to project; shape is (Bx)(Tx)Nx34
        cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension;
        T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12
        num_samples_per_edge (int): how many points to sample per edge to
            compute 2d bb (1, and 2 means only corners)
    Returns:
        bb3_corners_im (Tensor): bb3 corners in the image coordinate system; shape is (Bx)TxNx8x2
        bb3_valids (Tensor): valid indices of bb3_corners_im (indicates which
            corners lie within the images); shape is (Bx)TxNx8
    """
    obb_dim = obbs.dim()
    # support 3 sets of input shapes
    if obb_dim == 2:  # Nx34
        # cam: TxC, T_world_rig: Tx12
        assert (
            cam.dim() == 2
            and T_world_rig.dim() == 2
            and cam.shape[0]
            == T_world_rig.shape[0]  # T dim should be the same for cam and T_world_rig
        ), (
            f"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}."
        )

        # To the consistent shapes
        obbs = obbs.unsqueeze(0).unsqueeze(0)  # expand to B(1)xT(1)xNx34
        cam = cam[None, ...]  # expand to B(1)xTxC
        T_world_rig = T_world_rig[None, ...]  # expand to B(1)xTx12
        B, T = cam.shape[0:2]
        N = obbs.shape[-2]
        obbs = obbs.expand(B, T, *obbs.shape[-2:])  # repeat to real T: B(1)xTxNx34

    elif obb_dim == 3:  # BxNx34
        # cam: BxTxC, T_world_rig: BxTx12
        assert cam.dim() == 3 and T_world_rig.dim() == 3
        # B dim should be the same
        assert obbs.shape[0] == cam.shape[0] and obbs.shape[0] == T_world_rig.shape[0]
        # T dim of cam and pose should be the same
        assert cam.shape[1] == T_world_rig.shape[1]

        # To the consistent shapes
        obbs = obbs.unsqueeze(1)  # expand to BxT(1)xNx34
        B, T = cam.shape[0:2]
        obbs = obbs.expand(B, T, *obbs.shape[-2:])

    elif obb_dim == 4:  # BxTxNx34
        pass
    else:
        raise ValueError(
            f"Unsupported input shapes: obb: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}."
        )

    # check if all tensors are of correct shapes.
    assert obbs.dim() == 4 and cam.dim() == 3 and T_world_rig.dim() == 3, (
        f"The shapes of obbs, cam and T_world_rig should be BxTxNx34, BxTxC, and BxTx12, respectively. However, we got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}"
    )
    assert (
        obbs.shape[0:2] == cam.shape[0:2] and obbs.shape[0:2] == T_world_rig.shape[0:2]
    ), (
        f"The BxT dims should be the same for all tensors, but got obbs: {obbs.shape}, cam: {cam.shape}, T_world_rig: {T_world_rig.shape}"
    )

    B, T = cam.shape[0:2]
    N = obbs.shape[-2]
    assert N > 0, "obbs have to exist for this frame"
    # Get pose of camera.
    T_world_cam = T_world_rig @ cam.T_camera_rig.inverse()
    # Project the 3D BB corners into the image.
    # BxTxNx8x3 -> BxTxN*8x3
    if num_samples_per_edge <= 2:
        bb3pts_world = obbs.bb3corners_world.view(B, T, -1, 3)
    else:
        bb3pts_object = obbs.bb3edge_pts_object(num_samples_per_edge)
        bb3pts_world = obbs.T_world_object * bb3pts_object
        bb3pts_world = bb3pts_world.view(B, T, -1, 3)
    Npt = bb3pts_world.shape[2]
    T_world_cam = T_world_cam.unsqueeze(2).repeat(1, 1, Npt, 1)
    bb3pts_cam = (
        T_world_cam.inverse()
        .view(-1, 12)
        .batch_transform(bb3pts_world.view(-1, 3))
        .view(B, T, -1, 3)
    )
    bb3pts_im, bb3pts_valids = cam.project(bb3pts_cam)
    bb3pts_im = bb3pts_im.view(B, T, N, -1, 2)
    bb3pts_valids = bb3pts_valids.detach().view(B, T, N, -1)

    if obb_dim == 2:
        # remove B dim if it didn't exist before.
        bb3pts_im = bb3pts_im.squeeze(0)
        bb3pts_valids = bb3pts_valids.squeeze(0)
    return bb3pts_im, bb3pts_valids


def bb2d_from_project_bb3d(
    obbs: ObbTW,
    cam: CameraTW,
    T_world_rig: PoseTW,
    num_samples_per_edge: int = 1,
    return_frac_valids: bool = False,
):
    """
    get 2d bbs around the 3d bb corners of obbs projected into the image coordinate system
    defined by T_world_rig and camera cam. The assumption is that obbs are in the
    "world" coordinate system of T_world_rig.

    This is done by sampling points on the 3d bb edges (see bb3edge_pts_object),
    projecting them and then computing the 2d bbs from the valid projected
    points.

    Supports batched operation.

    Args:
        obbs (ObbTW): obbs to project; shape is (Bx)Nx34
        cam (CameraTW): camera to project to; shape is (Bx)TxC where T is the snippet dimension;
        T_world_rig (PoseTW): T_world_rig defining where the camera rig is; shape is (Bx)Tx12
    Returns:
        bb2s (Tensor): 2d bounding boxes in the image coordinate system; shape is (Bx)TxNx4
        bb2s_valid (Tensor): valid indices of bb2s; shape is (Bx)TxN
    """
    from torchvision.ops.boxes import box_iou

    bb3corners_im, bb3corners_valids = project_bb3d_onto_image(
        obbs, cam, T_world_rig, num_samples_per_edge
    )
    # get image points that will min and max reduce correctly given the valid masks
    bb3corners_im_min = torch.where(
        bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im),
        bb3corners_im,
        999999 * torch.ones_like(bb3corners_im),
    )
    bb3corners_im_max = torch.where(
        bb3corners_valids.unsqueeze(-1).expand_as(bb3corners_im),
        bb3corners_im,
        -999999 * torch.ones_like(bb3corners_im),
    )
    # compute 2d bounding boxes
    bb2s_min = torch.min(bb3corners_im_min, dim=-2)[0]
    bb2s_max = torch.max(bb3corners_im_max, dim=-2)[0]
    bb2s = torch.stack(
        [bb2s_min[..., 0], bb2s_max[..., 0], bb2s_min[..., 1], bb2s_max[..., 1]], dim=-1
    )
    # min < max so that it's a valid box.
    non_empty_boxes = (bb2s[..., 0] < bb2s[..., 1]) & (bb2s[..., 2] < bb2s[..., 3])
    if cam.is_linear:
        bb2s_full = bb2s.clone()
        # Clamp based on the camera size for linear cameras.
        # Note that this could generate very big/loose bounding boxes if the object is badly truncated due to out of view.
        bb2s[..., 0:2] = torch.clamp(
            bb2s[..., 0:2], min=0, max=cam.size.view(-1, 2)[0, 0] - 1
        )
        bb2s[..., 2:4] = torch.clamp(
            bb2s[..., 2:4], min=0, max=cam.size.view(-1, 2)[0, 1] - 1
        )
        # filter out empty boxes.
        bb2s_valid = torch.logical_and(non_empty_boxes, bb3corners_valids.any(-1))
        if return_frac_valids:
            frac_valid = torch.zeros_like(bb2s_valid).float()
            frac_valid[non_empty_boxes] = box_iou(
                bb2_xxyy_to_xyxy(bb2s[non_empty_boxes]),
                bb2_xxyy_to_xyxy(bb2s_full[non_empty_boxes]),
            ).diagonal()
    else:
        # count number of valid points
        num_points = bb3corners_valids.count_nonzero(-1)
        # valid 2d bbs are non-empty and have at least 1/6 of the edge sample
        # points in the valid image region
        bb2s_valid = torch.logical_and(
            non_empty_boxes, num_points > num_samples_per_edge * 2
        )
        if return_frac_valids:
            frac_valid = num_points / bb3corners_valids.shape[-1]
            frac_valid[~non_empty_boxes] = 0.0
    if return_frac_valids:
        return bb2s, bb2s_valid, frac_valid
    return bb2s, bb2s_valid


def bb2_xxyy_to_xyxy(bb2s):
    # check if the input is xxyy
    is_xxyy = torch.logical_and(
        bb2s[..., 0] <= bb2s[..., 1], bb2s[..., 2] <= bb2s[..., 3]
    )
    is_xxyy = is_xxyy.all()
    if not is_xxyy:
        logger.warning("Input 2d bbx doesn't follow xxyy convention.")
    return bb2s[..., [0, 2, 1, 3]]


def bb2_xyxy_to_xxyy(bb2s):
    # check if the input is xxyy
    is_xyxy = torch.logical_and(
        bb2s[..., 0] <= bb2s[..., 2], bb2s[..., 1] <= bb2s[..., 3]
    )
    is_xyxy = is_xyxy.all()
    if not is_xyxy:
        logger.warning("Input 2d bbx doesn't follow xyxy convention.")
    return bb2s[..., [0, 2, 1, 3]]


def bb3_xyzxyz_to_xxyyzz(bb3s):
    """
    take bb3 in xyzxyz format and return xxyyzz format.
    """
    return bb3s[..., [0, 3, 1, 4, 2, 5]]


def bb3_xyz_xyz_to_xxyyzz(bb3s_min, bb3s_max):
    """
    take min and max points of the bb3 and return xxyyzz format
    """
    return torch.cat([bb3s_min, bb3s_max], -1)[..., [0, 3, 1, 4, 2, 5]]


def rnd_obbs(N: int = 1, num_semcls: int = 10, bb3_min_diag=0.1, bb2_min_diag=10):
    pts3_min = torch.randn(N, 3)
    pts3_max = pts3_min + bb3_min_diag + torch.randn(N, 3).abs()
    pts2_min = torch.randn(N, 2)
    pts2_max = pts2_min + bb2_min_diag + torch.randn(N, 2).abs()

    obb = ObbTW.from_lmc(
        bb3_object=bb3_xyzxyz_to_xxyyzz(torch.cat([pts3_min, pts3_max], -1)),
        prob=torch.ones(N),
        bb2_rgb=bb2_xyxy_to_xxyy(torch.cat([pts2_min, pts2_max], -1)),
        sem_id=torch.randint(low=0, high=num_semcls - 1, size=[N]),
        T_world_object=PoseTW.from_aa(torch.randn(N, 3), 10.0 * torch.randn(N, 3)),
    )
    return obb


def obb_time_union(obbs, pad_size=128):
    """
    Take frame level ground truth shaped BxTxNxC and take the union
    over the time dimensions using the instance id to extend to snippet level
    obbs shaped BxNxC.
    """
    # T already merged somewhere else.
    if obbs.ndim == 3:
        return obbs

    assert obbs.ndim == 4, "Only B x T x N x C supported"
    new_obbs = []
    for obb in obbs:
        new_obb = []
        flat_time_obb = obb.clone().reshape(-1, 34)
        unique = flat_time_obb.inst_id.unique()
        for uni in unique:
            if uni == PAD_VAL:
                continue
            found = int(torch.argwhere(flat_time_obb.inst_id == uni)[0, 0])
            found_obb = flat_time_obb[found].clone()
            new_obb.append(found_obb)
        if len(new_obb) == 0:
            print(f"Adding empty OBB in time_union {obbs.shape}")
            new_obb.append(ObbTW().reshape(-1).to(obbs._data))
        new_obbs.append(torch.stack(new_obb).add_padding(pad_size))
    new_obbs = torch.stack(new_obbs)
    # Remove all bb2 observations since we no longer know which frame in time it came from.
    # Note: we set the visibility for the merged obbs in order to do the evaluation on those theses obbs.
    pad_mask = new_obbs.get_padding_mask()
    new_obbs.set_bb2(cam_id=0, bb2d=1)
    new_obbs.set_bb2(cam_id=1, bb2d=1)
    new_obbs.set_bb2(cam_id=2, bb2d=1)
    new_obbs._data[pad_mask] = PAD_VAL
    return new_obbs


def obb_filter_outside_volume(obbs, T_ws, T_wv, voxel_extent, border=0.1):
    """
    Remove obbs outside a volume of size voxel_extent, e.g. from a lifter volume.
    Obbs are filtered based on their center point being inside the volume, and
    are additionally filtered near the border.
    """
    assert obbs.ndim == 3, "Only B x N x C supported"
    T_vs = T_wv.inverse() @ T_ws
    obbs_v = obbs.transform(T_vs.unsqueeze(1))
    centers_v = obbs_v.bb3_center_world
    cx = centers_v[:, :, 0]
    cy = centers_v[:, :, 1]
    cz = centers_v[:, :, 2]
    x_min = voxel_extent[0]
    x_max = voxel_extent[1]
    y_min = voxel_extent[2]
    y_max = voxel_extent[3]
    z_min = voxel_extent[4]
    z_max = voxel_extent[5]
    valid = (obbs_v.inst_id != PAD_VAL).squeeze(-1)
    inside = (
        (cx > (x_min + border))
        & (cy > (y_min + border))
        & (cz > (z_min + border))
        & (cx < (x_max - border))
        & (cy < (y_max - border))
        & (cz < (z_max - border))
    )
    remove = valid & ~inside
    obbs._data[remove, :] = PAD_VAL
    return obbs


def tensor_linspace(start, end, steps, device):
    """
    Vectorized version of torch.linspace.
    Inputs:
    - start: Tensor of any shape
    - end: Tensor of the same shape as start
    - steps: Integer
    Returns:
    - out: Tensor of shape start.size() + (steps,), such that
      out.select(-1, 0) == start, out.select(-1, -1) == end,
      and the other elements of out linearly interpolate between
      start and end.
    """
    assert start.size() == end.size()
    view_size = start.size() + (1,)
    w_size = (1,) * start.dim() + (steps,)
    out_size = start.size() + (steps,)

    start_w = torch.linspace(1, 0, steps=steps, device=device).to(start)
    start_w = start_w.view(w_size).expand(out_size)
    end_w = torch.linspace(0, 1, steps=steps, device=device).to(start)
    end_w = end_w.view(w_size).expand(out_size)

    start = start.contiguous().view(view_size).expand(out_size)
    end = end.contiguous().view(view_size).expand(out_size)

    out = start_w * start + end_w * end
    return out


def make_obb(sz, position, prob=1.0, roll=0.0, pitch=0.0, yaw=0.1):
    e_angles = torch.tensor([roll, pitch, yaw]).reshape(-1, 3)
    R = rotation_from_euler(e_angles).reshape(3, 3)
    T_voxel_object = PoseTW.from_Rt(R, torch.tensor(position))
    bb3 = [
        -sz[0] / 2.0,
        sz[0] / 2.0,
        -sz[1] / 2.0,
        sz[1] / 2.0,
        -sz[2] / 2.0,
        sz[2] / 2.0,
    ]
    return ObbTW.from_lmc(
        bb3_object=torch.tensor(bb3),
        prob=[prob],
        T_world_object=T_voxel_object,
    )


# =====> Main function for 3D IoU computation. <=======
def obb_iou3d(obb1: ObbTW, obb2: ObbTW, samp_per_dim=32):
    """
    Computes the intersection of two boxes by sampling points uniformly in
    x,y,z dims.

    samp_per_dim: int, number of samples per dimension, e.g. if 8, then 8x8x8
                       increase for more accuracy but less speed
                       8: fast but not so accurate
                       32: medium
                       128: most accurate but slow
    """
    assert obb1.ndim == 2
    assert obb2.ndim == 2

    B1 = obb1.shape[0]
    B2 = obb2.shape[0]
    vol1 = obb1.bb3_volumes
    vol2 = obb2.bb3_volumes

    dim = samp_per_dim
    points1_w = obb1.voxel_grid(vD=dim, vH=dim, vW=dim)
    points2_w = obb2.voxel_grid(vD=dim, vH=dim, vW=dim)
    num_samples = points1_w.shape[1]

    isin21 = is_point_inside_box(points2_w, obb1.bb3corners_world, verbose=True)
    num21 = isin21.sum(dim=-1)
    isin12 = is_point_inside_box(points1_w, obb2.bb3corners_world, verbose=True)
    num12 = isin12.sum(dim=-1)

    inters12 = vol1.view(B1, 1) * num12.view(B1, B2)
    inters21 = vol2.view(B2, 1) * num21.view(B2, B1)
    inters = (inters12 + inters21.transpose(1, 0)) / 2.0
    union = (vol1.view(B1, 1) * num_samples) + (vol2.view(1, B2) * num_samples) - inters
    iou = inters / union
    return iou


def is_point_inside_box(points: torch.Tensor, box: torch.Tensor, verbose=False):
    """
    Determines whether points are inside the boxes
    Args:
        points: tensor of shape (B1, P, 3) of the points
        box: tensor of shape (B2, 8, 3) of the corners of the boxes
    Returns:
        inside: bool tensor of whether point (row) is in box (col) shape (B1, B2, P)
    """
    device = box.device
    B1 = points.shape[0]
    B2 = box.shape[0]
    P = points.shape[1]

    normals = box_planar_dir(box)  # (B2, 6, 3)
    box_planes = get_plane_verts(box)  # (B2, 6, 4, 3)
    NP = box_planes.shape[1]  # = 6

    # a point p is inside the box if it "inside" all planes of the box
    # so we run the checks
    ins = torch.zeros((B1, B2, P, NP), device=device, dtype=torch.bool)
    # ins = []
    for i in range(NP):
        is_in = is_inside(points, box_planes[:, i], normals[:, i])
        ins[:, :, :, i] = is_in
        # ins.append(is_in)
    # ins = torch.stack(ins, dim=-1)

    ins = ins.all(dim=-1)
    return ins


def box_planar_dir(
    box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS
) -> torch.Tensor:
    """
    Finds the unit vector n which is perpendicular to each plane in the box
    and points towards the inside of the box.
    The planes are defined by `_box_planes`.
    Since the shape is convex, we define the interior to be the direction
    pointing to the center of the shape.
    Args:
       box: tensor of shape (B, 8, 3) of the vertices of the 3D box
    Returns:
       n: tensor of shape (B, 6) of the unit vector orthogonal to the face pointing
          towards the interior of the shape
    """
    assert box.shape[1] == 8 and box.shape[2] == 3
    # center point of each box
    box_ctr = box.mean(dim=1).view(-1, 1, 3)
    # box planes
    plane_verts = get_plane_verts(box)  # (B, 6, 4, 3)
    v0, v1, v2, v3 = plane_verts.unbind(2)
    plane_ctr, n = get_plane_center_normal(plane_verts)
    # Check all verts are coplanar
    normv = F.normalize(v3 - v0, dim=-1).unsqueeze(2).reshape(-1, 1, 3)
    nn = n.unsqueeze(3).reshape(-1, 3, 1)
    dists = normv @ nn
    if not (dists.abs() < dot_eps).all().item():
        msg = "Plane vertices are not coplanar"
        raise ValueError(msg)
    # Check all faces have non zero area
    area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
    area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
    if (area1 < area_eps).any().item() or (area2 < area_eps).any().item():
        msg = "Planes have zero areas"
        raise ValueError(msg)
    # We can write:  `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1).
    # With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
    # since that e0 is orthogonal to n. Same for e1.
    """
    # Below is how one would solve for (a, b, c)
    # Solving for (a, b)
    numF = verts.shape[0]
    A = torch.ones((numF, 2, 2), dtype=torch.float32, device=device)
    B = torch.ones((numF, 2), dtype=torch.float32, device=device)
    A[:, 0, 1] = (e0 * e1).sum(-1)
    A[:, 1, 0] = (e0 * e1).sum(-1)
    B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1)
    B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1)
    ab = torch.linalg.solve(A, B)  # (numF, 2)
    a, b = ab.unbind(1)
    # solving for c
    c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1)
    """
    # Since we know that <e0, n> = 0 and <e1, n> = 0 (e0 and e1 are orthogonal to n),
    # the above solution is equivalent to
    direc = F.normalize(box_ctr - plane_ctr, dim=-1)  # (6, 3)
    c = (direc * n).sum(-1)
    # If c is negative, then we revert the direction of n such that n points "inside"
    negc = c < 0.0
    n[negc] *= -1.0
    # c[negc] *= -1.0
    # Now (a, b, c) is the solution to (1)
    return n


def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
    """
    Return the vertex coordinates forming the planes of the box.
    The computation here resembles the Meshes data structure.
    But since we only want this tiny functionality, we abstract it out.
    Args:
        box: tensor of shape (B, 8, 3)
    Returns:
        plane_verts: tensor of shape (B, 6, 4, 3)
    """
    device = box.device
    B = box.shape[0]
    faces = torch.tensor(_box_planes, device=device, dtype=torch.int64)  # (6, 4)
    plane_verts = torch.stack([box[b, faces] for b in range(B)])  # (B, 6, 4, 3)
    return plane_verts


def is_inside(
    points: torch.Tensor,
    plane: torch.Tensor,
    normal: torch.Tensor,
    return_proj: bool = True,
):
    """
    Computes whether point is "inside" the plane.
    The definition of "inside" means that the point
    has a positive component in the direction of the plane normal defined by n.
    For example,
                  plane
                    |
                    |         . (A)
                    |--> n
                    |
         .(B)       |

    Point (A) is "inside" the plane, while point (B) is "outside" the plane.
    Args:
      points: tensor of shape (B1, P, 3) of coordinates of a point
      plane: tensor of shape (B2, 4, 3) of vertices of a box plane
      normal: tensor of shape (B2, 3) of the unit "inside" direction on the plane
      return_proj: bool whether to return the projected point on the plane
    Returns:
      is_inside: bool of shape (B2, P) of whether point is inside
    """
    device = plane.device
    assert plane.ndim == 3
    assert normal.ndim == 2
    assert points.ndim == 3
    assert points.shape[2] == 3
    B1 = points.shape[0]
    B2 = plane.shape[0]
    P = points.shape[1]
    v0, v1, v2, v3 = plane.unbind(dim=1)
    plane_ctr = plane.mean(dim=1)
    e0 = F.normalize(v0 - plane_ctr, dim=1)
    e1 = F.normalize(v1 - plane_ctr, dim=1)

    dot1 = (e0.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2)
    if not torch.allclose(dot1, torch.zeros((B2,), device=device), atol=1e-2):
        raise ValueError("Input n is not perpendicular to the plane")
    dot2 = (e1.unsqueeze(1) @ normal.unsqueeze(2)).reshape(B2)
    if not torch.allclose(dot2, torch.zeros((B2,), device=device), atol=1e-2):
        raise ValueError("Input n is not perpendicular to the plane")

    # Every point p can be written as p = ctr + a e0 + b e1 + c n
    # solving for c
    # c = (point - ctr - a * e0 - b * e1).dot(n)
    pts = points.view(B1, 1, P, 3)
    ctr = plane_ctr.view(1, B2, 1, 3)
    e0 = e0.view(1, B2, 1, 3)
    e1 = e1.view(1, B2, 1, 3)
    normal = normal.view(1, B2, 1, 3)

    direc = torch.sum((pts - ctr) * normal, dim=-1)
    ins = direc >= 0.0
    return ins


def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor:
    """
    Returns the center and normal of planes
    Args:
        planes: tensor of shape (B, P, 4, 3)
    Returns:
        center: tensor of shape (B, P, 3)
        normal: tensor of shape (B, P, 3)
    """
    B = planes.shape[0]

    add_dim1 = False
    if planes.ndim == 3:
        planes = planes.unsqueeze(1)
        add_dim1 = True

    ctr = planes.mean(dim=2)  # (B, P, 3)
    normals = torch.zeros_like(ctr)

    v0, v1, v2, v3 = planes.unbind(dim=2)  # 4 x (B, P, 3)

    P = planes.shape[1]
    for t in range(P):
        ns = torch.zeros((B, 6, 3), device=planes.device)
        ns[:, 0] = torch.cross(v0[:, t] - ctr[:, t], v1[:, t] - ctr[:, t], dim=-1)
        ns[:, 1] = torch.cross(v0[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1)
        ns[:, 2] = torch.cross(v0[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)
        ns[:, 3] = torch.cross(v1[:, t] - ctr[:, t], v2[:, t] - ctr[:, t], dim=-1)
        ns[:, 4] = torch.cross(v1[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)
        ns[:, 5] = torch.cross(v2[:, t] - ctr[:, t], v3[:, t] - ctr[:, t], dim=-1)
        ii = torch.argmax(torch.norm(ns, dim=-1), dim=-1)
        normals[:, t] = ns[torch.arange(B), ii]

    if add_dim1:
        ctr = ctr[:, 0]
        normals = normals[:, 0]
    normals = F.normalize(normals, dim=-1)
    return ctr, normals


================================================
FILE: efm3d/aria/pose.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import math
from typing import Dict, List, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F

from .tensor_wrapper import autocast, autoinit, smart_stack, TensorWrapper

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# logger.setLevel(logging.DEBUG)

IdentityPose = torch.tensor(
    [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
).reshape(12)

PAD_VAL = -1


def get_T_rot_z(angle: float):
    T_rot_z = np.array(
        [
            [np.cos(angle), -np.sin(angle), 0.0, 0.0],
            [np.sin(angle), np.cos(angle), 0.0, 0.0],
            [0.0, 0.0, 1.0, 0.0],
        ]
    )
    return torch.from_numpy(T_rot_z).float()


def skew_symmetric(v):
    """Create a skew-symmetric matrix from a (batched) vector of size (..., 3)."""
    z = torch.zeros_like(v[..., 0])
    M = torch.stack(
        [
            z,
            -v[..., 2],
            v[..., 1],
            v[..., 2],
            z,
            -v[..., 0],
            -v[..., 1],
            v[..., 0],
            z,
        ],
        dim=-1,
    ).reshape(v.shape[:-1] + (3, 3))
    return M


def inv_skew_symmetric(V):
    """Create a (batched) vector from a skew-symmetric matrix of size (..., 3, 3)."""
    # average lower and uper triangular entries in case skew symmetric matrix
    # has numeric errors.
    VVT = 0.5 * (V - V.transpose(-2, -1))
    return torch.stack(
        [
            -VVT[..., 1, 2],
            VVT[..., 0, 2],
            -VVT[..., 0, 1],
        ],
        -1,
    )


def so3exp_map(w, eps: float = 1e-7):
    """Compute rotation matrices from batched twists.
    Args:
        w: batched 3D axis-angle vectors of size (..., 3).
    Returns:
        A batch of rotation matrices of size (..., 3, 3).
    """
    theta = w.norm(p=2, dim=-1, keepdim=True)
    small = theta < eps
    div = torch.where(small, torch.ones_like(theta), theta)
    W = skew_symmetric(w / div)
    theta = theta[..., None]  # ... x 1 x 1
    res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta))
    res = torch.where(small[..., None], W, res)  # first-order Taylor approx
    return torch.eye(3).to(W) + res


def so3log_map(R, eps: float = 1e-7):
    trace = torch.diagonal(R, dim1=-1, dim2=-2).sum(-1)
    cos = torch.clamp((trace - 1.0) * 0.5, -1, 1)
    theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1)
    ones = torch.ones_like(theta)
    small = theta < eps
    # compute factors and approximate them around 0 using second order
    # taylor expansion (from WolframAlpha)
    theta_over_sin_theta = torch.where(
        small,
        ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0,
        theta / torch.sin(theta),
    )
    # compute log-map W of rotation R first
    W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2))
    omega = inv_skew_symmetric(W)
    return omega


def interpolation_boundaries_alphas(times: torch.Tensor, interp_times: torch.Tensor):
    """
    find the ids in times tensor that bound each of the interp_times timestamps
    from below (lower_ids) and above (upper_ids).
    If interp_times are outside the interval spanned by times, upper and lower
    ids will both point to the boundary timestamps and the returned good boolean
    tensor will be False at those interpolation timestamps.

    Also return the alphas needed to interpolate a value as:
    interp_value = alpha * value[lower_id] + (1-alpha)* value[upper_id]

    Note that because the upper and lower ids are pointing to the boundary
    timestamps when the interpolation time is outside the time interval,
    applying the formula above will yield the values at the boundaries as a
    reasonable "interpolation". No extrapolation will be performed. Again the
    good values can be used to check which values are at the boundaries and
    which ones are interpolated.
    """
    times = times.unsqueeze(-2)
    interp_times = interp_times.unsqueeze(-1)
    dt = times - interp_times
    if dt.dtype == torch.long:
        dt_max = torch.iinfo(type=dt.dtype).max
    else:
        dt_max = torch.finfo(type=dt.dtype).max
    dt_upper = torch.where(dt < 0.0, torch.ones_like(dt) * dt_max, dt)
    dt_lower = torch.where(dt > 0.0, torch.ones_like(dt) * dt_max, -dt)
    upper_alpha, upper_ids = torch.min(dt_upper, dim=-1)
    lower_alpha, lower_ids = torch.min(dt_lower, dim=-1)
    good = torch.logical_and(lower_alpha < dt_max, upper_alpha < dt_max)
    upper_ids = torch.where(good, upper_ids, torch.maximum(lower_ids, upper_ids))
    lower_ids = torch.where(good, lower_ids, torch.minimum(lower_ids, upper_ids))
    assert (lower_ids <= upper_ids).all()
    # nan_to_num handles the case where time and interpolation time are the same
    # and hence this is a 0/0
    # okay to go to floats now since the critical bit is the computation of the time difference
    alpha = torch.nan_to_num(lower_alpha.float() / (lower_alpha + upper_alpha).float())
    alpha = torch.where(good, alpha, torch.zeros_like(alpha))
    return lower_ids, upper_ids, alpha, good


def quaternion_to_matrix(quaternions_wxyz: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as quaternions to rotation matrices. Input quaternions
    should be in wxyz format, with real part first, imaginary part last.

    The function is copied from `quaternion_to_matrix` in Pytorch3d:
    https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py

    Args:
        quaternions: quaternions with real part first,
            as tensor of shape (..., 4).

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    r, i, j, k = torch.unbind(quaternions_wxyz, -1)
    two_s = 2.0 / (quaternions_wxyz * quaternions_wxyz).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions_wxyz.shape[:-1] + (3, 3))


class PoseTW(TensorWrapper):
    @autocast
    @autoinit
    def __init__(self, data: torch.Tensor = IdentityPose):
        assert isinstance(data, torch.Tensor)
        assert data.shape[-1] == 12
        super().__init__(data)

    @classmethod
    @autocast
    def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
        """Pose from a rotation matrix and translation vector.
        Accepts numpy arrays or PyTorch tensors.

        Args:
            R: rotation matrix with shape (..., 3, 3).
            t: translation vector with shape (..., 3).
        """
        assert R.shape[-2:] == (3, 3)
        assert t.shape[-1] == 3
        assert R.shape[:-2] == t.shape[:-1]
        data = torch.cat([R.flatten(start_dim=-2), t], -1)
        return cls(data)

    @classmethod
    @autocast
    def from_qt(cls, quaternion_wxyz: torch.Tensor, t: torch.Tensor):
        """Pose from quaternion and translation vectors. Quaternion should
        be wxyz format, with real part first, and imaginary part last.

        Args:
            quaternion: quaternion with shape (..., 4).
            t: translation vector with shape (..., 3).
        """
        assert quaternion_wxyz.shape[:-1] == t.shape[:-1], (
            f"quaternion shape {quaternion_wxyz.shape[:-1]} must match translation shape {t.shape[:-1]} expect the last dim"
        )
        assert quaternion_wxyz.shape[-1] == 4, "quaternion must be of shape (..., 4)"
        assert t.shape[-1] == 3, "translation must be of shape (..., 3)"

        R = quaternion_to_matrix(quaternion_wxyz)
        data = torch.cat([R.flatten(start_dim=-2), t], -1)
        return cls(data)

    @classmethod
    @autocast
    def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):
        """Pose from an axis-angle rotation vector and translation vector.
        Accepts numpy arrays or PyTorch tensors.

        Args:
            aa: axis-angle rotation vector with shape (..., 3).
            t: translation vector with shape (..., 3).
        """
        assert aa.shape[-1] == 3
        assert t.shape[-1] == 3
        assert aa.shape[:-1] == t.shape[:-1]
        return cls.from_Rt(so3exp_map(aa), t)

    @classmethod
    @autocast
    def from_matrix(cls, T: torch.Tensor):
        """Pose from an SE(3) transformation matrix.
        Args:
            T: transformation matrix with shape (..., 4, 4).
        """
        assert T.shape[-2:] == (4, 4)
        R, t = T[..., :3, :3], T[..., :3, 3]
        return cls.from_Rt(R, t)

    @classmethod
    @autocast
    def from_matrix3x4(cls, T_3x4: torch.Tensor):
        """Pose from an SE(3) transformation matrix.
        Args:
            T: transformation matrix with shape (..., 3, 4).
        """
        assert T_3x4.shape[-2:] == (3, 4)
        R, t = T_3x4[..., :3, :3], T_3x4[..., :3, 3]
        return cls.from_Rt(R, t)

    @classmethod
    @autocast
    def exp(cls, u_omega: torch.Tensor, eps: float = 1e-7):
        """
        Compute the SE3 exponential map from input se3 vectors u_omega [....,6] where
        the last 3 entries are the so3 entires omega and the first 3 the entries
        for translation.
        """
        # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf
        u = u_omega[..., :3]
        omega = u_omega[..., 3:]
        theta = omega.norm(p=2, dim=-1, keepdim=True).unsqueeze(-1)
        small = theta < eps
        R = so3exp_map(omega, eps)
        # compute V
        shape = [1] * len(omega.shape[:-1])
        ones = torch.ones_like(theta)
        # compute factors and approximate them around 0 using second order
        # taylor expansion (from WolframAlpha)
        b = torch.where(
            small,
            0.5 * ones - theta**2 / 24.0 + theta**4 / 720.0,
            (ones - torch.cos(theta)) / theta**2,
        )
        c = torch.where(
            small,
            1.0 / 6.0 * ones - theta**2 / 120.0 + theta**4 / 5040.0,
            (theta - torch.sin(theta)) / theta**3,
        )
        Identity = (
            torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(u_omega)
        )
        W = skew_symmetric(omega)
        V = Identity + b * W + c * W @ W
        # compute t
        t = (V @ u.unsqueeze(-1)).squeeze(-1)
        return cls.from_Rt(R, t)

    # @classmethod
    # def from_colmap(cls, image: NamedTuple):
    #    '''Pose from a COLMAP Image.'''
    #    return cls.from_Rt(image.qvec2rotmat(), image.tvec)

    @property
    def R(self) -> torch.Tensor:
        """Underlying rotation matrix with shape (..., 3, 3)."""
        rvec = self._data[..., :9]
        return rvec.reshape(rvec.shape[:-1] + (3, 3))

    @property
    def t(self) -> torch.Tensor:
        """Underlying translation vector with shape (..., 3)."""
        return self._data[..., -3:]

    @property
    def q(self) -> torch.Tensor:
        """
        Convert rotations of shape (..., 3, 3) to a quaternion (..., 4).
        The returned quaternions have real part first, as wxyz.
        The function is adapted from `matrix_to_quaternion` in Pytorch3d:
        https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py

        The major difference to the original pytorch3d function is that the returned
        quaternions are normalized and have positive real part.
        """

        def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
            """
            Returns torch.sqrt(torch.max(0, x))
            but with a zero subgradient where x is 0.
            """
            ret = torch.zeros_like(x)
            positive_mask = x > 0
            ret[positive_mask] = torch.sqrt(x[positive_mask])
            return ret

        matrix = self.R
        batch_dim = matrix.shape[:-2]
        m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
            matrix.reshape(batch_dim + (9,)), dim=-1
        )

        q_abs = _sqrt_positive_part(
            torch.stack(
                [
                    1.0 + m00 + m11 + m22,
                    1.0 + m00 - m11 - m22,
                    1.0 - m00 + m11 - m22,
                    1.0 - m00 - m11 + m22,
                ],
                dim=-1,
            )
        )

        # we produce the desired quaternion multiplied by each of r, i, j, k
        quat_by_wxyz = torch.stack(
            [
                torch.stack(
                    [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1
                ),
                torch.stack(
                    [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1
                ),
                torch.stack(
                    [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1
                ),
                torch.stack(
                    [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1
                ),
            ],
            dim=-2,
        )

        # We floor here at 0.1 but the exact level is not important; if q_abs is small,
        # the candidate won't be picked.
        flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
        quat_candidates = quat_by_wxyz / (2.0 * q_abs[..., None].max(flr))

        # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
        # forall i; we pick the best-conditioned one (with the largest denominator)
        best_quat = quat_candidates[
            F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
        ].reshape(batch_dim + (4,))

        # normalize quaternions and make the real part to be positive for all quaternions
        best_quat = best_quat.reshape(-1, 4)
        neg_ind = torch.nonzero(best_quat[:, 0] < 0).squeeze()
        best_quat[neg_ind, :] *= -1
        best_quat = best_quat.reshape(batch_dim + (4,))
        best_quat_normalized = F.normalize(best_quat, p=2, dim=-1)
        return best_quat_normalized

    @property
    def q_xyzw(self) -> torch.Tensor:
        """
        Get the quaternion representation similar to self.q, but the real part
        of the quaternion comes last rather than first. This is a handy function to increase
        interoperability, e.g. lietorch requires xyzw quaternions.
        """
        quat_wxyz = self.q
        return torch.concat([quat_wxyz[..., 1:4], quat_wxyz[..., 0:1]], dim=-1)

    @property
    def matrix3x4(self) -> torch.Tensor:
        """Underlying transformation matrix with shape (..., 3, 4)."""
        rvec = self._data[..., :9]
        rmat = rvec.reshape(rvec.shape[:-1] + (3, 3))
        tvec = self._data[..., -3:].unsqueeze(-1)
        T = torch.cat([rmat, tvec], dim=-1)
        return T

    @property
    def matrix(self) -> torch.Tensor:
        """Underlying transformation matrix with shape (..., 4, 4)."""
        T_3x4 = self.matrix3x4
        bot_row = T_3x4.new_zeros(T_3x4.shape[:-2] + (1, 4))
        bot_row[..., 0, 3] = 1
        return torch.cat([T_3x4, bot_row], dim=-2)

    def to_euler(self, rad=True) -> torch.Tensor:
        """Convert the rotation matrix to Euler angles using ZYX convention."""
        """Reference: http://eecs.qmul.ac.uk/~gslabaugh/publications/euler.pdf"""
        # Test gimbal lock (ignore rotations that are all PAD_VAL).
        is_pad = torch.all(torch.all(self.R == PAD_VAL, dim=-1), dim=-1)
        assert (~torch.abs(self.R[~is_pad][..., 2, 0]).isclose(torch.tensor(1.0))).all()
        Y_angle = -torch.asin(self.R[..., 2, 0])
        euler_angles = (
            torch.atan2(self.R[..., 2, 1], self.R[..., 2, 2]),
            Y_angle,
            torch.atan2(self.R[..., 1, 0], self.R[..., 0, 0]),
        )
        if not rad:
            # return degree
            return torch.stack(euler_angles, -1) * 180.0 / torch.pi
        return torch.stack(euler_angles, -1)

    def to_ypr(self, rad=True) -> torch.Tensor:
        # yaw, pitch, roll from rotation matrix: http://lavalle.pl/planning/node103.html
        R = self.R
        yaw = torch.atan(R[..., 1, 0] / R[..., 0, 0])
        pitch = torch.atan(
            -R[..., 2, 0]
            / torch.sqrt(R[..., 2, 1] * R[..., 2, 1] + R[..., 2, 2] * R[..., 2, 2])
        )
        roll = torch.atan(R[..., 2, 1] / R[..., 2, 2])
        return yaw, pitch, roll

    def inverse(self) -> "PoseTW":
        """Invert an SE(3) pose."""
        R = self.R.transpose(-1, -2)
        t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
        return self.__class__.from_Rt(R, t)

    def compose(self, other: "PoseTW") -> "PoseTW":
        """Chain two SE(3) poses: T_C_B.compose(T_B_A) -> T_C_A."""
        R = self.R @ other.R
        t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
        return self.__class__.from_Rt(R, t)

    @autocast
    def transform(self, p3d: torch.Tensor) -> torch.Tensor:
        """Transform a set of 3D points.
        Args:
            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
        """
        assert p3d.shape[-1] == 3
        # use more efficient right multiply that avoids transpose of the points
        # according to the equality:
        # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T
        # where p^T = p3d, R = self.R and t = self.t
        return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)

    @autocast
    def batch_transform(self, p3d: torch.Tensor) -> torch.Tensor:
        """Transform a set of 3D points each by the associated (in batch
        dimensions) transform in this PoseTW.
        Args:
            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
        """
        assert p3d.shape == self.t.shape, f"shapes of p3d {p3d.shape}, t {self.t.shape}"
        # bmm assumes one batch dimension
        assert p3d.dim() == 2, f"{p3d.shape}"
        assert self.ndim == 2, f"{self.shape}"
        # use more efficient right multiply that avoids transpose of the points
        # according to the equality:
        # (Rp + t)^T = (Rp)^T + t^T = p^T R^T + t^T
        # where p^T = p3d, R = self.R and t = self.t
        return (
            torch.bmm(p3d.unsqueeze(-2), self.R.transpose(-1, -2)).squeeze(-2) + self.t
        )

    @autocast
    def rotate(self, p3d: torch.Tensor) -> torch.Tensor:
        """Rotate a set of 3D points. Useful for directional vectors which should not be translated.
        Args:
            p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
        """
        assert p3d.shape[-1] == 3
        # use more efficient right multiply that avoids transpose of the points
        # according to the equality:
        # (Rp)^T = p^T R^T where p3d = p^T and self.R = R
        return p3d @ self.R.transpose(-1, -2)

    def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:
        """Transform a set of 3D points: T_B_A * p3D_A -> p3D_B"""
        return self.transform(p3D)

    def __matmul__(self, other: "PoseTW") -> "PoseTW":
        """Chain two SE(3) poses: T_C_B @ T_B_A -> T_C_A."""
        return self.compose(other)

    def numpy(self) -> Tuple[np.ndarray]:
        return self.R.numpy(), self.t.numpy()

    def magnitude(self, deg=True, eps=0) -> Tuple[torch.Tensor]:
        """Magnitude of the SE(3) transformation. The `eps` has to be
        positive if you want to use this function as part of a training loop.

        Returns:
            dr: rotation angle in degrees (if deg=True) or in radians.
            dt: translation distance in meters.
        """
        trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
        cos = torch.clamp((trace - 1) / 2, min=-1.0 + eps, max=1.0 - eps)
        dr = torch.acos(cos)
        if deg:
            dr = dr * 180.0 / math.pi
        dt = torch.norm(self.t, dim=-1)
        return dr, dt

    def so3_geodesic(self, other: "PoseTW", deg=False) -> "PoseTW":
        """Compute the geodesic distance for rotation between this pose and another pose"""
        pose_e = self.compose(other.inverse())
        dr, _ = pose_e.magnitude(deg=deg, eps=1e-6)
        return dr

    def log(self, eps: float = 1e-6) -> torch.Tensor:
        """
        Compute the SE3 log map for these poses.
        Returns [...,6] where the last 3 entries are the so3 entires omega and
        the first 3 the entries for translation.
        """
        # following https://www.ethaneade.com/lie.pdf and http://people.csail.mit.edu/jstraub/download/straubTransformationCookbook.pdf
        R, t = self.R, self.t
        trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
        cos = torch.clamp((trace - 1.0) * 0.5, -1, 1)
        theta = torch.acos(cos).unsqueeze(-1).unsqueeze(-1)
        ones = torch.ones_like(theta)
        small = theta < eps
        # compute factors and approximate them around 0 using second order
        # taylor expansion (from WolframAlpha)
        theta_over_sin_theta = torch.where(
            small,
            ones - (theta**2) / 6.0 + 7.0 * (theta**4) / 360.0,
            theta / torch.sin(theta),
        )
        c = torch.where(
            small,
            0.08333333 + 0.001388889 * theta**2 + 0.0000330688 * theta**4,
            (ones - ((0.5 * theta * torch.sin(theta)) / (ones - torch.cos(theta))))
            / theta**2,
        )
        # compute log-map W of rotation R first
        W = 0.5 * theta_over_sin_theta * (R - R.transpose(-1, -2))
        # compute V_inv to be able to get u
        shape = [1] * len(R.shape[:-2])
        Identity = (
            torch.eye(3).reshape(shape + [3, 3]).repeat(shape + [1, 1]).to(self._data)
        )
        V_inv = Identity - 0.5 * W + c * W @ W
        u = (V_inv @ t.unsqueeze(-1)).squeeze(-1)
        omega = inv_skew_symmetric(W)
        return torch.cat([u, omega], -1)

    def interpolate(self, times: torch.Tensor, interp_times: torch.Tensor):
        """
        Return poses at the given interpolation times interp_times based on the
        poses in this object and the provided associated timestamps times.

        If interpolation timestamps are outside the interval of times, the poses
        at the interval boundaries will be returned and the good boolean tensor
        will indicate those boundary values with a False.
        """
        assert times.shape == self._data.shape[:-1], (
            f"time stamps for the poses do not match poses shape {times.shape} vs {self._data.shape}"
        )

        assert times.dim() <= 2, (
            "The shape of the input times should be either BxT or T."
        )
        times = times.to(self.device)
        interp_times = interp_times.to(self.device)
        # find the closest timestamps above and below for each interp_times in times
        lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas(
            times, interp_times
        )
        # get the bounding poses
        upper_ids = upper_ids.unsqueeze(-1)
        upper_ids = upper_ids.expand(*upper_ids.shape[0:-1], self._data.shape[-1])
        lower_ids = lower_ids.unsqueeze(-1)
        lower_ids = lower_ids.expand(*lower_ids.shape[0:-1], self._data.shape[-1])
        T_upper = self.__class__(self._data.gather(times.dim() - 1, upper_ids))
        T_lower = self.__class__(self._data.gather(times.dim() - 1, lower_ids))
        # get se3 element connecting the lower and upper poses
        dT = T_lower.inverse() @ T_upper
        dx = dT.log()
        # interpolate on se3
        dT = self.exp(dx * alpha.unsqueeze(-1))
        return T_lower @ dT, good

    def align(self, other, self_times=None, other_times=None):
        """Align two trajectories using the method of Horn (closed-form).

        Input:
            other -- second PoseTW (Nx12) trajectory to align to

        Output:
            T_self_other -- relative SE3 transform (Nx12)
            trans_error -- translational error per point (Nx1)

        code inspired by: https://github.com/symao/vio_evaluation/blob/master/align.py#L6-L38
        """
        if self.t.ndim != 2:
            raise ValueError(
                "Only Nx12 Pose supported in alignment, given {self.shape}"
            )
        if other.t.ndim != 2:
            raise ValueError(
                "Only Nx12 Pose supported in alignment, given {other.shape}"
            )
        dtype = torch.promote_types(self.dtype, other.dtype)

        # Optionally interpolate other to match the size of self.
        if self.shape[0] != other.shape[0]:
            if self_times is None or other_times is None:
                raise ValueError(
                    "Got different length PoseTW (self {self.shape} and other {other.shape}). Must provide timestamps to support interpolation"
                )
            # Do interpolation on temporal intersection.
            other, goods = other.interpolate(other_times, self_times)
            self2 = self.clone()[goods].to(dtype)
            other2 = other.clone()[goods].to(dtype)
        else:
            self2 = self.clone().to(dtype)
            other2 = other.clone().to(dtype)

        P = self2.t.transpose(0, 1)
        Q = other2.t.transpose(0, 1)

        if P.shape != Q.shape:
            raise ValueError("Matrices P and Q must be of the same dimensionality")

        centroids_P = torch.mean(P, dim=1)
        centroids_Q = torch.mean(Q, dim=1)
        A = P - torch.outer(centroids_P, torch.ones(P.shape[1], dtype=dtype))
        B = Q - torch.outer(centroids_Q, torch.ones(Q.shape[1], dtype=dtype))
        C = A @ B.transpose(0, 1)
        U, S, V = torch.linalg.svd(C)
        R = V.transpose(0, 1) @ U.transpose(0, 1)
        L = torch.eye(3, dtype=dtype)
        if torch.linalg.det(R) < 0:
            L[2][2] *= -1

        R = V.transpose(0, 1) @ (L @ U.transpose(0, 1))
        t = (-R @ centroids_P) + centroids_Q
        T_self_other = PoseTW.from_Rt(R, t).inverse().to(dtype)

        other_aligned = T_self_other @ other2

        error = torch.linalg.norm(other_aligned.t - self2.t, dim=-2)
        mean_error = error.mean(dim=-1)

        return T_self_other, mean_error

    def fit_to_SO3(self):
        # Math used from quora post and this berkeley pdf.
        # https://qr.ae/pKQaG5
        # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf
        assert self._data.ndim == 1
        Q = fit_to_SO3(self.R)
        return PoseTW.from_Rt(Q, self.t)

    def __repr__(self):
        return f"PoseTW: {self.shape} {self.dtype} {self.device}"


def interpolate_timed_poses(
    timed_poses: Dict[
        Union[float, int],
        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],
    ],
    time: Union[float, int],
):
    """
    interpolate timed poses given as a dict[time:container[PoseTW]] to given
    time.  The poses container indexed by time can be given plain as poses, or a
    list or dict of poses.  If a list or dict of poses is given, the output will
    also be a list or dict of the interpolated poses. This allows batched
    interpolation.
    """
    ts_list = list(timed_poses.keys())
    ts = torch.from_numpy(np.array(ts_list))
    interp_time = torch.from_numpy(np.array([time]))
    lower_ids, upper_ids, _, _ = interpolation_boundaries_alphas(ts, interp_time)
    t_lower = ts_list[lower_ids[0]]
    t_upper = ts_list[upper_ids[0]]
    poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper]
    poses_interp = None
    times = torch.from_numpy(np.array([t_lower, t_upper])).float()
    if isinstance(poses_lower, PoseTW):
        poses = PoseTW(smart_stack([poses_lower, poses_upper]))
        if poses.dim() == 3:
            times = times.unsqueeze(-1).repeat(1, poses.shape[1])
        poses_interp = poses.interpolate(times, interp_time)[0].squeeze()
    elif isinstance(poses_lower, dict):
        keys_lower = set(poses_lower.keys())
        keys_upper = set(poses_upper.keys())
        keys = keys_lower & keys_upper
        poses_interp = {}
        for key in keys:
            poses = PoseTW(smart_stack([poses_lower[key], poses_upper[key]]))
            if poses.dim() == 3 and times.dim() == 1:
                times = times.unsqueeze(-1).repeat(1, poses.shape[1])
            poses_interp[key] = poses.interpolate(times, interp_time)[0].squeeze()
    elif isinstance(poses_lower, list):
        assert len(poses_lower) == len(poses_upper)
        poses_interp = []
        for i in range(len(poses_lower)):
            poses = PoseTW(smart_stack([poses_lower[i], poses_upper[i]]))
            if poses.dim() == 3 and times.dim() == 1:
                times = times.unsqueeze(-1).repeat(1, poses.shape[1])
            poses_interp.append(poses.interpolate(times, interp_time)[0].squeeze())
    return poses_interp


def lower_timed_poses(
    timed_poses: Dict[
        Union[float, int],
        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],
    ],
    time: Union[float, int],
):
    """
    interpolate timed poses given as a dict[time:container[PoseTW]] to given
    time.  The poses container indexed by time can be given plain as poses, or a
    list or dict of poses.  If a list or dict of poses is given, the output will
    also be a list or dict of the interpolated poses. This allows batched
    interpolation.
    """
    ts_list = list(timed_poses.keys())
    ts = torch.from_numpy(np.array(ts_list))
    interp_time = torch.from_numpy(np.array([time]))
    lower_ids, _, alpha, good = interpolation_boundaries_alphas(ts, interp_time)
    t_lower = ts_list[lower_ids[0]]
    poses_lower = timed_poses[t_lower]
    return poses_lower, t_lower - time


def closest_timed_poses(
    timed_poses: Dict[
        Union[float, int],
        Union[PoseTW, List[PoseTW], Dict[Union[float, int, str], PoseTW]],
    ],
    time: Union[float, int],
):
    """
    interpolate timed poses given as a dict[time:container[PoseTW]] to given
    time.  The poses container indexed by time can be given plain as poses, or a
    list or dict of poses.  If a list or dict of poses is given, the output will
    also be a list or dict of the interpolated poses. This allows batched
    interpolation.
    """
    ts_list = list(timed_poses.keys())
    ts = torch.from_numpy(np.array(ts_list))
    interp_time = torch.from_numpy(np.array([time]))
    lower_ids, upper_ids, alpha, good = interpolation_boundaries_alphas(ts, interp_time)
    t_lower = ts_list[lower_ids[0]]
    t_upper = ts_list[upper_ids[0]]
    poses_lower, poses_upper = timed_poses[t_lower], timed_poses[t_upper]
    if time - t_lower < t_upper - time:
        return poses_lower, time - t_lower
    else:
        return poses_upper, t_upper - time


def all_rot90():
    # construct all possible 90 degree rotations
    dirs = torch.cat([torch.eye(3), -torch.eye(3)], dim=0)
    ids = torch.arange(0, 6).long()
    jds = torch.arange(0, 6).long()
    ids, jds = torch.meshgrid(ids, jds)
    ids, jds = ids.reshape(-1), jds.reshape(-1)
    a, b = dirs[ids, :], dirs[jds, :]
    c = torch.cross(a, b, -1)
    Rs = torch.cat([a.unsqueeze(2), b.unsqueeze(2), c.unsqueeze(2)], dim=2)
    # filter to valid rotations
    det = torch.linalg.det(Rs)
    Rs = Rs[det > 0.99]
    return Rs


def find_r90(Ta, Tb, R90s):
    N = None
    if Tb.ndim == 2:
        N = Tb.shape[0]
        # 24xNx3x3
        R90s = R90s.unsqueeze(1).repeat(1, N, 1, 1)
    Ra_inv, Rb = Ta.inverse().R.unsqueeze(0), Tb.R
    # 24x(Nx)3x3
    dR = Ra_inv @ Rb.unsqueeze(0) @ R90s
    w = so3log_map(dR)
    ang = torch.linalg.norm(w, 2, dim=-1)
    ang_min, id_min = torch.min(ang, dim=0)
    if N is None:
        R90min = R90s[id_min]
    else:
        R90min = R90s[id_min, torch.arange(N)]
    Rb = Rb @ R90min
    Tb = PoseTW.from_Rt(Rb, Tb.t)
    return Tb, R90min


def stereographic_unproject(a, axis=None):
    """
    Inverse of stereographic projection: https://en.wikipedia.org/wiki/Stereographic_projection
    This is from the paper "On the Continuity of Rotation Representations in Neural
    Networks" https://arxiv.org/pdf/1812.07035.pdf, equation [8,9],
    used in rotation_from_ortho_5d.
    """
    batch = a.shape[0]
    if axis is None:
        axis = a.shape[1]
    s2 = torch.pow(a, 2).sum(1)
    ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1] + 1).to(a))
    unproj = 2 * a / (s2 + 1).reshape(batch, 1).repeat(1, a.shape[1])
    if axis > 0:
        ans[:, :axis] = unproj[:, :axis]
    ans[:, axis] = (s2 - 1) / (s2 + 1)
    ans[:, axis + 1 :] = unproj[:, axis:]
    return ans


def rotation_from_ortho_6d(ortho6d):
    """
    Convert a 6-d rotation representation to rotation matrix

    From the paper "On the Continuity of Rotation Representations in Neural Networks"
    https://arxiv.org/pdf/1812.07035.pdf
    """
    x_raw = ortho6d[..., 0:3]
    y_raw = ortho6d[..., 3:6]

    x = F.normalize(x_raw, dim=-1, eps=1e-6)
    y = F.normalize(y_raw, dim=-1, eps=1e-6)

    z = torch.cross(x, y, -1)
    z = F.normalize(z, dim=-1, eps=1e-6)
    y = torch.cross(z, x, -1)

    x = x.reshape(-1, 3, 1)
    y = y.reshape(-1, 3, 1)
    z = z.reshape(-1, 3, 1)
    matrix = torch.cat((x, y, z), 2)
    return matrix


def rotation_from_ortho_5d(ortho5d):
    """
    Convert a 5-d rotation representation to rotation matrix

    From the paper "On the Continuity of Rotation Representations in Neural Networks"
    https://arxiv.org/pdf/1812.07035.pdf
    """
    batch = ortho5d.shape[0]
    proj_scale_np = np.array([np.sqrt(2) + 1, np.sqrt(2) + 1, np.sqrt(2)])
    proj_scale = (
        torch.autograd.Variable(torch.FloatTensor(proj_scale_np).to(ortho5d))
        .reshape(1, 3)
        .repeat(batch, 1)
    )

    u = stereographic_unproject(ortho5d[:, 2:5] * proj_scale, axis=0)
    norm = torch.sqrt(torch.pow(u[:, 1:], 2).sum(1))
    u = u / norm.reshape(batch, 1).repeat(1, u.shape[1])
    b = torch.cat((ortho5d[:, 0:2], u), 1)
    matrix = rotation_from_ortho_6d(b)
    return matrix


def rotation_from_euler(euler):
    """
    Convert a 3-d Euler angle representation to rotation matrix
    """
    batch = euler.shape[0]

    c1 = torch.cos(euler[:, 0]).reshape(batch, 1)
    s1 = torch.sin(euler[:, 0]).reshape(batch, 1)
    c2 = torch.cos(euler[:, 2]).reshape(batch, 1)
    s2 = torch.sin(euler[:, 2]).reshape(batch, 1)
    c3 = torch.cos(euler[:, 1]).reshape(batch, 1)
    s3 = torch.sin(euler[:, 1]).reshape(batch, 1)

    row1 = torch.cat((c2 * c3, -s2, c2 * s3), 1).reshape(-1, 1, 3)
    row2 = torch.cat(
        (c1 * s2 * c3 + s1 * s3, c1 * c2, c1 * s2 * s3 - s1 * c3), 1
    ).reshape(-1, 1, 3)
    row3 = torch.cat(
        (s1 * s2 * c3 - c1 * s3, s1 * c2, s1 * s2 * s3 + c1 * c3), 1
    ).reshape(-1, 1, 3)

    matrix = torch.cat((row1, row2, row3), 1)
    return matrix


def fit_to_SO3(R):
    # Math used from quora post and this berkeley pdf.
    # https://qr.ae/pKQaG5
    # https://people.eecs.berkeley.edu/~wkahan/Math128/NearestQ.pdf
    #
    # Input:
    #   R - torch 3x3 rotation matrix that is not quite orthogonal
    # Output:
    #   Q - torch 3x3 nearest valid rotation matrix
    assert R.ndim == 2
    assert R.shape[0] == 3 and R.shape[1] == 3
    B = R
    I = torch.eye(3)
    Y = B.transpose(-2, -1) @ B - I
    Q = B - B @ Y @ (
        I / 2.0 - (3.0 * Y) / 8.0 + (5 * Y @ Y) / 16 - (35 * Y @ Y @ Y @ Y) / 128
    )
    return Q


================================================
FILE: efm3d/aria/projection_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


def sign_plus(x):
    """
    return +1 for positive and for 0.0 in x. This is important for our handling
    of z values that should never be 0.0
    """
    sgn = torch.ones_like(x)
    sgn[sgn < 0.0] = -1.0
    return sgn


@torch.jit.script
def fisheye624_project(xyz, params):
    """
    Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
    model project() function.

    Inputs:
        xyz: Bx(T)xNx3 tensor of 3D points to be projected
        params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this:
                [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
                or Bx(T)x15 tensor of Fisheye624 parameters formatted like this:
                [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
    Outputs:
        uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane

    Model for fisheye cameras with radial, tangential, and thin-prism distortion.
    This model allows fu != fv.
    Specifically, the model is:
    uvDistorted = [x_r]  + tangentialDistortion  + thinPrismDistortion
                  [y_r]
    proj = diag(fu,fv) * uvDistorted + [cu;cv];
    where:
      a = x/z, b = y/z, r = (a^2+b^2)^(1/2)
      th = atan(r)
      cosPhi = a/r, sinPhi = b/r
      [x_r]  = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi]
      [y_r]                                     [sinPhi]
      the number of terms in the series is determined by the template parameter numK.
      tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]
                             [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0]
      where rd^2 = x_r^2 + y_r^2
      thinPrismDistortion = [s0 * rd^2 + s1 rd^4]
                            [s2 * rd^2 + s3 rd^4]
    """

    assert (xyz.ndim == 3 and params.ndim == 2) or (
        xyz.ndim == 4 and params.ndim == 3
    ), f"point dim {xyz.shape} does not match cam parameter dim {params}"
    assert xyz.shape[-1] == 3
    assert params.shape[-1] == 16 or params.shape[-1] == 15, (
        "This model allows fx != fy"
    )
    assert xyz.dtype == params.dtype, "data type must match"

    eps = 1e-9
    T = -1
    if xyz.ndim == 4:
        # has T dim
        T, N = xyz.shape[1], xyz.shape[2]
        xyz = xyz.reshape(-1, N, 3)  # (BxT)xNx3
        params = params.reshape(-1, params.shape[-1])  #  (BxT)x16

    B, N = xyz.shape[0], xyz.shape[1]

    # Radial correction.
    z = xyz[:, :, 2].reshape(B, N, 1)
    # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a
    # nan when we compute xy/z
    z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z)
    ab = xyz[:, :, :2] / z
    # make sure abs are not too small or 0 otherwise gradients are nan
    ab = torch.where(torch.abs(ab) < eps, eps * sign_plus(ab), ab)
    r = torch.norm(ab, dim=-1, p=2, keepdim=True)
    th = torch.atan(r)
    th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
    th_k = th.reshape(B, N, 1).clone()
    for i in range(6):
        th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2)
    xr_yr = th_k * th_divr
    uv_dist = xr_yr

    # Tangential correction.
    p0 = params[:, -6].reshape(B, 1)
    p1 = params[:, -5].reshape(B, 1)
    xr = xr_yr[:, :, 0].reshape(B, N)
    yr = xr_yr[:, :, 1].reshape(B, N)
    xr_yr_sq = torch.square(xr_yr)
    xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
    yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
    rd_sq = xr_sq + yr_sq
    uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
    uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
    uv_dist = torch.stack(
        [uv_dist_tu, uv_dist_tv], dim=-1
    )  # Avoids in-place complaint.

    # Thin Prism correction.
    s0 = params[:, -4].reshape(B, 1)
    s1 = params[:, -3].reshape(B, 1)
    s2 = params[:, -2].reshape(B, 1)
    s3 = params[:, -1].reshape(B, 1)
    rd_4 = torch.square(rd_sq)
    uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
    uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)

    # Finally, apply standard terms: focal length and camera centers.
    if params.shape[-1] == 15:
        fx_fy = params[:, 0].reshape(B, 1, 1)
        cx_cy = params[:, 1:3].reshape(B, 1, 2)
    else:
        fx_fy = params[:, 0:2].reshape(B, 1, 2)
        cx_cy = params[:, 2:4].reshape(B, 1, 2)
    result = uv_dist * fx_fy + cx_cy

    if T > 0:
        result = result.reshape(B // T, T, N, 2)

    assert result.ndim == 4 or result.ndim == 3
    assert result.shape[-1] == 2

    return result


@torch.jit.script
def fisheye624_unproject(uv, params, max_iters: int = 5):
    """
    Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
    model. There is no analytical solution for the inverse of the project()
    function so this solves an optimization problem using Newton's method to get
    the inverse.

    Inputs:
        uv: Bx(T)xNx2 tensor of 2D pixels to be projected
        params: Bx(T)x16 tensor of Fisheye624 parameters formatted like this:
                [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
                or Bx(T)x15 tensor of Fisheye624 parameters formatted like this:
                [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
    Outputs:
        xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1.

    Model for fisheye cameras with radial, tangential, and thin-prism distortion.
    This model assumes fu=fv. This unproject function holds that:

    X = unproject(project(X))     [for X=(x,y,z) in R^3, z>0]

    and

    x = project(unproject(s*x))   [for s!=0 and x=(u,v) in R^2]
    """

    assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx2"
    assert uv.shape[-1] == 2
    assert params.ndim == 2 or params.ndim == 3, (
        "Expected batched input shaped Bx(T)x16 or Bx(T)x15"
    )
    assert params.shape[-1] == 16 or params.shape[-1] == 15, (
        "This model allows fx != fy"
    )
    assert uv.dtype == params.dtype, "data type must match"
    eps = 1e-6

    T = -1
    if uv.ndim == 4:
        # has T dim
        T, N = uv.shape[1], uv.shape[2]
        uv = uv.reshape(-1, N, 2)  # (BxT)xNx2
        params = params.reshape(-1, params.shape[-1])  #  (BxT)x16

    B, N = uv.shape[0], uv.shape[1]

    if params.shape[-1] == 15:
        fx_fy = params[:, 0].reshape(B, 1, 1)
        cx_cy = params[:, 1:3].reshape(B, 1, 2)
    else:
        fx_fy = params[:, 0:2].reshape(B, 1, 2)
        cx_cy = params[:, 2:4].reshape(B, 1, 2)

    uv_dist = (uv - cx_cy) / fx_fy

    # Compute xr_yr using Newton's method.
    xr_yr = uv_dist.clone()  # Initial guess.
    for _ in range(max_iters):
        uv_dist_est = xr_yr.clone()
        # Tangential terms.
        p0 = params[:, -6].reshape(B, 1)
        p1 = params[:, -5].reshape(B, 1)
        xr = xr_yr[:, :, 0].reshape(B, N)
        yr = xr_yr[:, :, 1].reshape(B, N)
        xr_yr_sq = torch.square(xr_yr)
        xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
        yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
        rd_sq = xr_sq + yr_sq
        uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (
            (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
        )
        uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (
            (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
        )
        # Thin Prism terms.
        s0 = params[:, -4].reshape(B, 1)
        s1 = params[:, -3].reshape(B, 1)
        s2 = params[:, -2].reshape(B, 1)
        s3 = params[:, -1].reshape(B, 1)
        rd_4 = torch.square(rd_sq)
        uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
        uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
        # Compute the derivative of uv_dist w.r.t. xr_yr.
        duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
        duv_dist_dxr_yr[:, :, 0, 0] = (
            1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1
        )
        offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0)
        duv_dist_dxr_yr[:, :, 0, 1] = offdiag
        duv_dist_dxr_yr[:, :, 1, 0] = offdiag
        duv_dist_dxr_yr[:, :, 1, 1] = (
            1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0
        )
        xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1]
        temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
        duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (
            xr_yr[:, :, 0] * temp1
        )
        duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (
            xr_yr[:, :, 1] * temp1
        )
        temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
        duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (
            xr_yr[:, :, 0] * temp2
        )
        duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (
            xr_yr[:, :, 1] * temp2
        )
        # Compute 2x2 inverse manually here since torch.inverse() is very slow.
        # Because this is slow: inv = duv_dist_dxr_yr.inverse()
        # About a 10x reduction in speed with above line.
        mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
        a = mat[:, 0, 0].reshape(-1, 1, 1)
        b = mat[:, 0, 1].reshape(-1, 1, 1)
        c = mat[:, 1, 0].reshape(-1, 1, 1)
        d = mat[:, 1, 1].reshape(-1, 1, 1)
        det = 1.0 / ((a * d) - (b * c))
        top = torch.cat([d, -b], dim=2)
        bot = torch.cat([-c, a], dim=2)
        inv = det * torch.cat([top, bot], dim=1)
        inv = inv.reshape(B, N, 2, 2)
        # Manually compute 2x2 @ 2x1 matrix multiply.
        # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0]
        diff = uv_dist - uv_dist_est
        a = inv[:, :, 0, 0]
        b = inv[:, :, 0, 1]
        c = inv[:, :, 1, 0]
        d = inv[:, :, 1, 1]
        e = diff[:, :, 0]
        f = diff[:, :, 1]
        step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
        # Newton step.
        xr_yr = xr_yr + step

    # Compute theta using Newton's method.
    xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
    th = xr_yr_norm.clone()
    for _ in range(max_iters):
        th_radial = uv.new_ones(B, N, 1)
        dthd_th = uv.new_ones(B, N, 1)
        for k in range(6):
            r_k = params[:, -12 + k].reshape(B, 1, 1)
            th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2))
            dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2))
        th_radial = th_radial * th
        step = (xr_yr_norm - th_radial) / dthd_th
        # handle dthd_th close to 0.
        step = torch.where(dthd_th.abs() > eps, step, sign_plus(step) * eps * 10.0)
        th = th + step
    # Compute the ray direction using theta and xr_yr.
    close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
    ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
    ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
    assert ray.shape[-1] == 3

    if T > 0:
        ray = ray.reshape(B // T, T, N, 3)

    return ray


def pinhole_project(xyz, params):
    """
    Batched implementation of the Pinhole (aka Linear) camera
    model project() function.

    Inputs:
        xyz: Bx(T)xNx3 tensor of 3D points to be projected
        params: Bx(T)x4 tensor of Pinhole parameters formatted like this:
                [f_u f_v c_u c_v]
    Outputs:
        uv: Bx(T)xNx2 tensor of 2D projections of xyz in image plane
    """

    assert (xyz.ndim == 3 and params.ndim == 2) or (xyz.ndim == 4 and params.ndim == 3)
    assert params.shape[-1] == 4
    eps = 1e-9

    # Focal length and principal point
    fx_fy = params[..., 0:2].reshape(*xyz.shape[:-2], 1, 2)
    cx_cy = params[..., 2:4].reshape(*xyz.shape[:-2], 1, 2)
    # Make sure depth is not too close to zero.
    z = xyz[..., 2:]
    # Do not use torch.sign(z) it leads to 0.0 zs if z == 0.0 which leads to a
    # nan when we compute xy/z
    z = torch.where(torch.abs(z) < eps, eps * sign_plus(z), z)
    uv = (xyz[..., :2] / z) * fx_fy + cx_cy
    return uv


def pinhole_unproject(uv, params, max_iters: int = 5):
    """
    Batched implementation of the Pinhole (aka Linear) camera model.

    Inputs:
        uv: Bx(T)xNx2 tensor of 2D pixels to be projected
        params: Bx(T)x4 tensor of Pinhole parameters formatted like this:
                [f_u f_v c_u c_v]
    Outputs:
        xyz: Bx(T)xNx3 tensor of 3D rays of uv points with z = 1.

    """
    assert uv.ndim == 3 or uv.ndim == 4, "Expected batched input shaped Bx(T)xNx3"
    assert params.ndim == 2 or params.ndim == 3
    assert params.shape[-1] == 4
    assert uv.shape[-1] == 2

    # Focal length and principal point
    fx_fy = params[..., 0:2].reshape(*uv.shape[:-2], 1, 2)
    cx_cy = params[..., 2:4].reshape(*uv.shape[:-2], 1, 2)

    uv_dist = (uv - cx_cy) / fx_fy

    ray = torch.cat([uv_dist, uv.new_ones(*uv.shape[:-1], 1)], dim=-1)
    return ray


================================================
FILE: efm3d/aria/tensor_wrapper.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import inspect
import logging
from typing import List

import numpy as np
import torch
from torch.utils.data._utils.collate import (
    collate,
    collate_tensor_fn,
    default_collate_fn_map,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# logger.setLevel(logging.DEBUG)


def smart_cat(inp_arr, dim=-1):
    devices = set()
    for i, inp in enumerate(inp_arr):
        if isinstance(inp, TensorWrapper):
            inp_arr[i] = inp._data
        else:
            inp_arr[i] = inp
        devices.add(inp_arr[i].device)
    if len(devices) > 1:
        raise RuntimeError(f"More than one device found! {devices}")
    return torch.cat(inp_arr, dim=dim)


def smart_stack(inp_arr, dim: int = 0):
    devices = set()
    for i, inp in enumerate(inp_arr):
        if isinstance(inp, TensorWrapper):
            inp_arr[i] = inp._data
        else:
            inp_arr[i] = inp
        devices.add(inp_arr[i].device)
    if len(devices) > 1:
        raise RuntimeError(f"More than one device found! {devices}")
    return torch.stack(inp_arr, dim=dim)


def get_default_args(func):
    signature = inspect.signature(func)
    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }


def get_nonempty_arg_names(func):
    spec = inspect.getfullargspec(func)
    signature = inspect.signature(func)
    return [
        k
        for k in spec.args
        if signature.parameters[k].default is not inspect.Parameter.empty
    ]


def autocast(func):
    """Cast the inputs of a TensorWrapper method to PyTorch tensors
    if they are numpy arrays. Use the device and dtype of the wrapper.
    """

    @functools.wraps(func)
    def wrap(self, *args):
        device = torch.device("cpu")
        dtype = None
        if isinstance(self, TensorWrapper):
            if self._data is not None:
                device = self.device
                dtype = self.dtype
        elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
            raise ValueError(self)

        cast_args = []
        for arg in args:
            if isinstance(arg, np.ndarray):
                arg = torch.from_numpy(arg)
                arg = arg.to(device=device, dtype=dtype)
            cast_args.append(arg)

        return func(self, *cast_args)

    return wrap


def autoinit(func):
    """
    Helps with initialization. Will auto-reshape and auto-expand input arguments
    to match the first argument, as well as check shapes based on default tensor sizes.
    """

    @functools.wraps(func)
    def wrap(self, *args, **kwargs):
        # Combine args and kwargs.
        arg_names = get_nonempty_arg_names(func)
        all_args = {}
        for i, arg in enumerate(args):
            all_args[arg_names[i]] = arg
        for arg_name in kwargs:
            all_args[arg_name] = kwargs[arg_name]

        # Add default values to all_args if unspecified inputs.
        default_args = get_default_args(func)
        extra_args = {}
        for arg_name in default_args:
            default_arg = default_args[arg_name]
            if not isinstance(default_arg, (TensorWrapper, torch.Tensor)):
                # If not TW or torch tensor, pass it through unperturbed.
                extra_args[arg_name] = all_args.pop(arg_name)
            else:
                if arg_name not in all_args or all_args[arg_name] is None:
                    all_args[arg_name] = default_arg

        # Auto convert numpy,lists,floats to torch, check that shapes are good.
        for arg_name in all_args:
            arg = all_args[arg_name]
            if isinstance(arg, (torch.Tensor, TensorWrapper)):
                pass
            elif isinstance(arg, (int, float)):
                arg = torch.tensor(arg).reshape(1)
            elif isinstance(arg, List):
                arg = torch.tensor(arg)
            elif isinstance(arg, np.ndarray):
                arg = torch.from_numpy(arg)
            else:
                raise ValueError("Unsupported initialization type")
            assert isinstance(arg, (torch.Tensor, TensorWrapper))

            default_arg = default_args[arg_name]
            if isinstance(default_arg, TensorWrapper):
                # Convert list of torch.Size to tuple of ints.
                default_dims = tuple([da[0] for da in default_arg.shape])
            else:
                default_dims = (default_arg.shape[-1],)
            if arg.shape[-1] not in default_dims:
                # probably need a more general solution here to handle single dim inputs.
                if default_dims[0] == 1:
                    arg = arg.unsqueeze(-1)
                if arg.shape[-1] not in default_dims:
                    raise ValueError(
                        "Bad shape of %d for %s, should be in %s"
                        % (arg.shape[-1], arg_name, default_dims)
                    )

            all_args[arg_name] = arg

        # Shape of all inputs is determined by first arg.
        first_arg_name = arg_names[0]
        batch_shape = all_args[first_arg_name].shape[:-1]

        has_cuda_tensor = False

        for arg_name in all_args:
            arg = all_args[arg_name]
            # Try to trim any extra dimensions at the beginning of arg shape.
            while True:
                if arg.ndim > len(batch_shape) and arg.shape[0] == 1 and arg.ndim > 1:
                    arg = arg.squeeze(0)
                else:
                    break
            arg = arg.expand(*batch_shape, arg.shape[-1])
            all_args[arg_name] = arg

            if (
                isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper))
                and all_args[arg_name].is_cuda
            ):
                has_cuda_tensor = True

        if has_cuda_tensor:
            for arg_name in all_args:
                if (
                    isinstance(all_args[arg_name], (torch.Tensor, TensorWrapper))
                    and not all_args[arg_name].is_cuda
                ):
                    all_args[arg_name] = all_args[arg_name].cuda()

        # Add the unperturbed args back to all args.
        all_args.update(extra_args)

        return func(self, **all_args)

    return wrap


def tensor_wrapper_collate(batch, *, collate_fn_map=None):
    """Simply call stack for TensorWrapper"""
    return torch.stack(batch, 0)


def float_collate(batch, *, collate_fn_map=None):
    """Auto convert float to float32"""
    return torch.tensor(batch, dtype=torch.float32)


def list_dict_collate(batch, *, collate_fn_map=None):
    """collate lists; handles the case where the lists in the batch are
    expressing a dict via List[Tuple[key, value]] and returns a Dict[key, value]
    in that case."""
    if len(batch) > 0:
        list_0 = batch[0]
        if len(list_0) > 0:
            elem_0 = list_0[0]
            if isinstance(elem_0, tuple) and len(elem_0) == 2:
                # the lists in each batch sample are (key, value) pairs and we hence return a dictionary
                for i in range(len(batch)):
                    batch[i] = {k: v for k, v in batch[i]}
    return batch


def tensor_wrapper_collate_cat(batch, *, collate_fn_map=None):
    """Simply call cat for TensorWrapper"""
    return torch.cat(batch, 0)


def tensor_collate_cat(batch, *, collate_fn_map=None):
    """identical to "collate_tensor_fn" but replace torch.stack with torch.cat"""
    elem = batch[0]
    out = None
    if torch.utils.data.get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum(x.numel() for x in batch)
        # Note: pytorch 1.12 doesn't have the _typed_storage() interface. Need to use storage() instead.
        # storage = elem._typed_storage()._new_shared(numel, device=elem.device)
        storage = elem.storage()._new_shared(numel, device=elem.device)

        # since we are using torch.cat, we don't need to add a new dimension here
        dims_from_one = list(elem.size())[1:]
        out = elem.new(storage).resize_(len(batch), *dims_from_one)
    return torch.cat(batch, 0, out=out)  # concatenate instead of stack


def custom_collate_fn(batch):
    # Get the common keys between samples. This is required when we train with
    # multiple datasets with samples having different keys.
    if isinstance(batch, list) and isinstance(batch[0], dict):
        common_keys = set(batch[0].keys())

        for sample in batch[1:]:
            common_keys &= set(sample.keys())

        # update the batch with new samples with only the common keys
        new_batch = []
        for sample in batch:
            new_sample = {k: v for k, v in sample.items() if k in common_keys}
            new_batch.append(new_sample)
        batch = new_batch

    """Custom collate function for tensor wrapper"""
    default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate
    default_collate_fn_map[float] = float_collate
    default_collate_fn_map[list] = list_dict_collate
    default_collate_fn_map[torch.Tensor] = collate_tensor_fn
    if "already_collated" in batch[0]:
        # Use torch.cat instead of torch.stack
        default_collate_fn_map[torch.Tensor] = tensor_collate_cat
        default_collate_fn_map[TensorWrapper] = tensor_wrapper_collate_cat
    batch = collate(batch, collate_fn_map=default_collate_fn_map)
    return batch


class TensorWrapper:
    """Base class for making "smart" tensor objects that behave like pytorch tensors
    Inpired by Paul-Edouard Sarlin's code here in pixloc:
    https://github.com/cvg/pixloc/blob/master/pixloc/pixlib/geometry/wrappers.py
    Adopted and modified by Daniel DeTone.
    """

    _data = None

    @autocast
    def __init__(self, data: torch.Tensor):
        self._data = data

    @property
    def shape(self):
        return self._data.shape

    @property
    def device(self):
        return self._data.device

    @property
    def dtype(self):
        return self._data.dtype

    @property
    def ndim(self):
        return self._data.ndim

    def dim(self):
        return self._data.dim()

    def nelement(self):
        return self._data.nelement()

    def numel(self):
        return self._data.numel()

    @property
    def collate_fn(self):
        return custom_collate_fn

    @property
    def is_cuda(self):
        return self._data.is_cuda

    @property
    def is_contiguous(self):
        return self._data.is_contiguous

    @property
    def requires_grad(self):
        return self._data.requires_grad

    @property
    def grad(self):
        return self._data.grad

    @property
    def grad_fn(self):
        return self._data.grad_fn

    def requires_grad_(self, requires_grad: bool = True):
        self._data.requires_grad_(requires_grad)

    def __getitem__(self, index):
        return self.__class__(self._data[index])

    def __setitem__(self, index, item):
        self._data[index] = item._data

    def to(self, *args, **kwargs):
        return self.__class__(self._data.to(*args, **kwargs))

    def reshape(self, *args, **kwargs):
        return self.__class__(self._data.reshape(*args, **kwargs))

    def repeat(self, *args, **kwargs):
        return self.__class__(self._data.repeat(*args, **kwargs))

    def expand(self, *args, **kwargs):
        return self.__class__(self._data.expand(*args, **kwargs))

    def clone(self):
        return self.__class__(self._data.clone())

    def cpu(self):
        return self.__class__(self._data.cpu())

    def cuda(self, gpu_id=0):
        return self.__class__(self._data.cuda(gpu_id))

    def contiguous(self):
        return self.__class__(self._data.contiguous())

    def pin_memory(self):
        return self.__class__(self._data.pin_memory())

    def float(self):
        return self.__class__(self._data.float())

    def double(self):
        return self.__class__(self._data.double())

    def detach(self):
        return self.__class__(self._data.detach())

    def numpy(self):
        return self._data.numpy()

    def tensor(self):
        return self._data

    def tolist(self):
        return self._data.tolist()

    def squeeze(self, dim=None):
        assert dim != -1 and dim != self._data.dim() - 1
        if dim is None:
            return self.__class__(self._data.squeeze())
        return self.__class__(self._data.squeeze(dim=dim))

    def unsqueeze(self, dim=None):
        assert dim != -1 and dim != self._data.dim()
        return self.__class__(self._data.unsqueeze(dim=dim))

    def view(self, *shape):
        assert shape[-1] == -1 or shape[-1] == self._data.shape[-1]
        return self.__class__(self._data.view(*shape))

    def __len__(self):
        return self._data.shape[0]

    @classmethod
    def stack(cls, objects: List, dim=0, *, out=None):
        data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
        return cls(data)

    @classmethod
    def cat(cls, objects: List, dim=0, *, out=None):
        data = torch.cat([obj._data for obj in objects], dim=dim, out=out)
        return cls(data)

    @classmethod
    def allclose(
        cls,
        input: torch.Tensor,
        other: torch.Tensor,
        rtol=1e-5,
        atol=1e-8,
        equal_nan=False,
    ):
        return torch.allclose(
            input._data, other._data, rtol=rtol, atol=atol, equal_nan=equal_nan
        )

    @classmethod
    def take_along_dim(cls, obj, indices, dim, *, out=None):
        data = torch.take_along_dim(obj._data, indices, dim, out=out)
        return cls(data)

    @classmethod
    def flatten(cls, obj, start_dim=0, end_dim=-1):
        data = torch.flatten(obj._data, start_dim=start_dim, end_dim=end_dim)
        return cls(data)

    @classmethod
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func is torch.stack:
            return self.stack(*args, **kwargs)
        elif func is torch.cat:
            return self.cat(*args, **kwargs)
        elif func is torch.allclose:
            return self.allclose(*args, **kwargs)
        elif func is torch.take_along_dim:
            return self.take_along_dim(*args, **kwargs)
        elif func is torch.flatten:
            return self.flatten(*args, **kwargs)
        else:
            return NotImplemented


================================================
FILE: efm3d/config/efm_preprocessing_conf.yaml
================================================
atek_config_name: "efm"
camera_temporal_subsampler:
  main_camera_label: "camera-rgb"
  time_domain: "DEVICE_TIME"
  main_camera_target_freq_hz: 10.0
  sample_length_in_num_frames: 20
  stride_length_in_num_frames: 10
processors:
  rgb:
    selected: true
    sensor_label: "camera-rgb"
    time_domain: "DEVICE_TIME"
    tolerance_ns: 10_000_000
    undistort_to_linear_cam: false  # if set, undistort to a linear camera model
    target_camera_resolution: [240, 240] # if set, rescale to [image_width, image_height]
    rescale_antialias: false # to be consistent with cv2
    rotate_image_cw90deg: false # if set, rotate image by 90 degrees clockwise
  slam_left:
    selected: true
    sensor_label: "camera-slam-left"
    tolerance_ns: 10_000_000
    time_domain: "DEVICE_TIME"
    target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height]
    rescale_antialias: false # to be consistent with cv2
  slam_right:
    selected: true
    sensor_label: "camera-slam-right"
    tolerance_ns: 10_000_000
    time_domain: "DEVICE_TIME"
    target_camera_resolution: [320, 240] # if set, rescale to [image_width, image_height]
    rescale_antialias: false # to be consistent with cv2
  mps_traj:
    selected: true
    tolerance_ns: 10_000_000
  mps_semidense:
    selected: true
    tolerance_ns: 10_000_000
  rgb_depth:
    selected: true
    sensor_stream_id: "345-1" # 345-1 for ADT data, 214-8 for ASE data
    tolerance_ns: 10_000_000
    time_domain: "DEVICE_TIME"
    convert_zdepth_to_dist: false
  efm_gt:
    selected: true
    tolerance_ns : 10_000_000
    category_mapping_field_name: category # {prototype_name, category}
wds_writer:
  prefix_string: ""
  max_samples_per_shard: 8


================================================
FILE: efm3d/config/evl_inf.yaml
================================================
_target_: efm3d.model.evl.EVL
neck_hidden_dims: [128, 256, 512]
head_hidden_dim: 256
head_layers: 2
taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv

video_backbone:
  _target_: efm3d.model.video_backbone.VideoBackboneDinov2
  freeze_encoder: true
  image_tokenizer:
    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens
    dinov2_name: vit_base_v25
    freeze: true
    handle_rotated_data: true
    dim_out: 768
    add_lin_layer: false
    multilayer_output: true
    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth
  video_streams: [rgb]
  correct_vignette: false
  optimize_vignette: false
video_backbone3d:
  _target_: efm3d.model.lifter.Lifter
  in_dim: 768
  out_dim: 64
  patch_size: 16
  voxel_size: [96,96,96]
  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]
  head_type: dpt_ori
  streams: [rgb]
  joint_slam_streams: false
  joint_streams: false


================================================
FILE: efm3d/config/evl_inf_desktop.yaml
================================================
_target_: efm3d.model.evl.EVL
neck_hidden_dims: [32, 64, 128]
head_hidden_dim: 256
head_layers: 2
taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv

video_backbone:
  _target_: efm3d.model.video_backbone.VideoBackboneDinov2
  freeze_encoder: true
  image_tokenizer:
    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens
    dinov2_name: vit_base_v25
    freeze: true
    handle_rotated_data: true
    dim_out: 768
    add_lin_layer: false
    multilayer_output: true
    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth
  video_streams: [rgb]
  correct_vignette: false
  optimize_vignette: false
video_backbone3d:
  _target_: efm3d.model.lifter.Lifter
  in_dim: 768
  out_dim: 32
  patch_size: 16
  voxel_size: [48,48,48]
  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]
  head_type: dpt_ori
  streams: [rgb]
  joint_slam_streams: false
  joint_streams: false


================================================
FILE: efm3d/config/evl_train.yaml
================================================
_target_: efm3d.model.evl_train.EVLTrain
neck_hidden_dims: [128, 256, 512]
head_hidden_dim: 256
head_layers: 2
taxonomy_file: efm3d/config/taxonomy/ase_sem_name_to_id.csv

video_backbone:
  _target_: efm3d.model.video_backbone.VideoBackboneDinov2
  freeze_encoder: true
  image_tokenizer:
    _target_: efm3d.model.image_tokenizer.ImageToDinoV2Tokens
    dinov2_name: vit_base_v25
    freeze: true
    handle_rotated_data: true
    dim_out: 768
    add_lin_layer: false
    multilayer_output: true
    ckpt_path: ckpt/dinov2_vitb14_reg4_pretrain.pth
  video_streams: [rgb]
  correct_vignette: false
  optimize_vignette: false
video_backbone3d:
  _target_: efm3d.model.lifter.Lifter
  in_dim: 768
  out_dim: 64
  patch_size: 16
  voxel_size: [96,96,96]
  voxel_extent: [-2.0, 2.0, 0.0, 4.0, -2.0, 2.0]
  head_type: dpt_ori
  streams: [rgb]
  joint_slam_streams: false
  joint_streams: false


================================================
FILE: efm3d/config/taxonomy/aeo_to_efm.csv
================================================
AEO Category Name,EFM Category Name,EFM Category Id
Chair,chair,3
Couch,sofa,1
Table,table,0
Bed,bed,4
WallArt,picture_frame,21
Plant,flower_pot,13
Window,window,28
Mirror,mirror,22
Lamp,lamp,16


================================================
FILE: efm3d/config/taxonomy/ase_sem_name_to_id.csv
================================================
sem_name,sem_id
table,0
sofa,1
shelf,2
chair,3
bed,4
floor_mat,5
exercise_weight,6
cutlery,7
container,8
clock,9
cart,10
vase,11
tent,12
flower_pot,13
pillow,14
mount,15
lamp,16
ladder,17
fan,18
cabinet,19
jar,20
picture_frame,21
mirror,22
electronic_device,23
dresser,24
clothes_rack,25
battery_charger,26
air_conditioner,27
window,28


================================================
FILE: efm3d/config/taxonomy/atek_to_efm.csv
================================================
ATEK Category Name,EFM version ASE Category Name,EFM version ASE Category Id
table,table,0
sofa,sofa,1
shelves,shelf,2
chair,chair,3
bed,bed,4
floor mat,floor_mat,5
exercise_weight,exercise_weight,6
cutlery,cutlery,7
container,container,8
clock,clock,9
cart,cart,10
vase,vase,11
tent,tent,12
plant,flower_pot,13
pillow,pillow,14
mount,mount,15
lamp,lamp,16
ladder,ladder,17
fan,fan,18
cabinet,cabinet,19
jar,jar,20
picture,picture_frame,21
mirror,mirror,22
electronic_device,electronic_device,23
dresser,dresser,24
clothes_rack,clothes_rack,25
battery_charger,battery_charger,26
air_conditioner,air_conditioner,27
window,window,28


================================================
FILE: efm3d/dataset/atek_vrs_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-strict

import logging
import os
from typing import Dict, List, Optional

from atek.data_loaders.atek_wds_dataloader import select_and_remap_dict_keys
from atek.data_preprocess.atek_data_sample import AtekDataSample
from atek.data_preprocess.sample_builders.atek_data_paths_provider import (
    AtekDataPathsProvider,
)
from atek.data_preprocess.sample_builders.efm_sample_builder import EfmSampleBuilder
from atek.data_preprocess.subsampling_lib.temporal_subsampler import (
    CameraTemporalSubsampler,
)
from efm3d.dataset.efm_model_adaptor import EfmModelAdaptor
from omegaconf.omegaconf import DictConfig, OmegaConf

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class AtekRawDataloaderAsEfm:
    def __init__(
        self,
        vrs_file: str,
        mps_files: Dict[str, str],
        gt_files: Dict[str, str],
        conf: DictConfig,
        freq_hz: int,
        snippet_length_s: float,
        semidense_points_pad_to_num: int = 50000,
        max_snippets=9999,
    ) -> None:
        self.max_snippets = max_snippets

        # initialize the sample builder
        self.sample_builder = EfmSampleBuilder(
            conf=conf.processors,
            vrs_file=vrs_file,
            mps_files=mps_files,
            gt_files=gt_files,
            depth_vrs_file="",
            sequence_name=os.path.basename(vrs_file),
        )

        self.subsampler = CameraTemporalSubsampler(
            vrs_file=vrs_file,
            conf=conf.camera_temporal_subsampler,
        )

        # Create a EFM model adaptor
        self.model_adaptor = EfmModelAdaptor(
            freq=freq_hz,
            snippet_length_s=snippet_length_s,
            semidense_points_pad_to_num=semidense_points_pad_to_num,
            atek_to_efm_taxonomy_mapping_file=f"{os.path.dirname(__file__)}/../config/taxonomy/atek_to_efm.csv",
        )

    def __len__(self):
        return min(self.subsampler.get_total_num_samples(), self.max_snippets)

    def get_timestamps_by_sample_index(self, index: int) -> List[int]:
        return self.subsampler.get_timestamps_by_sample_index(index)

    def get_atek_sample_at_timestamps_ns(
        self, timestamps_ns: List[int]
    ) -> Optional[AtekDataSample]:
        return self.sample_builder.get_sample_by_timestamps_ns(timestamps_ns)

    def get_model_specific_sample_at_timestamps_ns(
        self, timestamps_ns: List[int]
    ) -> Optional[Dict]:
        atek_sample = self.get_atek_sample_at_timestamps_ns(timestamps_ns)
        if atek_sample is None:
            logger.warning(
                f"Cannot retrieve valid atek sample at timestamp {timestamps_ns}"
            )
            return None

        # Flatten to dict
        atek_sample_dict = atek_sample.to_flatten_dict()

        # key remapping
        remapped_data_dict = select_and_remap_dict_keys(
            sample_dict=atek_sample_dict,
            key_mapping=self.model_adaptor.get_dict_key_mapping_all(),
        )

        # transform
        model_specific_sample_gen = self.model_adaptor.atek_to_efm([remapped_data_dict])

        # Obtain a dict from a generator object
        model_specific_sample = next(model_specific_sample_gen)

        return model_specific_sample

    def __getitem__(self, index):
        if index >= self.max_snippets:
            raise StopIteration

        timestamps = self.get_timestamps_by_sample_index(index)
        maybe_sample = self.get_model_specific_sample_at_timestamps_ns(timestamps)

        return maybe_sample


def create_atek_raw_data_loader_from_vrs_path(
    vrs_path: str,
    freq_hz: int,
    snippet_length_s,
    stride_length_s,
    skip_begin_seconds: float = 0.0,
    skip_end_seconds: float = 0.0,
    semidense_points_pad_to_num=50000,
    max_snippets=9999,
):
    vrs_dir = os.path.dirname(vrs_path)
    data_path_provider = AtekDataPathsProvider(data_root_path=vrs_dir)
    atek_data_paths = data_path_provider.get_data_paths()

    conf = OmegaConf.load("efm3d/config/efm_preprocessing_conf.yaml")

    # Update snippet / stride length
    conf.camera_temporal_subsampler.main_camera_target_freq_hz = float(freq_hz)
    conf.camera_temporal_subsampler.sample_length_in_num_frames = int(
        freq_hz * snippet_length_s
    )
    conf.camera_temporal_subsampler.stride_length_in_num_frames = int(
        freq_hz * stride_length_s
    )
    conf.camera_temporal_subsampler.update(
        {
            "skip_begin_seconds": skip_begin_seconds,
            "skip_end_seconds": skip_end_seconds,
        }
    )

    data_loader = AtekRawDataloaderAsEfm(
        vrs_file=atek_data_paths["video_vrs_file"],
        mps_files={
            "mps_closedloop_traj_file": atek_data_paths["mps_closedloop_traj_file"],
            "mps_semidense_points_file": atek_data_paths["mps_semidense_points_file"],
            "mps_semidense_observations_file": atek_data_paths[
                "mps_semidense_observations_file"
            ],
        },
        gt_files={
            "obb3_file": atek_data_paths["gt_obb3_file"],
            "obb3_traj_file": atek_data_paths["gt_obb3_traj_file"],
            "obb2_file": atek_data_paths["gt_obb2_file"],
            "instance_json_file": atek_data_paths["gt_instance_json_file"],
        },
        conf=conf,
        freq_hz=freq_hz,
        snippet_length_s=snippet_length_s,
        semidense_points_pad_to_num=semidense_points_pad_to_num,
        max_snippets=max_snippets,
    )

    return data_loader


================================================
FILE: efm3d/dataset/atek_wds_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Lice
Download .txt
gitextract_3a354cf0/

├── .github/
│   ├── CODE_OF_CONDUCT.md
│   ├── CONTRIBUTING.md
│   └── workflows/
│       └── conda-env.yaml
├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── benchmark.md
├── efm3d/
│   ├── __init__.py
│   ├── aria/
│   │   ├── __init__.py
│   │   ├── aria_constants.py
│   │   ├── camera.py
│   │   ├── obb.py
│   │   ├── pose.py
│   │   ├── projection_utils.py
│   │   └── tensor_wrapper.py
│   ├── config/
│   │   ├── efm_preprocessing_conf.yaml
│   │   ├── evl_inf.yaml
│   │   ├── evl_inf_desktop.yaml
│   │   ├── evl_train.yaml
│   │   └── taxonomy/
│   │       ├── aeo_to_efm.csv
│   │       ├── ase_sem_name_to_id.csv
│   │       └── atek_to_efm.csv
│   ├── dataset/
│   │   ├── atek_vrs_dataset.py
│   │   ├── atek_wds_dataset.py
│   │   ├── augmentation.py
│   │   ├── efm_model_adaptor.py
│   │   ├── vrs_dataset.py
│   │   └── wds_dataset.py
│   ├── inference/
│   │   ├── __init__.py
│   │   ├── eval.py
│   │   ├── fuse.py
│   │   ├── model.py
│   │   ├── pipeline.py
│   │   ├── track.py
│   │   └── viz.py
│   ├── model/
│   │   ├── __init__.py
│   │   ├── cnn.py
│   │   ├── dinov2_utils.py
│   │   ├── dpt.py
│   │   ├── evl.py
│   │   ├── evl_train.py
│   │   ├── image_tokenizer.py
│   │   ├── lifter.py
│   │   └── video_backbone.py
│   ├── thirdparty/
│   │   ├── __init__.py
│   │   └── mmdetection3d/
│   │       ├── LICENSE
│   │       ├── __init__.py
│   │       ├── cuda/
│   │       │   ├── cuda_utils.h
│   │       │   ├── iou3d.cpp
│   │       │   ├── iou3d.h
│   │       │   ├── iou3d_kernel.cu
│   │       │   ├── setup.py
│   │       │   ├── sort_vert.cpp
│   │       │   ├── sort_vert.h
│   │       │   ├── sort_vert_kernel.cu
│   │       │   └── utils.h
│   │       └── iou3d.py
│   └── utils/
│       ├── __init__.py
│       ├── common.py
│       ├── depth.py
│       ├── detection_utils.py
│       ├── evl_loss.py
│       ├── file_utils.py
│       ├── gravity.py
│       ├── image.py
│       ├── image_sampling.py
│       ├── marching_cubes.py
│       ├── mesh_utils.py
│       ├── obb_csv_writer.py
│       ├── obb_io.py
│       ├── obb_matchers.py
│       ├── obb_metrics.py
│       ├── obb_trackers.py
│       ├── obb_utils.py
│       ├── pointcloud.py
│       ├── ray.py
│       ├── reconstruction.py
│       ├── render.py
│       ├── rescale.py
│       ├── viz.py
│       ├── voxel.py
│       └── voxel_sampling.py
├── environment-mac.yml
├── environment.yml
├── eval.py
├── infer.py
├── prepare_inference.sh
├── requirements-extra.txt
├── requirements.txt
├── sbatch_run.sh
└── train.py
Download .txt
SYMBOL INDEX (785 symbols across 54 files)

FILE: efm3d/aria/camera.py
  class DefaultCameraTWData (line 34) | class DefaultCameraTWData(TensorWrapper):
    method __init__ (line 37) | def __init__(self):
    method shape (line 41) | def shape(self):
  class DefaultCameraTWParam (line 45) | class DefaultCameraTWParam(TensorWrapper):
    method __init__ (line 48) | def __init__(self):
    method shape (line 52) | def shape(self):
  class DefaultCameraTWDistParam (line 56) | class DefaultCameraTWDistParam(TensorWrapper):
    method __init__ (line 59) | def __init__(self):
    method shape (line 63) | def shape(self):
  function is_fisheye624 (line 86) | def is_fisheye624(inp):
  function is_kb3 (line 99) | def is_kb3(inp):
  function is_pinhole (line 105) | def is_pinhole(inp):
  function get_aria_camera (line 111) | def get_aria_camera(params=SLAM_PARAMS, width=640, height=480, valid_rad...
  function get_pinhole_camera (line 128) | def get_pinhole_camera(params, width=640, height=480, valid_radius=None,...
  function get_base_aria_rgb_camera_full_res (line 145) | def get_base_aria_rgb_camera_full_res():
  function get_base_aria_rgb_camera (line 151) | def get_base_aria_rgb_camera():
  function get_base_aria_slam_camera (line 155) | def get_base_aria_slam_camera():
  class CameraTW (line 159) | class CameraTW(TensorWrapper):
    method __init__ (line 175) | def __init__(
    method from_parameters (line 184) | def from_parameters(
    method from_surreal (line 223) | def from_surreal(
    method size (line 318) | def size(self) -> torch.Tensor:
    method f (line 323) | def f(self) -> torch.Tensor:
    method c (line 328) | def c(self) -> torch.Tensor:
    method K (line 333) | def K(self) -> torch.Tensor:
    method K44 (line 348) | def K44(self) -> torch.Tensor:
    method gain (line 363) | def gain(self) -> torch.Tensor:
    method exposure_s (line 368) | def exposure_s(self) -> torch.Tensor:
    method valid_radius (line 373) | def valid_radius(self) -> torch.Tensor:
    method T_camera_rig (line 378) | def T_camera_rig(self) -> torch.Tensor:
    method dist (line 383) | def dist(self) -> torch.Tensor:
    method params (line 388) | def params(self) -> torch.Tensor:
    method is_fisheye624 (line 393) | def is_fisheye624(self):
    method is_kb3 (line 397) | def is_kb3(self):
    method is_linear (line 401) | def is_linear(self):
    method set_valid_radius (line 404) | def set_valid_radius(self, valid_radius: torch.Tensor):
    method set_T_camera_rig (line 407) | def set_T_camera_rig(self, T_camera_rig: PoseTW):
    method scale (line 410) | def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
    method scale_to_size (line 430) | def scale_to_size(self, size_wh: Union[int, Tuple[int]]):
    method scale_to (line 444) | def scale_to(self, im: torch.Tensor):
    method crop (line 452) | def crop(self, left_top: Tuple[float], size: Tuple[int]):
    method in_image (line 479) | def in_image(self, p2d: torch.Tensor):
    method in_radius (line 488) | def in_radius(self, p2d: torch.Tensor):
    method in_radius_mask (line 500) | def in_radius_mask(self):
    method project (line 516) | def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
    method unproject (line 549) | def unproject(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
    method rotate_90_cw (line 580) | def rotate_90_cw(self):
    method rotate_90_ccw (line 583) | def rotate_90_ccw(self):
    method rotate_90 (line 586) | def rotate_90(self, clock_wise: bool):
    method __repr__ (line 639) | def __repr__(self):
  function grid_2d (line 643) | def grid_2d(
  function pixel_grid (line 661) | def pixel_grid(cam: CameraTW):
  function scale_image_to_cam (line 667) | def scale_image_to_cam(cams: CameraTW, ims: torch.Tensor) -> torch.Tensor:

FILE: efm3d/aria/obb.py
  class ObbTW (line 97) | class ObbTW(TensorWrapper):
    method __init__ (line 105) | def __init__(self, data: torch.Tensor = PAD_VAL * torch.ones((1, 34))):
    method from_lmc (line 112) | def from_lmc(
    method bb3_object (line 144) | def bb3_object(self) -> torch.Tensor:
    method bb3_min_object (line 149) | def bb3_min_object(self) -> torch.Tensor:
    method bb3_max_object (line 154) | def bb3_max_object(self) -> torch.Tensor:
    method bb3_center_object (line 159) | def bb3_center_object(self) -> torch.Tensor:
    method bb3_center_world (line 164) | def bb3_center_world(self) -> torch.Tensor:
    method bb3_diagonal (line 173) | def bb3_diagonal(self) -> torch.Tensor:
    method bb3_volumes (line 178) | def bb3_volumes(self) -> torch.Tensor:
    method bb2_rgb (line 184) | def bb2_rgb(self) -> torch.Tensor:
    method visible_bb3_ind (line 188) | def visible_bb3_ind(self, cam_id) -> torch.Tensor:
    method bb2_slaml (line 195) | def bb2_slaml(self) -> torch.Tensor:
    method bb2_slamr (line 200) | def bb2_slamr(self) -> torch.Tensor:
    method bb2 (line 204) | def bb2(self, cam_id) -> torch.Tensor:
    method set_bb2 (line 214) | def set_bb2(self, cam_id, bb2d, use_mask=True):
    method set_bb3_object (line 226) | def set_bb3_object(self, bb3_object, use_mask=True) -> torch.Tensor:
    method set_prob (line 233) | def set_prob(self, prob, use_mask=True):
    method T_world_object (line 241) | def T_world_object(self) -> torch.Tensor:
    method get_padding_mask (line 245) | def get_padding_mask(self) -> torch.Tensor:
    method set_T_world_object (line 249) | def set_T_world_object(self, T_world_object: PoseTW):
    method sem_id (line 256) | def sem_id(self) -> torch.Tensor:
    method set_sem_id (line 260) | def set_sem_id(self, sem_id: torch.Tensor):
    method inst_id (line 265) | def inst_id(self) -> torch.Tensor:
    method set_inst_id (line 269) | def set_inst_id(self, inst_id: torch.Tensor):
    method prob (line 274) | def prob(self) -> torch.Tensor:
    method moveable (line 279) | def moveable(self) -> torch.Tensor:
    method bb3corners_world (line 284) | def bb3corners_world(self) -> torch.Tensor:
    method bb3corners_object (line 288) | def bb3corners_object(self) -> torch.Tensor:
    method bb3edge_pts_object (line 296) | def bb3edge_pts_object(self, num_samples_per_edge: int = 10) -> torch....
    method center (line 320) | def center(self):
    method add_padding (line 355) | def add_padding(self, max_elts: int = 1000) -> "ObbTW":
    method remove_padding (line 377) | def remove_padding(self) -> List["ObbTW"]:
    method _mark_invalid (line 408) | def _mark_invalid(self, invalid_mask: torch.Tensor) -> "ObbTW":
    method _mark_invalid_ids (line 418) | def _mark_invalid_ids(self, invalid_ids: torch.Tensor) -> "ObbTW":
    method num_valid (line 427) | def num_valid(self) -> int:
    method scale_bb2 (line 446) | def scale_bb2(self, scale_rgb: float, scale_slam: float):
    method crop_bb2 (line 491) | def crop_bb2(self, left_top_rgb: Tuple[float], left_top_slam: Tuple[fl...
    method rotate_bb2_cw (line 527) | def rotate_bb2_cw(self, image_sizes: List[Tuple[int]]):
    method rectify_obb2 (line 561) | def rectify_obb2(self, fisheye_cams: List[CameraTW], pinhole_cams: Lis...
    method get_pseudo_bb2 (line 632) | def get_pseudo_bb2(
    method get_bb2_heights (line 654) | def get_bb2_heights(self, cam_id):
    method get_bb2_widths (line 661) | def get_bb2_widths(self, cam_id):
    method get_bb2_areas (line 668) | def get_bb2_areas(self, cam_id):
    method get_bb2_centers (line 675) | def get_bb2_centers(self, cam_id):
    method batch_points_inside_bb3 (line 684) | def batch_points_inside_bb3(self, pts_world: torch.Tensor) -> torch.Te...
    method points_inside_bb3 (line 696) | def points_inside_bb3(
    method _transform (line 708) | def _transform(self, T_new_world):
    method transform (line 716) | def transform(self, T_new_world):
    method _transform_object (line 724) | def _transform_object(self, T_object_new):
    method filter_by_sem_id (line 732) | def filter_by_sem_id(self, keep_sem_ids):
    method filter_by_prob (line 739) | def filter_by_prob(self, prob_thr: float):
    method filter_bb2_center_by_radius (line 745) | def filter_bb2_center_by_radius(self, calib, cam_id):
    method voxel_grid (line 757) | def voxel_grid(self, vD: int, vH: int, vW: int):
    method __repr__ (line 798) | def __repr__(self):
  function _single_transform_obbs (line 802) | def _single_transform_obbs(obbs_padded, Ts_other_world):
  function _batched_transform_obbs (line 819) | def _batched_transform_obbs(obbs_padded, Ts_other_world):
  function transform_obbs (line 832) | def transform_obbs(obbs_padded, Ts_other_world):
  function rot_obb2_cw (line 842) | def rot_obb2_cw(bb2: torch.Tensor, size: Tuple[int]):
  function project_bb3d_onto_image (line 853) | def project_bb3d_onto_image(
  function bb2d_from_project_bb3d (line 956) | def bb2d_from_project_bb3d(
  function bb2_xxyy_to_xyxy (line 1040) | def bb2_xxyy_to_xyxy(bb2s):
  function bb2_xyxy_to_xxyy (line 1051) | def bb2_xyxy_to_xxyy(bb2s):
  function bb3_xyzxyz_to_xxyyzz (line 1062) | def bb3_xyzxyz_to_xxyyzz(bb3s):
  function bb3_xyz_xyz_to_xxyyzz (line 1069) | def bb3_xyz_xyz_to_xxyyzz(bb3s_min, bb3s_max):
  function rnd_obbs (line 1076) | def rnd_obbs(N: int = 1, num_semcls: int = 10, bb3_min_diag=0.1, bb2_min...
  function obb_time_union (line 1092) | def obb_time_union(obbs, pad_size=128):
  function obb_filter_outside_volume (line 1129) | def obb_filter_outside_volume(obbs, T_ws, T_wv, voxel_extent, border=0.1):
  function tensor_linspace (line 1162) | def tensor_linspace(start, end, steps, device):
  function make_obb (line 1192) | def make_obb(sz, position, prob=1.0, roll=0.0, pitch=0.0, yaw=0.1):
  function obb_iou3d (line 1212) | def obb_iou3d(obb1: ObbTW, obb2: ObbTW, samp_per_dim=32):
  function is_point_inside_box (line 1249) | def is_point_inside_box(points: torch.Tensor, box: torch.Tensor, verbose...
  function box_planar_dir (line 1281) | def box_planar_dir(
  function get_plane_verts (line 1346) | def get_plane_verts(box: torch.Tensor) -> torch.Tensor:
  function is_inside (line 1363) | def is_inside(
  function get_plane_center_normal (line 1424) | def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor:

FILE: efm3d/aria/pose.py
  function get_T_rot_z (line 36) | def get_T_rot_z(angle: float):
  function skew_symmetric (line 47) | def skew_symmetric(v):
  function inv_skew_symmetric (line 67) | def inv_skew_symmetric(V):
  function so3exp_map (line 82) | def so3exp_map(w, eps: float = 1e-7):
  function so3log_map (line 99) | def so3log_map(R, eps: float = 1e-7):
  function interpolation_boundaries_alphas (line 118) | def interpolation_boundaries_alphas(times: torch.Tensor, interp_times: t...
  function quaternion_to_matrix (line 159) | def quaternion_to_matrix(quaternions_wxyz: torch.Tensor) -> torch.Tensor:
  class PoseTW (line 194) | class PoseTW(TensorWrapper):
    method __init__ (line 197) | def __init__(self, data: torch.Tensor = IdentityPose):
    method from_Rt (line 204) | def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
    method from_qt (line 220) | def from_qt(cls, quaternion_wxyz: torch.Tensor, t: torch.Tensor):
    method from_aa (line 240) | def from_aa(cls, aa: torch.Tensor, t: torch.Tensor):
    method from_matrix (line 255) | def from_matrix(cls, T: torch.Tensor):
    method from_matrix3x4 (line 266) | def from_matrix3x4(cls, T_3x4: torch.Tensor):
    method exp (line 277) | def exp(cls, u_omega: torch.Tensor, eps: float = 1e-7):
    method R (line 319) | def R(self) -> torch.Tensor:
    method t (line 325) | def t(self) -> torch.Tensor:
    method q (line 330) | def q(self) -> torch.Tensor:
    method q_xyzw (line 408) | def q_xyzw(self) -> torch.Tensor:
    method matrix3x4 (line 418) | def matrix3x4(self) -> torch.Tensor:
    method matrix (line 427) | def matrix(self) -> torch.Tensor:
    method to_euler (line 434) | def to_euler(self, rad=True) -> torch.Tensor:
    method to_ypr (line 451) | def to_ypr(self, rad=True) -> torch.Tensor:
    method inverse (line 462) | def inverse(self) -> "PoseTW":
    method compose (line 468) | def compose(self, other: "PoseTW") -> "PoseTW":
    method transform (line 475) | def transform(self, p3d: torch.Tensor) -> torch.Tensor:
    method batch_transform (line 488) | def batch_transform(self, p3d: torch.Tensor) -> torch.Tensor:
    method rotate (line 507) | def rotate(self, p3d: torch.Tensor) -> torch.Tensor:
    method __mul__ (line 518) | def __mul__(self, p3D: torch.Tensor) -> torch.Tensor:
    method __matmul__ (line 522) | def __matmul__(self, other: "PoseTW") -> "PoseTW":
    method numpy (line 526) | def numpy(self) -> Tuple[np.ndarray]:
    method magnitude (line 529) | def magnitude(self, deg=True, eps=0) -> Tuple[torch.Tensor]:
    method so3_geodesic (line 545) | def so3_geodesic(self, other: "PoseTW", deg=False) -> "PoseTW":
    method log (line 551) | def log(self, eps: float = 1e-6) -> torch.Tensor:
    method interpolate (line 589) | def interpolate(self, times: torch.Tensor, interp_times: torch.Tensor):
    method align (line 625) | def align(self, other, self_times=None, other_times=None):
    method fit_to_SO3 (line 689) | def fit_to_SO3(self):
    method __repr__ (line 697) | def __repr__(self):
  function interpolate_timed_poses (line 701) | def interpolate_timed_poses(
  function lower_timed_poses (line 750) | def lower_timed_poses(
  function closest_timed_poses (line 773) | def closest_timed_poses(
  function all_rot90 (line 800) | def all_rot90():
  function find_r90 (line 816) | def find_r90(Ta, Tb, R90s):
  function stereographic_unproject (line 837) | def stereographic_unproject(a, axis=None):
  function rotation_from_ortho_6d (line 857) | def rotation_from_ortho_6d(ortho6d):
  function rotation_from_ortho_5d (line 881) | def rotation_from_ortho_5d(ortho5d):
  function rotation_from_euler (line 904) | def rotation_from_euler(euler):
  function fit_to_SO3 (line 929) | def fit_to_SO3(R):

FILE: efm3d/aria/projection_utils.py
  function sign_plus (line 18) | def sign_plus(x):
  function fisheye624_project (line 29) | def fisheye624_project(xyz, params):
  function fisheye624_unproject (line 142) | def fisheye624_unproject(uv, params, max_iters: int = 5):
  function pinhole_project (line 303) | def pinhole_project(xyz, params):
  function pinhole_unproject (line 332) | def pinhole_unproject(uv, params, max_iters: int = 5):

FILE: efm3d/aria/tensor_wrapper.py
  function smart_cat (line 33) | def smart_cat(inp_arr, dim=-1):
  function smart_stack (line 46) | def smart_stack(inp_arr, dim: int = 0):
  function get_default_args (line 59) | def get_default_args(func):
  function get_nonempty_arg_names (line 68) | def get_nonempty_arg_names(func):
  function autocast (line 78) | def autocast(func):
  function autoinit (line 106) | def autoinit(func):
  function tensor_wrapper_collate (line 206) | def tensor_wrapper_collate(batch, *, collate_fn_map=None):
  function float_collate (line 211) | def float_collate(batch, *, collate_fn_map=None):
  function list_dict_collate (line 216) | def list_dict_collate(batch, *, collate_fn_map=None):
  function tensor_wrapper_collate_cat (line 231) | def tensor_wrapper_collate_cat(batch, *, collate_fn_map=None):
  function tensor_collate_cat (line 236) | def tensor_collate_cat(batch, *, collate_fn_map=None):
  function custom_collate_fn (line 254) | def custom_collate_fn(batch):
  class TensorWrapper (line 283) | class TensorWrapper:
    method __init__ (line 293) | def __init__(self, data: torch.Tensor):
    method shape (line 297) | def shape(self):
    method device (line 301) | def device(self):
    method dtype (line 305) | def dtype(self):
    method ndim (line 309) | def ndim(self):
    method dim (line 312) | def dim(self):
    method nelement (line 315) | def nelement(self):
    method numel (line 318) | def numel(self):
    method collate_fn (line 322) | def collate_fn(self):
    method is_cuda (line 326) | def is_cuda(self):
    method is_contiguous (line 330) | def is_contiguous(self):
    method requires_grad (line 334) | def requires_grad(self):
    method grad (line 338) | def grad(self):
    method grad_fn (line 342) | def grad_fn(self):
    method requires_grad_ (line 345) | def requires_grad_(self, requires_grad: bool = True):
    method __getitem__ (line 348) | def __getitem__(self, index):
    method __setitem__ (line 351) | def __setitem__(self, index, item):
    method to (line 354) | def to(self, *args, **kwargs):
    method reshape (line 357) | def reshape(self, *args, **kwargs):
    method repeat (line 360) | def repeat(self, *args, **kwargs):
    method expand (line 363) | def expand(self, *args, **kwargs):
    method clone (line 366) | def clone(self):
    method cpu (line 369) | def cpu(self):
    method cuda (line 372) | def cuda(self, gpu_id=0):
    method contiguous (line 375) | def contiguous(self):
    method pin_memory (line 378) | def pin_memory(self):
    method float (line 381) | def float(self):
    method double (line 384) | def double(self):
    method detach (line 387) | def detach(self):
    method numpy (line 390) | def numpy(self):
    method tensor (line 393) | def tensor(self):
    method tolist (line 396) | def tolist(self):
    method squeeze (line 399) | def squeeze(self, dim=None):
    method unsqueeze (line 405) | def unsqueeze(self, dim=None):
    method view (line 409) | def view(self, *shape):
    method __len__ (line 413) | def __len__(self):
    method stack (line 417) | def stack(cls, objects: List, dim=0, *, out=None):
    method cat (line 422) | def cat(cls, objects: List, dim=0, *, out=None):
    method allclose (line 427) | def allclose(
    method take_along_dim (line 440) | def take_along_dim(cls, obj, indices, dim, *, out=None):
    method flatten (line 445) | def flatten(cls, obj, start_dim=0, end_dim=-1):
    method __torch_function__ (line 450) | def __torch_function__(self, func, types, args=(), kwargs=None):

FILE: efm3d/dataset/atek_vrs_dataset.py
  class AtekRawDataloaderAsEfm (line 37) | class AtekRawDataloaderAsEfm:
    method __init__ (line 38) | def __init__(
    method __len__ (line 74) | def __len__(self):
    method get_timestamps_by_sample_index (line 77) | def get_timestamps_by_sample_index(self, index: int) -> List[int]:
    method get_atek_sample_at_timestamps_ns (line 80) | def get_atek_sample_at_timestamps_ns(
    method get_model_specific_sample_at_timestamps_ns (line 85) | def get_model_specific_sample_at_timestamps_ns(
    method __getitem__ (line 112) | def __getitem__(self, index):
  function create_atek_raw_data_loader_from_vrs_path (line 122) | def create_atek_raw_data_loader_from_vrs_path(

FILE: efm3d/dataset/atek_wds_dataset.py
  function batchify (line 37) | def batchify(datum, device=None):
  function unbatchify (line 49) | def unbatchify(datum):
  class AtekWdsStreamDataset (line 57) | class AtekWdsStreamDataset:
    method __init__ (line 60) | def __init__(
    method __len__ (line 113) | def __len__(self):
    method sample_snippet_ (line 116) | def sample_snippet_(self, snippet, start, end):
    method __iter__ (line 130) | def __iter__(self):
    method if_get_next_ (line 133) | def if_get_next_(self):
    method __next__ (line 142) | def __next__(self):

FILE: efm3d/dataset/augmentation.py
  class ColorJitter (line 36) | class ColorJitter:
    method __init__ (line 41) | def __init__(
    method rnd_sharpen (line 68) | def rnd_sharpen(self, im):
    method apply (line 73) | def apply(self, im):
    method __call__ (line 78) | def __call__(self, batch: Dict):
  class PointDrop (line 90) | class PointDrop:
    method __init__ (line 101) | def __init__(
    method __call__ (line 118) | def __call__(self, batch: Dict):
  class PointDropSimple (line 166) | class PointDropSimple:
    method __init__ (line 171) | def __init__(
    method __call__ (line 178) | def __call__(self, batch: Dict):
  class PointJitter (line 195) | class PointJitter:
    method __init__ (line 200) | def __init__(
    method __call__ (line 213) | def __call__(self, batch: Dict):

FILE: efm3d/dataset/efm_model_adaptor.py
  function get_local_pose_helper (line 48) | def get_local_pose_helper(snippet_origin_time_s, batch, local_coordinate):
  function run_local_cosy (line 86) | def run_local_cosy(
  function get_snippet_cosy_from_rig (line 148) | def get_snippet_cosy_from_rig(
  function get_snippet_cosy_from_cam_rgb (line 166) | def get_snippet_cosy_from_cam_rgb(
  class EfmModelAdaptor (line 195) | class EfmModelAdaptor:
    method __init__ (line 205) | def __init__(
    method get_dict_key_mapping_for_camera (line 228) | def get_dict_key_mapping_for_camera(atek_camera_label: str, efm_camera...
    method get_dict_key_mapping_all (line 242) | def get_dict_key_mapping_all():
    method _get_pose_to_align_gravity (line 272) | def _get_pose_to_align_gravity(self, sample_dict: Dict) -> Optional[Po...
    method _load_taxonomy_mapping_file (line 297) | def _load_taxonomy_mapping_file(self, filename: str) -> Dict:
    method _fill_dict_with_freq (line 316) | def _fill_dict_with_freq(self, sample_dict: Dict) -> Dict:
    method _convert_to_batched_camera_tw (line 333) | def _convert_to_batched_camera_tw(
    method _update_efm_obb_gt (line 370) | def _update_efm_obb_gt(self, atek_gt_dict: Dict) -> Dict:
    method _pad_semidense_data (line 492) | def _pad_semidense_data(self, sample_dict: Dict) -> Dict:
    method _pad_over_frames (line 520) | def _pad_over_frames(self, sample_dict: Dict, fields_to_pad: List[str]...
    method _split_pose_over_snippet (line 533) | def _split_pose_over_snippet(self, sample_dict: Dict) -> Dict:
    method _split_timestamps_over_snippet (line 579) | def _split_timestamps_over_snippet(self, sample_dict: Dict) -> Dict:
    method atek_to_efm (line 602) | def atek_to_efm(self, data, train=False):
  function load_atek_wds_dataset_as_efm (line 706) | def load_atek_wds_dataset_as_efm(
  function load_atek_wds_dataset_as_efm_train (line 733) | def load_atek_wds_dataset_as_efm_train(

FILE: efm3d/dataset/vrs_dataset.py
  function is_adt (line 74) | def is_adt(vrs_path):
  function is_aeo (line 84) | def is_aeo(vrs_path):
  function get_transform_to_vio_gravity_convention (line 88) | def get_transform_to_vio_gravity_convention(gravity_direction: np.array):
  function compute_time_intersection (line 121) | def compute_time_intersection(time_lists):
  function preprocess_inference (line 137) | def preprocess_inference(batch):
  function preprocess (line 157) | def preprocess(
  function tensor_unify (line 205) | def tensor_unify(tensor, dim_size: int, dim: int = 0):
  function run_sensor_poses (line 254) | def run_sensor_poses(batch, num_notified=-1, max_notified=10):
  class VrsSequenceDataset (line 283) | class VrsSequenceDataset(Dataset):
    method __init__ (line 284) | def __init__(
    method load_objects (line 366) | def load_objects(self):
    method load_semidense (line 439) | def load_semidense(self, vrs_path, max_inv_depth_std=0.005, max_depth_...
    method load_poses (line 532) | def load_poses(self, vrs_path, subsample):
    method load_snippet_pose (line 571) | def load_snippet_pose(self, start, end):
    method load_snippet_semidense (line 583) | def load_snippet_semidense(self, start, end, max_size=20000):
    method load_snippet_objects (line 606) | def load_snippet_objects(self, start, end):
    method __len__ (line 696) | def __len__(self):
    method __getitem__ (line 699) | def __getitem__(self, index):

FILE: efm3d/dataset/wds_dataset.py
  function convert_to_aria_multimodal_dataset (line 36) | def convert_to_aria_multimodal_dataset(sample):
  function batchify (line 122) | def batchify(datum, device=None):
  function unbatchify (line 134) | def unbatchify(datum):
  function get_tar_sample_num (line 142) | def get_tar_sample_num(tar_file):
  class WdsStreamDataset (line 150) | class WdsStreamDataset:
    method __init__ (line 153) | def __init__(
    method __len__ (line 201) | def __len__(self):
    method sample_snippet_ (line 204) | def sample_snippet_(self, snippet, start, end):
    method __iter__ (line 218) | def __iter__(self):
    method if_get_next_ (line 221) | def if_get_next_(self):
    method __next__ (line 230) | def __next__(self):

FILE: efm3d/inference/eval.py
  function check_sem_id_conflict (line 29) | def check_sem_id_conflict(ids_pred, ids_gt):
  function evaluate_obb_csv (line 42) | def evaluate_obb_csv(
  function obb_eval_dataset (line 140) | def obb_eval_dataset(input_folder: str, iou: float = 0.2):
  function main (line 227) | def main():

FILE: efm3d/inference/fuse.py
  function set_boundary_value (line 35) | def set_boundary_value(x, val, thickness):
  function load_tensor (line 47) | def load_tensor(fname, device):
  class VolumeFusion (line 54) | class VolumeFusion:
    method __init__ (line 55) | def __init__(
    method set_boundary_mask (line 98) | def set_boundary_mask(self, mask):
    method fuse (line 112) | def fuse(
    method get_volume (line 168) | def get_volume(self, reshape=True):
    method get_weights (line 174) | def get_weights(self, reshape=True):
    method get_mask (line 180) | def get_mask(self, reshape=True):
    method get_trimesh (line 187) | def get_trimesh(self, iso_level=0.5):
  class VolumetricFusion (line 202) | class VolumetricFusion:
    method __init__ (line 203) | def __init__(
    method reinit (line 248) | def reinit(self):
    method init_from_range (line 254) | def init_from_range(self, xyz_min, xyz_max):
    method get_trimesh (line 283) | def get_trimesh(self):
    method run_step (line 286) | def run_step(self, i):
    method run (line 303) | def run(self):

FILE: efm3d/inference/model.py
  class EfmInference (line 41) | class EfmInference:
    method __init__ (line 42) | def __init__(self, streamer, model, output_dir, device, zip, obb_only=...
    method __del__ (line 63) | def __del__(self):
    method save_tensor (line 75) | def save_tensor(self, tensor, key, idx=None, output_dir=""):
    method save_output (line 82) | def save_output(self, data, idx, output_dir):
    method run (line 148) | def run(self):

FILE: efm3d/inference/pipeline.py
  function get_gt_mesh_ply (line 33) | def get_gt_mesh_ply(data_path):
  function compute_avg_metrics (line 51) | def compute_avg_metrics(paths):
  function create_streamer (line 71) | def create_streamer(
  function create_output_dir (line 122) | def create_output_dir(output_dir, model_ckpt, data_path):
  function run_one (line 137) | def run_one(

FILE: efm3d/inference/track.py
  function track_obbs (line 27) | def track_obbs(input_path, prob_inst_thr=0.3, prob_assoc_thr=0.25):
  function main (line 83) | def main():

FILE: efm3d/inference/viz.py
  function find_nearest (line 52) | def find_nearest(array, value):
  function fill_obbs_to_snippet (line 66) | def fill_obbs_to_snippet(obbs, rgb_ts, T_ws):
  function compose_views (line 86) | def compose_views(view_dict, keys, vertical=True):
  function draw_scene_with_mesh_and_obbs (line 104) | def draw_scene_with_mesh_and_obbs(
  function render_views (line 141) | def render_views(snippet, h, w, pred_sem_ids_to_names, gt_sem_ids_to_nam...
  function generate_video (line 230) | def generate_video(

FILE: efm3d/model/cnn.py
  function cnn_weight_initialization (line 23) | def cnn_weight_initialization(modules):
  class GELU (line 37) | class GELU(nn.Module):
    method forward (line 38) | def forward(self, x):
  class LayerNorm2d (line 42) | class LayerNorm2d(nn.LayerNorm):
    method __init__ (line 47) | def __init__(self, num_channels, eps=1e-6, affine=True):
    method forward (line 50) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class UpsampleCNN (line 60) | class UpsampleCNN(nn.Module):
    method __init__ (line 61) | def __init__(
    method forward (line 120) | def forward(self, x, force_hw=None):
  class LayerNorm3d (line 159) | class LayerNorm3d(nn.LayerNorm):
    method __init__ (line 164) | def __init__(self, num_channels, eps=1e-6, affine=True):
    method forward (line 167) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class UpConv3d (line 177) | class UpConv3d(torch.nn.Module):
    method __init__ (line 178) | def __init__(self, dim_in, dim_out):
    method forward (line 191) | def forward(self, x_up):
  class FpnUpConv3d (line 197) | class FpnUpConv3d(torch.nn.Module):
    method __init__ (line 198) | def __init__(self, dim_in, dim_out):
    method forward (line 212) | def forward(self, x_up, x_lat):
  class InvBottleNeck3d (line 220) | class InvBottleNeck3d(torch.nn.Module):
    method __init__ (line 221) | def __init__(self, dim_in, dim_out, stride: int = 1, expansion: float ...
    method forward (line 238) | def forward(self, x):
  class InvResnetBlock3d (line 247) | class InvResnetBlock3d(torch.nn.Module):
    method __init__ (line 248) | def __init__(
    method forward (line 260) | def forward(self, x):
  class InvResnetFpn3d (line 266) | class InvResnetFpn3d(torch.nn.Module):
    method __init__ (line 267) | def __init__(self, dims, num_bottles, strides, expansions, freeze=False):
    method forward (line 297) | def forward(self, x):
  class VolumeCNN (line 312) | class VolumeCNN(nn.Module):
    method __init__ (line 318) | def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False):
    method forward (line 346) | def forward(self, x):
  class VolumeCNNHead (line 362) | class VolumeCNNHead(nn.Module):
    method __init__ (line 363) | def __init__(
    method forward (line 419) | def forward(self, x):
  class ResidualConvUnit3d (line 434) | class ResidualConvUnit3d(nn.Module):
    method __init__ (line 437) | def __init__(self, features, kernel_size):
    method forward (line 448) | def forward(self, x):
  class FeatureFusionBlock3d (line 452) | class FeatureFusionBlock3d(nn.Module):
    method __init__ (line 455) | def __init__(self, features, kernel_size, with_skip=True):
    method forward (line 463) | def forward(self, x, skip_x=None):
  class VolumeResnet (line 472) | class VolumeResnet(nn.Module):
    method __init__ (line 473) | def __init__(self, hidden_dims, conv3=nn.Conv3d, freeze=False):
    method forward (line 507) | def forward(self, x):

FILE: efm3d/model/dinov2_utils.py
  class Attention (line 54) | class Attention(nn.Module):
    method __init__ (line 55) | def __init__(
    method forward (line 74) | def forward(self, x: Tensor) -> Tensor:
  class MemEffAttention (line 94) | class MemEffAttention(Attention):
    method forward (line 95) | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
  class CrossAttention (line 113) | class CrossAttention(nn.Module):
    method __init__ (line 114) | def __init__(
    method forward (line 134) | def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor:
  class MemEffCrossAttention (line 160) | class MemEffCrossAttention(CrossAttention):
    method forward (line 161) | def forward(self, x_kv: Tensor, x_q: Tensor, attn_bias=None) -> Tensor:
  class Mlp (line 181) | class Mlp(nn.Module):
    method __init__ (line 182) | def __init__(
    method forward (line 199) | def forward(self, x: Tensor) -> Tensor:
    method __init__ (line 623) | def __init__(
    method forward (line 640) | def forward(self, x: Tensor) -> Tensor:
  class Block (line 208) | class Block(nn.Module):
    method __init__ (line 209) | def __init__(
    method forward (line 257) | def forward(self, x: Tensor) -> Tensor:
  class CrossBlock (line 285) | class CrossBlock(nn.Module):
    method __init__ (line 286) | def __init__(
    method forward (line 335) | def forward(self, x_kv: Tensor, x_q: Tensor) -> Tensor:
  function drop_add_residual_stochastic_depth (line 351) | def drop_add_residual_stochastic_depth(
  function get_branges_scales (line 377) | def get_branges_scales(x, sample_drop_ratio=0.0):
  function add_residual (line 385) | def add_residual(x, brange, residual, residual_scale_factor, scaling_vec...
  function get_attn_bias_and_cat (line 406) | def get_attn_bias_and_cat(x_list, branges=None):
  function drop_add_residual_stochastic_depth_list (line 436) | def drop_add_residual_stochastic_depth_list(
  class NestedTensorBlock (line 467) | class NestedTensorBlock(Block):
    method forward_nested (line 468) | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
    method forward (line 512) | def forward(self, x_or_x_list):
  class DINOHead (line 524) | class DINOHead(nn.Module):
    method __init__ (line 525) | def __init__(
    method _init_weights (line 549) | def _init_weights(self, m):
    method forward (line 555) | def forward(self, x):
  function _build_mlp (line 563) | def _build_mlp(
  function drop_path (line 582) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  class DropPath (line 596) | class DropPath(nn.Module):
    method __init__ (line 599) | def __init__(self, drop_prob=None):
    method forward (line 603) | def forward(self, x):
  class LayerScale (line 607) | class LayerScale(nn.Module):
    method __init__ (line 608) | def __init__(
    method forward (line 618) | def forward(self, x: Tensor) -> Tensor:
  class Mlp (line 622) | class Mlp(nn.Module):
    method __init__ (line 182) | def __init__(
    method forward (line 199) | def forward(self, x: Tensor) -> Tensor:
    method __init__ (line 623) | def __init__(
    method forward (line 640) | def forward(self, x: Tensor) -> Tensor:
  function make_2tuple (line 649) | def make_2tuple(x):
  class PatchEmbed (line 658) | class PatchEmbed(nn.Module):
    method __init__ (line 670) | def __init__(
    method forward (line 703) | def forward(self, x: Tensor) -> Tensor:
    method flops (line 722) | def flops(self) -> float:
  class SwiGLUFFN (line 736) | class SwiGLUFFN(nn.Module):
    method __init__ (line 737) | def __init__(
    method forward (line 752) | def forward(self, x: Tensor) -> Tensor:
  class SwiGLUFFNFused (line 768) | class SwiGLUFFNFused(SwiGLU):
    method __init__ (line 769) | def __init__(
  function named_apply (line 789) | def named_apply(
  class BlockChunk (line 808) | class BlockChunk(nn.ModuleList):
    method forward (line 809) | def forward(self, x):
  class DinoVisionTransformer (line 815) | class DinoVisionTransformer(nn.Module):
    method __init__ (line 816) | def __init__(
    method init_weights (line 958) | def init_weights(self):
    method interpolate_pos_encoding (line 965) | def interpolate_pos_encoding(self, x, w, h):
    method prepare_tokens_with_masks (line 999) | def prepare_tokens_with_masks(self, x, masks=None):
    method forward_features_list (line 1022) | def forward_features_list(self, x_list, masks_list):
    method forward_features (line 1045) | def forward_features(self, x, masks=None):
    method forward_features_multi (line 1063) | def forward_features_multi(self, x, masks=None):
    method _get_intermediate_layers_not_chunked (line 1098) | def _get_intermediate_layers_not_chunked(self, x, n=1):
    method _get_intermediate_layers_chunked (line 1114) | def _get_intermediate_layers_chunked(self, x, n=1):
    method get_intermediate_layers (line 1132) | def get_intermediate_layers(
    method forward (line 1160) | def forward(self, *args, is_training=False, **kwargs):
  function init_weights_vit_timm (line 1168) | def init_weights_vit_timm(module: nn.Module, name: str = ""):
  function vit_small (line 1176) | def vit_small(patch_size, **kwargs):
  function vit_small_reg (line 1190) | def vit_small_reg(patch_size, **kwargs):
  function vit_base (line 1207) | def vit_base(patch_size, **kwargs):
  function vit_base_reg (line 1220) | def vit_base_reg(patch_size, **kwargs):
  function vit_large (line 1236) | def vit_large(patch_size, **kwargs):
  function vit_large_reg (line 1249) | def vit_large_reg(patch_size, **kwargs):
  function vit_giant2 (line 1265) | def vit_giant2(patch_size, **kwargs):
  class DinoV2Wrapper (line 1322) | class DinoV2Wrapper(torch.nn.Module):
    method __init__ (line 1327) | def __init__(
    method forward (line 1376) | def forward(self, img):

FILE: efm3d/model/dpt.py
  class ResidualConvUnit (line 22) | class ResidualConvUnit(nn.Module):
    method __init__ (line 25) | def __init__(self, features, kernel_size):
    method forward (line 36) | def forward(self, x):
  class FeatureFusionBlock (line 40) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 43) | def __init__(self, features, kernel_size, with_skip=True):
    method forward (line 51) | def forward(self, x, skip_x=None):
  class Interpolate (line 63) | class Interpolate(nn.Module):
    method __init__ (line 68) | def __init__(self, scale_factor, mode, align_corners=False):
    method forward (line 82) | def forward(self, x):
  class DPTOri (line 101) | class DPTOri(nn.Module):
    method __init__ (line 106) | def __init__(self, input_dim, hidden_dim=256, output_dim=256, depth=Fa...
    method forward (line 163) | def forward(self, feats):

FILE: efm3d/model/evl.py
  class EVL (line 35) | class EVL(torch.nn.Module):
    method __init__ (line 36) | def __init__(
    method post_process (line 137) | def post_process(self, batch, out):
    method forward (line 172) | def forward(self, batch, obb_only=False):

FILE: efm3d/model/evl_train.py
  class EVLTrain (line 59) | class EVLTrain(EVL):
    method __init__ (line 60) | def __init__(
    method compute_losses (line 82) | def compute_losses(self, outputs, batch):
    method render2d (line 122) | def render2d(self, imgs, obbs, Ts_wr, cams):
    method log_single_obb (line 141) | def log_single_obb(self, batch, outputs, batch_idx):
    method log_single (line 242) | def log_single(self, batch, outputs, batch_idx):
    method render3d_mesh (line 368) | def render3d_mesh(
    method render3d_points (line 457) | def render3d_points(
    method render3d_obb (line 543) | def render3d_obb(
    method render3d_occ (line 611) | def render3d_occ(
    method get_log_images (line 689) | def get_log_images(self, batch, outputs):
    method reset_metrics (line 697) | def reset_metrics(self):
    method update_metrics (line 710) | def update_metrics(self, outputs, batch):
    method compute_metrics (line 732) | def compute_metrics(self):

FILE: efm3d/model/image_tokenizer.py
  class ImageToDinoV2Tokens (line 29) | class ImageToDinoV2Tokens(torch.nn.Module):
    method __init__ (line 34) | def __init__(
    method feat_dim (line 76) | def feat_dim(self):
    method patch_size (line 79) | def patch_size(self):
    method post_process (line 82) | def post_process(self, feats, B, T, out_size=None):
    method forward_resize (line 105) | def forward_resize(self, img: torch.Tensor) -> torch.Tensor:
    method forward (line 123) | def forward(self, img: torch.Tensor) -> torch.Tensor:

FILE: efm3d/model/lifter.py
  class VideoBackbone3d (line 44) | class VideoBackbone3d(torch.nn.Module, ABC):
    method __init__ (line 49) | def __init__(
    method feat_dim (line 61) | def feat_dim(self):
    method forward_impl (line 64) | def forward_impl(self, batch):
    method forward (line 67) | def forward(self, batch):
  class Lifter (line 86) | class Lifter(VideoBackbone3d):
    method __init__ (line 91) | def __init__(
    method output_dim (line 165) | def output_dim(self):
    method get_freespace_world (line 178) | def get_freespace_world(self, batch, batch_idx, T_wv, vW, vH, vD, S=1):
    method get_points_world (line 230) | def get_points_world(self, batch, batch_idx, keep_T=False):
    method get_freespace_counts (line 260) | def get_freespace_counts(
    method get_points_counts (line 295) | def get_points_counts(
    method get_voxelgrid_pose (line 343) | def get_voxelgrid_pose(self, cams, T_ws, Ts_sr):
    method lift (line 360) | def lift(self, feats2d, vox_w, cam, Ts_wr, vD, vH, vW):
    method aggregate (line 377) | def aggregate(self, vox_feats, vox_valid):
    method lift_aggregate_centers (line 396) | def lift_aggregate_centers(self, batch, feats2d, vox_w, Ts_wr, T_wv=No...
    method forward (line 458) | def forward(self, batch):

FILE: efm3d/model/video_backbone.py
  class VideoBackbone (line 31) | class VideoBackbone(torch.nn.Module, ABC):
    method __init__ (line 37) | def __init__(
    method feat_dim (line 76) | def feat_dim(self):
    method feat_dim (line 80) | def feat_dim(self, _feat_dim: int):
    method patch_size (line 84) | def patch_size(self):
    method forward_impl (line 87) | def forward_impl(self, img, stream) -> Dict[str, torch.Tensor]:
    method forward (line 97) | def forward(self, batch):
  class VideoBackboneDinov2 (line 129) | class VideoBackboneDinov2(VideoBackbone):
    method __init__ (line 134) | def __init__(
    method patch_size (line 161) | def patch_size(self):
    method forward_impl (line 164) | def forward_impl(self, img, stream):

FILE: efm3d/thirdparty/mmdetection3d/cuda/cuda_utils.h
  function opt_n_thread (line 15) | inline int opt_n_thread(int work_size) {
  function dim3 (line 20) | inline dim3 opt_block_config(int x, int y) {

FILE: efm3d/thirdparty/mmdetection3d/cuda/iou3d.cpp
  function gpuAssert (line 27) | inline void
  function boxes_overlap_bev_gpu (line 62) | int boxes_overlap_bev_gpu(
  function boxes_iou_bev_gpu (line 87) | int boxes_iou_bev_gpu(
  function nms_gpu (line 111) | int nms_gpu(
  function nms_normal_gpu (line 171) | int nms_normal_gpu(

FILE: efm3d/thirdparty/mmdetection3d/cuda/sort_vert.cpp
  function sort_vertices (line 20) | at::Tensor
  function PYBIND11_MODULE (line 54) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: efm3d/thirdparty/mmdetection3d/iou3d.py
  class SortVertices (line 17) | class SortVertices(Function):
    method forward (line 19) | def forward(ctx, vertices, mask, num_valid):
    method backward (line 25) | def backward(ctx, gradout):
  function box_intersection (line 29) | def box_intersection(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor...
  function box1_in_box2 (line 70) | def box1_in_box2(corners1: Tensor, corners2: Tensor) -> Tensor:
  function box_in_box (line 101) | def box_in_box(corners1: Tensor, corners2: Tensor) -> Tuple[Tensor, Tens...
  function build_vertices (line 118) | def build_vertices(
  function sort_indices (line 153) | def sort_indices(vertices: Tensor, mask: Tensor) -> Tensor:
  function calculate_area (line 180) | def calculate_area(idx_sorted: Tensor, vertices: Tensor) -> Tuple[Tensor...
  function oriented_box_intersection_2d (line 203) | def oriented_box_intersection_2d(
  function box2corners (line 226) | def box2corners(box: Tensor) -> Tensor:
  function diff_iou_rotated_2d (line 254) | def diff_iou_rotated_2d(box1: Tensor, box2: Tensor) -> Tensor:
  function diff_iou_rotated_3d (line 274) | def diff_iou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor:
  function rotated_iou_3d_loss (line 301) | def rotated_iou_3d_loss(pred, target):
  class RotatedIoU3DLoss (line 318) | class RotatedIoU3DLoss(torch.nn.Module):
    method __init__ (line 325) | def __init__(self, loss_weight=1.0):
    method forward (line 329) | def forward(
  function boxes_iou_bev (line 352) | def boxes_iou_bev(boxes_a, boxes_b):
  function nms_gpu (line 369) | def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
  function nms_normal_gpu (line 397) | def nms_normal_gpu(boxes, scores, thresh):

FILE: efm3d/utils/common.py
  function sample_nearest (line 19) | def sample_nearest(value_a, value_b, array_b):
  function find_nearest (line 27) | def find_nearest(array, value, return_index=False):

FILE: efm3d/utils/depth.py
  function dist_im_to_point_cloud_im (line 19) | def dist_im_to_point_cloud_im(dist_m, cams):

FILE: efm3d/utils/detection_utils.py
  function norm2ind (line 22) | def norm2ind(norm_xyz, vD, vH, vW):
  function ind2norm (line 49) | def ind2norm(inds_dhw, vD, vH, vW):
  function normalize_coord3d (line 62) | def normalize_coord3d(xyz, extent):
  function unnormalize_coord3d (line 74) | def unnormalize_coord3d(xyz_n, extent):
  function create_heatmap_gt (line 86) | def create_heatmap_gt(mu_xy, H, W, valid=None):
  function simple_nms (line 126) | def simple_nms(scores, nms_radius: int):
  function simple_nms3d (line 146) | def simple_nms3d(scores, nms_radius: int):
  function heatmap2obb (line 166) | def heatmap2obb(scores, threshold=0.3, size=20, max_elts=1000):
  function compute_focal_loss (line 192) | def compute_focal_loss(pred, gt, focal_gamma=2, focal_alpha=0.25):
  function compute_chamfer_loss (line 220) | def compute_chamfer_loss(vals, target):
  function obb2voxel (line 232) | def obb2voxel(obb_v, vD, vH, vW, voxel_extent, num_class, splat_sigma=2):
  function voxel2obb (line 370) | def voxel2obb(

FILE: efm3d/utils/evl_loss.py
  function get_gt_obbs (line 34) | def get_gt_obbs(batch, voxel_extent, T_wv=None):
  function obbs_to_7d (line 58) | def obbs_to_7d(obbs):
  function iou_3d_loss (line 74) | def iou_3d_loss(obbs_pr, obbs_gt, cent_pr, cent_gt, valid_gt):
  function compute_obb_losses (line 97) | def compute_obb_losses(
  function compute_occ_losses (line 195) | def compute_occ_losses(

FILE: efm3d/utils/file_utils.py
  function load_gt_calibration (line 38) | def load_gt_calibration(
  function get_image_info (line 102) | def get_image_info(image_reader: SyncVRSReader) -> Tuple[Dict, Dict]:
  function load_factory_calib (line 124) | def load_factory_calib(
  function load_2d_bounding_boxes (line 168) | def load_2d_bounding_boxes(bb2d_path, time_in_secs=False):
  function load_2d_bounding_boxes_adt (line 224) | def load_2d_bounding_boxes_adt(bb2d_path):
  function remove_invalid_2d_bbs (line 278) | def remove_invalid_2d_bbs(timed_bb2s, filter_bb2_area=-1):
  function load_instances (line 296) | def load_instances(instances_path):
  function load_instances_adt (line 314) | def load_instances_adt(instances_path):
  function load_3d_bounding_box_transforms (line 331) | def load_3d_bounding_box_transforms(scene_path, time_in_secs=False, load...
  function load_3d_bounding_box_local_extents (line 377) | def load_3d_bounding_box_local_extents(bb3d_path, load_torch=False):
  function load_obbs_gt (line 408) | def load_obbs_gt(
  function load_trajectory_adt (line 532) | def load_trajectory_adt(
  function load_trajectory_aeo (line 585) | def load_trajectory_aeo(
  function load_trajectory (line 652) | def load_trajectory(
  function parse_global_name_to_id_csv (line 714) | def parse_global_name_to_id_csv(csv_path: str, verbose: bool = True) -> ...
  function exists_nonzero_path (line 741) | def exists_nonzero_path(path: Union[str, list]) -> Optional[str]:
  function get_timestamp_list_ns (line 767) | def get_timestamp_list_ns(reader, stream_id=None):
  function sample_times (line 780) | def sample_times(time_list: List, start_time: int, end_time: int) -> Tup...
  function sample_from_range (line 819) | def sample_from_range(
  function read_image_from_vrs (line 881) | def read_image_from_vrs(
  function read_image_snippet_from_vrs (line 945) | def read_image_snippet_from_vrs(
  function load_global_points_csv (line 1007) | def load_global_points_csv(
  function load_semidense_observations (line 1066) | def load_semidense_observations(path: str):

FILE: efm3d/utils/gravity.py
  function get_transform_to_vio_gravity_convention (line 25) | def get_transform_to_vio_gravity_convention(gravity_direction: np.array):
  function correct_adt_mesh_gravity (line 58) | def correct_adt_mesh_gravity(mesh):
  function reject_vector_a_from_b (line 69) | def reject_vector_a_from_b(a, b):
  function gravity_align_T_world_cam (line 79) | def gravity_align_T_world_cam(

FILE: efm3d/utils/image.py
  function string2color (line 34) | def string2color(string):
  function normalize (line 50) | def normalize(img, robust=0.0, eps=1e-6):
  function put_text (line 73) | def put_text(
  function rotate_image90 (line 123) | def rotate_image90(image: np.ndarray, k: int = 3):
  function smart_resize (line 134) | def smart_resize(
  function torch2cv2 (line 174) | def torch2cv2(
  function numpy2mp4 (line 243) | def numpy2mp4(imgs, output_path, fps=10):

FILE: efm3d/utils/image_sampling.py
  function compute_factor (line 21) | def compute_factor(size):
  function convert_pixel_to_coordinates (line 25) | def convert_pixel_to_coordinates(coordinates, factor):
  function normalize_keypoints (line 29) | def normalize_keypoints(kpts, height, width):
  function sample_images (line 39) | def sample_images(

FILE: efm3d/utils/marching_cubes.py
  function marching_cubes_scaled (line 24) | def marching_cubes_scaled(values, isolevel, voxel_extent, voxel_mask):

FILE: efm3d/utils/mesh_utils.py
  function point_to_closest_vertex_dist (line 25) | def point_to_closest_vertex_dist(pts, verts, tris):
  function point_to_closest_tri_dist (line 45) | def point_to_closest_tri_dist(pts, verts, tris):
  function compute_pts_to_mesh_dist (line 95) | def compute_pts_to_mesh_dist(pts, faces, verts, step):
  function eval_mesh_to_mesh (line 118) | def eval_mesh_to_mesh(

FILE: efm3d/utils/obb_csv_writer.py
  class ObbCsvReader (line 25) | class ObbCsvReader:
    method __init__ (line 26) | def __init__(self, file_name):
    method parse_row (line 37) | def parse_row(self, row):
    method __iter__ (line 108) | def __iter__(self):
    method __next__ (line 111) | def __next__(self):
    method obbs (line 132) | def obbs(self):
  class ObbCsvWriter (line 143) | class ObbCsvWriter:
    method __init__ (line 144) | def __init__(self, file_name=""):
    method write_rows (line 158) | def write_rows(self):
    method write (line 164) | def write(
    method flush (line 206) | def flush(self):
    method __del__ (line 209) | def __del__(self):

FILE: efm3d/utils/obb_io.py
  function bb2extent (line 22) | def bb2extent(bb):
  function extent2bb (line 35) | def extent2bb(extent):
  function get_all_Ts_world_object_for_time (line 80) | def get_all_Ts_world_object_for_time(
  function get_inst_id_in_camera (line 126) | def get_inst_id_in_camera(
  function get_instance_id_in_frameset (line 152) | def get_instance_id_in_frameset(
  function get_bb2s_for_instances (line 206) | def get_bb2s_for_instances(obs, time, inst_ids, cam_names, cam_scales=No...
  function next_obb_observations (line 243) | def next_obb_observations(

FILE: efm3d/utils/obb_matchers.py
  class HungarianMatcher2d3d (line 29) | class HungarianMatcher2d3d(torch.nn.Module):
    method __init__ (line 36) | def __init__(
    method forward_obbs (line 65) | def forward_obbs(
    method forward (line 102) | def forward(

FILE: efm3d/utils/obb_metrics.py
  class ObbMetrics (line 33) | class ObbMetrics(torch.nn.Module):
    method __init__ (line 39) | def __init__(
    method update (line 118) | def update(self, prediction: ObbTW, target: ObbTW, cam: Optional[Camer...
    method forward (line 151) | def forward(self, prediction: ObbTW, target: ObbTW):
    method update_3d (line 155) | def update_3d(
    method update_2d (line 184) | def update_2d(
    method update_2d_instances (line 211) | def update_2d_instances(
    method compute (line 227) | def compute(self):
    method reset (line 251) | def reset(self):

FILE: efm3d/utils/obb_trackers.py
  function nms_3d (line 32) | def nms_3d(
  function nms_2d (line 97) | def nms_2d(obbs, nms_iou2_thr: float, verbose: bool = False):
  class ObbTracker (line 128) | class ObbTracker:
    method __init__ (line 134) | def __init__(
    method reset (line 226) | def reset(self):
    method set_hz (line 234) | def set_hz(self, hz: float):
    method obbs_world (line 239) | def obbs_world(self):
    method track (line 294) | def track(
    method update_last_obs_time (line 514) | def update_last_obs_time(self, cam, T_world_rig):
    method add_new_obbs (line 534) | def add_new_obbs(self, new_obbs, new_probs_full):
    method nms_3d (line 567) | def nms_3d(self, obbs):
    method nms_2d (line 571) | def nms_2d(self, obbs):
    method set_2d_bbs (line 574) | def set_2d_bbs(self, obbs_w: ObbTW, cam: CameraTW, T_world_rig: PoseTW):

FILE: efm3d/utils/obb_utils.py
  class IouOutputs (line 44) | class IouOutputs:
  function input_validator_box3d (line 49) | def input_validator_box3d(  # noqa
  class MAPMetricResults3D (line 103) | class MAPMetricResults3D(BaseMetricResults):
  function box3d_volume (line 116) | def box3d_volume(boxes: Tensor) -> Tensor:
  function box3d_convert (line 136) | def box3d_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
  class MeanAveragePrecision3D (line 145) | class MeanAveragePrecision3D(MeanAveragePrecision):
    method __init__ (line 146) | def __init__(
    method update (line 208) | def update(
    method _compute_iou (line 290) | def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
    method _evaluate_image (line 330) | def _evaluate_image(
    method _summarize_results (line 443) | def _summarize_results(
    method compute (line 497) | def compute(self, sem_id_to_name_mapping: Optional[Dict[int, str]] = N...
  function coplanar_mask (line 534) | def coplanar_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None:
  function nonzero_area_mask (line 555) | def nonzero_area_mask(boxes: torch.Tensor, eps: float = 1e-4) -> None:
  function bb3_valid (line 571) | def bb3_valid(boxes: torch.Tensor, eps: float = 1e-4) -> None:
  function box3d_overlap_wrapper (line 579) | def box3d_overlap_wrapper(
  function remove_invalid_box3d (line 607) | def remove_invalid_box3d(obbs: ObbTW, mark_in_place: bool = False) -> to...
  function prec_recall_bb3 (line 627) | def prec_recall_bb3(
  function prec_recall_curve (line 752) | def prec_recall_curve(
  function draw_prec_recall_curve (line 810) | def draw_prec_recall_curve(

FILE: efm3d/utils/pointcloud.py
  function get_points_world (line 33) | def get_points_world(batch, batch_idx=None, dist_std0=0.04, prefer_point...
  function get_freespace_world (line 73) | def get_freespace_world(
  function collapse_pointcloud_time (line 172) | def collapse_pointcloud_time(pc_w):
  function pointcloud_to_voxel_ids (line 183) | def pointcloud_to_voxel_ids(pc_v, vW, vH, vD, voxel_extent):
  function pointcloud_to_occupancy_snippet (line 216) | def pointcloud_to_occupancy_snippet(
  function pointcloud_occupancy_samples (line 292) | def pointcloud_occupancy_samples(
  function pointcloud_to_occupancy (line 357) | def pointcloud_to_occupancy(
  function pointcloud_to_voxel_counts (line 409) | def pointcloud_to_voxel_counts(points_v, voxel_extent, vW, vH, vD):
  function get_points_counts (line 449) | def get_points_counts(
  function get_freespace_counts (line 497) | def get_freespace_counts(

FILE: efm3d/utils/ray.py
  function grid_ray (line 24) | def grid_ray(pixel_grid, camera):
  function ray_grid (line 72) | def ray_grid(cam: CameraTW):
  function transform_rays (line 87) | def transform_rays(rays_old: torch.Tensor, T_new_old):
  function ray_obb_intersection (line 99) | def ray_obb_intersection(
  function sample_depths_in_grid (line 178) | def sample_depths_in_grid(

FILE: efm3d/utils/reconstruction.py
  function build_gt_occupancy (line 36) | def build_gt_occupancy(occ, visible, p3s_w, Ts_wc, cams, T_wv, voxel_ext...
  function get_fused_gt_feat (line 62) | def get_fused_gt_feat(
  function get_feats_world (line 111) | def get_feats_world(batch, tgt_feats):
  function compute_tv_loss (line 144) | def compute_tv_loss(occ):
  function compute_occupancy_loss_subvoxel (line 153) | def compute_occupancy_loss_subvoxel(

FILE: efm3d/utils/render.py
  function get_colors (line 32) | def get_colors(num_colors: int, scale_to_255: bool = False):
  function get_colors_from_sem_map (line 80) | def get_colors_from_sem_map(
  function draw_bb2s (line 116) | def draw_bb2s(
  function draw_bb3_lines (line 194) | def draw_bb3_lines(
  function draw_bb3s (line 248) | def draw_bb3s(
  function draw_obbs_image (line 355) | def draw_obbs_image(
  function draw_obbs_snippet (line 455) | def draw_obbs_snippet(
  function discretize_values (line 521) | def discretize_values(values: torch.Tensor, precision: int):

FILE: efm3d/utils/rescale.py
  function get_crops_scale (line 24) | def get_crops_scale(
  function rescale_camera_tw (line 91) | def rescale_camera_tw(
  function rescale_calib (line 135) | def rescale_calib(
  function rescale_image (line 170) | def rescale_image(
  function rescale_image_tensor (line 199) | def rescale_image_tensor(
  function rescale_depth_img (line 245) | def rescale_depth_img(
  function rescale_obb_tw (line 279) | def rescale_obb_tw(

FILE: efm3d/utils/viz.py
  function render_points (line 57) | def render_points(pts, rgba, prog=None, ctx=None, point_size=1.0, scene=...
  function render_cubes (line 74) | def render_cubes(centers, bb3_halfdiag, prog, ctx, rgb=None):
  function render_tri_mesh (line 110) | def render_tri_mesh(pts, normals, tris, prog, ctx):
  function render_rgb_tri_mesh (line 134) | def render_rgb_tri_mesh(pts, normals, tris, rgb, prog, ctx):
  function render_scalar_field_points (line 164) | def render_scalar_field_points(
  function render_rgb_points (line 209) | def render_rgb_points(
  function render_linestrip (line 233) | def render_linestrip(pts, rgba, prog=None, ctx=None, scene=None):
  function render_line (line 251) | def render_line(p0, p1, rgba, prog=None, ctx=None, scene=None):
  function render_cosy (line 266) | def render_cosy(
  function render_frustum (line 282) | def render_frustum(
  function render_obbs_line (line 328) | def render_obbs_line(
  function get_color_from_id (line 363) | def get_color_from_id(sem_id, max_sem_id, rgba=None):
  function render_obb_line (line 369) | def render_obb_line(obb: ObbTW, prog, ctx, rgba=None, draw_cosy=False):
  class SceneView (line 390) | class SceneView:
    method __init__ (line 427) | def __init__(
    method valid (line 480) | def valid(self):
    method clear (line 483) | def clear(self, bg_color: Optional[Tuple[float, float, float]] = None):
    method set_default_view (line 511) | def set_default_view(self, T_world_camera: PoseTW, zoom_factor: float ...
    method set_follow_view (line 518) | def set_follow_view(self, T_world_camera: PoseTW, zoom_factor: float =...
    method set_birds_eye_view (line 525) | def set_birds_eye_view(self, T_world_camera: PoseTW, zoom_factor: floa...
    method set_side_view (line 534) | def set_side_view(self, T_world_camera: PoseTW, zoom_factor: float = 6...
    method set_birds_eye_view_from_bb (line 543) | def set_birds_eye_view_from_bb(
    method set_view (line 565) | def set_view(self, mv: Union[PoseTW, np.array]):
    method finish (line 583) | def finish(self):
  function draw_obb_scene_3d (line 595) | def draw_obb_scene_3d(
  function draw_snippet_scene_3d (line 876) | def draw_snippet_scene_3d(
  function normalize (line 1064) | def normalize(x):
  function model_view_look_at_rdf (line 1072) | def model_view_look_at_rdf(e, look_at, u):
  function get_mv (line 1100) | def get_mv(T_world_cam: PoseTW, zoom_factor: float = 3.0, position=[-1, ...
  function projection_matrix_rdf_top_left (line 1133) | def projection_matrix_rdf_top_left(w, h, fu, fv, u0, v0, zNear, zFar):
  function init_egl_context (line 1151) | def init_egl_context():
  function simple_shader_program (line 1163) | def simple_shader_program(ctx):
  function mesh_normal_shader_program (line 1187) | def mesh_normal_shader_program(ctx):
  function mesh_rgb_shader_program (line 1216) | def mesh_rgb_shader_program(ctx):
  function rgb_point_cloud_shader_program (line 1249) | def rgb_point_cloud_shader_program(ctx):
  function scalar_field_shader_program (line 1278) | def scalar_field_shader_program(ctx):
  function semantic_color_shader_program (line 1371) | def semantic_color_shader_program(ctx):

FILE: efm3d/utils/voxel.py
  function tensor_wrap_voxel_extent (line 18) | def tensor_wrap_voxel_extent(voxel_extent, B=None, device="cpu"):
  function create_voxel_grid (line 32) | def create_voxel_grid(vW, vH, vD, voxel_extent, device="cpu"):
  function erode_voxel_mask (line 57) | def erode_voxel_mask(mask):

FILE: efm3d/utils/voxel_sampling.py
  function pc_to_vox (line 19) | def pc_to_vox(pc_v, vW, vH, vD, voxel_extent):
  function compute_factor (line 57) | def compute_factor(size):
  function convert_coordinates_to_voxel (line 61) | def convert_coordinates_to_voxel(coordinates, factor):
  function convert_voxel_to_coordinates (line 65) | def convert_voxel_to_coordinates(coordinates, factor):
  function normalize_keypoints (line 69) | def normalize_keypoints(kpts, depth, height, width):
  function denormalize_keypoints (line 81) | def denormalize_keypoints(kpts, depth, height, width):
  function in_grid (line 99) | def in_grid(pt_vox, depth, height, width):
  function sample_voxels (line 109) | def sample_voxels(feat3d, pts_v, differentiable=False, interp_mode="bili...
  function diff_grid_sample (line 147) | def diff_grid_sample(feature_3d, pts_norm, align_corners=False):

FILE: train.py
  function get_lr (line 60) | def get_lr(it, warmup_its, max_its, max_lr, min_lr):
  function get_dataloader (line 80) | def get_dataloader(
Condensed preview — 92 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (905K chars).
[
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "chars": 3543,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "chars": 1236,
    "preview": "# Contributing to \"efm3d\"\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull Re"
  },
  {
    "path": ".github/workflows/conda-env.yaml",
    "chars": 622,
    "preview": "name: Conda Environment CI\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\njobs:\n  t"
  },
  {
    "path": ".gitignore",
    "chars": 459,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "INSTALL.md",
    "chars": 1647,
    "preview": "# Installation\n\nWe provide two ways to install the dependencies of EFM3D. We recommend using miniconda to manage the dep"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 7598,
    "preview": "# EFM3D: A Benchmark for Measuring Progress Towards 3D Egocentric Foundation Models\n\n[[paper](https://arxiv.org/abs/2406"
  },
  {
    "path": "benchmark.md",
    "chars": 2086,
    "preview": "## EFM3D Benchmark\n\nWe provide three evaluation datasets for the EFM3D benchmarks. For more details on the benchmark see"
  },
  {
    "path": "efm3d/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/aria/__init__.py",
    "chars": 782,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/aria_constants.py",
    "chars": 12475,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/camera.py",
    "chars": 23708,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/obb.py",
    "chars": 54908,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/pose.py",
    "chars": 36008,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/projection_utils.py",
    "chars": 13531,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/aria/tensor_wrapper.py",
    "chars": 15029,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/config/efm_preprocessing_conf.yaml",
    "chars": 1719,
    "preview": "atek_config_name: \"efm\"\ncamera_temporal_subsampler:\n  main_camera_label: \"camera-rgb\"\n  time_domain: \"DEVICE_TIME\"\n  mai"
  },
  {
    "path": "efm3d/config/evl_inf.yaml",
    "chars": 879,
    "preview": "_target_: efm3d.model.evl.EVL\nneck_hidden_dims: [128, 256, 512]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_file: efm3d"
  },
  {
    "path": "efm3d/config/evl_inf_desktop.yaml",
    "chars": 877,
    "preview": "_target_: efm3d.model.evl.EVL\nneck_hidden_dims: [32, 64, 128]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_file: efm3d/c"
  },
  {
    "path": "efm3d/config/evl_train.yaml",
    "chars": 890,
    "preview": "_target_: efm3d.model.evl_train.EVLTrain\nneck_hidden_dims: [128, 256, 512]\nhead_hidden_dim: 256\nhead_layers: 2\ntaxonomy_"
  },
  {
    "path": "efm3d/config/taxonomy/aeo_to_efm.csv",
    "chars": 195,
    "preview": "AEO Category Name,EFM Category Name,EFM Category Id\nChair,chair,3\nCouch,sofa,1\nTable,table,0\nBed,bed,4\nWallArt,picture_f"
  },
  {
    "path": "efm3d/config/taxonomy/ase_sem_name_to_id.csv",
    "chars": 366,
    "preview": "sem_name,sem_id\r\ntable,0\r\nsofa,1\r\nshelf,2\r\nchair,3\r\nbed,4\r\nfloor_mat,5\r\nexercise_weight,6\r\ncutlery,7\r\ncontainer,8\r\nclock"
  },
  {
    "path": "efm3d/config/taxonomy/atek_to_efm.csv",
    "chars": 661,
    "preview": "ATEK Category Name,EFM version ASE Category Name,EFM version ASE Category Id\r\ntable,table,0\r\nsofa,sofa,1\r\nshelves,shelf,"
  },
  {
    "path": "efm3d/dataset/atek_vrs_dataset.py",
    "chars": 6073,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/dataset/atek_wds_dataset.py",
    "chars": 5107,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/dataset/augmentation.py",
    "chars": 7615,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/dataset/efm_model_adaptor.py",
    "chars": 30352,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/dataset/vrs_dataset.py",
    "chars": 29955,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/dataset/wds_dataset.py",
    "chars": 8119,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/inference/eval.py",
    "chars": 9198,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/fuse.py",
    "chars": 10633,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/model.py",
    "chars": 6796,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/pipeline.py",
    "chars": 8603,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/track.py",
    "chars": 3873,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/inference/viz.py",
    "chars": 11081,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/model/cnn.py",
    "chars": 19014,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/dinov2_utils.py",
    "chars": 45677,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/dpt.py",
    "chars": 8334,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/evl.py",
    "chars": 8334,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/evl_train.py",
    "chars": 25973,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/image_tokenizer.py",
    "chars": 5966,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/lifter.py",
    "chars": 21035,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/model/video_backbone.py",
    "chars": 6695,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/thirdparty/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/LICENSE",
    "chars": 11400,
    "preview": "Copyright 2018-2019 Open-MMLab. All rights reserved.\n\n                                 Apache License\n                  "
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/cuda_utils.h",
    "chars": 1440,
    "preview": "// @lint-ignore-every LICENSELINT\n\n#ifndef _CUDA_UTILS_H\n#define _CUDA_UTILS_H\n\n#include <ATen/ATen.h>\n#include <ATen/cu"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d.cpp",
    "chars": 6322,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d.h",
    "chars": 730,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/iou3d_kernel.cu",
    "chars": 14943,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/setup.py",
    "chars": 565,
    "preview": "# @lint-ignore-every LICENSELINT\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUD"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert.cpp",
    "chars": 1721,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert.h",
    "chars": 388,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/sort_vert_kernel.cu",
    "chars": 4639,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/sort"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/cuda/utils.h",
    "chars": 1713,
    "preview": "// @lint-ignore-every LICENSELINT\n\n// Modified from\n// https://github.com/lilanxiao/Rotated_IoU/blob/master/cuda_op/util"
  },
  {
    "path": "efm3d/thirdparty/mmdetection3d/iou3d.py",
    "chars": 14702,
    "preview": "# Copyright (c) OpenMMLab. All rights reserved.\n# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_"
  },
  {
    "path": "efm3d/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "efm3d/utils/common.py",
    "chars": 1069,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/depth.py",
    "chars": 1614,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/detection_utils.py",
    "chars": 17858,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/evl_loss.py",
    "chars": 8146,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/file_utils.py",
    "chars": 39640,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/gravity.py",
    "chars": 4652,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/image.py",
    "chars": 8680,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/image_sampling.py",
    "chars": 4950,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/marching_cubes.py",
    "chars": 3176,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/mesh_utils.py",
    "chars": 8244,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_csv_writer.py",
    "chars": 7016,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_io.py",
    "chars": 9983,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_matchers.py",
    "chars": 9601,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_metrics.py",
    "chars": 9167,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_trackers.py",
    "chars": 24285,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/obb_utils.py",
    "chars": 32107,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/pointcloud.py",
    "chars": 19428,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/ray.py",
    "chars": 9011,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/reconstruction.py",
    "chars": 10464,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/render.py",
    "chars": 18804,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/rescale.py",
    "chars": 10933,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/viz.py",
    "chars": 47196,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/voxel.py",
    "chars": 2889,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "efm3d/utils/voxel_sampling.py",
    "chars": 10507,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "environment-mac.yml",
    "chars": 516,
    "preview": "name: efm3d\nchannels:\n  - defaults\n  - conda-forge\n  - pytorch\ndependencies:\n  - python=3.9\n  - pytorch=2.3.0\n  - torchv"
  },
  {
    "path": "environment.yml",
    "chars": 818,
    "preview": "name: efm3d\nchannels:\n  - nvidia/label/cuda-12.1.1\n  - pytorch\n  - nvidia\n  - conda-forge\n  - defaults\ndependencies:\n  -"
  },
  {
    "path": "eval.py",
    "chars": 5176,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "infer.py",
    "chars": 2890,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  },
  {
    "path": "prepare_inference.sh",
    "chars": 1339,
    "preview": "#!/usr/bin/env bash\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version "
  },
  {
    "path": "requirements-extra.txt",
    "chars": 119,
    "preview": "projectaria-atek\ngit+https://github.com/facebookresearch/pytorch3d.git@V0.7.8\ntensorboard==2.14.0\ntorchmetrics==0.10.1\n"
  },
  {
    "path": "requirements.txt",
    "chars": 303,
    "preview": "torch==2.3.0\ntorchvision==0.18.0\nomegaconf==2.3.0\nhydra-core==1.3.2\nwebdataset==0.2.86\nvrs==1.2.1\nfsspec==2024.6.0\neinop"
  },
  {
    "path": "sbatch_run.sh",
    "chars": 1171,
    "preview": "#!/bin/bash\n# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the"
  },
  {
    "path": "train.py",
    "chars": 9908,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/efm3d GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 92 files (848.9 KB), approximately 239.5k tokens, and a symbol index with 785 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!