[
  {
    "path": ".gitignore",
    "content": "# debugging files\ndebug/\nSMPLX_NEUTRAL.npz\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\ndatasets\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# vscode\n.vscode\n*.code-workspace\n/pyrightconfig.json\nwandb/\n\n# others\nout\ntmr_out\n.ruff_cache\noutputs\n/debug\n/batch*.sh\ncheckpoints/**/test/*\nnohup.out\n\n*.swp\n*.swo\n*.txt~*\n*.un~\n*~\ntrain_done\n.aider*\nonelogger.err\n\n# deploy files\n/helm-library\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  # code formatting\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.6.4\n    hooks:\n      - id: ruff\n        name: sort imports with ruff\n        args: [--select, I, --fix]\n      - id: ruff-format\n        name: format with ruff\n\n  # docstring formatting\n  - repo: https://github.com/PyCQA/docformatter\n    rev: v1.7.7\n    hooks:\n      - id: docformatter\n        args:\n          [\n            --in-place,\n            --wrap-summaries=100,\n            --wrap-descriptions=100,\n            --style=sphinx,\n          ]\n\n  # yaml formatting\n  - repo: https://github.com/pre-commit/mirrors-prettier\n    rev: v3.0.0-alpha.6\n    hooks:\n      - id: prettier\n        types: [yaml]\n        exclude: |\n          (?x)^(\n            environment\\.yaml$ |\n            \\.gitlab-ci\\.yml$ |\n            \\.k8s/.*\\.(ya?ml)$\n          )\n\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.0.1\n    hooks:\n      - id: trailing-whitespace # Trims trailing whitespace.\n      - id: end-of-file-fixer # Makes sure files end in a newline and only a newline.\n      - id: check-yaml # Attempts to load all yaml files to verify syntax.\n        exclude: |\n          (?x)^(\n            \\.gitlab-ci\\.yml$ |\n            \\.k8s/.*\\.(ya?ml)$\n          )\n\nexclude: \"checkpoints/.*\"\n"
  },
  {
    "path": "ATTRIBUTIONS.MD",
    "content": "LLM2Vec MIT License https://github.com/McGill-NLP/llm2vec Copyright (c) 2024 McGill NLP\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\nUnitree mujoco BSD 3-Clause License https://github.com/unitreerobotics/unitree_mujoco/blob/main/LICENSE\nCopyright (c) 2016-2024 HangZhou YuShu TECHNOLOGY CO.,LTD. (\"Unitree Robotics\")\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).\n\n## [2026-05-03]\n\n### Fixed\n- `benchmark/parse_folder.py` now averages each metric only over the testcases that actually report it. Previously, sparse constraint metrics (`constraint_root2d_acc`, `constraint_root2d_err`, `constraint_root2d_err_p95`, `constraint_fullbody_keyframe`, `constraint_end_effector`) were divided by the total motion count of the (split, category), including testcases of other constraint kinds that did not report the metric. This silently scaled values by `motions_with_metric / total_motions` (e.g. `constraint_root2d_acc` displayed as ~0.57 when the true value was ~0.93). Both the printed table and `summary_rows.json` are affected, including the combined constraints row that merges `constraints_withtext` and `constraints_notext`. Text-following metrics, foot-skate/contact metrics, and TMR metrics are unchanged.\n- Updated Kimodo benchmark results in the documentation with this fix applied.\n\n## [2026-04-24]\n\n### Added\n- Support for `TEXT_ENCODER_DEVICE` environment variable to force LLM2Vec to use the CPU instead of GPU. Setting `TEXT_ENCODER_DEVICE=cpu` reduces VRAM usage to <3 GB with a fairly small speed hit.\n- `--save_example_dir` argument to `kimodo_gen` to save outputs to an example directory that can be directly loaded into `kimodo_demo`\n\n### Fixed\n- Bug in post-processing that was incorrectly making the smoothed root the target for the root in full-body constraints\n- Modified how transitions are handled in multi-prompt generation to improve smoothness\n\n### Removed\n- `share_transition` and `percentage_transition_override` options from python API for multi-prompt generation\n\n## [2026-04-13]\n\n### Added\n- Option `--bvh_standard_tpose` to use standard T-pose for BVH file saved from `generate.py`\n- Option to use standard T-pose for BVH file saved or downloaded from demo\n- Option to input/output BVH files that use standard T-pose with `motion_convert.py`\n- Added BVH file containing the standard Kimodo T-pose to `kimodo/assets/skeletons/somaskel77/somaskel77_standard_tpose.bvh`\n- Updated documentation with these new options\n\n## [2026-04-10]\n\n### Added\n- [Kimodo-SOMA-RP-v1.1](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1.1) and [Kimodo-SOMA-SEED-v1.1](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1.1) models and added support in the codebase. If not specified, the latest version of the models will be used automatically with the demo and CLI.\n- [Kimodo Motion Generation Benchmark](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark) for standardized evaluation of motion generation models training on the BONES-SEED dataset.\n- Scripts to construct the full benchmark, generate motions for test cases, and compute evaluation metrics. \n- Documentation explaining the benchmark and how to use the evaluation pipeline.\n- [TMR-SOMA-RP-v1](https://huggingface.co/nvidia/TMR-SOMA-RP-v1) motion-text embedding model to be used for evaluation metrics.\n- Added option to load LLM2Vec text encoder in fp32 precision.\n\n### Fixed\n- Always use batch size 1 with LLM2Vec to avoid unexpected behavior of different embeddings based on batch size.\n- Load LLM2Vec directly onto the GPU, if available.\n- Updated documentation on constraints with more details.\n\n## [2026-04-01]\n\n### Fixed\n- Fix unnecessary text encoder reload when switching between models in the interactive demo (if not using the text encoder server API).\n\n## [2026-03-31]\n\n### Added\n- New `kimodo_convert` CLI tool for converting generated motions between formats (NPZ, BVH, MuJoCo CSV, AMASS NPZ).\n- Support for loading and saving BVH, CSV, and NPZ motion files in the interactive demo.\n\n## [2026-03-27]\n\n### Fixed\n- Bug fix for foot contact visualization in the interactive demo.\n- Patch bug with BVH export for SOMA models.\n\n## [2026-03-19]\n\n### Changed\n- **Breaking:** Model inputs/outputs now use the SOMA 77-joint skeleton (`somaskel77`). This affects saved motion formats and constraint files from previous versions.\n\n### Added\n- Released timeline annotations for the BONES-SEED dataset on HuggingFace.\n\n## [2026-03-16] - Initial Release\n\n### Added\n- Open-source release of Kimodo codebase under Apache-2.0 license.\n- Five model variants: Kimodo-SOMA-RP-v1, Kimodo-G1-RP-v1, Kimodo-SOMA-SEED-v1, Kimodo-G1-SEED-v1, Kimodo-SMPLX-RP-v1.\n- Command-line interface (`kimodo_gen`) for motion generation with text prompts and kinematic constraints.\n- Interactive web-based motion authoring demo (`kimodo_demo`) with timeline editor, constraint tracks, and 3D visualization.\n- Support for multiple output formats: default NPZ, MuJoCo qpos CSV (G1), AMASS NPZ (SMPL-X).\n- Documentation site with quick start guide, installation instructions, CLI reference, and API docs.\n- Compatibility with downstream tools: ProtoMotions (physics-based policy training) and GMR (motion retargeting).\n"
  },
  {
    "path": "CONTRIBUTING.MD",
    "content": "# How to Contribute\n\n## Code Reviews\n\nAll submissions require review. We use GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests.\n\n## Signing Your Work\n\n* We require that all contributors \"sign-off\" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.\n\n  * Any contribution which contains commits that are not Signed-Off will not be accepted.\n\n* To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes:\n  ```bash\n  $ git commit -s -m \"Add cool feature.\"\n  ```\n  This will append the following to your commit message:\n  ```\n  Signed-off-by: Your Name <your@email.com>\n  ```\n\n* Full text of the DCO:\n\n  ```\n    Developer Certificate of Origin\n    Version 1.1\n\n    Copyright (C) 2004, 2006 The Linux Foundation and its contributors.\n    1 Letterman Drive\n    Suite D4700\n    San Francisco, CA, 94129\n\n    Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.\n  ```\n\n  ```\n    Developer's Certificate of Origin 1.1\n\n    By making a contribution to this project, I certify that:\n\n    (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or\n\n    (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or\n\n    (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.\n\n    (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.\n  ```\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvcr.io/nvidia/pytorch:24.10-py3\n\n# Avoid some interactive prompts + make pip quieter/reproducible-ish\nENV DEBIAN_FRONTEND=noninteractive \\\n    PIP_DISABLE_PIP_VERSION_CHECK=1 \\\n    PYTHONDONTWRITEBYTECODE=1 \\\n    PYTHONUNBUFFERED=1\n\n# Where your code will live inside the container\nWORKDIR /workspace\n\n# System deps\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n      git curl ca-certificates \\\n      cmake build-essential \\\n      gosu \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Some base images ship a broken `/usr/local/bin/cmake` shim (from a partial pip install),\n# which shadows `/usr/bin/cmake` and breaks builds that invoke `cmake` (e.g. MotionCorrection).\n# Prefer the system cmake.\nRUN rm -f /usr/local/bin/cmake || true\n\n# Install from docker_requirements.txt: kimodo editable (-e .),\n# but MotionCorrection non-editable (./MotionCorrection). The -e . line ensures [project.scripts]\n# from pyproject.toml are installed (kimodo_gen, kimodo_demo, kimodo_textencoder).\n# SKIP_MOTION_CORRECTION_IN_SETUP=1 so setup.py does not bundle motion_correction; it is\n# installed separately from ./MotionCorrection in the requirements file (non-editable).\nCOPY docker_requirements.txt /workspace/docker_requirements.txt\nCOPY setup.py /workspace/setup.py\nCOPY pyproject.toml /workspace/pyproject.toml\nCOPY kimodo /workspace/kimodo\nCOPY kimodo-viser /workspace/kimodo-viser\nCOPY MotionCorrection /workspace/MotionCorrection\n\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    python -m pip install --upgrade pip \\\n && SKIP_MOTION_CORRECTION_IN_SETUP=1 python -m pip install -r docker_requirements.txt\n\n# Use the docker-entrypoint script, to allow the docker to run as the actual user instead of root\nCOPY kimodo/scripts/docker-entrypoint.sh /usr/local/bin/docker-entrypoint\nRUN chmod +x /usr/local/bin/docker-entrypoint\n\n# Default command (change to your entrypoint if you have one)\nENTRYPOINT [\"docker-entrypoint\"]\nCMD [\"bash\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include setup.py\nrecursive-include kimodo/assets *\nrecursive-include MotionCorrection/src *.cpp *.h *.inl\nrecursive-include MotionCorrection/python *.py *.dll\ninclude MotionCorrection/CMakeLists.txt\ninclude MotionCorrection/test_example.py\n"
  },
  {
    "path": "MotionCorrection/.gitignore",
    "content": "# Python\r\n__pycache__/\r\n*.py[cod]\r\n*$py.class\r\n*.so\r\n*.egg\r\n*.egg-info/\r\ndist/\r\nbuild/\r\n*.whl\r\n.Python\r\ndevelop-eggs/\r\n.installed.cfg\r\npip-log.txt\r\npip-delete-this-directory.txt\r\n.pytest_cache/\r\n.coverage\r\nhtmlcov/\r\n.tox/\r\n.venv\r\nvenv/\r\nENV/\r\nenv/\r\n\r\n# C/C++\r\n*.o\r\n*.obj\r\n*.exe\r\n*.out\r\n*.app\r\n*.dll\r\n*.dylib\r\n*.lib\r\n*.a\r\n*.la\r\n*.lo\r\n*.slo\r\n*.ko\r\n*.elf\r\n*.ilk\r\n*.map\r\n*.exp\r\n*.gch\r\n*.pch\r\n*.idb\r\n*.pdb\r\n*.mod\r\n*.smod\r\n*.lai\r\n\r\n# CMake\r\nCMakeCache.txt\r\nCMakeFiles/\r\nCMakeScripts/\r\ncmake_install.cmake\r\ninstall_manifest.txt\r\nCTestTestfile.cmake\r\n_deps/\r\ncmake-build-*/\r\nCMakeUserPresets.json\r\n\r\n# IDE\r\n.vscode/\r\n.idea/\r\n*.swp\r\n*.swo\r\n*~\r\n.DS_Store\r\n*.iml\r\n.project\r\n.cproject\r\n.settings/\r\n\r\n# Visual Studio\r\n.vs/\r\n*.user\r\n*.suo\r\n*.userosscache\r\n*.sln.docstates\r\n*.VC.db\r\n*.VC.opendb\r\n\r\n# Build directories\r\nbuild/\r\nBuild/\r\nout/\r\ndist/\r\ntemp/\r\n\r\n# Logs\r\n*.log\r\n"
  },
  {
    "path": "MotionCorrection/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.15)\nproject(motion_correction)\n\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\n# Find Python\nfind_package(Python3 COMPONENTS Interpreter Development REQUIRED)\n\n# Find or fetch pybind11\nfind_package(pybind11 CONFIG QUIET)\nif(NOT pybind11_FOUND)\n    message(STATUS \"pybind11 not found, fetching from GitHub...\")\n    include(FetchContent)\n    FetchContent_Declare(\n        pybind11\n        GIT_REPOSITORY https://github.com/pybind/pybind11.git\n        GIT_TAG v2.11.1\n    )\n    FetchContent_MakeAvailable(pybind11)\nendif()\n\n# Find or fetch Eigen\nfind_package(Eigen3 3.3 CONFIG QUIET)\nif(NOT Eigen3_FOUND)\n    message(STATUS \"Eigen3 not found, fetching from GitLab...\")\n    include(FetchContent)\n    FetchContent_Declare(\n        Eigen\n        GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git\n        GIT_TAG 3.4.0\n    )\n    set(EIGEN_BUILD_DOC OFF CACHE BOOL \"\" FORCE)\n    set(BUILD_TESTING OFF CACHE BOOL \"\" FORCE)\n    set(EIGEN_BUILD_PKGCONFIG OFF CACHE BOOL \"\" FORCE)\n    FetchContent_MakeAvailable(Eigen)\nendif()\n\n# Source files\nset(MATH_SOURCES\n    src/cpp/Math/Matrix.cpp\n    src/cpp/Math/Quaternion.cpp\n    src/cpp/Math/Transform.cpp\n    src/cpp/Math/Types.cpp\n    src/cpp/Math/Vector.cpp\n)\n\nset(ANIM_SOURCES\n    src/cpp/AnimProcessing/InverseKinematics.cpp\n    src/cpp/AnimProcessing/TrajectoryCorrector.cpp\n    src/cpp/AnimProcessing/Utility.cpp\n)\n\n# Create static library for the core functionality\nadd_library(motion_correction_cpp_base STATIC ${MATH_SOURCES} ${ANIM_SOURCES})\n\n# Enable Position Independent Code (required for linking into shared library)\nset_target_properties(motion_correction_cpp_base PROPERTIES POSITION_INDEPENDENT_CODE ON)\n\ntarget_include_directories(motion_correction_cpp_base PUBLIC\n    ${CMAKE_CURRENT_SOURCE_DIR}/src/cpp\n)\n\nif(TARGET Eigen3::Eigen)\n    target_link_libraries(motion_correction_cpp_base PUBLIC Eigen3::Eigen)\nelse()\n    target_link_libraries(motion_correction_cpp_base PUBLIC eigen)\nendif()\n\ntarget_compile_definitions(motion_correction_cpp_base PUBLIC EIGEN_MPL2_ONLY)\n\n# Compiler-specific settings\nif(MSVC)\n    # MSVC-specific flags\n    target_compile_options(motion_correction_cpp_base PRIVATE /W4 /arch:AVX)\nelse()\n    # GCC/Clang flags (also applies to MinGW on Windows)\n    # Enable SSE4.1 and AVX instructions for SIMD operations\n    target_compile_options(motion_correction_cpp_base PRIVATE -Wall -Wextra -msse4.1 -mavx)\nendif()\n\n# Python bindings\npybind11_add_module(_motion_correction src/cpp/BindingsPython.cpp)\n\ntarget_link_libraries(_motion_correction PRIVATE motion_correction_cpp_base)\n\ntarget_include_directories(_motion_correction PRIVATE\n    ${CMAKE_CURRENT_SOURCE_DIR}/src/cpp\n)\n\n\n# Install the Python module\ninstall(TARGETS _motion_correction LIBRARY DESTINATION python/motion_correction)\ninstall(FILES python/motion_correction/__init__.py DESTINATION python/motion_correction)\ninstall(FILES python/motion_correction/motion_postprocess.py DESTINATION python/motion_correction)\n"
  },
  {
    "path": "MotionCorrection/MANIFEST.in",
    "content": "include CMakeLists.txt\ninclude test_example.py\nrecursive-include src *.cpp *.h *.inl\nrecursive-include python *.py *.dll\n"
  },
  {
    "path": "MotionCorrection/README.md",
    "content": "# motion_correction\r\n\r\nStandalone `correct_motion` implementation packaged as a small C++ motion trajectory correction library with Python bindings.\r\n\r\n## Installation Guide\r\n\r\n### Prerequisites\r\n\r\nEnsure you have a C++17 compatible compiler (GCC 7.0+, Clang 5.0+, or MSVC 2017+) and CMake 3.15+. On Windows, install MinGW-w64 or Visual Studio with C++ tools. On Linux, install `build-essential` and `cmake`.\r\n\r\nThis project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.\r\n\r\n### Build & Install\r\n\r\n#### Standard Installation\r\n```bash\r\npip install .\r\n```\r\n\r\n#### Development Installation\r\n```bash\r\npip install -e .\r\n```\r\n\r\n### Verify Installation\r\n\r\n```python\r\nimport motion_correction\r\nprint(\"Installation successful!\")\r\n```\r\nYou can also run `python run_test.py` for a simple test.\r\n"
  },
  {
    "path": "MotionCorrection/python/motion_correction/__init__.py",
    "content": "from ._motion_correction import *\n"
  },
  {
    "path": "MotionCorrection/python/motion_correction/motion_postprocess.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport pickle\n\nimport numpy as np\nimport torch\n\nimport motion_correction\n\n\ndef correct_motion(\n    hipTranslations,\n    jointRotations,\n    contacts,\n    hipTranslationsInput,\n    rotationsInput,\n    constraint_masks,\n    contact_threshold,\n    root_margin,\n    working_rig,\n    has_double_ankle_joints=False,\n):\n    joint_names = [x.name for x in working_rig]\n    joint_parents = [\n        joint_names.index(working_rig[i].parent) if working_rig[i].parent in joint_names else -1\n        for i in range(len(working_rig))\n    ]\n    joint_ref_translations = [list(x.t_pose_translation) for x in working_rig]\n    joint_ref_rotations = [list(x.t_pose_rotation) for x in working_rig]\n\n    left_hand_idx = [i for i in range(len(joint_names)) if working_rig[i].retarget_tag == \"LeftHand\"]\n    if len(left_hand_idx) != 1:\n        raise RuntimeError(f\"correct_motion: Expected exactly one joint with LeftHand tag\")\n    left_hand_idx = left_hand_idx[0]\n\n    right_hand_idx = [i for i in range(len(joint_names)) if working_rig[i].retarget_tag == \"RightHand\"]\n    if len(right_hand_idx) != 1:\n        raise RuntimeError(f\"correct_motion: Expected exactly one joint with RightHand tag\")\n    right_hand_idx = right_hand_idx[0]\n\n    left_foot_idx = [i for i in range(len(joint_names)) if working_rig[i].retarget_tag == \"LeftFoot\"]\n    if len(left_foot_idx) != 1:\n        raise RuntimeError(f\"correct_motion: Expected exactly one joint with LeftFoot tag\")\n    left_foot_idx = left_foot_idx[0]\n\n    right_foot_idx = [i for i in range(len(joint_names)) if working_rig[i].retarget_tag == \"RightFoot\"]\n    if len(right_foot_idx) != 1:\n        raise RuntimeError(f\"correct_motion: Expected exactly one joint with RightFoot tag\")\n    right_foot_idx = right_foot_idx[0]\n\n    end_frame = hipTranslations.shape[1]\n\n    default_mask = torch.zeros(hipTranslations.shape[1], dtype=torch.float32)\n    root_mask = constraint_masks.get(\"Root\", default_mask)\n    full_body_mask = constraint_masks.get(\"FullBody\", default_mask)\n    left_hand_mask = constraint_masks.get(\"LeftHand\", default_mask)\n    right_hand_mask = constraint_masks.get(\"RightHand\", default_mask)\n    left_foot_mask = constraint_masks.get(\"LeftFoot\", default_mask)\n    right_foot_mask = constraint_masks.get(\"RightFoot\", default_mask)\n\n    batch_size = hipTranslations.shape[0]\n\n    for b in range(batch_size):\n        hipTranslationsCorrected = hipTranslations[b, :end_frame].detach().cpu().flatten().numpy().astype(np.float32)\n        rotationsCorrected = jointRotations[b, :end_frame].detach().cpu().flatten().numpy().astype(np.float32)\n\n        hipTranslationsInput_flat = hipTranslationsInput.detach().cpu().flatten().numpy().astype(np.float32)\n        rotationsInput_flat = rotationsInput.detach().cpu().flatten().numpy().astype(np.float32)\n        ctcs = contacts[b].detach().cpu().flatten().numpy().astype(np.float32)\n\n        motion_correction.correct_motion(\n            hipTranslationsCorrected,\n            rotationsCorrected,\n            hipTranslationsInput_flat,\n            rotationsInput_flat,\n            full_body_mask,\n            left_hand_mask,\n            right_hand_mask,\n            left_foot_mask,\n            right_foot_mask,\n            root_mask,\n            np.array(ctcs, dtype=np.float32),\n            joint_parents,\n            joint_ref_translations,\n            joint_ref_rotations,\n            left_hand_idx,\n            right_hand_idx,\n            left_foot_idx,\n            right_foot_idx,\n            contact_threshold,\n            root_margin,\n            has_double_ankle_joints,\n        )\n\n        hipTranslations[b, :end_frame] = torch.from_numpy(\n            hipTranslationsCorrected.reshape(*hipTranslations[b, :end_frame].shape)\n        )\n        jointRotations[b, :end_frame] = torch.from_numpy(\n            rotationsCorrected.reshape(*jointRotations[b, :end_frame].shape)\n        )\n"
  },
  {
    "path": "MotionCorrection/run_test.py",
    "content": "#!/usr/bin/env python3\n\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport torch\nfrom motion_correction.motion_postprocess import correct_motion\n\n\nclass Joint:\n    def __init__(self, name, parent, t_pose_translation, t_pose_rotation, retarget_tag=\"\"):\n        self.name = name\n        self.parent = parent\n        self.t_pose_translation = t_pose_translation\n        self.t_pose_rotation = t_pose_rotation\n        self.retarget_tag = retarget_tag\n\n\ndef create_test_rig():\n    return [\n        Joint(\"Hips\", None, [0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], \"Root\"),\n        Joint(\"Spine\", \"Hips\", [0.0, 0.1, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"LeftUpLeg\", \"Hips\", [-0.1, -0.05, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"LeftLeg\", \"LeftUpLeg\", [0.0, -0.4, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"LeftFoot\", \"LeftLeg\", [0.0, -0.4, 0.0], [0.0, 0.0, 0.0, 1.0], \"LeftFoot\"),\n        Joint(\"RightUpLeg\", \"Hips\", [0.1, -0.05, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"RightLeg\", \"RightUpLeg\", [0.0, -0.4, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"RightFoot\", \"RightLeg\", [0.0, -0.4, 0.0], [0.0, 0.0, 0.0, 1.0], \"RightFoot\"),\n        Joint(\"LeftArm\", \"Spine\", [-0.3, 0.3, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"LeftHand\", \"LeftArm\", [-0.3, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], \"LeftHand\"),\n        Joint(\"RightArm\", \"Spine\", [0.3, 0.3, 0.0], [0.0, 0.0, 0.0, 1.0]),\n        Joint(\"RightHand\", \"RightArm\", [0.3, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], \"RightHand\"),\n    ]\n\n\nif __name__ == \"__main__\":\n    # Test data\n    batch_size, num_frames, num_joints = 1, 60, 12\n\n    hipTranslations = torch.randn(batch_size, num_frames, 3)\n    jointRotations = torch.randn(batch_size, num_frames, num_joints, 4)\n    jointRotations = jointRotations / jointRotations.norm(dim=-1, keepdim=True)\n\n    contacts = torch.rand(batch_size, num_frames, 4)\n    hipTranslationsInput = hipTranslations.clone()\n    rotationsInput = jointRotations.clone()\n\n    constraint_masks = {\n        \"Root\": torch.zeros(num_frames),\n        \"FullBody\": torch.zeros(num_frames),\n        \"LeftHand\": torch.zeros(num_frames),\n        \"RightHand\": torch.zeros(num_frames),\n        \"LeftFoot\": torch.zeros(num_frames),\n        \"RightFoot\": torch.zeros(num_frames),\n    }\n\n    working_rig = create_test_rig()\n\n    # Run correction\n    correct_motion(\n        hipTranslations=hipTranslations,\n        jointRotations=jointRotations,\n        contacts=contacts,\n        hipTranslationsInput=hipTranslationsInput,\n        rotationsInput=rotationsInput,\n        constraint_masks=constraint_masks,\n        contact_threshold=0.5,\n        root_margin=0.01,\n        working_rig=working_rig,\n    )\n\n    print(\"Test completed successfully\")\n"
  },
  {
    "path": "MotionCorrection/setup.py",
    "content": "#!/usr/bin/env python3\n\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Setup script for correct_motion standalone package.\"\"\"\n\nimport os\nimport shutil\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom setuptools import Extension, setup\nfrom setuptools.command.build_ext import build_ext\n\n\nclass CMakeExtension(Extension):\n    def __init__(self, name, sourcedir=\"\"):\n        Extension.__init__(self, name, sources=[])\n        self.sourcedir = os.path.abspath(sourcedir)\n\n\nclass CMakeBuild(build_ext):\n    def run(self):\n        try:\n            subprocess.check_output([\"cmake\", \"--version\"])\n        except OSError:\n            raise RuntimeError(\"CMake must be installed to build this package\")\n\n        for ext in self.extensions:\n            self.build_extension(ext)\n\n    def build_extension(self, ext):\n        # import pdb; pdb.set_trace()  # Debug build process\n\n        extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}\",\n            f\"-DPYTHON_EXECUTABLE={sys.executable}\",\n        ]\n\n        cfg = \"Debug\" if self.debug else \"Release\"\n        build_args = [\"--config\", cfg]\n\n        cmake_args += [f\"-DCMAKE_BUILD_TYPE={cfg}\"]\n\n        use_mingw = False\n        mingw_bin = None\n\n        if sys.platform == \"win32\":\n            generator = os.environ.get(\"CMAKE_GENERATOR\", \"\")\n            if generator:\n                cmake_args = [\"-G\", generator] + cmake_args\n                if \"mingw\" in generator.lower():\n                    use_mingw = True\n                else:\n                    cmake_args += [f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\"]\n            else:\n                # Try MinGW Makefiles as default on Windows\n                try:\n                    subprocess.check_output([\"g++\", \"--version\"], stderr=subprocess.STDOUT)\n                    use_mingw = True\n                    cmake_args = [\"-G\", \"MinGW Makefiles\"] + cmake_args\n                    build_args = []  # MinGW Makefiles do not accept --config\n                except (OSError, subprocess.CalledProcessError):\n                    # If g++ is not found, let CMake use its default (Visual Studio)\n                    cmake_args += [f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\"]\n\n            if use_mingw:\n                gxx_path = shutil.which(\"g++\")\n                if gxx_path:\n                    mingw_bin = Path(gxx_path).parent\n        else:\n            build_args += [\"--\", \"-j4\"]\n\n        env = os.environ.copy()\n        env[\"CXXFLAGS\"] = f'{env.get(\"CXXFLAGS\", \"\")} -DVERSION_INFO=\\\\\"{self.distribution.get_version()}\\\\\"'\n\n        if not os.path.exists(self.build_temp):\n            os.makedirs(self.build_temp)\n\n        subprocess.check_call([\"cmake\", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)\n        subprocess.check_call([\"cmake\", \"--build\", \".\"] + build_args, cwd=self.build_temp)\n\n        if use_mingw and mingw_bin is not None:\n            runtime_libs = [\n                \"libstdc++-6.dll\",\n                \"libgcc_s_seh-1.dll\",\n                \"libwinpthread-1.dll\",\n            ]\n            extdir_path = Path(extdir)\n            extdir_path.mkdir(parents=True, exist_ok=True)\n            for lib_name in runtime_libs:\n                src_path = mingw_bin / lib_name\n                if src_path.exists():\n                    shutil.copy2(src_path, extdir_path / lib_name)\n                else:\n                    self.announce(\n                        f\"Warning: Expected MinGW runtime DLL '{lib_name}' not found next to g++ (looked in {mingw_bin}). \"\n                        \"The built extension may fail to import if the DLL is not on PATH.\",\n                        level=3,\n                    )\n\n\nsetup(\n    name=\"motion_correction\",\n    version=\"1.0.0\",\n    author=\"NVIDIA\",\n    description=\"Standalone correct_motion function\",\n    long_description=\"\",\n    packages=[\"motion_correction\"],\n    package_dir={\"\": \"python\"},\n    ext_modules=[CMakeExtension(\"motion_correction._motion_correction\")],\n    cmdclass={\"build_ext\": CMakeBuild},\n    zip_safe=False,\n    python_requires=\">=3.8\",\n    install_requires=[\n        \"torch>=1.10.0\",\n        \"numpy>=1.19.0\",\n        # 'cmake' # can install this via pip if the windows system does not have it. But need to run this by yourself before build, not in here.\n    ],\n)\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/InverseKinematics.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"InverseKinematics.h\"\n#include \"Math/Scalar.h\"\n#include <iostream>\n\n\nusing namespace IK;\n\nnamespace\n{\n\nfloat getAngleWithTwoSideVectors(const Math::Vector& vecLeft, const Math::Vector& vecRight)\n{\n    auto lNorm = vecLeft.GetNormalized3();\n    auto rNorm = vecRight.GetNormalized3();\n\n    float cosine = lNorm.GetDot3(rNorm);\n    float sine = lNorm.Cross3(rNorm).GetLength3();\n\n    return atan2f(sine, cosine);  // in radian\n}\n\nfloat getAngleWithCosineRule (const float lSideLeft, const float lSideRight, const float lSideAcross)\n{\n    float val =\n        (lSideRight * lSideRight + lSideLeft * lSideLeft - lSideAcross * lSideAcross) /\n            (2.0f * lSideLeft * lSideRight);\n    val = Math::Clamp(val, -1.0f, 1.0f);  // numerical stability. also avoid impossible trangulars\n    return acosf(val);  // in radian\n}\n\n}\n\n\nvoid IK::TwoBoneIk(\n    Pose& pose,\n    const Math::Transform& rootTransform,\n    uint32_t cIdx,\n    float weight,\n    const Math::Vector& target,\n    const std::vector<int>& joint_parents_vec,\n    const Math::Vector& hintOffset\n)\n{\n    weight = Math::Clamp(weight, 0.0f, 1.0f);\n    if (!(weight > 0.0f))\n        return;\n\n    // Two bone IK: joints are represented as \"a\", \"b\", \"c\" in the below comments:\n    //  1. stage 1, bend joint a and joint b, so that |ac| = |at|, while vec_ac maintain the same direction\n    //  2. stage 2, rotate start joint a so that c and t are in the same place\n\n    //  a                   a                   a             |\n    //  |\\                  |\\                  |\\            |\n    //  | \\                 |  \\                | \\           |\n    //  |  \\  (stage 1 ->)  |   \\  (stage 2 ->) |  \\          |\n    //  |   b               |    b              |   b         |\n    //  |    \\              |    |              |  /          |\n    //  |     \\             |     |             | /           |\n    //  t      c            t      c            t(c)          |\n    //  (a is the root joint, b is the middle joint and c is the end joint)\n    //\n\n    int32_t bIdx = joint_parents_vec[cIdx];\n    if (bIdx < 0)\n    {\n        return;\n    }\n    int32_t aIdx = joint_parents_vec[bIdx];\n    if (aIdx < 0)\n    {\n        return;\n    }\n\n    // Find the parent world transform of joint a:\n    Math::Transform aParentWorldTransform = Math::Transform::Identity;\n    int32_t idx = joint_parents_vec[aIdx];\n    while (idx >= 0)\n    {\n        aParentWorldTransform = aParentWorldTransform * pose[idx];\n        idx = joint_parents_vec[idx];\n    }\n    aParentWorldTransform = aParentWorldTransform * rootTransform;\n\n    // Compute world space transforms of a, b and c:\n    Math::Transform aWorld = pose[aIdx] * aParentWorldTransform;\n    Math::Transform bWorld = pose[bIdx] * aWorld;\n    Math::Transform cWorld = pose[cIdx] * bWorld;\n\n    auto a = aWorld.GetTranslation();\n    auto b = bWorld.GetTranslation();\n    auto c = cWorld.GetTranslation();\n    auto t = Math::Vector::Lerp(c, target, weight);\n\n    // step 1 (stage 1): extend / contract the joint chain to match the distance\n    float eps = 0.0001f;  // numerical stability\n    float l_ab = (b - a).Length3().GetX();\n    float l_bc = (c - b).Length3().GetX();\n    float l_at = (a - t).Length3().GetX();\n    l_at = Math::Clamp(l_at, eps, (l_ab + l_bc) * 0.999f); // when not reachable, replace with maximum reachable length\n\n    // get the current angles\n    float theta_bac_current = getAngleWithTwoSideVectors(a - b, a - c);\n    float theta_abc_current = getAngleWithTwoSideVectors(b - a, b - c);\n    // get the desired angles\n    if (l_ab < eps || l_bc < eps || l_at < eps)\n    {\n        return;  // the length is too small. rejecting potentially numerically unstable requests.\n    }\n    float theta_bac_desired = getAngleWithCosineRule(l_ab, l_at, l_bc);\n    float theta_abc_desired = getAngleWithCosineRule(l_ab, l_bc, l_at);\n\n    // in joint[0]'s parent's space\n    Math::Vector rotationAxis = Math::Vector::Cross3(c - a, bWorld.TransformPoint(hintOffset) - a);\n    float l = rotationAxis.GetLength3();\n    if (l == 0)\n    {\n        rotationAxis = Math::Vector(0,0,1);\n    }\n    else\n    {\n        rotationAxis /= l;\n    }\n\n    // get the rotation with axis in the local space of joint a and joint b\n    Math::Vector rotationAxisLocalInBSpace = bWorld.GetRotation().RotateVectorInverse(rotationAxis);\n    Math::Transform rotateInB(\n        Math::Quaternion(rotationAxisLocalInBSpace,\n            (theta_abc_desired - theta_abc_current)), Math::Vector::Zero);\n\n    pose[bIdx] = rotateInB * pose[bIdx];\n\n    Math::Vector rotationAxisLocalInASpace = aWorld.GetRotation().RotateVectorInverse(rotationAxis);\n    Math::Transform rotateInA(\n        Math::Quaternion(rotationAxisLocalInASpace,\n            (theta_bac_desired - theta_bac_current)), Math::Vector::Zero);\n\n    pose[aIdx] = rotateInA * pose[aIdx];\n\n    // recompute a's world space transform as we're going to need it:\n    aWorld = pose[aIdx] * aParentWorldTransform;\n\n    // step 2 (stage 2): rotate joint a so that the target and the end joint c matches\n    auto acLocal = aWorld.GetRotation().RotateVectorInverse(\n        c - a);\n    auto atLocal = aWorld.GetRotation().RotateVectorInverse(\n        target - a);\n    Math::Transform rotateStageTwo(\n        Math::Quaternion::FromRotationBetweenVectors(acLocal, atLocal), Math::Vector::Zero\n    );\n\n    pose[aIdx] = rotateStageTwo * pose[aIdx];\n\n}\n\nvoid IK::OneBoneIk(\n    Pose& pose,\n    const Math::Transform& rootTransform,\n    uint32_t bIdx,\n    float weight,\n    const Math::Vector& target,\n    const std::vector<int>& joint_parents_vec\n)\n{\n    weight = Math::Clamp(weight, 0.0f, 1.0f);\n    if (!(weight > 0.0f))\n        return;\n\n    int32_t aIdx = joint_parents_vec[bIdx];\n    if (aIdx < 0)\n    {\n        return;\n    }\n\n    // Find the parent world transform of joint a:\n    Math::Transform aParentWorldTransform = Math::Transform::Identity;\n    int32_t idx = joint_parents_vec[aIdx];\n    while (idx >= 0)\n    {\n        aParentWorldTransform = aParentWorldTransform * pose[idx];\n        idx = joint_parents_vec[idx];\n    }\n    aParentWorldTransform = aParentWorldTransform * rootTransform;\n\n    // Compute world space transforms of a, b and c:\n    Math::Transform aWorld = pose[aIdx] * aParentWorldTransform;\n    Math::Transform bWorld = pose[bIdx] * aWorld;\n\n    auto abLocal = aWorld.GetRotation().RotateVectorInverse(\n        bWorld.GetTranslation() - aWorld.GetTranslation());\n    auto atLocal = aWorld.GetRotation().RotateVectorInverse(\n        target - aWorld.GetTranslation());\n\n    auto deltaRLocal = Math::Quaternion::NLerp(Math::Quaternion::Identity, Math::Quaternion::FromRotationBetweenVectors(abLocal, atLocal), weight);\n    pose[aIdx].SetRotation(deltaRLocal * pose[aIdx].GetRotation());\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/InverseKinematics.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Math/Transform.h\"\n\n#include <vector>\n\nusing Pose = std::vector<Math::Transform>;\n\nnamespace IK {\n\n    void TwoBoneIk(\n        Pose& pose,\n        const Math::Transform& rootTransform,\n        uint32_t jointIdx,\n        float weight,\n        const Math::Vector& target,\n        const std::vector<int>& joint_parents_vec,\n        const Math::Vector& hintOffset = Math::Vector::Zero\n    );\n\n    void OneBoneIk(\n        Pose& pose,\n        const Math::Transform& rootTransform,\n        uint32_t jointIdx,\n        float weight,\n        const Math::Vector& target,\n        const std::vector<int>& joint_parents_vec\n    );\n\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/TrajectoryCorrector.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"TrajectoryCorrector.h\"\n#include <iostream>\n\nstatic void removeRows(\n    Eigen::SparseMatrix<double>& M,\n    Eigen::MatrixXd *v,\n    int minCoeffs)\n{\n    Eigen::SparseMatrix<double, Eigen::RowMajor> rowMajorMat = M;\n    rowMajorMat.makeCompressed(); // Ensure compressed format\n\n    std::vector<Eigen::Triplet<double>> triplets;\n    triplets.reserve(rowMajorMat.nonZeros());\n\n    int newRow = 0;\n    for (int i = 0; i < rowMajorMat.outerSize(); ++i) {\n        // Get nonzero count via outerIndexPtr (compressed format only)\n        int nnz = rowMajorMat.outerIndexPtr()[i + 1] - rowMajorMat.outerIndexPtr()[i];\n\n        if (nnz >= minCoeffs) {\n            // Iterate through nonzeros in this row\n            for (Eigen::SparseMatrix<double, Eigen::RowMajor>::InnerIterator it(rowMajorMat, i); it; ++it) {\n                triplets.emplace_back(newRow, it.col(), it.value());\n            }\n            if (v)\n            {\n                v->row(newRow) = v->row(i);\n            }\n            newRow++;\n        }\n    }\n\n    M = Eigen::SparseMatrix<double>(newRow, M.cols());\n    M.setFromTriplets(triplets.begin(), triplets.end());\n    if (v)\n    {\n        v->conservativeResize(newRow, v->cols());\n    }\n}\n\nstatic void multVelWeights(\n    Eigen::SparseMatrix<double>& V,\n    Eigen::MatrixXd* v_rhs,\n    const Eigen::VectorXd& velocityWeights\n)\n{\n    Eigen::SparseMatrix<double, Eigen::RowMajor> rowMajorMat = V;\n    rowMajorMat.makeCompressed(); // Ensure compressed format\n\n    std::vector<Eigen::Triplet<double>> triplets;\n    triplets.reserve(rowMajorMat.nonZeros());\n\n    for (int i = 0; i < rowMajorMat.outerSize(); ++i) {\n        // Iterate through nonzeros in this row\n        Eigen::SparseMatrix<double, Eigen::RowMajor>::InnerIterator it(rowMajorMat, i);\n        double vel_weight = velocityWeights[it.col()];\n        for(; it; ++it)\n        {\n            triplets.emplace_back(i, it.col(), it.value() * vel_weight);\n        }\n        if (v_rhs)\n        {\n            (*v_rhs).row(i) = (*v_rhs).row(i) * vel_weight;\n        }\n    }\n}\n\n\nvoid TrajectoryCorrector::computeDiffMats(\n    Eigen::SparseMatrix<double>& V,\n    Eigen::SparseMatrix<double>& A,\n    uint32_t N,\n    const Eigen::VectorXd& velocityWeights,\n    Eigen::MatrixXd* v_rhs,\n    Eigen::MatrixXd* a_rhs)\n{\n    std::vector<Eigen::Triplet<double>> tripletList;\n\n    // Identity matrix\"\n    tripletList.clear();\n    Eigen::SparseMatrix<double> I(N, N);\n    for (uint32_t i = 0; i < N; ++i)\n    {\n        tripletList.emplace_back(i, i, 1);\n    }\n    I.setFromTriplets(tripletList.begin(), tripletList.end());\n\n    // urr, a time translation operator? Gives you the value on the next frame.\n    // Leave the last row blank because that's the end of the timeline.\n    tripletList.clear();\n    Eigen::SparseMatrix<double> T(N, N);\n    Eigen::MatrixXd t_rhs;\n    for(uint32_t i = 0; i < N-1; ++i)\n    {\n        // next frame is\n        tripletList.emplace_back(i, i+1, 1.0);\n    }\n    T.setFromTriplets(tripletList.begin(), tripletList.end());\n\n    // v = Tx + t_rhs - x;\n    // v = (T - I)x + t_rhs;\n    V = T - I;\n    if (v_rhs)\n    {\n        *v_rhs = t_rhs;\n    }\n    removeRows(V, v_rhs, 2);\n\n    // a = -x + 2 (T x + t_rhs) - (T (T x + t_rhs) + t_rhs)\n    // a = (-I + 2 T - T^2) x + t_rhs - T t_rhs\n    A = 2 * T - I - T * T;\n    if (a_rhs)\n    {\n        *a_rhs = t_rhs - T * t_rhs;\n    }\n    removeRows(A, a_rhs, 3);\n\n    if (velocityWeights.size() > 0)\n    {\n        multVelWeights(V, v_rhs, velocityWeights);\n    }\n}\n\nTrajectoryCorrector::TrajectoryCorrector(\n    const Eigen::VectorXd& margins,\n    float pos_weight,\n    float vel_weight,\n    float acc_weight,\n    const Eigen::VectorXd& velocityWeights,\n    uint32_t admm_iters ) :\n    m_admm_iters(admm_iters)\n{\n\n    // This class is used to modify a trajectory to hit specific values at\n    // specific frames, while respecting the following soft constraints:\n\n    // * Preserve the original positions\n    // * Preserve the original velocities\n    // * Preserve the original accelerations\n\n    // The weights of these soft constraints are specified in \"pos_weight\" etc.\n\n    // This is posed as a minimization problem:\n\n    // E(x) = pos_weight * |x - x_orig|^2 + vel_weight * |V x - V x_orig| + acc_weight * |A x - A x_orig|\n\n    // where you minimize E(x) subject to specified values at indices where \"mask\"\n    // is equal to 1. V is a matrix that computes the N-1 velocities between frame n-1 and frame n,\n    // and A computes the N-2 accelerations associated with frames n-1, n and n+1.\n\n    // In addition to this, there are constraints where the trajectory is allowed to\n    // deviate from the target points by a maximum margin. The \"margins\" input to this\n    // constructor specifies what type of constraint is active on a particular frame:\n\n    // margins[0] < 0   ==> unconstrained\n    // margins[i] == 0  ==> pinned on this frame\n    // margins[i] > 0   ==> can deviate within the margin\n\n    // I'm solving the optimization problem using ADMM, ie following equations\n    // 8,9,10 on this paper:\n\n    // https://mattoverby.net/files/admm-pd-overby17.pdf\n\n    uint32_t N = uint32_t(margins.rows());\n    for(uint32_t i = 0; i < N; ++i)\n    {\n        if( margins[i] > 0 )\n        {\n            m_margin_locs.push_back(i);\n            m_margin_vals.push_back(margins[i]);\n        }\n\n        if(margins[i] == 0)\n        {\n            m_constrained_locs.push_back(i);\n        }\n        else\n        {\n            m_unconstrained_locs.push_back(i);\n        }\n    }\n\n    Eigen::SparseMatrix<double> V, A;\n    computeDiffMats(\n        V,  A,\n        N, velocityWeights\n    );\n\n    // build an identity matrix:\n    std::vector<Eigen::Triplet<double>> tripletList;\n    Eigen::SparseMatrix<double> I(N, N);\n    for (uint32_t i = 0; i < N; ++i)\n    {\n        tripletList.emplace_back(i, i, 1.0f);\n    }\n    I.setFromTriplets(tripletList.begin(), tripletList.end());\n\n    /*\n    self.N = (\n            self.pos_weight * torch.diag_embed(torch.full_like(interp_mask, 1)) +\n            self.vel_weight * torch.matmul(self.V.T, self.V) +\n            self.acc_weight * torch.matmul(self.A.T, self.A)\n        )\n    */\n\n    m_N = pos_weight * I + vel_weight * (V.transpose() * V) + acc_weight * (A.transpose() * A);\n\n    double diagMax = 0;\n    for (uint32_t i = 0; i < N; ++i)\n    {\n        diagMax = std::max(m_N.coeff(i,i), diagMax);\n    }\n    m_admm_stepsize = 0.5f * sqrtf(float(diagMax));\n\n    /*\n    M = (\n        self.N +\n        self.admm_stepsize * torch.matmul(self.S.T, self.S)\n    )\n    */\n    tripletList.clear();\n    Eigen::SparseMatrix<double> M(N, N);\n    for( auto i : m_margin_locs)\n    {\n        tripletList.emplace_back(i, i, m_admm_stepsize);\n    }\n    M.setFromTriplets(tripletList.begin(), tripletList.end());\n    M += m_N;\n\n    /*\n    self.lhsmat = torch.matmul(self.U.T, torch.matmul(self.M, self.U))\n    self.lhsmat_inv = torch.inverse(self.lhsmat)\n    */\n    tripletList.clear();\n    Eigen::SparseMatrix<double> S(m_unconstrained_locs.size(), N);\n    for (uint32_t i = 0; i < m_unconstrained_locs.size(); ++i)\n    {\n        uint32_t ifull = m_unconstrained_locs[i];\n        tripletList.emplace_back(i, ifull, 1.0f);\n    }\n    S.setFromTriplets(tripletList.begin(), tripletList.end());\n    M = S * M * S.transpose();\n\n    if(m_unconstrained_locs.size())\n    {\n        m_system_lu.compute(M);\n    }\n}\n\n\nvoid TrajectoryCorrector::Interpolate(\n    Eigen::MatrixXd& x,\n    const Eigen::MatrixXd& observations,\n    const Eigen::MatrixXd& ref_positions\n) const\n{\n    if(\n        m_constrained_locs.empty() &&\n        m_margin_locs.empty()\n    )\n    {\n        x = ref_positions;\n        return;\n    }\n\n    uint32_t numCols = uint32_t(x.cols());\n    if(m_margin_locs.empty())\n    {\n        x_update(\n            x,\n            Eigen::MatrixXd(0, numCols),\n            Eigen::MatrixXd(0, numCols),\n            ref_positions,\n            observations\n        );\n    }\n    else\n    {\n        x = ref_positions;\n        Eigen::MatrixXd z(m_margin_locs.size(), numCols);\n        Eigen::MatrixXd z_t(m_margin_locs.size(), numCols);\n        Eigen::MatrixXd u(m_margin_locs.size(), numCols);\n        for( uint32_t i = 0; i < m_margin_locs.size(); ++i)\n        {\n            for(uint32_t j = 0; j < numCols; ++j)\n            {\n                z_t(i, j) = observations(m_margin_locs[i], j);\n                z(i, j) = ref_positions(m_margin_locs[i], j);\n                u(i, j) =0;\n            }\n        }\n\n        for(uint32_t i = 0; i < m_admm_iters; ++i)\n        {\n            x_update(\n                x,\n                z,\n                u,\n                ref_positions,\n                observations\n            );\n            z_update(z, x, z_t, u);\n            u_update(u, x, z);\n        }\n    }\n\n}\n\nvoid TrajectoryCorrector::x_update(\n    Eigen::MatrixXd &x,\n    const Eigen::MatrixXd &z,\n    const Eigen::MatrixXd &u,\n    const Eigen::MatrixXd &x_t, // reference positions - defines the original shape of the curve that we want to preserve\n    const Eigen::MatrixXd &x_o  // target positions for constraints\n) const\n{\n\n    uint32_t numRows = uint32_t(x.rows());\n    uint32_t numCols = uint32_t(x.cols());\n\n    // Here's what we're minimizing with ADMM:\n    // min f(x) + g(z)\n    // s.t A x + B z = c\n\n    // Make these choices so that z = S x:\n    // A = S, B = -I, c = 0\n    //\n    // g(z) = infinity when it's too far away from z_target, zero otherwise\n    //\n    // f(x) penalizes deviations in position, velocity and acceleration\n    // from a reference trajectory:\n    //\n    // f(x) = 1/2(\n    //    kx |I x - x_t|^2 +\n    //    kv |V x - v_t|^2 +\n    //    kx |A x - a_t|^2\n    // )\n    //\n    // It's also infinite when components of x devaiate from their target\n    // values where they're pinned...\n\n    // Substituting the matrices into the standard admm update rules gives us this:\n    // x{n+1} = argmin(f(x) + ρ/2 |S x - z{n} + u{n}|^2)\n    // z{n+1} = argmin(g(z) + ρ/2 |S x{n+1} - z + u{n}|^2)\n    // u{n+1} = u{n} + (S x{n+1} - z{n+1})\n    //\n\n    // x update:\n    //\n    // x{n+1} = argmin  1/2 (\n    //     kx |I x - x_t|^2 +\n    //     kv |V x - v_t|^2 +\n    //     ka |A x - a_t|^2 +\n    //     ρ  |S x - d|^2\n    // )\n    // d = (z{n} - u{n})\n\n    // Rewrite in a friendlier way:\n    // |A x - b|^2 = x^T A^T A x - 2 x^T A^T b + C\n    // 1/2 (\n    //     kx (x^T x - 2 x^T x_t) +\n    //     kv (x^T V^T V x - 2 x^T V^T v_t) +\n    //     ka (x^T A^T A x - 2 x^T A^T a_t) +\n    //     ρ  (x^T S^T S x - 2 x^T S^T d)\n    // ) + C\n    //\n    // 1/2 x^T (kx I + kv V^T V + ka A^T A + ρ S^T S) x\n    //   - x^T (kx x_t + kv V^T v_t + ka A^T a_t + ρ S^T d)\n    // + C\n    //\n    // voila:\n    // M = kx I + kv V^T V + ka A^T A + ρ S^T S\n    // r = kx x_t + kv V^T v_t + ka A^T a_t + ρ S^T d\n    // E = 1/2 x^T M x - x^T r + C\n\n    /*\n    r = (\n        torch.matmul(self.N, x_t - x_o_filtered) +\n        self.admm_stepsize * torch.matmul(self.S.T, - u + z)\n    )\n    */\n    Eigen::MatrixXd x_diffs(x_t);\n    for(auto i : m_constrained_locs)\n    {\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            x_diffs(i, j) = x_diffs(i,j) - x_o(i,j);\n        }\n    }\n\n    Eigen::MatrixXd r = m_N * x_diffs;\n\n    for(uint32_t i = 0; i < m_margin_locs.size(); ++i)\n    {\n        uint32_t ifull = m_margin_locs[i];\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            r(ifull, j) = r(ifull, j) + m_admm_stepsize * (z(i,j) - u(i,j));\n        }\n    }\n\n    // Solve with respect to pin constraints:\n    // x = U x_r + x_o\n    // E = 1/2 (U x_r + x_o)^T M (U x_r + x_o) - (U x_r + x_o)^T r + C\n    // E = 1/2 (x_r^T U^T + x_o^T) M (U x_r + x_o) - (x_r^T U^T + x_o^T) r + C\n    // E = 1/2 (x_r^T U^T M (U x_r + x_o) + x_o^T M (U x_r + x_o)) - x_r^T U^T r - x_o^T r + C\n    // E = 1/2 (x_r^T U^T M U x_r) + x_r^T U^T (M x_o - r) + C\n\n    // minimized when x_r solves this equation:\n    // U^T M U x_r + U^T (M x_o - r) = 0\n    // x_r = (U^T M U)^-1 U^T (r - M x_o)\n\n    // collapse r down to unconstrained variable set:\n    // rhs = torch.matmul(self.U.T, r)\n\n    uint32_t numRows_reduced = m_unconstrained_locs.size();\n    Eigen::MatrixXd r_reduced(numRows_reduced, numCols);\n    for(uint32_t i = 0; i < numRows_reduced; ++i)\n    {\n        uint32_t ifull = m_unconstrained_locs[i];\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            r_reduced(i,j) = r(ifull, j);\n        }\n    }\n\n    // solve system:\n    // x_r = torch.matmul(self.lhsmat_inv, rhs)\n    r_reduced.conservativeResize(r_reduced.rows(), r_reduced.cols());\n\n    Eigen::MatrixXd result;\n    if(m_unconstrained_locs.size())\n    {\n        result = m_system_lu.solve(r_reduced);\n    }\n\n    // map back to full variable set:\n    // return torch.matmul(self.U, x_r) + x_o_filtered\n    for(uint32_t i = 0; i < numRows_reduced; ++i)\n    {\n        uint32_t ifull = m_unconstrained_locs[i];\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            x(ifull, j) = result(i, j);\n        }\n    }\n    for(auto i : m_constrained_locs)\n    {\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            x(i, j) = x_o(i, j);\n        }\n    }\n}\n\nvoid TrajectoryCorrector::z_update(\n    Eigen::MatrixXd &z,\n    const Eigen::MatrixXd &x,\n    const Eigen::MatrixXd &z_t,\n    const Eigen::MatrixXd &u\n) const\n{\n    uint32_t numCols = uint32_t(z.cols());\n\n    for(uint32_t i = 0; i < m_margin_locs.size(); ++i)\n    {\n\n        // z_diffs = S x + u - z_t\n        uint32_t ifull = m_margin_locs[i];\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            z(i, j) = x(ifull, j) + u(i, j) - z_t(i, j);\n        }\n\n        // find the norm of the current z diff vector:\n        double z_diff_norm = 0.0;\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            double z_diff = z(i, j);\n            z_diff_norm += z_diff * z_diff;\n        }\n        z_diff_norm = sqrt(z_diff_norm);\n\n        // if the norm is greater than the margin size, we need to rescale\n        // the diff:\n        if( z_diff_norm > m_margin_vals[i] )\n        {\n            for(uint32_t j = 0; j < numCols; ++j)\n            {\n                z(i, j) = z(i, j) * m_margin_vals[i] / z_diff_norm;\n            }\n        }\n\n        // add the diff back on to the target:\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            z(i, j) = z_t(i, j) + z(i, j);\n        }\n    }\n}\n\nvoid TrajectoryCorrector::u_update(\n    Eigen::MatrixXd &u,\n    const Eigen::MatrixXd &x,\n    const Eigen::MatrixXd &z\n) const\n{\n    uint32_t numCols = uint32_t(z.cols());\n\n    // u += S x - z\n    for(uint32_t i = 0; i < m_margin_locs.size(); ++i)\n    {\n        uint32_t ifull = m_margin_locs[i];\n        for(uint32_t j = 0; j < numCols; ++j)\n        {\n            u(i,j) += x(ifull, j) - z(i,j);\n        }\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/TrajectoryCorrector.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include <Eigen/Sparse>\n\nclass TrajectoryCorrector\n{\npublic:\n\n\tstatic void computeDiffMats(\n\t\tEigen::SparseMatrix<double>& V,\n\t\tEigen::SparseMatrix<double>& A,\n\t\tuint32_t N,\n\t\tconst Eigen::VectorXd& velocityWeights = Eigen::VectorXd(),\n\t\tEigen::MatrixXd* v_rhs = nullptr,\n\t\tEigen::MatrixXd* a_rhs = nullptr\n\t);\n\n\tTrajectoryCorrector(\n\t\tconst Eigen::VectorXd& margins,\n        float pos_weight,\n        float vel_weight,\n        float acc_weight,\n\t\tconst Eigen::VectorXd& velocityWeights = Eigen::VectorXd(),\n\t\tuint32_t admm_iters=100 );\n\n\tvoid Interpolate(\n\t\tEigen::MatrixXd& ret,\n\t\tconst Eigen::MatrixXd& observations,\n\t\tconst Eigen::MatrixXd& ref_positions\n\t) const;\n\n\tvoid x_update(\n\t\tEigen::MatrixXd& x,\n\t\tconst Eigen::MatrixXd& z,\n\t\tconst Eigen::MatrixXd& u,\n\t\tconst Eigen::MatrixXd& x_t,\n\t\tconst Eigen::MatrixXd& x_o\n\t) const;\n\n\tvoid z_update(\n\t\tEigen::MatrixXd& z,\n\t\tconst Eigen::MatrixXd& x,\n\t\tconst Eigen::MatrixXd& z_t,\n\t\tconst Eigen::MatrixXd& u\n\t) const;\n\n\tvoid u_update(\n\t\tEigen::MatrixXd& u,\n\t\tconst Eigen::MatrixXd& x,\n\t\tconst Eigen::MatrixXd& z\n\t) const;\n\n\tfloat admm_stepsize() const { return m_admm_stepsize; }\n\n\tconst std::vector<uint32_t>& margin_locs() { return m_margin_locs; }\n\nprivate:\n\n\tEigen::SparseMatrix<double> m_N;\n\tEigen::SparseLU<Eigen::SparseMatrix<double>> m_system_lu;\n\n\tuint32_t m_admm_iters;\n\n\tstd::vector<uint32_t> m_margin_locs;\n\tstd::vector<double> m_margin_vals;\n\n\tstd::vector<uint32_t> m_unconstrained_locs;\n\tstd::vector<uint32_t> m_constrained_locs;\n\n\tfloat m_admm_stepsize;\n\n};\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/Utility.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"TrajectoryCorrector.h\"\n#include \"InverseKinematics.h\"\n\n#include \"Utility.h\"\n\n#include <map>\n#include <array>\n#include <cmath>\n#include <cstdlib>\n#include <iostream>\nusing Pose = std::vector<Math::Transform>;\n\nstatic const float pos_weight = 0.001f;\nstatic const float vel_weight = 1.0f;\nstatic const float acc_weight = 10.0f;\n\n\nnamespace {\n\n    // Enable with: MOTIONCORRECTION_DEBUG_INTERVALS=1\n    // Default: off (no Interval printing).\n    bool DebugPrintIntervalsEnabled()\n    {\n        const char* v = std::getenv(\"MOTIONCORRECTION_DEBUG_INTERVALS\");\n        if (v == nullptr || v[0] == '\\0')\n        {\n            return false;\n        }\n        // Treat \"0\" as false; any other non-empty value enables.\n        return v[0] != '0';\n    }\n\n\n    void FilterContactIntervals(\n        std::vector<std::pair<int, int>>& contactIntervals,\n        const std::vector<float>& mask,\n        bool one_bone_contact = false)\n    {\n        std::vector<uint32_t> keepIntervals;\n        for (size_t i = 0; i < contactIntervals.size(); ++i)\n        {\n            const auto& interval = contactIntervals[i];\n\n            bool startConstrained = (interval.first != 0 && mask[interval.first - 1]);\n            bool endConstrained;\n\n            endConstrained = (interval.second != mask.size() && mask[interval.second]);\n\n            if (one_bone_contact)\n            {\n                if (startConstrained || endConstrained)\n                {\n                    continue;\n                }\n            }\n            else\n            {\n                // If both the start and end of the contact interval are masked,\n                // there's no way we can correct the contact without popping, so\n                // let's filter these out:\n                if (startConstrained && endConstrained)\n                {\n                    continue;\n                }\n            }\n\n            keepIntervals.push_back(i);\n        }\n\n        for (size_t i = 0; i < keepIntervals.size(); ++i)\n        {\n            contactIntervals[i] = contactIntervals[keepIntervals[i]];\n        }\n        contactIntervals.resize(keepIntervals.size());\n    }\n\n    std::vector<std::pair<int, int>> ComputeContactIntervals(\n        const std::vector<float>& contacts,\n        const std::vector<float>& mask,\n        float contactThreshold)\n    {\n        // turn off the contacts for all frames that are constrained/masked:\n        std::vector<float> contactsNoMask = contacts;\n        for (size_t i = 0; i < mask.size(); ++i)\n        {\n            if (mask[i])\n            {\n                contactsNoMask[i] = 0;\n            }\n        }\n\n        // Find intervals that are in contact:\n        std::vector<std::pair<int, int>> contactIntervals;\n        int start = -1;\n        for (int frame = 0; frame < mask.size(); ++frame)\n        {\n            bool isContact = contactsNoMask[frame] > contactThreshold;\n            if (isContact && start == -1)\n            {\n                start = frame;\n            }\n            else if (!isContact && start != -1)\n            {\n                contactIntervals.emplace_back(start, frame);\n                start = -1;\n            }\n        }\n\n        // Close the final interval if needed:\n        if (start != -1)\n        {\n            contactIntervals.emplace_back(start, mask.size());\n        }\n        return contactIntervals;\n    }\n\n    void FindContactPoints(\n        std::vector<Math::Vector> &points,\n        std::vector<int> &inContact,\n        const std::vector<int>& joint_parents_vec,\n        int32_t jointIndex,\n        const std::vector<Pose> &poses,\n        const std::vector<std::pair<int, int>>& contactIntervals,\n        const std::vector<float>& mask,\n        size_t frameCount,\n        float minHeight)\n    {\n        // Find a representative frame for each interval.\n        // If the interval starts after a masked frame, use the start\n        // of the interval, if it ends before a mask use the end,\n        // otherwise use the middle frame.\n        inContact.clear();\n        inContact.resize(frameCount, 0);\n        points.clear();\n        points.resize(frameCount);\n        for (size_t i = 0; i < contactIntervals.size(); ++i)\n        {\n            const auto& interval = contactIntervals[i];\n            int frame = -1;\n            bool startConstrained = (interval.first != 0 && mask[interval.first - 1]);\n            bool endConstrained;\n\n            endConstrained = (interval.second != mask.size() && mask[interval.second]);\n\n            // Debug output (opt-in via env var)\n            if (DebugPrintIntervalsEnabled())\n            {\n                std::cout << \"Interval \" << i << \": start=\" << interval.first\n                          << \", end=\" << interval.second\n                          << \", startConstrained=\" << startConstrained\n                          << \", endConstrained=\" << endConstrained << std::endl;\n            }\n\n            if(startConstrained)\n            {\n                // If the interval starts on a constraint, use the constrained frame\n                // as a target (doing this modulo mask.size() in case we're looping)\n                frame = interval.first - 1;\n            }\n            else if (endConstrained)\n            {\n                // If the interval ends on a constraint, use the constrained frame\n                // as a target:\n                frame = interval.second;\n            }\n            else\n            {\n                // Otherwise use the midpoint of the interval:\n                frame = (interval.first + interval.second) / 2;\n            }\n\n            // get the target point:\n            Math::Vector target = Animation::JointLocalToGlobal(joint_parents_vec, jointIndex, poses[frame]).GetTranslation();\n            for(int i = interval.first; i < interval.second; ++i)\n            {\n                Math::Vector framePt = Animation::JointLocalToGlobal(joint_parents_vec, jointIndex, poses[i]).GetTranslation();\n                inContact[i] = 1;\n                points[i] = target;\n                if (!startConstrained && !endConstrained)\n                {\n                    points[i].SetY(std::max(framePt.GetY(), minHeight));\n                    // std::cout << \"  Frame \" << i << \": SetY with framePt.GetY()=\" << framePt.GetY()\n                    //           << \", minHeight=\" << minHeight << std::endl;\n                }\n            }\n        }\n    }\n\n    float TargetReachFalloff(\n        const std::vector<int>& joint_parents_vec,\n        const Pose& defaultPose,\n        int32_t jointIndex,\n        Animation::IKType ikType,\n        const Math::Vector& target,\n        const Pose& pose,\n        const Math::Transform& rootTx = Math::Transform::Identity)\n    {\n        float maxReach = defaultPose[jointIndex].GetTranslation().GetLength3();\n        if (ikType == Animation::IKType::kTwoBone)\n        {\n            jointIndex = joint_parents_vec[jointIndex];\n            ASSERT(jointIndex > -1);\n            maxReach += defaultPose[jointIndex].GetTranslation().GetLength3();\n        }\n        // Get base joint world Tx\n        jointIndex = joint_parents_vec[jointIndex];\n        ASSERT(jointIndex > -1);\n        const auto worldTx = Animation::JointLocalToGlobal(joint_parents_vec, jointIndex, pose, rootTx);\n\n        // Gaussian falloff\n        float targetDist = target.GetDistance3(worldTx.GetTranslation());\n        float tmp = Math::Max(targetDist / maxReach - 0.99f, 0.f) / 0.01f;\n        tmp = tmp * tmp;\n        return std::exp(-2.f * tmp * tmp);\n    }\n\n    void CorrectHipsY(\n        std::vector<Pose>& poses,\n        const std::vector<Pose>& targetPoses,\n        const std::vector<float>& fullBodyMask,\n        const std::vector<Animation::ContactInfo>& contacts,\n        float contactThreshold\n    )\n    {\n        // Correct the y coordinates of the root.\n        auto N = poses.size();\n        Eigen::MatrixXd x(N, 1);\n        Eigen::MatrixXd observations(N, 1);\n        Eigen::MatrixXd xfixed(N, 1);\n\n        // Fill in the initial trajectory (x) and the values we want to hit when we\n        // warp it (observations):\n        Eigen::VectorXd yCorrectMargins(N);\n        for(size_t frame = 0; frame < N; ++frame)\n        {\n            yCorrectMargins[frame] = fullBodyMask[frame] ? 0.0f : -1.0f;\n            x(frame, 0) = ((float*)&poses[frame][0].GetTranslation())[1];\n            observations(frame, 0) = ((float*)&targetPoses[frame][0].GetTranslation())[1];\n        }\n\n        TrajectoryCorrector ycorrector(\n            yCorrectMargins,\n            pos_weight * 10,\n            vel_weight,\n            acc_weight * 0.1f\n        );\n        ycorrector.Interpolate(\n            xfixed,\n            observations,\n            x\n        );\n\n        // fill channel again:\n        for (uint32_t frame = 0; frame < N; ++frame)\n        {\n            ((float*)&poses[frame][0].GetTranslation())[1] = float(xfixed(frame, 0));\n        }\n    }\n\n    void SmoothChannels(\n        Eigen::MatrixXd &x,\n        const std::vector<float>& mask\n    )\n    {\n        for( uint32_t i=0; i < mask.size(); ++i)\n        {\n            uint32_t i_prev = i == 0 ? 0 : i-1;\n            uint32_t i_next = std::min(uint32_t(i+1), uint32_t(mask.size()-1));\n            if(i > 0 && mask[i] > 0 && mask[i_prev] == 0)\n            {\n                // if the previous frame is unconstrained and the current frame is constrained,\n                // replace the current frame with the average of its neighbors:\n                for(long j=0; j < x.cols(); ++j)\n                {\n                    x(i, j) = 0.5f * (x(i_prev, j) + x(i_next, j));\n                }\n            }\n            if(mask[i] > 0 && mask[i_next] == 0)\n            {\n                // if the next frame is unconstrained and the current frame is constrained,\n                // replace the current frame with the average of its neighbors:\n                for(long j=0; j < x.cols(); ++j)\n                {\n                    x(i, j) = 0.5f * (x(i_prev, j) + x(i_next, j));\n                }\n            }\n        }\n    }\n\n\n    void CorrectHipsXZ(\n        std::vector<Pose>& poses,\n        const std::vector<Pose>& targetPoses,\n        const std::vector<float>& fullBodyMask,\n        const std::vector<float>& rootMask,\n        const std::vector<Animation::ContactInfo>& endEffectorPins,\n        const Eigen::VectorXd& velocity_weights,\n        float root_margin\n    )\n    {\n        auto N = poses.size();\n        Eigen::VectorXd margins(N);\n        for( size_t i = 0; i < N; ++i )\n        {\n            margins[i] = fullBodyMask[i] ? 0.0f : -1.0f;\n        }\n\n        std::vector<float> rootCombinedMask(N, 0.0f);\n        for(size_t i = 0; i < N; ++i)\n        {\n            rootCombinedMask[i] = (fullBodyMask[i] > 0) || (rootMask[i] > 0);\n            if(rootMask[i] > 0 && margins[i] != 0)\n            {\n                margins[i] = root_margin;\n            }\n            for (auto& c : endEffectorPins)\n            {\n                if (c.contactMask[i] && margins[i] != 0)\n                {\n                    margins[i] = root_margin;\n                }\n            }\n        }\n        TrajectoryCorrector xzcorrector(\n            margins,\n            pos_weight,\n            vel_weight,\n            acc_weight,\n            velocity_weights\n        );\n\n        // Enforce pose constraints on root xz trajectory:\n        Eigen::MatrixXd x(N, 2);\n        Eigen::MatrixXd observations(N, 2);\n        Eigen::MatrixXd x_fixed(N, 2);\n\n        observations.setZero();\n        for (uint32_t frame = 0; frame < N; ++frame)\n        {\n            x(frame, 0) = ((float*)&poses[frame][0].GetTranslation())[0];\n            x(frame, 1) = ((float*)&poses[frame][0].GetTranslation())[2];\n\n            observations(frame, 0) = ((float*)&targetPoses[frame][0].GetTranslation())[0];\n            observations(frame, 1) = ((float*)&targetPoses[frame][0].GetTranslation())[2];\n        }\n\n        SmoothChannels(x, rootCombinedMask);\n\n        xzcorrector.Interpolate(\n            x_fixed,\n            observations,\n            x\n        );\n\n        // fill channels again:\n        for (uint32_t frame = 0; frame < N; ++frame)\n        {\n            ((float*)&poses[frame][0].GetTranslation())[0] = float(x_fixed(frame, 0));\n            ((float*)&poses[frame][0].GetTranslation())[2] = float(x_fixed(frame, 1));\n        }\n    }\n\n    void CorrectRotationsForBone(\n        std::vector<Pose>& poses,\n        const std::vector<Pose>& targetPoses,\n        const std::vector<float>& mask,\n        const TrajectoryCorrector& corrector,\n        int boneIdx,\n        bool performChannelSmoothing)\n    {\n        auto N = poses.size();\n        Eigen::MatrixXd x(N, 1);\n        Eigen::MatrixXd observations(N, 1);\n        observations.setZero();\n        Eigen::MatrixXd x_fixed(N, 1);\n\n        // Quaternion components can flip when they pass through 180 degree\n        // rotations, so let's convert all the quaternions in this channel to\n        // the forward/up vector representation, modify them, then convert back\n        // to quaternions:\n\n        // convert time series to 6d forward/up:\n        std::vector<float> forwardUp(6 * N);\n        std::vector<float> targetForwardUp(6 * N);\n        for (uint32_t frame = 0; frame < N; ++frame)\n        {\n            auto q = poses[frame][boneIdx].GetRotation();\n            auto forward = q.ZAxis();\n            auto up = q.YAxis();\n            forwardUp[N * 0 + frame] = forward.GetX();\n            forwardUp[N * 1 + frame] = forward.GetY();\n            forwardUp[N * 2 + frame] = forward.GetZ();\n            forwardUp[N * 3 + frame] = up.GetX();\n            forwardUp[N * 4 + frame] = up.GetY();\n            forwardUp[N * 5 + frame] = up.GetZ();\n\n            q = targetPoses[frame][boneIdx].GetRotation();\n            forward = q.ZAxis();\n            up = q.YAxis();\n            targetForwardUp[N * 0 + frame] = forward.GetX();\n            targetForwardUp[N * 1 + frame] = forward.GetY();\n            targetForwardUp[N * 2 + frame] = forward.GetZ();\n            targetForwardUp[N * 3 + frame] = up.GetX();\n            targetForwardUp[N * 4 + frame] = up.GetY();\n            targetForwardUp[N * 5 + frame] = up.GetZ();\n        }\n\n        // correct trajectories:\n        for (uint32_t dim = 0; dim < 6; ++dim)\n        {\n            for (uint32_t frame = 0; frame < N; ++frame)\n            {\n                x(frame, 0) = forwardUp[N * dim + frame];\n                observations(frame, 0) = mask[frame] * targetForwardUp[N * dim + frame];\n            }\n\n            if (performChannelSmoothing)\n            {\n                SmoothChannels(x, mask);\n            }\n\n            corrector.Interpolate(\n                x_fixed,\n                observations,\n                x\n            );\n\n            // fill channel again:\n            for (uint32_t frame = 0; frame < N; ++frame)\n            {\n                forwardUp[N * dim + frame] = float(x_fixed(frame, 0));\n            }\n        }\n\n        for (uint32_t frame = 0; frame < N; ++frame)\n        {\n            Math::Vector forward = { forwardUp[N * 0 + frame] ,forwardUp[N * 1 + frame] ,forwardUp[N * 2 + frame] };\n            Math::Vector up = { forwardUp[N * 3 + frame] ,forwardUp[N * 4 + frame] ,forwardUp[N * 5 + frame] };\n\n            forward.Normalize3();\n            up.Normalize3();\n\n            poses[frame][boneIdx].SetRotation(Math::Quaternion::LookRotation(forward, up));\n        }\n    }\n\n    void CorrectJointRotations(\n        std::vector<Pose>& poses,\n        const std::vector<Pose>& targetPoses,\n        const std::vector<float>& fullBodyMask,\n        const Eigen::VectorXd& velocity_weights\n    )\n    {\n        auto N = poses.size();\n\n        // Create a trajectory corrector for fixing the full body fullBodyMask positions:\n        Eigen::VectorXd margins(N);\n        for( size_t i = 0; i < N; ++i )\n        {\n            margins[i] = fullBodyMask[i] ? 0.0f : -1.0f;\n        }\n        TrajectoryCorrector corrector(\n            margins,\n            pos_weight * 10,\n            vel_weight,\n            acc_weight,\n            velocity_weights\n        );\n\n        for (uint32_t boneIdx = 0; boneIdx < poses[0].size(); ++boneIdx)\n        {\n            CorrectRotationsForBone(\n                poses,\n                targetPoses,\n                fullBodyMask,\n                corrector,\n                boneIdx,\n                true\n            );\n        }\n    }\n\n    void DoEffectorIK(\n        std::vector<Pose>& poses,\n        const std::vector<Pose>& targetPoses,\n        const std::vector<float>& fullBodyMask,\n        const std::vector<Animation::ContactInfo>& endEffectorPins,\n        const std::vector<int>& joint_parents_vec,\n        const std::vector<Math::Transform>& defaultPose\n    )\n    {\n        // Apply IK for effector pins\n        auto N = poses.size();\n        std::map<uint32_t, std::vector<float>> jointCorrectionMasks;\n        std::vector<Pose> ikFixedPoses = poses;\n        for (auto& c : endEffectorPins)\n        {\n            auto jointIdx = c.jointIndex;\n\n            if(jointCorrectionMasks[jointIdx].empty())\n            {\n                // initialize to the full body constraint mask because we\n                // want to constrain that anyway:\n                jointCorrectionMasks[jointIdx] = fullBodyMask;\n            }\n\n            // Add a trajectory correction mask for the parent joint:\n            auto parentIdx = joint_parents_vec[jointIdx];\n            if(jointCorrectionMasks[parentIdx].empty())\n            {\n                // initialize to the full body constraint mask because we\n                // want to constrain that anyway:\n                jointCorrectionMasks[parentIdx] = fullBodyMask;\n            }\n\n            // Add a trajectory correction mask for its parent if this is\n            // 2 bone IK:\n            auto parentParentIdx = joint_parents_vec[parentIdx];\n            if(c.contactType == Animation::kTwoBone)\n            {\n                if(jointCorrectionMasks[parentParentIdx].empty())\n                {\n                    // initialize to the full body constraint mask because we\n                    // want to constrain that anyway:\n                    jointCorrectionMasks[parentParentIdx] = fullBodyMask;\n                }\n            }\n\n            for (uint32_t fixFrame = 0; fixFrame < fullBodyMask.size(); ++fixFrame)\n            {\n                if (c.contactMask[fixFrame])\n                {\n                    const auto targetGlobalTransform = Animation::JointLocalToGlobal(joint_parents_vec, jointIdx, targetPoses[fixFrame]);\n\n                    // flag the parent joint as fixed in its correction mask:\n                    jointCorrectionMasks[parentIdx][fixFrame] = 1;\n                    switch(c.contactType)\n                    {\n                        case Animation::kOneBone:\n                        {\n                            IK::OneBoneIk(\n                                ikFixedPoses[fixFrame],\n                                Math::Transform::Identity,\n                                jointIdx,\n                                1.0,\n                                targetGlobalTransform.GetTranslation(),\n                                joint_parents_vec\n                            );\n                            break;\n                        }\n                        case Animation::kTwoBone:\n                        {\n                            // flag the parent parent joint as fixed in its correction mask:\n                            jointCorrectionMasks[parentParentIdx][fixFrame] = 1;\n                            IK::TwoBoneIk(\n                                ikFixedPoses[fixFrame],\n                                Math::Transform::Identity,\n                                jointIdx,\n                                1.0,\n                                targetGlobalTransform.GetTranslation(),\n                                joint_parents_vec,\n                                c.hintOffset\n                            );\n                            break;\n                        }\n                    }\n\n                    // now we need to fix things so the global rotation of the joint\n                    // matches the input:\n                    jointCorrectionMasks[jointIdx][fixFrame] = 1;\n                    auto parentGlobalTransform = Animation::JointLocalToGlobal(joint_parents_vec, parentIdx, ikFixedPoses[fixFrame]);\n                    ikFixedPoses[fixFrame][jointIdx].SetRotation(\n                        targetGlobalTransform.GetRotation() * parentGlobalTransform.GetRotation().GetConjugate()\n                    );\n\n                }\n            }\n        }\n\n        // Applying the effector pin IK introduces popping into the animation,\n        // so let's apply the interpolator to all the joints we modified so as to\n        // line the trajectory up properly again:\n        Eigen::VectorXd margins(N);\n        for( auto &kv : jointCorrectionMasks)\n        {\n            for( size_t i = 0; i < N; ++i )\n            {\n                margins[i] = kv.second[i] ? 0.0f : -1.0f;\n            }\n            TrajectoryCorrector corrector(margins, pos_weight * 10, vel_weight, acc_weight);\n\n            CorrectRotationsForBone(\n                poses,\n                ikFixedPoses,\n                kv.second,\n                corrector,\n                kv.first,\n                false\n            );\n        }\n    }\n\n    void DoContactIK(\n        std::vector<Pose>& poses,\n        const std::vector<float>& fullBodyMask,\n        const std::vector<Animation::ContactInfo>& contacts,\n        const std::vector<Animation::ContactInfo>& endEffectorPins,\n        const std::vector<int>& joint_parents_vec,\n        const std::vector<Math::Transform>& defaultPose,\n        float contactThreshold,\n        bool has_double_ankle_joints\n    )\n    {\n        auto N = poses.size();\n        Eigen::VectorXd margins = Eigen::VectorXd::Zero(N);\n\n        // Apply IK to stabilize limbs on contacts\n        std::map<uint32_t, std::vector<float>> jointCorrectionMasks;\n        std::vector<Pose> ikFixedPoses = poses;\n\n        // Save original poses before any modifications (for double ankle correction later)\n        const std::vector<Pose> originalPoses = poses;\n\n        // Track which frames were corrected for each 2-bone contact (for double ankle correction later)\n        std::map<uint32_t, std::vector<bool>> twoBoneContactFrames;\n\n        auto addEndEffectorMask = [&](uint32_t jointIdx, uint32_t parentIdx, std::vector<float>& jointMask)\n        {\n            auto it = std::find_if(\n                endEffectorPins.begin(), endEffectorPins.end(),\n                [&](const auto &c)\n                {\n                    if(jointIdx == c.jointIndex)\n                    {\n                        return true;\n                    }\n                    return false;\n                }\n            );\n            if(it == endEffectorPins.end())\n            {\n                // We could be correcting the toe joint, in which case we need to use\n                // the parent joint instead:\n                it = std::find_if(\n                    endEffectorPins.begin(), endEffectorPins.end(),\n                    [&](const auto &c)\n                    {\n                        if(parentIdx == c.jointIndex)\n                        {\n                            return true;\n                        }\n                        return false;\n                    }\n                );\n            }\n            if(it != endEffectorPins.end())\n            {\n                const auto &msk = it->contactMask;\n                for(size_t i=0; i < msk.size(); ++i)\n                {\n                    if(msk[i])\n                    {\n                        jointMask[i] = 1.0f;\n                    }\n                }\n            }\n        };\n\n        // Process two bone contacts first:\n        for (auto& c : contacts)\n        {\n            if(c.contactType != Animation::kTwoBone)\n            {\n                continue;\n            }\n            const auto jointIdx = c.jointIndex;\n            auto parentIdx = joint_parents_vec[jointIdx];\n            auto parentParentIdx = joint_parents_vec[parentIdx];\n\n            auto jointMask = fullBodyMask;\n            addEndEffectorMask(jointIdx, parentIdx, jointMask);\n\n            // We'll actually be modifying 3 joints here:\n            // * The two joints immediately up in the hierarchy because of the 2 bone IK\n            // * The joint itself because we restore its original global rotation\n            if(jointCorrectionMasks[parentIdx].empty())\n            {\n                jointCorrectionMasks[parentIdx] = jointMask;\n            }\n            if(jointCorrectionMasks[parentParentIdx].empty())\n            {\n                jointCorrectionMasks[parentParentIdx] = jointMask;\n            }\n            if(jointCorrectionMasks[jointIdx].empty())\n            {\n                jointCorrectionMasks[jointIdx] = jointMask;\n            }\n\n            // Compute the intervals in which the joint is in contact with the floor:\n            auto contactIntervals = ComputeContactIntervals(c.contactMask, jointMask, contactThreshold);\n            FilterContactIntervals(contactIntervals, jointMask);\n\n            std::vector<Math::Vector> contactPoints;\n            std::vector<int> inContact;\n            FindContactPoints(\n                contactPoints,\n                inContact,\n                joint_parents_vec,\n                jointIdx,\n                poses,\n                contactIntervals,\n                jointMask,\n                c.contactMask.size(),\n                c.minHeight\n            );\n\n            for (uint32_t fixFrame = 0; fixFrame < fullBodyMask.size(); ++fixFrame)\n            {\n                if (inContact[fixFrame])\n                {\n                    auto target = contactPoints[fixFrame];\n                    jointCorrectionMasks[parentIdx][fixFrame] = 1.0f;\n                    jointCorrectionMasks[parentParentIdx][fixFrame] = 1.0f;\n                    jointCorrectionMasks[jointIdx][fixFrame] = 1.0f;\n\n                    // Track this frame for double ankle correction later\n                    if (has_double_ankle_joints)\n                    {\n                        if (twoBoneContactFrames[jointIdx].empty())\n                            twoBoneContactFrames[jointIdx].resize(fullBodyMask.size(), false);\n                        twoBoneContactFrames[jointIdx][fixFrame] = true;\n                    }\n\n                    // save the original global rotation of the joint:\n                    auto jointGlobalRotation = Animation::JointLocalToGlobal(\n                        joint_parents_vec,\n                        jointIdx,\n                        ikFixedPoses[fixFrame]\n                    ).GetRotation();\n\n                    const float w = TargetReachFalloff(\n                        joint_parents_vec,\n                        defaultPose,\n                        jointIdx,\n                        c.contactType,\n                        target,\n                        ikFixedPoses[fixFrame]\n                    );\n                    // std::cout << \"Frame \" << fixFrame << \": w=\" << w << std::endl;\n\n                    // apply the 2 bone IK:\n                    auto origParentRotation = ikFixedPoses[fixFrame][parentIdx].GetRotation();\n                    auto origParentParentRotation = ikFixedPoses[fixFrame][parentParentIdx].GetRotation();\n                    IK::TwoBoneIk(\n                        ikFixedPoses[fixFrame],\n                        Math::Transform::Identity,\n                        jointIdx,\n                        1.0f,\n                        target,\n                        joint_parents_vec,\n                        c.hintOffset\n                    );\n                    ikFixedPoses[fixFrame][parentIdx].SetRotation(Math::Quaternion::SLerp(origParentRotation, ikFixedPoses[fixFrame][parentIdx].GetRotation(), w));\n                    ikFixedPoses[fixFrame][parentParentIdx].SetRotation(Math::Quaternion::SLerp(origParentParentRotation, ikFixedPoses[fixFrame][parentParentIdx].GetRotation(), w));\n\n                    // restore previous global rotation of this joint:\n                    auto parentGloblalRotation = Animation::JointLocalToGlobal(\n                        joint_parents_vec,\n                        parentIdx,\n                        ikFixedPoses[fixFrame]\n                    ).GetRotation();\n\n                    jointCorrectionMasks[jointIdx][fixFrame] = 1.0f;\n                    ikFixedPoses[fixFrame][jointIdx].SetRotation(\n                        jointGlobalRotation * parentGloblalRotation.GetConjugate()\n                    );\n\n                    auto result = Animation::JointLocalToGlobal(\n                        joint_parents_vec,\n                        jointIdx,\n                        ikFixedPoses[fixFrame]\n                    ).GetTranslation();\n                }\n            }\n\n        }\n\n        for( auto &kv : jointCorrectionMasks)\n        {\n            for( size_t i = 0; i < N; ++i )\n            {\n                margins[i] = kv.second[i] ? 0.0f : -1.0f;\n            }\n            TrajectoryCorrector corrector(margins, pos_weight * 10, vel_weight, acc_weight);\n            CorrectRotationsForBone(\n                poses,\n                ikFixedPoses,\n                kv.second,\n                corrector,\n                kv.first,\n                false\n            );\n        }\n        jointCorrectionMasks.clear();\n\n        // Then process one bone contacts:\n        for(auto &c : contacts)\n        {\n            if(c.contactType != Animation::kOneBone)\n            {\n                continue;\n            }\n            const auto jointIdx = c.jointIndex;\n            auto parentIdx = joint_parents_vec[jointIdx];\n\n            // We can't touch frames that have been constrained with full body constraints\n            // or the end effector constraints for this joint, so let's combine fullBodyMask\n            // with the end effector mask for this joint if it exists so we can use that\n            // information later:\n            auto jointMask = fullBodyMask;\n            addEndEffectorMask(jointIdx, parentIdx, jointMask);\n\n            // Add a trajectory correction mask for the parent joint:\n            if(jointCorrectionMasks[parentIdx].empty())\n            {\n                jointCorrectionMasks[parentIdx] = jointMask;\n            }\n\n            // Compute the intervals in which the joint is in contact with the floor:\n            auto contactIntervals = ComputeContactIntervals(c.contactMask, jointMask, contactThreshold);\n            FilterContactIntervals(contactIntervals, jointMask, true);\n            for(const auto &interval : contactIntervals)\n            {\n                for (int fixFrame = interval.first; fixFrame < interval.second; ++fixFrame)\n                {\n                    // All we're going to do here is stick the joint to the floor -\n                    // we're going to allow it to slide from side to side.\n\n                    // Find a target position that lies on the floor by iteratively\n                    // projecting the joint to the floor (pure laziness really, this could\n                    // be done analytically):\n                    Math::Vector parentPos = Animation::JointLocalToGlobal(joint_parents_vec, parentIdx, ikFixedPoses[fixFrame]).GetTranslation();\n                    Math::Vector target = Animation::JointLocalToGlobal(joint_parents_vec, jointIdx, ikFixedPoses[fixFrame]).GetTranslation();\n                    float jointLength = (target - parentPos).GetLength3();\n                    for(int32_t i = 0; i < 10; ++i)\n                    {\n                        target.SetY(c.minHeight);\n                        auto dir = (target - parentPos).GetNormalized3();\n                        target = parentPos + dir * jointLength;\n                    }\n\n                    IK::OneBoneIk(\n                        ikFixedPoses[fixFrame],\n                        Math::Transform::Identity,\n                        jointIdx,\n                        1.0f,\n                        target,\n                        joint_parents_vec\n                    );\n                    jointCorrectionMasks[parentIdx][fixFrame] = 1.0f;\n                }\n            }\n\n        }\n\n        // Fixing the contacts with IK will introduce popping into the animation,\n        // so let's apply the interpolator to all the joints we modified so as to\n        // line the trajectory up properly again:\n        for( auto &kv : jointCorrectionMasks)\n        {\n            for( size_t i = 0; i < N; ++i )\n            {\n                margins[i] = kv.second[i] ? 0.0f : -1.0f;\n            }\n            TrajectoryCorrector corrector(margins, pos_weight * 10, vel_weight, acc_weight);\n            CorrectRotationsForBone(\n                poses,\n                ikFixedPoses,\n                kv.second,\n                corrector,\n                kv.first,\n                false\n            );\n        }\n\n        if (has_double_ankle_joints)\n        {\n            // Maps to save target positions BEFORE 2-bone IK modifies them\n            std::map<uint32_t, std::map<uint32_t, Math::Vector>> savedFirstAnkleTargets;  // [firstAnkleIdx][frame] -> position\n            std::map<uint32_t, std::map<uint32_t, Math::Vector>> savedToeTargets;         // [firstAnkleIdx][frame] -> position\n            std::map<uint32_t, uint32_t> contactToToeIdx;  // firstAnkleIdx -> toeIdx\n\n            // Find toe joints for each leg\n            for (const auto& tc : contacts)\n            {\n                if (tc.contactType == Animation::kOneBone)\n                {\n                    // The parent of the toe is the 1st ankle\n                    int parentIdx = joint_parents_vec[tc.jointIndex];\n                    if (parentIdx >= 0)\n                    {\n                        contactToToeIdx[parentIdx] = tc.jointIndex;\n                    }\n                }\n            }\n\n            // For each 2-bone contact, correct the parent (2nd ankle) joint\n            for (auto& c : contacts)\n            {\n                if (c.contactType != Animation::kTwoBone)\n                    continue;\n\n                const auto firstAnkleIdx = c.jointIndex;\n                const auto secondAnkleIdx = joint_parents_vec[firstAnkleIdx];\n                const auto kneeIdx = joint_parents_vec[secondAnkleIdx];\n                const auto hipIdx = joint_parents_vec[kneeIdx];\n\n                if (hipIdx < 0) continue;  // safety check\n\n                // Get saved contact frames for this ankle\n                auto it = twoBoneContactFrames.find(firstAnkleIdx);\n                if (it == twoBoneContactFrames.end())\n                    continue;\n                const auto& contactFrames = it->second;\n\n                // Add correction mask for knee and hip\n                auto jointMask = fullBodyMask;\n                addEndEffectorMask(firstAnkleIdx, secondAnkleIdx, jointMask);\n\n                if (jointCorrectionMasks[kneeIdx].empty())\n                    jointCorrectionMasks[kneeIdx] = jointMask;\n                if (jointCorrectionMasks[hipIdx].empty())\n                    jointCorrectionMasks[hipIdx] = jointMask;\n\n                for (uint32_t fixFrame = 0; fixFrame < fullBodyMask.size(); ++fixFrame)\n                {\n                    // Only correct frames where the 1st ankle was corrected\n                    if (!contactFrames[fixFrame])\n                        continue;\n\n                    // *** SAVE TARGET POSITIONS BEFORE 2-BONE IK ***\n                    savedFirstAnkleTargets[firstAnkleIdx][fixFrame] = Animation::JointLocalToGlobal(\n                        joint_parents_vec, firstAnkleIdx, ikFixedPoses[fixFrame]).GetTranslation();\n\n                    if (contactToToeIdx.count(firstAnkleIdx))\n                    {\n                        savedToeTargets[firstAnkleIdx][fixFrame] = Animation::JointLocalToGlobal(\n                            joint_parents_vec, contactToToeIdx[firstAnkleIdx], ikFixedPoses[fixFrame]).GetTranslation();\n                    }\n\n                    // Get original global transforms (before any IK corrections)\n                    auto originalFirstAnkleGlobal = Animation::JointLocalToGlobal(\n                        joint_parents_vec, firstAnkleIdx, originalPoses[fixFrame]);\n                    auto originalSecondAnkleGlobal = Animation::JointLocalToGlobal(\n                        joint_parents_vec, secondAnkleIdx, originalPoses[fixFrame]);\n\n                    // Compute delta from 1st ankle to 2nd ankle in original animation\n                    auto deltaFirstToSecond = originalFirstAnkleGlobal.GetDeltaToOther(originalSecondAnkleGlobal);\n\n                    // Get corrected 1st ankle global transform\n                    auto correctedFirstAnkleGlobal = Animation::JointLocalToGlobal(\n                        joint_parents_vec, firstAnkleIdx, ikFixedPoses[fixFrame]);\n\n                    // Apply the original delta to the corrected 1st ankle to get target for 2nd ankle\n                    auto target = (deltaFirstToSecond * correctedFirstAnkleGlobal).GetTranslation();\n\n                    // print current and target second ankle positions\n                    auto currPos = Animation::JointLocalToGlobal(\n                        joint_parents_vec, secondAnkleIdx, ikFixedPoses[fixFrame]).GetTranslation();\n\n                    // Apply 2-bone IK: Hip -> Knee -> 2nd Ankle\n                    IK::TwoBoneIk(\n                        ikFixedPoses[fixFrame],\n                        Math::Transform::Identity,\n                        secondAnkleIdx,\n                        1.0f,\n                        target,\n                        joint_parents_vec,\n                        c.hintOffset\n                    );\n\n                    // auto correctedPos = Animation::JointLocalToGlobal(\n                    //     joint_parents_vec, secondAnkleIdx, ikFixedPoses[fixFrame]).GetTranslation();\n                    // std::cout << \"Frame \" << fixFrame << \": target second ankle=(\" << target.GetX() << \", \" << target.GetY() << \", \" << target.GetZ() << \"), corrected second ankle position=(\" << correctedPos.GetX() << \", \" << correctedPos.GetY() << \", \" << correctedPos.GetZ() << \")\" << std::endl;\n\n                    jointCorrectionMasks[kneeIdx][fixFrame] = 1.0f;\n                    jointCorrectionMasks[hipIdx][fixFrame] = 1.0f;\n                }\n            }\n\n            // Smooth the corrected joints\n            for (auto& kv : jointCorrectionMasks)\n            {\n                for (size_t i = 0; i < N; ++i)\n                    margins[i] = kv.second[i] ? 0.0f : -1.0f;\n\n                TrajectoryCorrector corrector(margins, pos_weight * 10, vel_weight, acc_weight);\n                CorrectRotationsForBone(poses, ikFixedPoses, kv.second, corrector, kv.first, false);\n            }\n\n            // *** PHASE 2: 1-bone IKs to restore 1st ankle and toe ***\n            jointCorrectionMasks.clear();\n\n            for (auto& c : contacts)\n            {\n                if (c.contactType != Animation::kTwoBone)\n                    continue;\n\n                const auto firstAnkleIdx = c.jointIndex;\n                const auto secondAnkleIdx = joint_parents_vec[firstAnkleIdx];\n\n                auto it = twoBoneContactFrames.find(firstAnkleIdx);\n                if (it == twoBoneContactFrames.end())\n                    continue;\n\n                // Setup correction masks\n                auto jointMask = fullBodyMask;\n                addEndEffectorMask(firstAnkleIdx, secondAnkleIdx, jointMask);\n\n                if (jointCorrectionMasks[secondAnkleIdx].empty())\n                    jointCorrectionMasks[secondAnkleIdx] = jointMask;\n                if (jointCorrectionMasks[firstAnkleIdx].empty())\n                    jointCorrectionMasks[firstAnkleIdx] = jointMask;\n\n                for (uint32_t fixFrame = 0; fixFrame < fullBodyMask.size(); ++fixFrame)\n                {\n                    if (!it->second[fixFrame])\n                        continue;\n\n                    // 1-bone IK: Rotate 2nd ankle so 1st ankle reaches saved target\n                    IK::OneBoneIk(\n                        ikFixedPoses[fixFrame],\n                        Math::Transform::Identity,\n                        firstAnkleIdx,\n                        1.0f,\n                        savedFirstAnkleTargets[firstAnkleIdx][fixFrame],\n                        joint_parents_vec\n                    );\n                    jointCorrectionMasks[secondAnkleIdx][fixFrame] = 1.0f;\n\n                    // auto target = savedFirstAnkleTargets[firstAnkleIdx][fixFrame];\n                    // auto corrected = Animation::JointLocalToGlobal(\n                    //     joint_parents_vec, firstAnkleIdx, ikFixedPoses[fixFrame]).GetTranslation();\n                    // std::cout << \"Frame \" << fixFrame << \": target first ankle=(\" << target.GetX() << \", \" << target.GetY() << \", \" << target.GetZ() << \"), corrected first ankle=(\" << corrected.GetX() << \", \" << corrected.GetY() << \", \" << corrected.GetZ() << \")\" << std::endl;\n\n                    // 1-bone IK: Rotate 1st ankle so toe reaches saved target\n                    if (contactToToeIdx.count(firstAnkleIdx) && savedToeTargets[firstAnkleIdx].count(fixFrame))\n                    {\n                        IK::OneBoneIk(\n                            ikFixedPoses[fixFrame],\n                            Math::Transform::Identity,\n                            contactToToeIdx[firstAnkleIdx],\n                            1.0f,\n                            savedToeTargets[firstAnkleIdx][fixFrame],\n                            joint_parents_vec\n                        );\n                        jointCorrectionMasks[firstAnkleIdx][fixFrame] = 1.0f;\n                    }\n\n                    // target = savedToeTargets[firstAnkleIdx][fixFrame];\n                    // corrected = Animation::JointLocalToGlobal(\n                    //     joint_parents_vec, contactToToeIdx[firstAnkleIdx], ikFixedPoses[fixFrame]).GetTranslation();\n                    // std::cout << \"Frame \" << fixFrame << \": target toe=(\" << target.GetX() << \", \" << target.GetY() << \", \" << target.GetZ() << \"), corrected toe=(\" << corrected.GetX() << \", \" << corrected.GetY() << \", \" << corrected.GetZ() << \")\" << std::endl;\n                }\n            }\n\n            // Smooth 2nd ankle and 1st ankle\n            for (auto& kv : jointCorrectionMasks)\n            {\n                for (size_t i = 0; i < N; ++i)\n                    margins[i] = kv.second[i] ? 0.0f : -1.0f;\n\n                TrajectoryCorrector corrector(margins, pos_weight * 10, vel_weight, acc_weight);\n                CorrectRotationsForBone(poses, ikFixedPoses, kv.second, corrector, kv.first, false);\n            }\n        }\n    }\n\n}\n\n\nMath::Transform Animation::JointLocalToGlobal(\n    const std::vector<int>& joint_parents_vec,\n    int32_t index,\n    const Pose& localPose,\n    const Math::Transform& rootTx)\n{\n    Math::Transform worldTx = Math::Transform::Identity;\n    while (index > -1)\n    {\n        worldTx = worldTx * localPose[index];\n        index = joint_parents_vec[index];\n    }\n\n    return worldTx * rootTx;\n}\n\nvoid Animation::CorrectMotion(\n    std::vector<Pose>& poses,\n    const std::vector<Pose>& targetPoses,\n    const std::vector<float>& fullBodyMask,\n    const std::vector<float>& rootMask,\n    const std::vector<ContactInfo>& contacts,\n    const std::vector<ContactInfo>& endEffectorPins,\n    const std::vector<int>& joint_parents_vec,\n    const std::vector<Math::Transform>& defaultPose,\n    float contactThreshold,\n    float root_margin,\n    bool has_double_ankle_joints\n)\n{\n\n    // Calculate some weights so we can preserve velocities more strongly on frames where\n    // the root velocity is low\n    const uint32_t N = poses.size();\n    Eigen::VectorXd velocity_weights(N);\n    for (uint32_t frame = 1; frame < N; ++frame)\n    {\n        // work out xz velocity for this frame:\n        float xdiff = poses[frame][0].GetTranslation()[0] - poses[frame - 1][0].GetTranslation()[0];\n        float zdiff = poses[frame][0].GetTranslation()[2] - poses[frame - 1][0].GetTranslation()[2];\n\n        // find velocity magnitude, divided by a typical walking speed:\n        float v_mag = sqrtf(xdiff*xdiff + zdiff*zdiff) / 0.05f;\n\n        // weight lower velocities higher so that the corrector doesn't make the character drift around\n        // when it's supposed to stand still:\n        v_mag = std::max(v_mag, 1.0f/1000.0f);\n        velocity_weights(frame) = 1.0f / v_mag;\n    }\n    velocity_weights[0] = velocity_weights[1];\n\n    // Correct root y coordinates.\n    // This will warp the root y coordinates in \"poses\" so they match the root y coordinates\n    // in \"targetPoses\", on frames where the root y coordinates are constrained, ie the frames\n    // where fullBodyMask = 1.\n    // In addition to this, it preserves the root y coordinates in \"pose\" on frames where foot\n    // contacts are active, to avoid mushiness when characters are jumping.\n    CorrectHipsY(\n        poses,\n        targetPoses,\n        fullBodyMask,\n        contacts,\n        contactThreshold\n    );\n\n    // Correct root xz coordinates:\n    // This will warp the root xz coordinates in \"poses\" so they match the xz coordinates\n    // in \"targetPoses\" on frames where fullBodyMask = 1, and warp them so they're within\n    // \"root_margin\" units of targetPoses on frames where rootMask = 1.\n    CorrectHipsXZ(\n        poses,\n        targetPoses,\n        fullBodyMask,\n        rootMask,\n        endEffectorPins,\n        velocity_weights,\n        root_margin\n    );\n\n    // Correct joint rotations by warping the rotations so they match targetPoses on frames\n    // where fullBodyMask = 1:\n    CorrectJointRotations(\n        poses,\n        targetPoses,\n        fullBodyMask,\n        velocity_weights\n    );\n\n    // Apply IK for end effector pins\n    DoEffectorIK(\n        poses,\n        targetPoses,\n        fullBodyMask,\n        endEffectorPins,\n        joint_parents_vec,\n        defaultPose\n    );\n\n    // Apply IK to stabilize limbs on contacts\n    DoContactIK(\n        poses,\n        fullBodyMask,\n        contacts,\n        endEffectorPins,\n        joint_parents_vec,\n        defaultPose,\n        contactThreshold,\n        has_double_ankle_joints\n    );\n    // std::cout << \"Running post processing.\" << std::endl;\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/AnimProcessing/Utility.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Math/Transform.h\"\n\n#include <string>\n#include <vector>\n\nnamespace Animation\n{\n    enum IKType {\n        kOneBone,\n        kTwoBone\n    };\n\n    Math::Transform JointLocalToGlobal(\n        const std::vector<int>& joint_parents_vec,\n        int32_t index,\n        const std::vector<Math::Transform>& localPose,\n        const Math::Transform& rootTx = Math::Transform::Identity\n    );\n\n    struct ContactInfo {\n        // index IK contact joint:\n        int jointIndex;\n        // mask indicating which frames are in contact:\n        std::vector<float> contactMask;\n        // contact type:\n        IKType contactType = kTwoBone;\n\n        // Extra info for TwoBoneIK\n        Math::Vector hintOffset = Math::Vector::Zero;\n\n        float minHeight = 0.0f;\n    };\n\n    void CorrectMotion(\n        std::vector< std::vector<Math::Transform> >& poses,\n        const std::vector< std::vector<Math::Transform> >& targetPoses,\n        const std::vector<float>& mask,\n        const std::vector<float>& rootMask,\n        const std::vector<ContactInfo>& contacts,\n        const std::vector<ContactInfo>& endEffectorPins,\n        const std::vector<int>& joint_parents_vec,\n        const std::vector<Math::Transform>& defaultPose,\n        float contactThreshold,\n        float root_margin,\n        bool has_double_ankle_joints\n    );\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/BindingsPython.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"AnimProcessing/Utility.h\"\n\n#ifdef _WIN32\n#pragma warning(push)\n#pragma warning(disable : 4623 4191 4686 4868 5219 4191 4355)\n#endif\n#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\n#include <pybind11/stl.h>\n#ifdef _WIN32\n#pragma warning(pop)\n#endif\n\nnamespace py = pybind11;\n\nfloat strip_nan_inf(float x) noexcept\n{\n    if (std::isnan(x)) return 0;\n    if (std::isinf(x)) return 0;\n    return x;\n}\n\nvoid correct_motion(\n    py::array_t<float> &rootTranslations,\n    py::array_t<float> &jointRotations,\n    const py::array_t<float>& rootTranslationsTarget,\n    const py::array_t<float>& jointRotationsTarget,\n    const py::array_t<float>& fullPoseMask,\n    const py::array_t<float>& leftHandMask,\n    const py::array_t<float>& rightHandMask,\n    const py::array_t<float>& leftFootMask,\n    const py::array_t<float>& rightFootMask,\n    const py::array_t<float>& rootMask,\n    const py::array_t<float>& contacts,\n    const py::list& joint_parents,\n    const py::list& joint_ref_translations,\n    const py::list& joint_ref_rotations,\n    int left_hand_idx,\n    int right_hand_idx,\n    int left_foot_idx,\n    int right_foot_idx,\n    float contact_threshold,\n    float root_margin,\n    bool has_double_ankle_joints\n)\n{\n    if(joint_parents.size() != joint_ref_translations.size())\n    {\n        throw std::runtime_error(\"correct_motion python bindings: joint_parents and joint_ref_translations must have the same size\");\n    }\n    if(joint_parents.size() != joint_ref_rotations.size())\n    {\n        throw std::runtime_error(\"correct_motion python bindings: joint_parents and joint_ref_rotations must have the same size\");\n    }\n    if(left_hand_idx < 0 || right_hand_idx < 0 || left_foot_idx < 0 || right_foot_idx < 0)\n    {\n        throw std::runtime_error(\"correct_motion python bindings: left_hand_idx, right_hand_idx, left_foot_idx, and right_foot_idx must be non-negative\");\n    }\n    if(left_hand_idx >= joint_parents.size() || right_hand_idx >= joint_parents.size() || left_foot_idx >= joint_parents.size() || right_foot_idx >= joint_parents.size())\n    {\n        throw std::runtime_error(\"correct_motion python bindings: left_hand_idx, right_hand_idx, left_foot_idx, and right_foot_idx must be less than the number of joints\");\n    }\n\n    std::vector<Math::Transform> defaultPose(joint_parents.size());\n    for (size_t i = 0; i < joint_ref_translations.size(); ++i)\n    {\n        if (!py::isinstance<py::list>(joint_ref_translations[i]))\n        {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_translations to be a list of lists\");\n        }\n        py::list inner_list = joint_ref_translations[i].cast<py::list>();\n        if (inner_list.size() != 3) {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_translations to be a list of lists, length 3\");\n        }\n\n        if (\n            !py::isinstance<py::float_>(inner_list[0]) ||\n            !py::isinstance<py::float_>(inner_list[1]) ||\n            !py::isinstance<py::float_>(inner_list[2])\n        )\n        {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_translations to be a list of lists, length 3, float values\");\n        }\n\n\n        if (!py::isinstance<py::list>(joint_ref_rotations[i]))\n        {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_rotations to be a list of lists\");\n        }\n        py::list inner_list_rot = joint_ref_rotations[i].cast<py::list>();\n        if (inner_list_rot.size() != 4) {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_rotations to be a list of lists, length 4\");\n        }\n\n        if (\n            !py::isinstance<py::float_>(inner_list_rot[0]) ||\n            !py::isinstance<py::float_>(inner_list_rot[1]) ||\n            !py::isinstance<py::float_>(inner_list_rot[2]) ||\n            !py::isinstance<py::float_>(inner_list_rot[3])\n        )\n        {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_ref_rotations to be a list of lists, length 4, float values\");\n        }\n\n        defaultPose[i].SetTranslation(Math::Vector(\n            inner_list[0].cast<float>(),\n            inner_list[1].cast<float>(),\n            inner_list[2].cast<float>()));\n        defaultPose[i].SetRotation(Math::Quaternion(\n            inner_list_rot[0].cast<float>(),\n            inner_list_rot[1].cast<float>(),\n            inner_list_rot[2].cast<float>(),\n            inner_list_rot[3].cast<float>()\n        ));\n    }\n\n    std::vector<int> joint_parents_vec(joint_parents.size());\n    for (size_t i = 0; i < joint_parents.size(); ++i)\n    {\n        if (!py::isinstance<py::int_>(joint_parents[i]))\n        {\n            throw std::runtime_error(\"correct_motion python bindings: Expected joint_parents to be a list of ints\");\n        }\n        joint_parents_vec[i] = joint_parents[i].cast<int>();\n        if (joint_parents_vec[i] >= (int)joint_parents.size())\n        {\n            throw std::runtime_error(\"correct_motion python bindings: joint_parents must be a list of ints, and all values must be less than the number of joints\");\n        }\n    }\n\n    size_t num_joints = defaultPose.size();\n    size_t gen_length = fullPoseMask.size();\n\n    if(\n        leftHandMask.size() != (int)gen_length ||\n        rightHandMask.size() != (int)gen_length ||\n        leftFootMask.size() != (int)gen_length ||\n        rightFootMask.size() != (int)gen_length ||\n        rootMask.size() != (int)gen_length\n    )\n    {\n        throw std::runtime_error(\"correct_motion python bindings: all masks must have the same size\");\n    }\n\n    if(rootTranslations.size() != 3 * (int)gen_length)\n    {\n        throw std::runtime_error(\"correct_motion python bindings: rootTranslations has the wrong size\");\n    }\n    if(jointRotations.size() != 4 * (int)num_joints * (int)gen_length)\n    {\n        throw std::runtime_error(\"correct_motion python bindings: jointRotations has the wrong size\");\n    }\n\n    if(rootTranslationsTarget.size() != 3 * (int)gen_length)\n    {\n        throw std::runtime_error(\"correct_motion python bindings: rootTranslationsTarget has the wrong size\");\n    }\n    if(jointRotationsTarget.size() != 4 * (int)num_joints * (int)gen_length)\n    {\n        throw std::runtime_error(\"correct_motion python bindings: jointRotationsTarget has the wrong size\");\n    }\n\n    std::vector<Animation::ContactInfo> endEffectorPins(4);\n    endEffectorPins[0].jointIndex = left_hand_idx;\n    endEffectorPins[0].hintOffset = Math::Vector(0.0f, 0.0f, -0.1f);\n\n    endEffectorPins[1].jointIndex = right_hand_idx;\n    endEffectorPins[1].hintOffset = Math::Vector(0.0f, 0.0f, -0.1f);\n\n    endEffectorPins[2].jointIndex = left_foot_idx;\n    endEffectorPins[2].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f);\n\n    endEffectorPins[3].jointIndex = right_foot_idx;\n    endEffectorPins[3].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f);\n\n    endEffectorPins[0].contactMask.reserve(gen_length);\n    endEffectorPins[1].contactMask.reserve(gen_length);\n    endEffectorPins[2].contactMask.reserve(gen_length);\n    endEffectorPins[3].contactMask.reserve(gen_length);\n    for(size_t i = 0; i < gen_length; ++i)\n    {\n        endEffectorPins[0].contactMask.push_back((1.0f - fullPoseMask.at(i)) * leftHandMask.at(i));\n        endEffectorPins[1].contactMask.push_back((1.0f - fullPoseMask.at(i)) * rightHandMask.at(i));\n        endEffectorPins[2].contactMask.push_back((1.0f - fullPoseMask.at(i)) * leftFootMask.at(i));\n        endEffectorPins[3].contactMask.push_back((1.0f - fullPoseMask.at(i)) * rightFootMask.at(i));\n    }\n\n    std::vector<Animation::ContactInfo> contactInfo(2);\n\n    auto footTranslation = Animation::JointLocalToGlobal(\n        joint_parents_vec,\n        right_foot_idx,\n        defaultPose\n    ).GetTranslation();\n\n    contactInfo[0].jointIndex = right_foot_idx;\n    contactInfo[0].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f);\n    contactInfo[0].minHeight = footTranslation.GetY();\n\n    footTranslation = Animation::JointLocalToGlobal(\n        joint_parents_vec,\n        left_foot_idx,\n        defaultPose\n    ).GetTranslation();\n\n    contactInfo[1].jointIndex = left_foot_idx;\n    contactInfo[1].hintOffset = Math::Vector(0.0f, 0.0f, 0.1f);\n    contactInfo[1].minHeight = footTranslation.GetY();\n\n    auto& rContacts = contactInfo[0].contactMask;\n    auto& lContacts = contactInfo[1].contactMask;\n\n    rContacts.resize(fullPoseMask.size());\n    lContacts.resize(fullPoseMask.size());\n    for (int i = 0; i < fullPoseMask.size(); ++i)\n    {\n        // don't flag it as a contact if it's been masked:\n        rContacts[i] = rightFootMask.at(i) ? 0 : contacts.at(4 * i + 2);\n        lContacts[i] = leftFootMask.at(i) ? 0 : contacts.at(4 * i + 0);\n\n        // Flag the heel as a contact if the toe is a contact:\n        rContacts[i] = std::min((rightFootMask.at(i) ? 0 : contacts.at(4 * i + 3)) + rContacts[i], 1.0f);\n        lContacts[i] = std::min((leftFootMask.at(i) ? 0 : contacts.at(4 * i + 1)) + lContacts[i], 1.0f);\n    }\n\n    int left_toe_idx = -1;\n    int right_toe_idx = -1;\n    for(int i = 0; i < num_joints; ++i)\n    {\n        if(joint_parents_vec[i] == left_foot_idx)\n        {\n            left_toe_idx = i;\n        }\n        if(joint_parents_vec[i] == right_foot_idx)\n        {\n            right_toe_idx = i;\n        }\n    }\n\n    if(left_toe_idx != -1 && right_toe_idx != -1)\n    {\n        auto toeTranslation = Animation::JointLocalToGlobal(\n            joint_parents_vec,\n            right_toe_idx,\n            defaultPose\n        ).GetTranslation();\n\n        contactInfo.resize(4);\n        contactInfo[2].jointIndex = right_toe_idx;\n        contactInfo[2].contactType = Animation::kOneBone;\n        contactInfo[2].minHeight = toeTranslation.GetY();\n\n        contactInfo[3].jointIndex = left_toe_idx;\n        contactInfo[3].contactType = Animation::kOneBone;\n        contactInfo[3].minHeight = toeTranslation.GetY();\n\n        auto& rToeContacts = contactInfo[2].contactMask;\n        auto& lToeContacts = contactInfo[3].contactMask;\n\n        // fill up the ankle contacts:\n        rToeContacts.resize(fullPoseMask.size());\n        lToeContacts.resize(fullPoseMask.size());\n\n        for (int i = 0; i < fullPoseMask.size(); ++i)\n        {\n            // don't flag it as a contact if it's been masked:\n            rToeContacts[i] = rightFootMask.at(i) ? 0 : contacts.at(4 * i + 3);\n            lToeContacts[i] = leftFootMask.at(i) ? 0 : contacts.at(4 * i + 1);\n        }\n    }\n\n\n    auto setTransforms = [gen_length, num_joints](\n        std::vector< std::vector<Math::Transform> > &poses,\n        const py::array_t<float> &rootTranslations,\n        const py::array_t<float> &jointRotations\n    )\n    {\n        for (size_t f = 0; f < gen_length; ++f)\n        {\n            poses[f][0].SetTranslation({\n                strip_nan_inf(rootTranslations.at(3*f+0)),\n                strip_nan_inf(rootTranslations.at(3*f+1)),\n                strip_nan_inf(rootTranslations.at(3*f+2))\n            });\n        }\n\n        for (size_t f = 0; f < gen_length; ++f)\n        {\n            for (size_t j = 0; j < num_joints; ++j)\n            {\n                // x y z w order:\n                Math::Quaternion q(\n                    strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 1)),\n                    strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 2)),\n                    strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 3)),\n                    strip_nan_inf(jointRotations.at(4 * (num_joints * f + j) + 0))\n                );\n                q.Normalize();\n                poses[f][j].SetRotation(q);\n            }\n        }\n    };\n\n    std::vector< std::vector<Math::Transform> > posesFixed(gen_length, defaultPose);\n    setTransforms(posesFixed, rootTranslations, jointRotations);\n\n    std::vector< std::vector<Math::Transform> > posesTarget(gen_length, defaultPose);\n    setTransforms(posesTarget, rootTranslationsTarget, jointRotationsTarget);\n\n    std::vector<float> fullPoseMask_vec;\n    std::vector<float> rootMask_vec;\n    for (size_t f = 0; f < gen_length; ++f)\n    {\n        fullPoseMask_vec.push_back(fullPoseMask.at(f));\n        rootMask_vec.push_back(rootMask.at(f));\n    }\n\n    Animation::CorrectMotion(\n        posesFixed,\n        posesTarget,\n        fullPoseMask_vec,\n        rootMask_vec,\n        contactInfo,\n        endEffectorPins,\n        joint_parents_vec,\n        defaultPose,\n        contact_threshold,\n        root_margin,\n        has_double_ankle_joints\n    );\n\n    for (size_t f = 0; f < gen_length; ++f)\n    {\n        auto t = posesFixed[f][0].GetTranslation();\n        rootTranslations.mutable_at(3*f+0) = t.GetX();\n        rootTranslations.mutable_at(3*f+1) = t.GetY();\n        rootTranslations.mutable_at(3*f+2) = t.GetZ();\n    }\n\n    for (size_t f = 0; f < gen_length; ++f)\n    {\n        for (size_t j = 0; j < num_joints; ++j)\n        {\n            auto q = posesFixed[f][j].GetRotation();\n            // w x y z order\n            jointRotations.mutable_at(4 * (num_joints * f + j) + 0) = ((float*)&q)[3];\n            jointRotations.mutable_at(4 * (num_joints * f + j) + 1) = ((float*)&q)[0];\n            jointRotations.mutable_at(4 * (num_joints * f + j) + 2) = ((float*)&q)[1];\n            jointRotations.mutable_at(4 * (num_joints * f + j) + 3) = ((float*)&q)[2];\n        }\n    }\n\n}\n\nPYBIND11_MODULE(_motion_correction, m) {\n    m.doc() = \"Motion Correction Python bindings\";\n    m.def(\"correct_motion\", &correct_motion);\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Compiler.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n// Compiler specific defines\n\n// Finds the compiler type and version.\n#if defined(__clang__)\n#    define COMPILER_CLANG\n#elif defined(__GNUC__) // Check after Clang, as Clang defines this too\n#    define COMPILER_GNUC\n#elif defined(_MSC_VER) // Check after Clang, since we could be building with either within VS\n#    define COMPILER_MSVC\n#else\n#    pragma error \"Unknown compiler. \"\n#endif\n\n#if defined(COMPILER_MSVC)\n\t#define FORCE_INLINE __forceinline\n#elif defined(COMPILER_GNUC)\n\t#define FORCE_INLINE inline __attribute__((always_inline))\n#endif\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Debug.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Platform.h\"\n\n#define ASSERT( cond ) do { if( !(cond) ) { DEBUG_BREAK(); } } while( 0 )\n#define HALT() { DEBUG_BREAK(); }\n#define UNIMPLEMENTED_FUNCTION() { DEBUG_BREAK(); }\n#define UNREACHABLE_CODE() { DEBUG_BREAK(); }\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Constants.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include <limits>\n\n// Mathematical constants\n\nnamespace Math\n{\n    static constexpr float const Epsilon = 1.0e-06f;\n    static constexpr float const LargeEpsilon = 1.0e-04f;\n    static constexpr float const HugeEpsilon = 1.0e-02f;\n    static constexpr float const Pi = 3.141592654f;\n    static constexpr float const TwoPi = 6.283185307f;\n    static constexpr float const OneDivPi = 0.318309886f;\n    static constexpr float const OneDivTwoPi = 0.159154943f;\n    static constexpr float const PiDivTwo = 1.570796327f;\n    static constexpr float const PiDivFour = 0.785398163f;\n\n    static constexpr float const SqrtTwo = 1.4142135623730950488016887242097f;\n    static constexpr float const OneDivSqrtTwo = 1.0f / SqrtTwo;\n\n    static constexpr float const DegreesToRadians = 0.0174532925f;\n    static constexpr float const RadiansToDegrees = 57.2957795f;\n\n    static constexpr float const Infinity = std::numeric_limits<float>::infinity();\n    static constexpr float const QNaN = std::numeric_limits<float>::quiet_NaN();\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Matrix.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"Matrix.h\"\n\n#include <cfloat>\n\nusing namespace Math;\n\nnamespace\n{\n    static bool CheckForZeroScaleInRow(float scale, const Vector& row)\n    {\n        float const absScale = Math::Abs(scale);\n\n        for (int i = 0; i < 3; i++)\n        {\n            if (absScale < 1 && Math::Abs(row[i]) >= FLT_MAX * absScale)\n            {\n                return false;\n            }\n        }\n\n        return true;\n    }\n\n    static bool ExtractAndRemoveScalingAndShear(Matrix& matrix, Vector& scale, Vector& shear)\n    {\n        scale = Vector::Zero;\n        shear = Vector::Zero;\n\n        Float3 scaleValues = Float3::Zero;\n        Float3 shearValues = Float3::Zero;\n\n        // This implementation follows the technique described in the paper by\n        // Spencer W. Thomas in the Graphics Gems II article: \"Decomposing a\n        // Matrix into Simple Transformations\", p. 320.\n\n        Vector row[3];\n        row[0] = Vector(matrix[0][0], matrix[0][1], matrix[0][2]);\n        row[1] = Vector(matrix[1][0], matrix[1][1], matrix[1][2]);\n        row[2] = Vector(matrix[2][0], matrix[2][1], matrix[2][2]);\n\n        float maxVal = 0;\n        for (int i = 0; i < 3; i++)\n        {\n            for (int j = 0; j < 3; j++)\n            {\n                if (Math::Abs(row[i][j]) > maxVal)\n                {\n                    maxVal = Math::Abs(row[i][j]);\n                }\n            }\n        }\n\n        // We normalize the 3x3 matrix here.\n        // It was noticed that this can improve numerical stability significantly,\n        // especially when many of the upper 3x3 matrix's coefficients are very\n        // close to zero; we correct for this step at the end by multiplying the\n        // scaling factors by maxVal at the end (shear and rotation are not\n        // affected by the normalization).\n\n        if (maxVal != 0)\n        {\n            for (int i = 0; i < 3; i++)\n            {\n                if (!CheckForZeroScaleInRow(maxVal, row[i]))\n                {\n                    return false;\n                }\n                else\n                {\n                    row[i] /= maxVal;\n                }\n            }\n        }\n\n        // Compute X scale factor.\n        scaleValues.m_x = row[0].Length3().ToFloat();\n        if (!CheckForZeroScaleInRow(scaleValues.m_x, row[0]))\n        {\n            return false;\n        }\n\n        // Normalize first row.\n        row[0] /= scaleValues.m_x;\n\n        // An XY shear factor will shear the X coord. as the Y coord. changes.\n        // There are 6 combinations (XY, XZ, YZ, YX, ZX, ZY), although we only\n        // extract the first 3 because we can effect the last 3 by shearing in\n        // XY, XZ, YZ combined rotations and scales.\n        //\n        // shear matrix <   1,  YX,  ZX,  0,\n        //                 XY,   1,  ZY,  0,\n        //                 XZ,  YZ,   1,  0,\n        //                  0,   0,   0,  1 >\n\n        // Compute XY shear factor and make 2nd row orthogonal to 1st.\n        shearValues[0] = Vector::Dot3(row[0], row[1]).ToFloat();\n        row[1] -= row[0] * shearValues[0];\n\n        // Now, compute Y scale.\n        scaleValues.m_y = row[1].Length3().ToFloat();\n        if (!CheckForZeroScaleInRow(scaleValues.m_y, row[1]))\n        {\n            return false;\n        }\n\n        // Normalize 2nd row and correct the XY shear factor for Y scaling.\n        row[1] /= scaleValues.m_y;\n        shearValues[0] /= scaleValues.m_y;\n\n        // Compute XZ and YZ shears, orthogonalize 3rd row.\n        shearValues[1] = Vector::Dot3(row[0], row[2]).ToFloat();\n        row[2] -= row[0] * shearValues[1];\n        shearValues[2] = Vector::Dot3(row[1], row[2]).ToFloat();\n        row[2] -= row[1] * shearValues[2];\n\n        // Next, get Z scale.\n        scaleValues.m_z = row[2].Length3().ToFloat();\n        if (!CheckForZeroScaleInRow(scaleValues.m_z, row[2]))\n        {\n            return false;\n        }\n\n        // Normalize 3rd row and correct the XZ and YZ shear factors for Z scaling.\n        row[2] /= scaleValues.m_z;\n        shearValues[1] /= scaleValues.m_z;\n        shearValues[2] /= scaleValues.m_z;\n\n        // At this point, the upper 3x3 matrix in mat is orthonormal.\n        // Check for a coordinate system flip. If the determinant\n        // is less than zero, then negate the matrix and the scaling factors.\n        if (Vector::Dot3(row[0], Vector::Cross3(row[1], row[2])).ToFloat() < 0)\n        {\n            for (int i = 0; i < 3; i++)\n            {\n                scaleValues[i] *= -1;\n                row[i] *= -1;\n            }\n        }\n\n        // Copy over the orthonormal rows into the returned matrix.\n        // The upper 3x3 matrix in mat is now a rotation matrix.\n        for (int i = 0; i < 3; i++)\n        {\n            matrix[i].SetX(row[i][0]);\n            matrix[i].SetY(row[i][1]);\n            matrix[i].SetZ(row[i][2]);\n        }\n\n        // Correct the scaling factors for the normalization step that we\n        // performed above; shear and rotation are not affected by the\n        // normalization.\n        scaleValues *= maxVal;\n\n        scale = Vector(scaleValues);\n        shear = Vector(shearValues);\n\n        return true;\n    }\n}\n\nnamespace Math\n{\n    Matrix const Matrix::Identity(1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1);\n\n    Matrix::Matrix(float v00, float v01, float v02, float v03, float v10, float v11, float v12, float v13, float v20, float v21, float v22, float v23, float v30, float v31, float v32, float v33)\n    {\n        m_rows[0] = Vector(v00, v01, v02, v03);\n        m_rows[1] = Vector(v10, v11, v12, v13);\n        m_rows[2] = Vector(v20, v21, v22, v23);\n        m_rows[3] = Vector(v30, v31, v32, v33);\n    }\n\n    Matrix::Matrix(float values[16])\n    {\n        m_rows[0] = Vector(values[0], values[1], values[2], values[3]);\n        m_rows[1] = Vector(values[4], values[5], values[6], values[7]);\n        m_rows[2] = Vector(values[8], values[9], values[10], values[11]);\n        m_rows[3] = Vector(values[12], values[13], values[14], values[15]);\n    }\n\n    Matrix::Matrix(const Vector& xAxis, const Vector& yAxis, const Vector& zAxis)\n    {\n        m_rows[0] = xAxis;\n        m_rows[1] = yAxis;\n        m_rows[2] = zAxis;\n        m_rows[3] = Vector::UnitW;\n    }\n\n    Matrix::Matrix(const Vector& xAxis, const Vector& yAxis, const Vector& zAxis, const Vector& translation)\n    {\n        m_rows[0] = xAxis;\n        m_rows[1] = yAxis;\n        m_rows[2] = zAxis;\n        m_rows[3] = translation.GetWithW1();\n    }\n\n    Matrix::Matrix(const EulerAngles& eulerAngles, const Vector translation)\n    {\n        float cx, cy, cz, sx, sy, sz, czsx, cxcz, sysz;\n\n        sx = sinf((float)eulerAngles.m_x); cx = cosf((float)eulerAngles.m_x);\n        sy = sinf((float)eulerAngles.m_y); cy = cosf((float)eulerAngles.m_y);\n        sz = sinf((float)eulerAngles.m_z); cz = cosf((float)eulerAngles.m_z);\n\n        czsx = cz * sx;\n        cxcz = cx * cz;\n        sysz = sy * sz;\n\n        // Order is XYZ\n        m_values[0][0] = cy * cz;\n        m_values[0][1] = cy * sz;\n        m_values[0][2] = -sy;\n        m_values[1][0] = czsx * sy - cx * sz;\n        m_values[1][1] = cxcz + sx * sysz;\n        m_values[1][2] = cy * sx;\n        m_values[2][0] = cxcz * sy + sx * sz;\n        m_values[2][1] = -czsx + cx * sysz;\n        m_values[2][2] = cx * cy;\n        m_values[0][3] = 0.0f;\n        m_values[1][3] = 0.0f;\n        m_values[2][3] = 0.0f;\n\n        // Translation\n        m_rows[3] = translation.GetWithW1();\n    }\n\n    EulerAngles Matrix::ToEulerAngles() const\n    {\n        EulerAngles result;\n\n        result.m_x = Radians(Math::ATan2(m_values[1][2], m_values[2][2]));\n\n        float const c2 = Math::Sqrt((m_values[0][0] * m_values[0][0]) + (m_values[0][1] * m_values[0][1]));\n        result.m_y = Radians(Math::ATan2(-m_values[0][2], c2));\n\n        float const s1 = Math::Sin((float)result.m_x);\n        float const c1 = Math::Cos((float)result.m_x);\n        result.m_z = Radians(Math::ATan2((s1 * m_values[2][0]) - (c1 * m_values[1][0]), (c1 * m_values[1][1]) - (s1 * m_values[2][1])));\n\n        return result;\n    }\n\n    bool Matrix::Decompose(Quaternion& outRotation, Vector& outTranslation, Vector& outScale) const\n    {\n        Matrix copy = *this;\n        Vector shr = Vector::Zero;\n        outScale = Vector::Zero;\n\n        // Extract and remove scale and shear from matrix\n        if (ExtractAndRemoveScalingAndShear(copy, outScale, shr))\n        {\n            // Extract rotation and translation from unscaled matrix\n            outRotation = copy.GetRotation();\n            outTranslation = copy.GetTranslation().GetWithW0();\n            return true;\n        }\n\n        return false;\n    }\n\n    Vector Matrix::GetScale() const\n    {\n        Matrix copy = *this;\n        Vector scale = Vector::Zero, shear;\n        if (!ExtractAndRemoveScalingAndShear(copy, scale, shear))\n        {\n            float const lengthX = m_rows[0].Length3().ToFloat();\n            float const lengthY = m_rows[1].Length3().ToFloat();\n            float const lengthZ = m_rows[2].Length3().ToFloat();\n            scale = Vector(lengthX, lengthY, lengthZ, 0.0f);\n        }\n\n        return scale;\n    }\n\n    Matrix& Matrix::SetScale(const Vector& newScale)\n    {\n        Vector scale, shear;\n        bool result = ExtractAndRemoveScalingAndShear(*this, scale, shear);\n\n        // Cannot set scale on matrix that contains zero-scale\n        ASSERT(result);\n\n        m_rows[0] = m_rows[0] * newScale.GetSplatX();\n        m_rows[1] = m_rows[1] * newScale.GetSplatY();\n        m_rows[2] = m_rows[2] * newScale.GetSplatZ();\n        return *this;\n    }\n\n    Matrix& Matrix::RemoveScale()\n    {\n        Vector scale, shear;\n        bool result = ExtractAndRemoveScalingAndShear(*this, scale, shear);\n\n        // Cannot remove zero scale from matrix\n        ASSERT(result);\n\n        return *this;\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Matrix.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Vector.h\"\n#include \"Quaternion.h\"\n\nenum class CoordinateSpace : uint8_t\n{\n    World,\n    Local,\n};\n\n//\n// Matrices are Row-Major\n// Multiplication order is right to left\n// ObjectWorldTransform = LocalObjectTransform * WorldTransform\n//\n\nnamespace Math\n{\n    class alignas(16) Matrix\n    {\n    public:\n\n        static Matrix const Identity;\n\n    public:\n\n        static Matrix FromRotation(const Quaternion& rotation);\n        static Matrix FromTranslation(const Vector& translation);\n        static Matrix FromScale(const Vector& scale);\n        static Matrix FromUniformScale(float uniformScale);\n        static Matrix FromTranslationAndScale(const Vector& translation, const Vector& scale);\n        static Matrix FromRotationBetweenVectors(const Vector sourceVector, const Vector targetVector);\n\n    public:\n\n        explicit Matrix();\n        explicit Matrix(NoInit_t);\n        explicit Matrix(ZeroInit_t);\n        explicit Matrix(float v00, float v01, float v02, float v03,\n                        float v10, float v11, float v12, float v13,\n                        float v20, float v21, float v22, float v23,\n                        float v30, float v31, float v32, float v33);\n        explicit Matrix(float values[16]);\n        explicit Matrix(Vector const& xAxis, Vector const& yAxis, Vector const& zAxis);\n        explicit Matrix(Vector const& xAxis, Vector const& yAxis, Vector const& zAxis, Vector const& translation);\n\n        Matrix(const Vector axis, Radians angleRadians);\n        Matrix(const AxisAngle axisAngle);\n\n        explicit Matrix(const Quaternion& rotation);\n        explicit Matrix(const Quaternion& rotation, const Vector& translation, const Vector& scale = Vector::One);\n        explicit Matrix(const Quaternion& rotation, const Vector& translation, float scale = 1.0f);\n        explicit Matrix(const EulerAngles& eulerAngles, const Vector translation = Vector::UnitW);\n\n        EulerAngles ToEulerAngles() const;\n\n        float* AsFloatArray();\n        const float* AsFloatArray() const;\n        const Vector& GetRow(uint32_t row) const;\n\n        const Vector& GetAxisX() const;\n        const Vector& GetAxisY() const;\n        const Vector& GetAxisZ() const;\n\n        void SetAxisX(const Vector& xAxis);\n        void SetAxisY(const Vector& yAxis);\n        void SetAxisZ(const Vector& zAxis);\n\n        Float3 GetForwardVector() const;\n        Float3 GetRightVector() const;\n        Float3 GetUpVector() const;\n\n        Vector GetUnitAxisX() const;\n        Vector GetUnitAxisY() const;\n        Vector GetUnitAxisZ() const;\n\n        bool IsIdentity() const;\n        bool IsOrthogonal() const;\n        bool IsOrthonormal() const;\n\n        bool Decompose(Quaternion& outRotation, Vector& outTranslation, Vector& outScale) const;\n\n        Matrix& Transpose();\n        Matrix GetTransposed() const;\n\n        Matrix& Invert();\n        Matrix GetInverse() const;\n\n        Vector GetDeterminant() const;\n        float GetDeterminantAsFloat() const;\n\n        Vector GetTranslation() const;\n        const Vector& GetTranslationWithW() const;\n        Matrix& SetTranslation(Vector const& v);\n        Matrix& SetTranslation(Float3 const& v);\n        Matrix& SetTranslation(Float4 const& v);\n\n        Quaternion GetRotation() const;\n\n        Matrix& SetRotation(const Matrix& rotation);\n        Matrix& SetRotation(const Quaternion& rotation);\n\n        Matrix& SetRotationMaintainingScale(const Matrix& rotation);\n        Matrix& SetRotationMaintainingScale(const Quaternion& rotation);\n\n        Vector GetScale() const;\n\n        Matrix& RemoveScale();\n        Matrix& SetScale(const Vector& scale);\n        Matrix& SetScale(float uniformScale);\n\n        Matrix& RemoveScaleFast();\n        Matrix& SetScaleFast(const Vector& scale);\n        Matrix& SetScaleFast(float uniformScale);\n\n        //\n        // Operators\n        //\n\n        // Applies rotation and scale to a vector and returns a result with the W = 0\n        Vector RotateVector(const Vector& vector) const;\n\n        // Applies rotation and scale to a vector and returns a result with the W = 0\n        Vector TransformNormal(const Vector& vector) const;\n\n        // Applies the transformation to a given point and ensures the resulting W = 1\n        Vector TransformPoint(const Vector& point) const;\n\n        // Applies the transformation to a vector ignoring the W value.\n        // Same as TransformPoint with the result W left unchanged\n        Vector TransformVector3(const Vector& vector) const;\n\n        // Applies the transformation to a given vector with the result W left unchanged\n        Vector TransformVector4(const Vector& vector) const;\n\n        Vector& operator[](uint32_t i);\n        const Vector operator[](uint32_t i) const;\n\n        Matrix operator*(const Matrix& rhs) const;\n        Matrix& operator*=(const Matrix& rhs);\n\n        Matrix operator*(const Quaternion& rhs) const;\n        Matrix operator*=(const Quaternion& rhs);\n\n        bool operator==(const Matrix& rhs) const;\n\n    public:\n\n        union\n        {\n            Vector      m_rows[4];\n            float       m_values[4][4];\n        };\n    };\n}\n\n#include \"Matrix.inl\"\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Matrix.inl",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include <cstring>\n\n#include \"Matrix.h\"\n\nnamespace Math\n{\n    inline Matrix Matrix::FromRotation(const Quaternion& rotation)\n    {\n        return Matrix(rotation);\n    }\n\n    inline Matrix Matrix::FromTranslation(const Vector& translation)\n    {\n        Matrix M;\n        M.m_rows[0] = Vector::UnitX;\n        M.m_rows[1] = Vector::UnitY;\n        M.m_rows[2] = Vector::UnitZ;\n        M.m_rows[3] = translation.GetWithW1();\n        return M;\n    }\n\n    inline Matrix Matrix::FromScale(const Vector& scale)\n    {\n        Matrix M;\n        M.m_rows[0] = _mm_and_ps(scale, SIMD::g_maskX000);\n        M.m_rows[1] = _mm_and_ps(scale, SIMD::g_mask0Y00);\n        M.m_rows[2] = _mm_and_ps(scale, SIMD::g_mask00Z0);\n        M.m_rows[3] = Vector::UnitW;\n        return M;\n    }\n\n    inline Matrix Matrix::FromUniformScale(float uniformScale)\n    {\n        Matrix M;\n        M.m_rows[0] = _mm_set_ps(0, 0, 0, uniformScale);\n        M.m_rows[1] = _mm_set_ps(0, 0, uniformScale, 0);\n        M.m_rows[2] = _mm_set_ps(0, uniformScale, 0, 0);\n        M.m_rows[3] = Vector::UnitW;\n        return M;\n    }\n\n    inline Matrix Matrix::FromTranslationAndScale(const Vector& translation, const Vector& scale)\n    {\n        Matrix M;\n        M.m_rows[0] = _mm_and_ps(scale, SIMD::g_maskX000);\n        M.m_rows[1] = _mm_and_ps(scale, SIMD::g_mask0Y00);\n        M.m_rows[2] = _mm_and_ps(scale, SIMD::g_mask00Z0);\n        M.m_rows[3] = translation.GetWithW1();\n        return M;\n    }\n\n    inline Matrix Matrix::FromRotationBetweenVectors(Vector const sourceVector, Vector const targetVector)\n    {\n        return Matrix(Quaternion::FromRotationBetweenNormalizedVectors(sourceVector, targetVector));\n    }\n\n    inline Matrix::Matrix()\n    {\n        memcpy(this, &Matrix::Identity, sizeof(Matrix));\n    }\n\n    inline Matrix::Matrix(NoInit_t)\n    {\n    }\n\n    inline Matrix::Matrix(ZeroInit_t)\n    {\n        memset(this, 0, sizeof(Matrix));\n    }\n\n    inline Matrix::Matrix(const Vector axis, Radians angleRadians)\n    {\n        Vector normal = axis.GetNormalized3();\n\n        Vector C0, C1;\n        Vector::SinCos(C0, C1, Vector((float)angleRadians));\n        Vector C2 = Vector::One - C1;\n\n        __m128 N0 = _mm_shuffle_ps(normal, normal, _MM_SHUFFLE(3, 0, 2, 1));\n        __m128 N1 = _mm_shuffle_ps(normal, normal, _MM_SHUFFLE(3, 1, 0, 2));\n\n        __m128 V0 = _mm_mul_ps(C2, N0);\n        V0 = _mm_mul_ps(V0, N1);\n\n        __m128 R0 = _mm_mul_ps(C2, normal);\n        R0 = _mm_mul_ps(R0, normal);\n        R0 = _mm_add_ps(R0, C1);\n\n        __m128 R1 = _mm_mul_ps(C0, normal);\n        R1 = _mm_add_ps(R1, V0);\n        __m128 R2 = _mm_mul_ps(C0, normal);\n        R2 = _mm_sub_ps(V0, R2);\n\n        V0 = _mm_and_ps(R0, SIMD::g_maskXYZ0);\n        __m128 V1 = _mm_shuffle_ps(R1, R2, _MM_SHUFFLE(2, 1, 2, 0));\n        V1 = _mm_shuffle_ps(V1, V1, _MM_SHUFFLE(0, 3, 2, 1));\n        __m128 V2 = _mm_shuffle_ps(R1, R2, _MM_SHUFFLE(0, 0, 1, 1));\n        V2 = _mm_shuffle_ps(V2, V2, _MM_SHUFFLE(2, 0, 2, 0));\n\n        R2 = _mm_shuffle_ps(V0, V1, _MM_SHUFFLE(1, 0, 3, 0));\n        R2 = _mm_shuffle_ps(R2, R2, _MM_SHUFFLE(1, 3, 2, 0));\n\n        m_rows[0] = R2;\n\n        R2 = _mm_shuffle_ps(V0, V1, _MM_SHUFFLE(3, 2, 3, 1));\n        R2 = _mm_shuffle_ps(R2, R2, _MM_SHUFFLE(1, 3, 0, 2));\n        m_rows[1] = R2;\n\n        V2 = _mm_shuffle_ps(V2, V0, _MM_SHUFFLE(3, 2, 1, 0));\n        m_rows[2] = V2;\n        m_rows[3] = Vector::UnitW;\n    }\n\n    inline Matrix::Matrix(const AxisAngle axisAngle)\n        : Matrix(Vector(axisAngle.m_axis), axisAngle.m_angle)\n    {\n    }\n\n    inline Matrix::Matrix(const Quaternion& rotation)\n    {\n        SetRotation(rotation);\n        m_rows[3] = Vector::UnitW;\n    }\n\n    inline Matrix::Matrix(const Quaternion& rotation, const Vector& translation, const Vector& scale)\n    {\n        SetRotation(rotation);\n        m_rows[0] = m_rows[0] * scale.GetSplatX();\n        m_rows[1] = m_rows[1] * scale.GetSplatY();\n        m_rows[2] = m_rows[2] * scale.GetSplatZ();\n        m_rows[3] = translation.GetWithW1();\n    }\n\n    inline Matrix::Matrix(const Quaternion& rotation, const Vector& translation, float scale)\n        : Matrix(rotation, translation, Vector(scale))\n    {\n    }\n\n    inline float* Matrix::AsFloatArray()\n    {\n        return &m_values[0][0];\n    }\n\n    inline const float* Matrix::AsFloatArray() const\n    {\n        return &m_values[0][0];\n    }\n\n    inline const Vector& Matrix::GetRow(uint32_t row) const\n    {\n        return m_rows[row];\n    }\n\n    inline const Vector& Matrix::GetAxisX() const\n    {\n        return m_rows[0];\n    }\n\n    inline const Vector& Matrix::GetAxisY() const\n    {\n        return m_rows[1];\n    }\n\n    inline const Vector& Matrix::GetAxisZ() const\n    {\n        return m_rows[2];\n    }\n\n    inline void Matrix::SetAxisX(const Vector& xAxis)\n    {\n        m_rows[0] = xAxis;\n    }\n\n    inline void Matrix::SetAxisY(const Vector& yAxis)\n    {\n        m_rows[1] = yAxis;\n    }\n\n    inline void Matrix::SetAxisZ(const Vector& zAxis)\n    {\n        m_rows[2] = zAxis;\n    }\n\n\n    inline Float3 Matrix::GetForwardVector() const\n    {\n        return GetAxisZ();\n    }\n\n    inline Float3 Matrix::GetRightVector() const\n    {\n        return GetAxisX();\n    }\n\n    inline Float3 Matrix::GetUpVector() const\n    {\n        return GetAxisY();\n    }\n\n    inline Vector Matrix::GetUnitAxisX() const\n    {\n        return m_rows[0].GetNormalized3();\n    }\n\n    inline Vector Matrix::GetUnitAxisY() const\n    {\n        return m_rows[1].GetNormalized3();\n    }\n\n    inline Vector Matrix::GetUnitAxisZ() const\n    {\n        return m_rows[2].GetNormalized3();\n    }\n\n    inline bool Matrix::IsIdentity() const\n    {\n        __m128 vTemp1 = _mm_cmpeq_ps(m_rows[0], Vector::UnitX);\n        __m128 vTemp2 = _mm_cmpeq_ps(m_rows[1], Vector::UnitY);\n        __m128 vTemp3 = _mm_cmpeq_ps(m_rows[2], Vector::UnitZ);\n        __m128 vTemp4 = _mm_cmpeq_ps(m_rows[3], Vector::UnitW);\n        vTemp1 = _mm_and_ps(vTemp1, vTemp2);\n        vTemp3 = _mm_and_ps(vTemp3, vTemp4);\n        vTemp1 = _mm_and_ps(vTemp1, vTemp3);\n        return (_mm_movemask_ps(vTemp1) == 0x0f);\n    }\n\n    inline bool Matrix::IsOrthogonal() const\n    {\n        Matrix const transpose = GetTransposed();\n        Matrix result = *this * transpose;\n        return result.IsIdentity();\n    }\n\n    inline bool Matrix::IsOrthonormal() const\n    {\n        static const Vector three(3);\n        auto dotCheck = Vector::Dot3(m_rows[0], m_rows[1]) + Vector::Dot3(m_rows[0], m_rows[2]) + Vector::Dot3(m_rows[1], m_rows[2]);\n        auto magnitudeCheck = m_rows[0].LengthSquared3() + m_rows[1].LengthSquared3() + m_rows[2].LengthSquared3();\n        auto result = dotCheck + magnitudeCheck;\n        return result.IsNearEqual3(three);\n    }\n\n    inline Matrix& Matrix::Transpose()\n    {\n        __m128 vTemp1 = _mm_shuffle_ps(m_rows[0], m_rows[1], _MM_SHUFFLE(1, 0, 1, 0));\n        __m128 vTemp3 = _mm_shuffle_ps(m_rows[0], m_rows[1], _MM_SHUFFLE(3, 2, 3, 2));\n        __m128 vTemp2 = _mm_shuffle_ps(m_rows[2], m_rows[3], _MM_SHUFFLE(1, 0, 1, 0));\n        __m128 vTemp4 = _mm_shuffle_ps(m_rows[2], m_rows[3], _MM_SHUFFLE(3, 2, 3, 2));\n        m_rows[0] = _mm_shuffle_ps(vTemp1, vTemp2, _MM_SHUFFLE(2, 0, 2, 0));\n        m_rows[1] = _mm_shuffle_ps(vTemp1, vTemp2, _MM_SHUFFLE(3, 1, 3, 1));\n        m_rows[2] = _mm_shuffle_ps(vTemp3, vTemp4, _MM_SHUFFLE(2, 0, 2, 0));\n        m_rows[3] = _mm_shuffle_ps(vTemp3, vTemp4, _MM_SHUFFLE(3, 1, 3, 1));\n        return *this;\n    }\n\n    inline Matrix Matrix::GetTransposed() const\n    {\n        Matrix m = *this;\n        m.Transpose();\n        return m;\n    }\n\n    inline Matrix& Matrix::Invert()\n    {\n        Matrix MT = GetTransposed();\n        __m128 V00 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[2], _MM_SHUFFLE(1, 1, 0, 0));\n        __m128 V10 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[3], _MM_SHUFFLE(3, 2, 3, 2));\n        __m128 V01 = _mm_shuffle_ps(MT.m_rows[0], MT.m_rows[0], _MM_SHUFFLE(1, 1, 0, 0));\n        __m128 V11 = _mm_shuffle_ps(MT.m_rows[1], MT.m_rows[1], _MM_SHUFFLE(3, 2, 3, 2));\n        __m128 V02 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[0], _MM_SHUFFLE(2, 0, 2, 0));\n        __m128 V12 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[1], _MM_SHUFFLE(3, 1, 3, 1));\n\n        __m128 D0 = _mm_mul_ps(V00, V10);\n        __m128 D1 = _mm_mul_ps(V01, V11);\n        __m128 D2 = _mm_mul_ps(V02, V12);\n\n        V00 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[2], _MM_SHUFFLE(3, 2, 3, 2));\n        V10 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[3], _MM_SHUFFLE(1, 1, 0, 0));\n        V01 = _mm_shuffle_ps(MT.m_rows[0], MT.m_rows[0], _MM_SHUFFLE(3, 2, 3, 2));\n        V11 = _mm_shuffle_ps(MT.m_rows[1], MT.m_rows[1], _MM_SHUFFLE(1, 1, 0, 0));\n        V02 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[0], _MM_SHUFFLE(3, 1, 3, 1));\n        V12 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[1], _MM_SHUFFLE(2, 0, 2, 0));\n\n        V00 = _mm_mul_ps(V00, V10);\n        V01 = _mm_mul_ps(V01, V11);\n        V02 = _mm_mul_ps(V02, V12);\n        D0 = _mm_sub_ps(D0, V00);\n        D1 = _mm_sub_ps(D1, V01);\n        D2 = _mm_sub_ps(D2, V02);\n        // V11 = D0Y,D0W,D2Y,D2Y\n        V11 = _mm_shuffle_ps(D0, D2, _MM_SHUFFLE(1, 1, 3, 1));\n        V00 = _mm_shuffle_ps(MT.m_rows[1], MT.m_rows[1], _MM_SHUFFLE(1, 0, 2, 1));\n        V10 = _mm_shuffle_ps(V11, D0, _MM_SHUFFLE(0, 3, 0, 2));\n        V01 = _mm_shuffle_ps(MT.m_rows[0], MT.m_rows[0], _MM_SHUFFLE(0, 1, 0, 2));\n        V11 = _mm_shuffle_ps(V11, D0, _MM_SHUFFLE(2, 1, 2, 1));\n        // V13 = D1Y,D1W,D2W,D2W\n        __m128 V13 = _mm_shuffle_ps(D1, D2, _MM_SHUFFLE(3, 3, 3, 1));\n        V02 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[3], _MM_SHUFFLE(1, 0, 2, 1));\n        V12 = _mm_shuffle_ps(V13, D1, _MM_SHUFFLE(0, 3, 0, 2));\n        __m128 V03 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[2], _MM_SHUFFLE(0, 1, 0, 2));\n        V13 = _mm_shuffle_ps(V13, D1, _MM_SHUFFLE(2, 1, 2, 1));\n\n        __m128 C0 = _mm_mul_ps(V00, V10);\n        __m128 C2 = _mm_mul_ps(V01, V11);\n        __m128 C4 = _mm_mul_ps(V02, V12);\n        __m128 C6 = _mm_mul_ps(V03, V13);\n\n        // V11 = D0X,D0Y,D2X,D2X\n        V11 = _mm_shuffle_ps(D0, D2, _MM_SHUFFLE(0, 0, 1, 0));\n        V00 = _mm_shuffle_ps(MT.m_rows[1], MT.m_rows[1], _MM_SHUFFLE(2, 1, 3, 2));\n        V10 = _mm_shuffle_ps(D0, V11, _MM_SHUFFLE(2, 1, 0, 3));\n        V01 = _mm_shuffle_ps(MT.m_rows[0], MT.m_rows[0], _MM_SHUFFLE(1, 3, 2, 3));\n        V11 = _mm_shuffle_ps(D0, V11, _MM_SHUFFLE(0, 2, 1, 2));\n        // V13 = D1X,D1Y,D2Z,D2Z\n        V13 = _mm_shuffle_ps(D1, D2, _MM_SHUFFLE(2, 2, 1, 0));\n        V02 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[3], _MM_SHUFFLE(2, 1, 3, 2));\n        V12 = _mm_shuffle_ps(D1, V13, _MM_SHUFFLE(2, 1, 0, 3));\n        V03 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[2], _MM_SHUFFLE(1, 3, 2, 3));\n        V13 = _mm_shuffle_ps(D1, V13, _MM_SHUFFLE(0, 2, 1, 2));\n\n        V00 = _mm_mul_ps(V00, V10);\n        V01 = _mm_mul_ps(V01, V11);\n        V02 = _mm_mul_ps(V02, V12);\n        V03 = _mm_mul_ps(V03, V13);\n        C0 = _mm_sub_ps(C0, V00);\n        C2 = _mm_sub_ps(C2, V01);\n        C4 = _mm_sub_ps(C4, V02);\n        C6 = _mm_sub_ps(C6, V03);\n\n        V00 = _mm_shuffle_ps(MT.m_rows[1], MT.m_rows[1], _MM_SHUFFLE(0, 3, 0, 3));\n        // V10 = D0Z,D0Z,D2X,D2Y\n        V10 = _mm_shuffle_ps(D0, D2, _MM_SHUFFLE(1, 0, 2, 2));\n        V10 = _mm_shuffle_ps(V10, V10, _MM_SHUFFLE(0, 2, 3, 0));\n        V01 = _mm_shuffle_ps(MT.m_rows[0], MT.m_rows[0], _MM_SHUFFLE(2, 0, 3, 1));\n        // V11 = D0X,D0W,D2X,D2Y\n        V11 = _mm_shuffle_ps(D0, D2, _MM_SHUFFLE(1, 0, 3, 0));\n        V11 = _mm_shuffle_ps(V11, V11, _MM_SHUFFLE(2, 1, 0, 3));\n        V02 = _mm_shuffle_ps(MT.m_rows[3], MT.m_rows[3], _MM_SHUFFLE(0, 3, 0, 3));\n        // V12 = D1Z,D1Z,D2Z,D2W\n        V12 = _mm_shuffle_ps(D1, D2, _MM_SHUFFLE(3, 2, 2, 2));\n        V12 = _mm_shuffle_ps(V12, V12, _MM_SHUFFLE(0, 2, 3, 0));\n        V03 = _mm_shuffle_ps(MT.m_rows[2], MT.m_rows[2], _MM_SHUFFLE(2, 0, 3, 1));\n        // V13 = D1X,D1W,D2Z,D2W\n        V13 = _mm_shuffle_ps(D1, D2, _MM_SHUFFLE(3, 2, 3, 0));\n        V13 = _mm_shuffle_ps(V13, V13, _MM_SHUFFLE(2, 1, 0, 3));\n\n        V00 = _mm_mul_ps(V00, V10);\n        V01 = _mm_mul_ps(V01, V11);\n        V02 = _mm_mul_ps(V02, V12);\n        V03 = _mm_mul_ps(V03, V13);\n        __m128 C1 = _mm_sub_ps(C0, V00);\n        C0 = _mm_add_ps(C0, V00);\n        __m128 C3 = _mm_add_ps(C2, V01);\n        C2 = _mm_sub_ps(C2, V01);\n        __m128 C5 = _mm_sub_ps(C4, V02);\n        C4 = _mm_add_ps(C4, V02);\n        __m128 C7 = _mm_add_ps(C6, V03);\n        C6 = _mm_sub_ps(C6, V03);\n\n        C0 = _mm_shuffle_ps(C0, C1, _MM_SHUFFLE(3, 1, 2, 0));\n        C2 = _mm_shuffle_ps(C2, C3, _MM_SHUFFLE(3, 1, 2, 0));\n        C4 = _mm_shuffle_ps(C4, C5, _MM_SHUFFLE(3, 1, 2, 0));\n        C6 = _mm_shuffle_ps(C6, C7, _MM_SHUFFLE(3, 1, 2, 0));\n        C0 = _mm_shuffle_ps(C0, C0, _MM_SHUFFLE(3, 1, 2, 0));\n        C2 = _mm_shuffle_ps(C2, C2, _MM_SHUFFLE(3, 1, 2, 0));\n        C4 = _mm_shuffle_ps(C4, C4, _MM_SHUFFLE(3, 1, 2, 0));\n        C6 = _mm_shuffle_ps(C6, C6, _MM_SHUFFLE(3, 1, 2, 0));\n\n        __m128 vTemp = Vector::Dot4(C0, MT.m_rows[0]);\n        vTemp = _mm_div_ps(Vector::One, vTemp);\n        m_rows[0] = _mm_mul_ps(C0, vTemp);\n        m_rows[1] = _mm_mul_ps(C2, vTemp);\n        m_rows[2] = _mm_mul_ps(C4, vTemp);\n        m_rows[3] = _mm_mul_ps(C6, vTemp);\n        return *this;\n    }\n\n    inline Matrix Matrix::GetInverse() const\n    {\n        Matrix m = *this;\n        m.Invert();\n        return m;\n    }\n\n    inline Vector Matrix::GetDeterminant() const\n    {\n        Vector V0 = m_rows[2].Shuffle(1, 0, 0, 0);\n        Vector V1 = m_rows[3].Shuffle(2, 2, 1, 1);\n        Vector V2 = m_rows[2].Shuffle(1, 0, 0, 0);\n        Vector V3 = m_rows[3].Shuffle(3, 3, 3, 2);\n        Vector V4 = m_rows[2].Shuffle(2, 2, 1, 1);\n        Vector V5 = m_rows[3].Shuffle(3, 3, 3, 2);\n\n        Vector P0 = V0 * V1;\n        Vector P1 = V2 * V3;\n        Vector P2 = V4 * V5;\n\n        V0 = m_rows[2].Shuffle(2, 2, 1, 1);\n        V1 = m_rows[3].Shuffle(1, 0, 0, 0);\n        V2 = m_rows[2].Shuffle(3, 3, 3, 2);\n        V3 = m_rows[3].Shuffle(1, 0, 0, 0);\n        V4 = m_rows[2].Shuffle(3, 3, 3, 2);\n        V5 = m_rows[3].Shuffle(2, 2, 1, 1);\n\n        P0 = Vector::NegativeMultiplySubtract(V0, V1, P0);\n        P1 = Vector::NegativeMultiplySubtract(V2, V3, P1);\n        P2 = Vector::NegativeMultiplySubtract(V4, V5, P2);\n\n        V0 = m_rows[1].Shuffle(3, 3, 3, 2);\n        V1 = m_rows[1].Shuffle(2, 2, 1, 1);\n        V2 = m_rows[1].Shuffle(1, 0, 0, 0);\n\n        static Vector const Sign(1.0f, -1.0f, 1.0f, -1.0f);\n        Vector S = m_rows[0] * Sign;\n        Vector R = V0 * P0;\n        R = Vector::NegativeMultiplySubtract(V1, P1, R);\n        R = Vector::MultiplyAdd(V2, P2, R);\n\n        return Vector::Dot4(S, R);\n    }\n\n    inline float Matrix::GetDeterminantAsFloat() const\n    {\n        return GetDeterminant().GetX();\n    }\n\n    inline Vector Matrix::GetTranslation() const\n    {\n        return m_rows[3].GetWithW0();\n    }\n\n    inline const Vector& Matrix::GetTranslationWithW() const\n    {\n        return m_rows[3];\n    }\n\n    inline Matrix& Matrix::SetTranslation(const Vector& v)\n    {\n        m_rows[3] = v.GetWithW1();\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetTranslation(const Float3& v)\n    {\n        m_rows[3] = Vector(v, 1.0f);\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetTranslation(const Float4& v)\n    {\n        m_rows[3] = Vector(v.m_x, v.m_y, v.m_z, 1.0f);\n        return *this;\n    }\n\n    inline Quaternion Matrix::GetRotation() const\n    {\n        // based on RTM: https://github.com/nfrechette/rtm\n\n        const Vector& axisX = m_rows[0];\n        const Vector& axisY = m_rows[1];\n        const Vector& axisZ = m_rows[2];\n\n        // Zero scale is not supported\n        if (axisX.IsNearZero4() || axisY.IsNearZero4() || axisZ.IsNearZero4())\n        {\n            HALT();\n        }\n\n        float const axisX_X = axisX.GetX();\n        float const axisY_Y = axisY.GetY();\n        float const axisZ_Z = axisZ.GetZ();\n\n        float const mtx_trace = axisX_X + axisY_Y + axisZ_Z;\n        if (mtx_trace > 0.0)\n        {\n            float const axisX_y = axisX.GetY();\n            float const axisX_z = axisX.GetZ();\n\n            float const axisY_x = axisY.GetX();\n            float const axisY_z = axisY.GetZ();\n\n            float const axisZ_x = axisZ.GetX();\n            float const axisZ_y = axisZ.GetY();\n\n            float const inv_trace = Math::Reciprocal(Math::Sqrt(mtx_trace + 1.0f));\n            float const half_inv_trace = inv_trace * 0.5f;\n\n            float const m_x = (axisY_z - axisZ_y) * half_inv_trace;\n            float const m_y = (axisZ_x - axisX_z) * half_inv_trace;\n            float const m_z = (axisX_y - axisY_x) * half_inv_trace;\n            float const m_w = Math::Reciprocal(inv_trace) * 0.5f;\n\n            return Quaternion(m_x, m_y, m_z, m_w).GetNormalized();\n        }\n        else\n        {\n            // Find the axis with the highest diagonal value\n            int32_t axisIdx0 = 0;\n            if (axisY_Y > axisX_X)\n            {\n                axisIdx0 = 1;\n            }\n\n            if (axisZ_Z > m_rows[axisIdx0][axisIdx0])\n            {\n                axisIdx0 = 2;\n            }\n\n            int32_t const axisIdx1 = (axisIdx0 + 1) % 3;\n            int32_t const axisIdx2 = (axisIdx1 + 1) % 3;\n\n            float const pseudoTrace = 1.0f + m_rows[axisIdx0][axisIdx0] - m_rows[axisIdx1][axisIdx1] - m_rows[axisIdx2][axisIdx2];\n            float const inversePseudoTrace = Math::Reciprocal(Math::Sqrt(pseudoTrace));\n            float const halfInversePseudoTrace = inversePseudoTrace * 0.5f;\n\n            Float4 rawQuatValues;\n            rawQuatValues[axisIdx0] = Math::Reciprocal(inversePseudoTrace) * 0.5f;\n            rawQuatValues[axisIdx1] = halfInversePseudoTrace * (m_rows[axisIdx0][axisIdx1] + m_rows[axisIdx1][axisIdx0]);\n            rawQuatValues[axisIdx2] = halfInversePseudoTrace * (m_rows[axisIdx0][axisIdx2] + m_rows[axisIdx2][axisIdx0]);\n            rawQuatValues[3] = halfInversePseudoTrace * (m_rows[axisIdx1][axisIdx2] - m_rows[axisIdx2][axisIdx1]);\n            return Quaternion(rawQuatValues).GetNormalized();\n        }\n    }\n\n    inline Matrix& Matrix::SetRotation(const Matrix& rotation)\n    {\n        ASSERT(Math::Abs(rotation.GetDeterminant().GetX()) == 1.0f);\n        m_rows[0] = rotation.m_rows[0];\n        m_rows[1] = rotation.m_rows[1];\n        m_rows[2] = rotation.m_rows[2];\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetRotation(const Quaternion& rotation)\n    {\n        static __m128 const constant1110 = { 1.0f, 1.0f, 1.0f, 0.0f };\n\n        __m128 Q0 = _mm_add_ps(rotation, rotation);\n        __m128 Q1 = _mm_mul_ps(rotation, Q0);\n\n        __m128 V0 = _mm_shuffle_ps(Q1, Q1, _MM_SHUFFLE(3, 0, 0, 1));\n        V0 = _mm_and_ps(V0, SIMD::g_maskXYZ0);\n        __m128 V1 = _mm_shuffle_ps(Q1, Q1, _MM_SHUFFLE(3, 1, 2, 2));\n        V1 = _mm_and_ps(V1, SIMD::g_maskXYZ0);\n        __m128 R0 = _mm_sub_ps(constant1110, V0);\n        R0 = _mm_sub_ps(R0, V1);\n\n        V0 = _mm_shuffle_ps(rotation, rotation, _MM_SHUFFLE(3, 1, 0, 0));\n        V1 = _mm_shuffle_ps(Q0, Q0, _MM_SHUFFLE(3, 2, 1, 2));\n        V0 = _mm_mul_ps(V0, V1);\n\n        V1 = _mm_shuffle_ps(rotation, rotation, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 V2 = _mm_shuffle_ps(Q0, Q0, _MM_SHUFFLE(3, 0, 2, 1));\n        V1 = _mm_mul_ps(V1, V2);\n\n        __m128 R1 = _mm_add_ps(V0, V1);\n        __m128 R2 = _mm_sub_ps(V0, V1);\n\n        V0 = _mm_shuffle_ps(R1, R2, _MM_SHUFFLE(1, 0, 2, 1));\n        V0 = _mm_shuffle_ps(V0, V0, _MM_SHUFFLE(1, 3, 2, 0));\n        V1 = _mm_shuffle_ps(R1, R2, _MM_SHUFFLE(2, 2, 0, 0));\n        V1 = _mm_shuffle_ps(V1, V1, _MM_SHUFFLE(2, 0, 2, 0));\n\n        Q1 = _mm_shuffle_ps(R0, V0, _MM_SHUFFLE(1, 0, 3, 0));\n        Q1 = _mm_shuffle_ps(Q1, Q1, _MM_SHUFFLE(1, 3, 2, 0));\n\n        m_rows[0] = Q1;\n\n        Q1 = _mm_shuffle_ps(R0, V0, _MM_SHUFFLE(3, 2, 3, 1));\n        Q1 = _mm_shuffle_ps(Q1, Q1, _MM_SHUFFLE(1, 3, 0, 2));\n        m_rows[1] = Q1;\n\n        Q1 = _mm_shuffle_ps(V1, R0, _MM_SHUFFLE(3, 2, 1, 0));\n        m_rows[2] = Q1;\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetRotationMaintainingScale(const Matrix& rotation)\n    {\n        Vector const scale = GetScale();\n        SetRotation(rotation);\n        return SetScale(scale);\n    }\n\n    inline Matrix& Matrix::SetRotationMaintainingScale(const Quaternion& rotation)\n    {\n        Vector const scale = GetScale();\n        SetRotation(rotation);\n        return SetScale(scale);\n    }\n\n    inline Matrix& Matrix::SetScale(float uniformScale)\n    {\n        SetScale(Vector(uniformScale));\n        return *this;\n    }\n\n    inline Matrix& Matrix::RemoveScaleFast()\n    {\n        m_rows[0] = m_rows[0].GetNormalized4();\n        m_rows[1] = m_rows[1].GetNormalized4();\n        m_rows[2] = m_rows[2].GetNormalized4();\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetScaleFast(const Vector& scale)\n    {\n        m_rows[0] = m_rows[0].GetNormalized3() * scale.GetSplatX();\n        m_rows[1] = m_rows[1].GetNormalized3() * scale.GetSplatY();\n        m_rows[2] = m_rows[2].GetNormalized3() * scale.GetSplatZ();\n        return *this;\n    }\n\n    inline Matrix& Matrix::SetScaleFast(float uniformScale)\n    {\n        SetScaleFast(Vector(uniformScale));\n        return *this;\n    }\n\n    inline Vector Matrix::RotateVector(const Vector& vector) const\n    {\n        Vector const X = vector.GetSplatX();\n        Vector const Y = vector.GetSplatY();\n        Vector const Z = vector.GetSplatZ();\n\n        Vector Result = Z * m_rows[2];\n        Result = Vector::MultiplyAdd(Y, m_rows[1], Result);\n        Result = Vector::MultiplyAdd(X, m_rows[0], Result);\n\n        return Result;\n    }\n\n    inline Vector Matrix::TransformNormal(const Vector& vector) const\n    {\n        return RotateVector(vector);\n    }\n\n    inline Vector Matrix::TransformPoint(const Vector& point) const\n    {\n        Vector const X = point.GetSplatX();\n        Vector const Y = point.GetSplatY();\n        Vector const Z = point.GetSplatZ();\n\n        Vector result = Vector::MultiplyAdd(Z, m_rows[2], m_rows[3]);\n        result = Vector::MultiplyAdd(Y, m_rows[1], result);\n        result = Vector::MultiplyAdd(X, m_rows[0], result);\n\n        Vector const W = result.GetSplatW();\n        return result / W;\n    }\n\n    inline Vector Matrix::TransformVector3(const Vector& V) const\n    {\n        Vector const X = V.GetSplatX();\n        Vector const Y = V.GetSplatY();\n        Vector const Z = V.GetSplatZ();\n\n        Vector result = Vector::MultiplyAdd(Z, m_rows[2], m_rows[3]);\n        result = Vector::MultiplyAdd(Y, m_rows[1], result);\n        result = Vector::MultiplyAdd(X, m_rows[0], result);\n\n        return result;\n    }\n\n    inline Vector Matrix::TransformVector4(const Vector& V) const\n    {\n        // Splat m_x,m_y,m_z and m_w\n        Vector vTempX = V.GetSplatX();\n        Vector vTempY = V.GetSplatY();\n        Vector vTempZ = V.GetSplatZ();\n        Vector vTempW = V.GetSplatW();\n\n        // Mul by the matrix\n        vTempX = _mm_mul_ps(vTempX, m_rows[0]);\n        vTempY = _mm_mul_ps(vTempY, m_rows[1]);\n        vTempZ = _mm_mul_ps(vTempZ, m_rows[2]);\n        vTempW = _mm_mul_ps(vTempW, m_rows[3]);\n\n        // Add them all together\n        vTempX = _mm_add_ps(vTempX, vTempY);\n        vTempZ = _mm_add_ps(vTempZ, vTempW);\n        vTempX = _mm_add_ps(vTempX, vTempZ);\n\n        return vTempX;\n    }\n\n    inline Vector& Matrix::operator[](uint32_t i)\n    {\n        ASSERT(i < 4);\n        return m_rows[i];\n    }\n\n    inline const Vector Matrix::operator[](uint32_t i) const\n    {\n        ASSERT(i < 4);\n        return m_rows[i];\n    }\n\n    inline Matrix Matrix::operator*(const Matrix& rhs) const\n    {\n        Matrix result = *this;\n        result *= rhs;\n        return result;\n    }\n\n    inline Matrix& Matrix::operator*= (const Matrix& rhs)\n    {\n        Vector vX, vY, vZ, vW;\n\n        // Use vW to hold the original row\n        vW = m_rows[0];\n        vX = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(0, 0, 0, 0));\n        vY = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(1, 1, 1, 1));\n        vZ = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(2, 2, 2, 2));\n        vW = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(3, 3, 3, 3));\n        vX = _mm_mul_ps(vX, rhs.m_rows[0]);\n        vY = _mm_mul_ps(vY, rhs.m_rows[1]);\n        vZ = _mm_mul_ps(vZ, rhs.m_rows[2]);\n        vW = _mm_mul_ps(vW, rhs.m_rows[3]);\n        vX = _mm_add_ps(vX, vZ);\n        vY = _mm_add_ps(vY, vW);\n        vX = _mm_add_ps(vX, vY);\n        m_rows[0] = vX;\n\n        // Repeat for the other 3 rows\n        vW = m_rows[1];\n        vX = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(0, 0, 0, 0));\n        vY = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(1, 1, 1, 1));\n        vZ = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(2, 2, 2, 2));\n        vW = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(3, 3, 3, 3));\n        vX = _mm_mul_ps(vX, rhs.m_rows[0]);\n        vY = _mm_mul_ps(vY, rhs.m_rows[1]);\n        vZ = _mm_mul_ps(vZ, rhs.m_rows[2]);\n        vW = _mm_mul_ps(vW, rhs.m_rows[3]);\n        vX = _mm_add_ps(vX, vZ);\n        vY = _mm_add_ps(vY, vW);\n        vX = _mm_add_ps(vX, vY);\n        m_rows[1] = vX;\n\n        vW = m_rows[2];\n        vX = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(0, 0, 0, 0));\n        vY = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(1, 1, 1, 1));\n        vZ = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(2, 2, 2, 2));\n        vW = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(3, 3, 3, 3));\n        vX = _mm_mul_ps(vX, rhs.m_rows[0]);\n        vY = _mm_mul_ps(vY, rhs.m_rows[1]);\n        vZ = _mm_mul_ps(vZ, rhs.m_rows[2]);\n        vW = _mm_mul_ps(vW, rhs.m_rows[3]);\n        vX = _mm_add_ps(vX, vZ);\n        vY = _mm_add_ps(vY, vW);\n        vX = _mm_add_ps(vX, vY);\n        m_rows[2] = vX;\n\n        vW = m_rows[3];\n        vX = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(0, 0, 0, 0));\n        vY = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(1, 1, 1, 1));\n        vZ = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(2, 2, 2, 2));\n        vW = _mm_shuffle_ps(vW, vW, _MM_SHUFFLE(3, 3, 3, 3));\n        vX = _mm_mul_ps(vX, rhs.m_rows[0]);\n        vY = _mm_mul_ps(vY, rhs.m_rows[1]);\n        vZ = _mm_mul_ps(vZ, rhs.m_rows[2]);\n        vW = _mm_mul_ps(vW, rhs.m_rows[3]);\n        vX = _mm_add_ps(vX, vZ);\n        vY = _mm_add_ps(vY, vW);\n        vX = _mm_add_ps(vX, vY);\n        m_rows[3] = vX;\n        return *this;\n    }\n\n    inline Matrix Matrix::operator*(const Quaternion& rhs) const\n    {\n        return operator*(Matrix(rhs));\n    }\n\n    inline Matrix Matrix::operator*=(const Quaternion& rhs)\n    {\n        return operator*=(Matrix(rhs));\n    }\n\n    inline bool Matrix::operator==(const Matrix& rhs) const\n    {\n        for (auto i = 0; i < 4; i++)\n        {\n            for (auto j = 0; j < 4; j++)\n            {\n                if (m_values[i][j] != rhs.m_values[i][j])\n                {\n                    return false;\n                }\n            }\n        }\n\n        return true;\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Quaternion.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"Quaternion.h\"\n#include \"Matrix.h\"\n\nnamespace Math\n{\n    Quaternion const Quaternion::Identity(0, 0, 0, 1);\n\n    // Rotation order is XYZ\n    EulerAngles Quaternion::ToEulerAngles() const\n    {\n        return Matrix(*this).ToEulerAngles();\n    }\n\n    Quaternion Quaternion::LookRotation(const Vector& forward, const Vector& up)\n    {\n        const Vector t = Vector::Cross3(up, forward).Normalize3();\n        return Matrix(t, Vector::Cross3(forward, t), forward).GetRotation();\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Quaternion.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Vector.h\"\n\nnamespace Math\n{\n    class alignas(16) Quaternion\n    {\n    public:\n\n        static Quaternion const Identity;\n\n        // Calculate the rotation required to align the source vector to the target vector (shortest path)\n        static Quaternion FromRotationBetweenNormalizedVectors(const Vector& sourceVector, const Vector& targetVector);\n\n        // Calculate the rotation required to align one vector onto another but also taking account a fallback rotation axis for opposite parallel vectors\n        static Quaternion FromRotationBetweenNormalizedVectors(const Vector& sourceVector, const Vector& targetVector, const Vector& fallbackRotationAxis);\n\n        // Calculate the rotation required to align the source vector to the target vector (shortest path)\n        static Quaternion FromRotationBetweenVectors(const Vector& sourceVector, const Vector& targetVector);\n\n        // Normalized LERP - not accurate - only use for really short distances\n        static Quaternion NLerp(const Quaternion& from, const Quaternion& to, float t);\n\n        // Standard and accurate Spherical LERP - based on DirectX Math\n        static Quaternion SLerp(const Quaternion& from, const Quaternion& to, float t);\n\n        // Fast approximation of a Spherical LERP - based on \"A fast and accurate estimate for SLERP\" by David Eberly\n        static Quaternion FastSLerp(const Quaternion& from, const Quaternion& to, float t);\n\n        // Spherical quadrangle/cubic interpolation for quaternions\n        static Quaternion SQuad(const Quaternion& q0, const Quaternion& q1, const Quaternion& q2, const Quaternion& q3, float t);\n\n        // Calculate the shortest delta quaternion needed to rotate 'from' onto 'to'\n        static Quaternion Delta(const Quaternion& from, const Quaternion& to);\n\n        // Simple vector dot product between two quaternions\n        static Vector Dot(const Quaternion& q0, const Quaternion& q1);\n\n        // Calculate the angular distance between two quaternions\n        static Radians Distance(const Quaternion& q0, const Quaternion& q1);\n\n        // Calculate look rotation given forward and up vectors\n        static Quaternion LookRotation(const Vector& forward, const Vector& up);\n\n    public:\n\n        Quaternion() = default;\n        explicit Quaternion(NoInit_t);\n        explicit Quaternion(IdentityInit_t);\n        explicit Quaternion(const Vector& v);\n        explicit Quaternion(float ix, float iy, float iz, float iw);\n        explicit Quaternion(const Float4& v);\n\n        explicit Quaternion(const Vector& axis, Radians angle);\n        explicit Quaternion(AxisAngle axisAngle);\n\n        explicit Quaternion(const EulerAngles& eulerAngles);\n        explicit Quaternion(Radians rotX, Radians rotY, Radians rotZ);\n\n        operator __m128& ();\n        operator const __m128& () const;\n\n        Float4 ToFloat4() const;\n        Vector ToVector() const;\n\n        Vector Length();\n        float GetLength() const;\n\n        // Get the angle this rotation represents around the specified axis\n        Radians GetAngle() const;\n\n        AxisAngle ToAxisAngle() const;\n        EulerAngles ToEulerAngles() const;\n\n        Vector RotateVector(const Vector& vector) const;\n        Vector RotateVectorInverse(const Vector& vector) const;\n\n        Quaternion& Conjugate();\n        Quaternion GetConjugate() const;\n\n        Quaternion& Negate();\n        Quaternion GetNegated() const;\n\n        Quaternion& Invert();\n        Quaternion GetInverse() const;\n\n        Quaternion& Normalize();\n        Quaternion GetNormalized() const;\n\n        Vector XAxis() const noexcept;\n        Vector YAxis() const noexcept;\n        Vector ZAxis() const noexcept;\n\n        // Ensure that this rotation is the shortest in terms of the angle (i.e. -5 instead of 355)\n        Quaternion& MakeShortestPath();\n\n        // Ensure that this rotation is the shortest in terms of the angle (i.e. -5 instead of 355)\n        Quaternion GetShortestPath() const;\n\n        // This function will return the estimated normalized quaternion, this is not super accurate but a lot faster (use with care)\n        Quaternion& NormalizeInaccurate();\n\n        // This function will return the estimated normalized quaternion, this is not super accurate but a lot faster (use with care)\n        Quaternion GetNormalizedInaccurate() const;\n\n        bool IsNormalized() const;\n        bool IsIdentity() const;\n\n        // Concatenate the rotation of this onto rhs and return the result i.e. first rotate by rhs then by this\n        // This means order of rotation is right-to-left: child-rotation * parent-rotation\n        Quaternion operator*(const Quaternion& rhs) const;\n        Quaternion& operator*=(const Quaternion& rhs);\n\n        // Is the distance between this quaternion and another one under the threshold?\n        bool IsNearEqual(const Quaternion& rhs, Radians const threshold = Math::DegreesToRadians) const;\n\n        // Exact equality\n        bool operator==(const Quaternion& rhs) const;\n\n        // Exact equality\n        bool operator!=(const Quaternion& rhs) const;\n\n    private:\n\n        Vector GetSplatW() const;\n        float GetW() const;\n\n        Quaternion& operator=(const Vector& v) = delete;\n\n    public:\n\n        __m128 m_data;\n    };\n\n    static_assert(sizeof(Vector) == 16, \"Quaternion size must be 16 bytes!\");\n}\n\n#include \"Quaternion.inl\"\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Quaternion.inl",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Quaternion.h\"\n\nnamespace Math\n{\n    inline Quaternion Quaternion::FromRotationBetweenNormalizedVectors(const Vector& from, const Vector& to)\n    {\n        ASSERT(from.IsNormalized3() && to.IsNormalized3());\n\n        Quaternion result;\n\n        // Parallel vectors - return zero rotation\n        Vector const dot = Vector::Dot3(from, to);\n        if (dot.IsGreaterThanEqual4(Vector::OneMinusEpsilon))\n        {\n            result = Quaternion::Identity;\n        }\n        // Opposite vectors - return 180 rotation around any orthogonal axis\n        else if (dot.IsLessThanEqual4(Vector::EpsilonMinusOne))\n        {\n            Float4 const fromValues = from.ToFloat4();\n            result = Quaternion(-fromValues.m_z, fromValues.m_y, fromValues.m_x, 0);\n            result.Normalize();\n        }\n        else // Calculate quaternion rotation\n        {\n            Vector const cross = Vector::Cross3(from, to);\n            Vector Q = Vector::Select(cross, dot, Vector::Select0001);\n            Q += Vector::Select(Vector::Zero, Q.Length4(), Vector::Select0001);\n            result = Quaternion(Q);\n            result.Normalize();\n        }\n\n        return result;\n    }\n\n    inline Quaternion Quaternion::FromRotationBetweenNormalizedVectors(const Vector& from, const Vector& to, const Vector& fallbackRotationAxis)\n    {\n        ASSERT(from.IsNormalized3() && to.IsNormalized3());\n\n        Quaternion Q(NoInit);\n\n        Vector rotationAxis = from.Cross3(to).GetNormalized3();\n        if (rotationAxis.GetLengthSquared3() == 0)\n        {\n            rotationAxis = fallbackRotationAxis;\n        }\n\n        float const dot = from.GetDot3(to);\n        if (dot >= (1.0f - Math::Epsilon))\n        {\n            Q = Quaternion::Identity;\n        }\n        else\n        {\n            float const angle = Math::ACos(dot);\n            Q = Quaternion(rotationAxis, angle);\n        }\n\n        return Q;\n    }\n\n    inline Quaternion Quaternion::FromRotationBetweenVectors(const Vector& sourceVector, const Vector& targetVector)\n    {\n        return FromRotationBetweenNormalizedVectors(\n            sourceVector.GetNormalized3(),\n                targetVector.GetNormalized3());\n    }\n\n    inline Quaternion Quaternion::NLerp(const Quaternion& from, const Quaternion& to, float T)\n    {\n        ASSERT(T >= 0.0f && T <= 1.0f);\n\n        Quaternion adjustedFrom(from);\n\n        // Ensure that the rotations are in the same direction\n        if (Quaternion::Dot(from, to).IsLessThan4(Vector::Zero))\n        {\n            adjustedFrom.Negate();\n        }\n\n        Quaternion result(Vector::Lerp(adjustedFrom.ToVector(), to.ToVector(), T));\n        result.Normalize();\n        return result;\n    }\n\n    inline Quaternion Quaternion::SLerp(const Quaternion& from, const Quaternion& to, float T)\n    {\n        ASSERT(T >= 0.0f && T <= 1.0f);\n\n        static SIMD::UIntMask const maskSign = { 0x80000000,0x00000000,0x00000000,0x00000000 };\n        static __m128 const oneMinusEpsilon = { 1.0f - 0.00001f, 1.0f - 0.00001f, 1.0f - 0.00001f, 1.0f - 0.00001f };\n\n        Vector const VecT(T);\n\n        Vector cosOmega = Quaternion::Dot(from, to);\n\n        Vector control = cosOmega.LessThan(Vector::Zero);\n        Vector sign = Vector::Select(Vector::One, Vector::NegativeOne, control);\n\n        cosOmega = _mm_mul_ps(cosOmega, sign);\n        control = cosOmega.LessThan(oneMinusEpsilon);\n\n        Vector sinOmega = _mm_mul_ps(cosOmega, cosOmega);\n        sinOmega = _mm_sub_ps(Vector::One, sinOmega);\n        sinOmega = _mm_sqrt_ps(sinOmega);\n\n        Vector omega = Vector::ATan2(sinOmega, cosOmega);\n\n        Vector V01 = _mm_shuffle_ps(VecT, VecT, _MM_SHUFFLE(2, 3, 0, 1));\n        V01 = _mm_and_ps(V01, SIMD::g_maskXY00);\n        V01 = _mm_xor_ps(V01, maskSign);\n        V01 = _mm_add_ps(Vector::UnitX, V01);\n\n        Vector S0 = _mm_mul_ps(V01, omega);\n        S0 = Vector::Sin(S0);\n        S0 = _mm_div_ps(S0, sinOmega);\n        S0 = Vector::Select(V01, S0, control);\n\n        Vector S1 = S0.GetSplatY();\n        S0 = S0.GetSplatX();\n\n        S1 = _mm_mul_ps(S1, sign);\n        Vector result = _mm_mul_ps(from, S0);\n        S1 = _mm_mul_ps(S1, to);\n        result = _mm_add_ps(result, S1);\n\n        return Quaternion(result);\n    }\n\n    inline Quaternion Quaternion::FastSLerp(const Quaternion& q0, const Quaternion& q1, float t)\n    {\n        // Precomputed constants\n        constexpr float const mu = 1.85298109240830f;\n        static Vector const u0123 = _mm_setr_ps(1.f / (1 * 3), 1.f / (2 * 5), 1.f / (3 * 7), 1.f / (4 * 9));\n        static Vector const u4567 = _mm_setr_ps(1.f / (5 * 11), 1.f / (6 * 13), 1.f / (7 * 15), mu / (8 * 17));\n        static Vector const v0123 = _mm_setr_ps(1.f / 3, 2.f / 5, 3.f / 7, 4.f / 9);\n        static Vector const v4567 = _mm_setr_ps(5.f / 11, 6.f / 13, 7.f / 15, mu * 8 / 17);\n        static Vector const vSignMask = _mm_set1_ps(-0.f);\n\n        // Common code for computing the scalar coefficients of SLERP\n        auto CalculateCoefficient = [](Vector vT, Vector xm1)\n        {\n            Vector const vTSquared = vT * vT;\n\n            // ( b4, b5, b6, b7 ) = ( x-1 ) * ( u4 * t^2 - v4, u5 * t^2 - v5, u6 * t^2 - v6, u7 * t^2 - v7 )\n            Vector b4567 = Vector::MultiplySubtract(u4567, vTSquared, v4567);\n            b4567 *= xm1;\n\n            // ( b7, b7, b7, b7 )\n            Vector b = b4567.GetSplatW();\n            Vector c = b + Vector::One;\n\n            // ( b6, b6, b6, b6 )\n            b = b4567.GetSplatZ();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b5, b5, b5, b5 )\n            b = b4567.GetSplatY();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b4, b4, b4, b4 )\n            b = b4567.GetSplatX();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b0, b1, b2, b3 ) =\n            // ( x-1)*(u0* t^2-v0, u1 * t^2 -v1, u2* t^2-v2, u3* t^2-v3 )\n            Vector b0123 = Vector::MultiplySubtract(u0123, vTSquared, v0123);\n            b0123 *= xm1;\n\n            // ( b3, b3, b3, b3 )\n            b = b0123.GetSplatW();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b2, b2, b2, b2 )\n            b = b0123.GetSplatZ();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b1, b1, b1, b1 )\n            b = b0123.GetSplatY();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n\n            // ( b0, b0, b0, b0 )\n            b = b0123.GetSplatX();\n            c = Vector::MultiplyAdd(b, c, Vector::One);\n            c *= vT;\n\n            return c;\n        };\n\n        Vector x = Vector::Dot4(q0.m_data, q1.m_data); // cos ( theta ) in all components\n\n        Vector sign = _mm_and_ps(vSignMask, x);\n        x = _mm_xor_ps(sign, x);\n        Vector localQ1 = _mm_xor_ps(sign, q1);\n\n        Vector xm1 = x - Vector::One;\n\n        Vector cT = CalculateCoefficient(Vector(t), xm1);\n        Vector cD = CalculateCoefficient(Vector(1.0f - t), xm1);\n        cT = cT * localQ1;\n\n        Quaternion result(Vector::MultiplyAdd(cD, q0.m_data, cT));\n        return result;\n    }\n\n    inline Quaternion Quaternion::SQuad(const Quaternion& q0, const Quaternion& q1, const Quaternion& q2, const Quaternion& q3, float t)\n    {\n        ASSERT(t >= 0.0f && t <= 1.0f);\n\n        Quaternion const q03 = Quaternion::SLerp(q0, q3, t);\n        Quaternion const q12 = Quaternion::SLerp(q1, q2, t);\n        t = (t - (t * t)) * 2;\n        Quaternion const result = Quaternion::SLerp(q03, q12, t);\n        return result;\n    }\n\n    inline Quaternion Quaternion::Delta(const Quaternion& from, const Quaternion& to)\n    {\n        return to * from.GetInverse();\n    }\n\n    inline Vector Quaternion::Dot(const Quaternion& q0, const Quaternion& q1)\n    {\n        return Vector::Dot4(q0.m_data, q1.m_data);\n    }\n\n    inline Radians Quaternion::Distance(const Quaternion& q0, const Quaternion& q1)\n    {\n        float const dot = Math::Clamp(Dot(q0, q1).ToFloat(), -1.0f, 1.0f);\n        return Radians(2 * Math::ACos(Math::Abs(dot)));\n    }\n\n    inline Quaternion::Quaternion(NoInit_t)\n    {\n    }\n\n    inline Quaternion::Quaternion(IdentityInit_t)\n        : m_data(Vector::UnitW.m_data)\n    {\n    }\n\n    inline Quaternion::Quaternion(const Vector& v)\n        : m_data(v.m_data)\n    {\n    }\n\n    inline Quaternion::Quaternion(float ix, float iy, float iz, float iw)\n    {\n        m_data = _mm_set_ps(iw, iz, iy, ix);\n    }\n\n    inline Quaternion::Quaternion(const Float4& v)\n        : Quaternion(v.m_x, v.m_y, v.m_z, v.m_w)\n    {\n    }\n\n    inline Quaternion::Quaternion(const Vector& axis, Radians angle)\n    {\n        ASSERT(axis.IsNormalized3());\n\n        auto N = _mm_and_ps(axis, SIMD::g_maskXYZ0);\n        N = _mm_or_ps(N, Vector::UnitW);\n        auto scale = _mm_set_ps1(0.5f * (float)angle);\n\n        Vector sine, cosine;\n        Vector::SinCos(sine, cosine, scale);\n\n        scale = _mm_and_ps(sine, SIMD::g_maskXYZ0);\n        cosine = _mm_and_ps(cosine, SIMD::g_mask000W);\n        scale = _mm_or_ps(scale, cosine);\n\n        N = _mm_mul_ps(N, scale);\n        m_data = N;\n    }\n\n    inline Quaternion::Quaternion(AxisAngle axisAngle)\n        : Quaternion(Vector(axisAngle.m_axis), axisAngle.m_angle)\n    {\n    }\n\n    inline Quaternion::Quaternion(const EulerAngles& eulerAngles)\n    {\n        auto const rotationX = Quaternion(Vector::UnitX, eulerAngles.m_x);\n        auto const rotationY = Quaternion(Vector::UnitY, eulerAngles.m_y);\n        auto const rotationZ = Quaternion(Vector::UnitZ, eulerAngles.m_z);\n\n        // Rotation order is XYZ - all in global space, hence the order is reversed\n        m_data = (rotationX * rotationY * rotationZ).GetNormalized().m_data;\n    }\n\n    inline Quaternion::Quaternion(Radians rotX, Radians rotY, Radians rotZ)\n        : Quaternion(EulerAngles(rotX, rotY, rotZ))\n    {\n    }\n\n    inline Quaternion::operator __m128& ()\n    {\n        return m_data;\n    }\n\n    inline Quaternion::operator const __m128& () const\n    {\n        return m_data;\n    }\n\n    inline Float4 Quaternion::ToFloat4() const\n    {\n        Float4 v;\n        _mm_storeu_ps(&v.m_x, m_data);\n        return v;\n    }\n\n    inline Vector Quaternion::ToVector() const\n    {\n        return Vector(m_data);\n    }\n\n    inline Vector Quaternion::Length()\n    {\n        return ToVector().Length4();\n    }\n\n    inline float Quaternion::GetLength() const\n    {\n        return ToVector().GetLength4();\n    }\n\n    inline Radians Quaternion::GetAngle() const\n    {\n        return Radians(2.0f * Math::ACos(GetW()));\n    }\n\n    inline AxisAngle Quaternion::ToAxisAngle() const\n    {\n        return AxisAngle(ToVector(), Radians(2.0f * Math::ACos(GetW())));\n    }\n\n    inline Vector Quaternion::RotateVector(const Vector& vector) const\n    {\n        Quaternion const A(Vector::Select(Vector::Select1110, vector, Vector::Select1110));\n        Quaternion const result = GetConjugate() * A;\n        return (result * *this).ToVector();\n    }\n\n    inline Vector Quaternion::RotateVectorInverse(const Vector& vector) const\n    {\n        Quaternion const A(Vector::Select(Vector::Select1110, vector, Vector::Select1110));\n        Quaternion const result = *this * A;\n        return (result * GetConjugate()).ToVector();\n    }\n\n    inline Quaternion& Quaternion::Conjugate()\n    {\n        static __m128 const conj = { -1.0f, -1.0f, -1.0f, 1.0f };\n        m_data = _mm_mul_ps(*this, conj);\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetConjugate() const\n    {\n        Quaternion q = *this;\n        q.Conjugate();\n        return q;\n    }\n    inline Quaternion& Quaternion::Negate()\n    {\n        m_data = _mm_mul_ps(*this, Vector::NegativeOne);\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetNegated() const\n    {\n        Quaternion q = *this;\n        q.Negate();\n        return q;\n    }\n\n    inline Quaternion& Quaternion::Invert()\n    {\n        Vector const conjugate(GetConjugate().m_data);\n        Vector const length = ToVector().Length4();\n        Vector const mask = length.LessThanEqual(Vector::Epsilon);\n        Vector const result = conjugate / length;\n        m_data = result.Select(result, Vector::Zero, mask);\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetInverse() const\n    {\n        Quaternion q = *this;\n        q.Invert();\n        return q;\n    }\n\n    inline Quaternion& Quaternion::Normalize()\n    {\n        m_data = ToVector().GetNormalized4().m_data;\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetNormalized() const\n    {\n        Quaternion q = *this;\n        q.Normalize();\n        return q;\n    }\n\n    inline Vector Quaternion::XAxis() const noexcept\n    {\n        const float x = _mm_cvtss_f32(m_data);\n        const float y = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(1, 1, 1, 1)));\n        const float z = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(2, 2, 2, 2)));\n        const float w = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(3, 3, 3, 3)));\n\n        const float s = 2.0f * w;\n        const float x2 = 2.0f * x;\n\n        return Vector(\n            x2 * x + s * w - 1.0f,\n                x2 * y + s * z,\n                    x2 * z + s * -y);\n    }\n\n    inline Vector Quaternion::YAxis() const noexcept\n    {\n        const float x = _mm_cvtss_f32(m_data);\n        const float y = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(1, 1, 1, 1)));\n        const float z = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(2, 2, 2, 2)));\n        const float w = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(3, 3, 3, 3)));\n\n        const float s = 2.0f * w;\n        const float y2 = 2.0f * y;\n\n        return Vector(\n            y2 * x + s * -z,\n                y2 * y + s * w - 1.0f,\n                    y2 * z + s * x);\n    }\n\n    inline Vector Quaternion::ZAxis() const noexcept\n    {\n        const float x = _mm_cvtss_f32(m_data);\n        const float y = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(1, 1, 1, 1)));\n        const float z = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(2, 2, 2, 2)));\n        const float w = _mm_cvtss_f32(\n            _mm_shuffle_ps(m_data, m_data,\n                _MM_SHUFFLE(3, 3, 3, 3)));\n\n        const float s = 2.0f * w;\n        const float z2 = 2.0f * z;\n\n        return Vector(\n            x * z2 + s * y,\n                y * z2 + s * -x,\n                    z * z2 + s * w - 1.0f);\n    }\n\n    inline Quaternion& Quaternion::MakeShortestPath()\n    {\n        // If we have a > 180 angle, negate\n        // w < 0.0f is the same as dot( identity, q ) < 0\n        if (GetW() < 0.0f)\n        {\n            Negate();\n        }\n\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetShortestPath() const\n    {\n        Quaternion sp = *this;\n        sp.MakeShortestPath();\n        return sp;\n    }\n\n    inline Quaternion& Quaternion::NormalizeInaccurate()\n    {\n        *this = GetNormalizedInaccurate();\n        return *this;\n    }\n\n    inline Quaternion Quaternion::GetNormalizedInaccurate() const\n    {\n        __m128 vLengthSq = _mm_mul_ps(m_data, m_data);\n        __m128 vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(3, 2, 3, 2));\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(1, 0, 0, 0));\n        vTemp = _mm_shuffle_ps(vTemp, vLengthSq, _MM_SHUFFLE(3, 3, 0, 0));\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(2, 2, 2, 2));\n\n        // Get the reciprocal and mul to perform the normalization\n        Quaternion result;\n        result.m_data = _mm_rsqrt_ps(vLengthSq);\n        result.m_data = _mm_mul_ps(result.m_data, m_data);\n        return result;\n    }\n\n    inline bool Quaternion::IsNormalized() const\n    {\n        return ToVector().IsNormalized4();\n    }\n\n    inline bool Quaternion::IsIdentity() const\n    {\n        return ToVector().IsEqual3(Vector::UnitW);\n    }\n\n    inline Quaternion Quaternion::operator*(const Quaternion& rhs) const\n    {\n        static const __m128 controlWZYX = { 1.0f,-1.0f, 1.0f,-1.0f };\n        static const __m128 controlZWXY = { 1.0f, 1.0f,-1.0f,-1.0f };\n        static const __m128 controlYXWZ = { -1.0f, 1.0f, 1.0f,-1.0f };\n\n        // Copy to SSE registers and use as few as possible for x86\n        __m128 Q2X = rhs;\n        __m128 Q2Y = rhs;\n        __m128 Q2Z = rhs;\n        __m128 vResult = rhs;\n        // Splat with one instruction\n        vResult = _mm_shuffle_ps(vResult, vResult, _MM_SHUFFLE(3, 3, 3, 3));\n        Q2X = _mm_shuffle_ps(Q2X, Q2X, _MM_SHUFFLE(0, 0, 0, 0));\n        Q2Y = _mm_shuffle_ps(Q2Y, Q2Y, _MM_SHUFFLE(1, 1, 1, 1));\n        Q2Z = _mm_shuffle_ps(Q2Z, Q2Z, _MM_SHUFFLE(2, 2, 2, 2));\n        // Retire Q1 and perform Q1*Q2W\n        vResult = _mm_mul_ps(vResult, *this);\n        __m128 Q1Shuffle = *this;\n        // Shuffle the copies of Q1\n        Q1Shuffle = _mm_shuffle_ps(Q1Shuffle, Q1Shuffle, _MM_SHUFFLE(0, 1, 2, 3));\n        // Mul by Q1WZYX\n        Q2X = _mm_mul_ps(Q2X, Q1Shuffle);\n        Q1Shuffle = _mm_shuffle_ps(Q1Shuffle, Q1Shuffle, _MM_SHUFFLE(2, 3, 0, 1));\n        // Flip the signs on m_y and m_z\n        Q2X = _mm_mul_ps(Q2X, controlWZYX);\n        // Mul by Q1ZWXY\n        Q2Y = _mm_mul_ps(Q2Y, Q1Shuffle);\n        Q1Shuffle = _mm_shuffle_ps(Q1Shuffle, Q1Shuffle, _MM_SHUFFLE(0, 1, 2, 3));\n        // Flip the signs on m_z and m_w\n        Q2Y = _mm_mul_ps(Q2Y, controlZWXY);\n        // Mul by Q1YXWZ\n        Q2Z = _mm_mul_ps(Q2Z, Q1Shuffle);\n        vResult = _mm_add_ps(vResult, Q2X);\n        // Flip the signs on m_x and m_w\n        Q2Z = _mm_mul_ps(Q2Z, controlYXWZ);\n        Q2Y = _mm_add_ps(Q2Y, Q2Z);\n        vResult = _mm_add_ps(vResult, Q2Y);\n\n        return Quaternion(vResult);\n    }\n\n    inline Quaternion& Quaternion::operator*=(const Quaternion& rhs)\n    {\n        *this = *this * rhs;\n        return *this;\n    }\n\n    inline bool Quaternion::IsNearEqual(const Quaternion& rhs, Radians const threshold) const\n    {\n        return Quaternion::Distance(*this, rhs) <= threshold;\n    }\n\n    inline bool Quaternion::operator==(const Quaternion& rhs) const\n    {\n        return ToVector() == rhs.ToVector();\n    }\n\n    inline bool Quaternion::operator!=(const Quaternion& rhs) const\n    {\n        return !operator==(rhs);\n    }\n\n    inline Vector Quaternion::GetSplatW() const\n    {\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(3, 3, 3, 3));\n    }\n\n    inline float Quaternion::GetW() const\n    {\n        auto vTemp = GetSplatW();\n        return _mm_cvtss_f32(vTemp);\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/SIMD.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include <stdint.h>\n#include <immintrin.h>\n\nnamespace SIMD\n{\n    struct alignas( 16 ) IntMask\n    {\n        inline operator __m128( ) const { return reinterpret_cast<__m128 const&>( *this ); }\n        inline operator __m128i( ) const { return _mm_castps_si128( *this ); }\n        inline operator __m128d( ) const { return _mm_castps_pd( *this ); }\n\n        int32_t i[4];\n    };\n\n    struct alignas( 16 ) UIntMask\n    {\n        inline operator __m128( ) const { return reinterpret_cast<__m128 const&>( *this ); }\n        inline operator __m128i( ) const { return _mm_castps_si128( *this ); }\n        inline operator __m128d( ) const { return _mm_castps_pd( *this ); }\n\n        uint32_t v[4];\n    };\n\n    struct alignas( 16 ) FloatMask\n    {\n        inline operator __m128() const { return reinterpret_cast<__m128 const&>( *this ); }\n        inline operator __m128i() const { return _mm_castps_si128( *this ); }\n        inline operator __m128d() const { return _mm_castps_pd( *this ); }\n\n        float v[4];\n    };\n\n    // Int Operations\n    //-------------------------------------------------------------------------\n\n    namespace Int\n    {\n        FORCE_INLINE bool Equal( __m128 V1, __m128 V2 )\n        {\n            __m128i vTemp = _mm_cmpeq_epi32( _mm_castps_si128( V1 ), _mm_castps_si128( V2 ) );\n            return ( ( ( _mm_movemask_ps( _mm_castsi128_ps( vTemp ) ) & 7 ) == 7 ) != 0 );\n        }\n\n        FORCE_INLINE bool NotEqual( __m128 V1, __m128 V2 )\n        {\n            __m128i vTemp = _mm_cmpeq_epi32( _mm_castps_si128( V1 ), _mm_castps_si128( V2 ) );\n            return ( ( _mm_movemask_ps( _mm_castsi128_ps( vTemp ) ) != 0xF ) != 0 );\n        }\n\n        FORCE_INLINE __m128 And( __m128 V1, __m128 V2 )\n        {\n            return _mm_and_ps( V1, V2 );\n        }\n\n        FORCE_INLINE __m128 Or( __m128 V1, __m128 V2 )\n        {\n            __m128i V = _mm_or_si128( _mm_castps_si128( V1 ), _mm_castps_si128( V2 ) );\n            return _mm_castsi128_ps( V );\n        }\n    }\n\n    //-------------------------------------------------------------------------\n\n    static __m128 const g_sinCoefficients0 = { -0.16666667f, +0.0083333310f, -0.00019840874f, +2.7525562e-06f };\n    static __m128 const g_sinCoefficients1 = { -2.3889859e-08f, -0.16665852f, +0.0083139502f, -0.00018524670f };\n    static __m128 const g_cosCoefficients0 = { -0.5f, +0.041666638f, -0.0013888378f, +2.4760495e-05f };\n    static __m128 const g_cosCoefficients1 = { -2.6051615e-07f, -0.49992746f, +0.041493919f, -0.0012712436f };\n    static __m128 const g_tanCoefficients0 = { 1.0f, 0.333333333f, 0.133333333f, 5.396825397e-2f };\n    static __m128 const g_tanCoefficients1 = { 2.186948854e-2f, 8.863235530e-3f, 3.592128167e-3f, 1.455834485e-3f };\n    static __m128 const g_tanCoefficients2 = { 5.900274264e-4f, 2.391290764e-4f, 9.691537707e-5f, 3.927832950e-5f };\n    static __m128 const g_arcCoefficients0 = { +1.5707963050f, -0.2145988016f, +0.0889789874f, -0.0501743046f };\n    static __m128 const g_arcCoefficients1 = { +0.0308918810f, -0.0170881256f, +0.0066700901f, -0.0012624911f };\n    static __m128 const g_aTanCoefficients0 = { -0.3333314528f, +0.1999355085f, -0.1420889944f, +0.1065626393f };\n    static __m128 const g_aTanCoefficients1 = { -0.0752896400f, +0.0429096138f, -0.0161657367f, +0.0028662257f };\n    static __m128 const g_aTanEstCoefficients0 = { +0.999866f, +0.999866f, +0.999866f, +0.999866f };\n    static __m128 const g_aTanEstCoefficients1 = { -0.3302995f, +0.180141f, -0.085133f, +0.0208351f };\n    static __m128 const g_tanEstCoefficients = { 2.484f, -1.954923183e-1f, 2.467401101f, Math::OneDivPi };\n    static __m128 const g_arcEstCoefficients = { +1.5707288f,-0.2121144f,+0.0742610f,-0.0187293f };\n    static __m128 const g_aTan2Constants = { Math::Pi, Math::PiDivTwo, Math::PiDivFour, 2.3561944905f /* 3/4 Pi */ };\n\n    //-------------------------------------------------------------------------\n\n    static FloatMask const g_noFraction = { 8388608.0f,8388608.0f,8388608.0f,8388608.0f };\n    static IntMask const g_absMask = { 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF };\n    static UIntMask const g_trueMask = { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF };\n    static UIntMask const g_signMask = { 0x80000000, 0x80000000, 0x80000000, 0x80000000 };\n    static UIntMask const g_maskX000 = { 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000 };\n    static UIntMask const g_mask0Y00 = { 0x00000000, 0xFFFFFFFF, 0x00000000, 0x00000000 };\n    static UIntMask const g_mask00Z0 = { 0x00000000, 0x00000000, 0xFFFFFFFF, 0x00000000 };\n    static UIntMask const g_mask000W = { 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF };\n    static UIntMask const g_maskXY00 = { 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000 };\n    static UIntMask const g_maskXYZ0 = { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000 };\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Scalar.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Compiler.h\"\n#include \"Debug.h\"\n\n#include \"Constants.h\"\n\n#include <math.h>\n#include <stdint.h>\n\n//\n// Scalar related methods\n//\n\nnamespace Math\n{\n    FORCE_INLINE float Sin( float value ) { return sinf( value ); }\n    FORCE_INLINE float Cos( float value ) { return cosf( value ); }\n    FORCE_INLINE float Tan( float value ) { return tanf( value ); }\n\n    FORCE_INLINE float ASin( float value ) { return asinf( value ); }\n    FORCE_INLINE float ACos( float value ) { return acosf( value ); }\n    FORCE_INLINE float ATan( float value ) { return atanf( value ); }\n    FORCE_INLINE float ATan2( float y, float x ) { return atan2f( y, x ); }\n\n    FORCE_INLINE float Cosec( float value ) { return 1.0f / sinf( value ); }\n    FORCE_INLINE float Sec( float value ) { return 1.0f / cosf( value ); }\n    FORCE_INLINE float Cot( float value ) { return 1.0f / tanf( PiDivTwo - value ); }\n\n    FORCE_INLINE float Pow( float x, float y ) { return powf( x, y ); }\n    FORCE_INLINE float Sqr( float value ) { return value * value; }\n    FORCE_INLINE float Sqrt( float value ) { return sqrtf( value ); }\n\n    FORCE_INLINE float Log( float value ) { return logf( value ); }\n    FORCE_INLINE float Log2f( float value ) { return log2f( value ); }\n\n    FORCE_INLINE float AddToMovingAverage( float currentAverage, uint64_t numCurrentSamples, float newValue )\n    {\n        return currentAverage + ( ( newValue - currentAverage ) / float( numCurrentSamples + 1 ) );\n    }\n\n    FORCE_INLINE float Abs( float a ) { return fabsf( a ); }\n    FORCE_INLINE double Abs( double a ) { return fabs( a ); }\n    FORCE_INLINE int8_t Abs( int8_t a ) { return (int8_t) abs( a ); }\n    FORCE_INLINE int16_t Abs( int16_t a ) { return (int16_t) abs( a ); }\n    FORCE_INLINE int32_t Abs( int32_t a ) { return labs( a ); }\n    FORCE_INLINE int64_t Abs( int64_t a ) { return llabs( a ); }\n\n    FORCE_INLINE float Reciprocal( float r ) { return 1.0f / r; }\n    FORCE_INLINE double Reciprocal( double r ) { return 1.0 / r; }\n\n    template<typename T>\n    FORCE_INLINE T Min( T a, T b ) { return a <= b ? a : b; }\n\n    template<typename T>\n    FORCE_INLINE T Max( T a, T b ) { return a >= b ? a : b; }\n\n    template<typename T>\n    FORCE_INLINE T AbsMin( T a, T b ) { return Abs( a ) <= Abs( b ) ? a : b; }\n\n    template<typename T>\n    FORCE_INLINE T AbsMax( T a, T b ) { return Abs( a ) >= Abs( b ) ? a : b; }\n\n    template<typename T>\n    FORCE_INLINE T Sqrt( T a ) { return sqrt( a ); }\n\n    template<typename T>\n    FORCE_INLINE T Clamp( T value, T lowerBound, T upperBound )\n    {\n        ASSERT( lowerBound <= upperBound );\n        return Min( Max( value, lowerBound ), upperBound );\n    }\n\n    template<typename T>\n    FORCE_INLINE bool IsInRangeInclusive( T value, T lowerBound, T upperBound )\n    {\n        ASSERT( lowerBound < upperBound );\n        return value >= lowerBound && value <= upperBound;\n    }\n\n    template<typename T>\n    FORCE_INLINE bool IsInRangeExclusive( T value, T lowerBound, T upperBound )\n    {\n        ASSERT( lowerBound < upperBound );\n        return value > lowerBound && value < upperBound;\n    }\n\n    // Decomposes a float into integer and remainder portions, remainder is return and the integer result is stored in the integer portion\n    FORCE_INLINE float ModF( float value, float& integerPortion )\n    {\n        return modff( value, &integerPortion );\n    }\n\n    // Returns the floating point remainder of x/y\n    FORCE_INLINE float FModF( float x, float y )\n    {\n        return fmodf( x, y );\n    }\n\n    template<typename T>\n    FORCE_INLINE T Lerp( T A, T B, float t )\n    {\n        return A + ( B - A ) * t;\n    }\n\n    FORCE_INLINE float PercentageThroughRange( float value, float lowerBound, float upperBound )\n    {\n        ASSERT( lowerBound < upperBound );\n        return Clamp( value, lowerBound, upperBound ) / ( upperBound - lowerBound );\n    }\n\n    FORCE_INLINE bool IsNearEqual( float value, float comparand, float epsilon = Epsilon )\n    {\n        return fabsf( value - comparand ) <= epsilon;\n    }\n\n    FORCE_INLINE bool IsNearZero( float value, float epsilon = Epsilon )\n    {\n        return fabsf( value ) <= epsilon;\n    }\n\n    FORCE_INLINE bool IsNearEqual( double value, double comparand, double epsilon = Epsilon )\n    {\n        return fabs( value - comparand ) <= epsilon;\n    }\n\n    FORCE_INLINE bool IsNearZero( double value, double epsilon = Epsilon )\n    {\n        return fabs( value ) <= epsilon;\n    }\n\n    FORCE_INLINE float Ceiling( float value )\n    {\n        return ceilf( value );\n    }\n\n    FORCE_INLINE int32_t CeilingToInt( float value )\n    {\n        return (int32_t) ceilf( value );\n    }\n\n    FORCE_INLINE float Floor( float value )\n    {\n        return floorf( value );\n    }\n\n    FORCE_INLINE int32_t FloorToInt( float value )\n    {\n        return (int32_t) floorf( value );\n    }\n\n    FORCE_INLINE float Round( float value )\n    {\n        return roundf( value );\n    }\n\n    FORCE_INLINE int32_t RoundToInt( float value )\n    {\n        return (int32_t) roundf( value );\n    }\n\n    inline float RemapRange( float value, float fromRangeBegin, float fromRangeEnd, float toRangeBegin, float toRangeEnd )\n    {\n        float const fromRangeLength = fromRangeEnd - fromRangeBegin;\n        float const percentageThroughFromRange = Clamp( ( value - fromRangeBegin ) / fromRangeLength, 0.0f, 1.0f );\n        float const toRangeLength = toRangeEnd - toRangeBegin;\n        float const result = toRangeBegin + ( percentageThroughFromRange * toRangeLength );\n\n        return result;\n    }\n\n    FORCE_INLINE float Square( float value )\n    {\n        return value * value;\n    }\n\n    FORCE_INLINE float SmoothStep01( float value )\n    {\n        value = Clamp( value, 0.0f, 1.0f );\n        return value * value * ( 3.0f - 2.0f * value );\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Transform.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"Transform.h\"\n\nnamespace Math\n{\n    Transform const Transform::Identity = Transform(Quaternion(0, 0, 0, 1), Vector(0, 0, 0, 1), 1.0f);\n\n    void Transform::SanitizeScaleValue()\n    {\n        if (Math::IsNearEqual(GetScale(), 1.0f, Math::LargeEpsilon))\n        {\n            SetScale(1.0f);\n        }\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Transform.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Matrix.h\"\n\nnamespace Math\n{\n    //\n    // VQS Transform\n    //\n\n    class Transform\n    {\n    public:\n\n        static Transform const Identity;\n\n        static Transform FromRotation(const Quaternion& rotation);\n        static Transform FromTranslation(const Vector& translation);\n        static Transform FromScale(float uniformScale);\n        static Transform FromTranslationAndScale(const Vector& translation, float uniformScale);\n        static Transform FromRotationBetweenVectors(const Vector sourceVector, const Vector targetVector);\n\n        // Linearly interpolate between two transforms - uses NLerp for rotations\n        static Transform Lerp(const Transform& from, const Transform& to, float t);\n\n        // Spherically interpolate between two transforms - uses SLerp for rotations\n        static Transform Slerp(const Transform& from, const Transform& to, float t);\n\n        // Spherically interpolate between two transforms - uses FastSLerp (SLerp approximation) for rotations\n        static Transform FastSlerp(const Transform& from, const Transform& to, float t);\n\n        // Calculate a delta transform that you can concatenate to the\n        // 'from' transform to get the 'to' transform. Properly handles the non-uniform scaling case.\n        static Transform Delta(const Transform& from, const Transform& to);\n\n        // Calculates a delta transform that you can concatenate to the\n        // 'from' transform to get the 'to' transform (ignoring scale)\n        static Transform DeltaNoScale(const Transform& from, const Transform& to);\n\n        static void DirectlySetRotation(Transform& transform, Quaternion&& rotation);\n        static void DirectlySetRotation(Transform& transform, const Quaternion& rotation);\n        static void DirectlySetTranslationScale(Transform& transform, Vector&& translationScale);\n        static void DirectlySetTranslationScale(Transform& transform, const Vector& translationScale);\n\n    public:\n\n        Transform() = default;\n\n        explicit Transform(NoInit_t);\n        explicit Transform(const Matrix& m);\n        explicit Transform(\n            const Quaternion& rotation,\n                const Vector& translation = Vector(0, 0, 0, 0),\n                    float scale = 1.0f);\n        explicit Transform(const AxisAngle& rotation);\n\n        Matrix ToMatrix() const;\n        Matrix ToMatrixNoScale() const;\n        EulerAngles ToEulerAngles() const;\n\n        Vector GetAxisX() const;\n        Vector GetAxisY() const;\n        Vector GetAxisZ() const;\n\n        Vector GetRightVector() const;\n        Vector GetForwardVector() const;\n        Vector GetUpVector() const;\n\n        bool IsIdentity() const;\n        bool IsRigidTransform() const;\n        void MakeRigidTransform();\n\n        //\n        // Inverse and Deltas\n        //\n\n        // Invert this transform.\n        // If you want a delta transform that you can\n        // concatenate, then you should use the 'Delta' functions\n        Transform& Inverse();\n\n        // Get the inverse of this transform.\n        // If you want a delta transform that you can\n        // concatenate, then you should use the 'Delta' functions\n        Transform GetInverse() const;\n\n        // Return the delta required to a given target\n        // transform (i.e., what do we need to add to reach that transform)\n        Transform GetDeltaToOther(const Transform& targetTransform) const;\n\n        // Return the delta relative from a given a start\n        // transform (i.e., how much do we differ from it)\n        Transform GetDeltaFromOther(const Transform& startTransform) const;\n\n        //\n        // Rotation\n\n        const Quaternion& GetRotation() const;\n        void SetRotation(const Quaternion& rotation);\n        void AddRotation(const Quaternion& delta);\n\n        //\n        // Translation\n        //\n\n        // Get the translation for this transform\n        // NOTE: you cannot rely on the W value as that will be the scale\n        const Vector& GetTranslation() const;\n\n        // Get the translation and scale for this transform\n        const Vector& GetTranslationAndScale() const;\n\n        // Set the translation\n        void SetTranslation(const Vector& newTranslation);\n\n        // Set the translation and scale simultaneously\n        void SetTranslationAndScale(const Vector& newTranslationScale);\n\n        // Add an offset to the current translation\n        void AddTranslation(const Vector& translationDelta);\n\n        // Get the translation as a homogeneous coordinates' vector (W=0)\n        Vector GetTranslationAsVector() const;\n\n        // Get the translation as a homogeneous coordinates' point (W=1)\n        Vector GetTranslationAsPoint() const;\n\n        //\n        // Scale\n        //\n\n        float GetScale() const;\n        Vector GetScaleVector() const;\n        Vector GetInverseScaleVector() const;\n        void SetScale(float uniformScale);\n        bool HasScale() const;\n        bool HasNegativeScale() const;\n\n        // This function will sanitize the scale values to remove any\n        // trailing values from scale factors i.e. 1.000000012 will be converted to 1\n        // This is primarily needed in import steps where scale values\n        // might be sampled from curves or have multiple conversions applied resulting in variance.\n        void SanitizeScaleValue();\n\n        //\n        // Transformations\n        //\n\n        Vector TranslateVector(const Vector& vector) const;\n        Vector ScaleVector(const Vector& vector) const;\n        Vector TransformPoint(const Vector& vector) const;\n        Vector TransformPointNoScale(const Vector& vector) const;\n\n        // Rotate a vector (same as TransformVectorNoScale)\n        Vector RotateVector(const Vector& vector) const;\n\n        // Rotate a vector (same as TransformVectorNoScale)\n        Vector TransformNormal(const Vector& vector) const;\n\n        // Unrotate a vector (same as InverseTransformVectorNoScale)\n        Vector InverseRotateVector(const Vector& vector) const;\n\n        // Invert the operation order when doing inverse transformation: first translation then rotation then scale\n        Vector InverseTransformPoint(const Vector& point) const;\n\n        // Invert the operation order when doing inverse transformation: first translation then rotation\n        Vector InverseTransformPointNoScale(const Vector& point) const;\n\n        // Applies scale and rotation to a vector (no translation)\n        Vector TransformVector(const Vector& vector) const;\n\n        // Rotate a vector\n        Vector TransformVectorNoScale(const Vector& vector) const;\n\n        // Invert the operation order when performing inverse transformation: first rotation then scale\n        Vector InverseTransformVector(const Vector& vector) const;\n\n        // Unrotate a vector\n        Vector InverseTransformVectorNoScale(const Vector& vector) const;\n\n        // WARNING: The results from multiplying transforms with shear or skew is ill-defined\n        Transform operator*(const Transform& rhs) const;\n\n        // WARNING: The results from multiplying transforms with shear or skew is ill-defined\n        Transform& operator*=(const Transform& rhs);\n\n        //\n        // Operators\n        //\n\n        bool IsNearEqual(\n            const Transform& rhs,\n                const Radians angleThreshold = Math::DegreesToRadians,\n                    float translationScaleThreshold = Math::Epsilon) const;\n\n        // Exact equality\n        bool operator==(const Transform& rhs) const;\n\n        bool operator!=(const Transform& rhs) const;\n\n    private:\n\n        Quaternion m_rotation = Quaternion(0, 0, 0, 1);\n        Vector m_translationScale = Vector(0, 0, 0, 1);\n    };\n}\n\n#include \"Transform.inl\"\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Transform.inl",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Transform.h\"\n\nnamespace Math\n{\n    inline Transform Transform::FromRotation(const Quaternion& rotation)\n    {\n        return Transform(rotation);\n    }\n\n    inline Transform Transform::FromTranslation(const Vector& translation)\n    {\n        return Transform(Quaternion::Identity, translation);\n    }\n\n    inline Transform Transform::FromScale(float uniformScale)\n    {\n        return Transform(Quaternion::Identity, Vector::Zero, uniformScale);\n    }\n\n    inline Transform Transform::FromTranslationAndScale(const Vector& translation, float uniformScale)\n    {\n        return Transform(Quaternion::Identity, translation, uniformScale);\n    }\n\n    inline Transform Transform::FromRotationBetweenVectors(const Vector sourceVector, const Vector targetVector)\n    {\n        return Transform(Quaternion::FromRotationBetweenNormalizedVectors(sourceVector, targetVector));\n    }\n\n    inline Transform Transform::Lerp(const Transform& from, const Transform& to, float t)\n    {\n        Quaternion const rotation = Quaternion::NLerp(Quaternion(from.m_rotation), Quaternion(to.m_rotation), t);\n        Vector const translationAndScale = Vector::Lerp(from.m_translationScale, to.m_translationScale, t);\n\n        Transform lerped(NoInit);\n        lerped.m_rotation = rotation;\n        lerped.m_translationScale = translationAndScale;\n        return lerped;\n    }\n\n    inline Transform Transform::Slerp(const Transform& from, const Transform& to, float t)\n    {\n        Quaternion const rotation = Quaternion::SLerp(Quaternion(from.m_rotation), Quaternion(to.m_rotation), t);\n        Vector const translationAndScale = Vector::Lerp(Vector(from.m_translationScale), Vector(to.m_translationScale), t);\n\n        Transform lerped(NoInit);\n        lerped.m_rotation = rotation;\n        lerped.m_translationScale = translationAndScale;\n        return lerped;\n    }\n\n    inline Transform Transform::FastSlerp(const Transform& from, const Transform& to, float t)\n    {\n        Quaternion const rotation = Quaternion::FastSLerp(Quaternion(from.m_rotation), Quaternion(to.m_rotation), t);\n        Vector const translationAndScale = Vector::Lerp(Vector(from.m_translationScale), Vector(to.m_translationScale), t);\n\n        Transform lerped(NoInit);\n        lerped.m_rotation = rotation;\n        lerped.m_translationScale = translationAndScale;\n        return lerped;\n    }\n\n    inline Transform Transform::Delta(const Transform& from, const Transform& to)\n    {\n        ASSERT(from.m_rotation.IsNormalized() && to.m_rotation.IsNormalized());\n        ASSERT(!from.m_translationScale.IsW0() && !to.m_translationScale.IsW0());\n\n        Transform result;\n\n        Vector const inverseScale = from.GetInverseScaleVector();\n        Vector const deltaScale = to.GetScaleVector() * inverseScale;\n\n        // If we have negative scaling, we need to use matrices to calculate the deltas\n        Vector const minScale = Vector::Min(from.m_translationScale.GetSplatW(), to.m_translationScale.GetSplatW());\n        if (minScale.IsAnyLessThan(Vector::Zero))\n        {\n            // Multiply the transforms using matrices to get the correct rotation and then remove the scale;\n            Matrix const toMtx = to.ToMatrix();\n            Matrix const fromMtx = from.ToMatrix();\n            Matrix resultMtx = toMtx * fromMtx.GetInverse();\n            resultMtx.RemoveScaleFast();\n\n            // Apply back the signs from the final scale\n            Vector const sign = deltaScale.GetSign();\n            resultMtx[0] *= sign.GetSplatX();\n            resultMtx[1] *= sign.GetSplatY();\n            resultMtx[2] *= sign.GetSplatZ();\n\n            result.m_rotation = resultMtx.GetRotation();\n            ASSERT(result.m_rotation.IsNormalized());\n            result.m_translationScale = Vector::Select(resultMtx.GetTranslation(), deltaScale, Vector::Select0001);\n        }\n        else\n        {\n            Quaternion const fromInverseRotation = from.m_rotation.GetInverse();\n            result.m_rotation = to.m_rotation * fromInverseRotation;\n\n            Vector const deltaTranslation = to.m_translationScale - from.m_translationScale;\n            Vector const translation = fromInverseRotation.RotateVector(deltaTranslation) * inverseScale;\n            result.m_translationScale = Vector::Select(translation, deltaScale, Vector::Select0001);\n        }\n\n        return result;\n    }\n\n    inline Transform Transform::DeltaNoScale(const Transform& from, const Transform& to)\n    {\n        Quaternion const inverseFromRotation = from.m_rotation.GetInverse();\n        Vector const deltaTranslation = to.GetTranslation() - from.GetTranslation();\n\n        Transform delta;\n        delta.m_rotation = to.m_rotation * inverseFromRotation;\n        delta.m_translationScale = inverseFromRotation.RotateVector(deltaTranslation).GetWithW1();\n        return delta;\n    }\n\n    inline void Transform::DirectlySetRotation(Transform& transform, Quaternion&& rotation)\n    {\n        transform.m_rotation = rotation;\n    }\n\n    inline void Transform::DirectlySetRotation(Transform& transform, const Quaternion& rotation)\n    {\n        transform.m_rotation = rotation;\n    }\n\n    inline void Transform::DirectlySetTranslationScale(Transform& transform, Vector&& translationScale)\n    {\n        transform.m_translationScale = translationScale;\n    }\n\n    inline void Transform::DirectlySetTranslationScale(Transform& transform, const Vector& translationScale)\n    {\n        transform.m_translationScale = translationScale;\n    }\n\n    inline Transform::Transform(NoInit_t)\n        : m_rotation(NoInit)\n        , m_translationScale(NoInit)\n    {\n    }\n\n    inline Transform::Transform(const Matrix& m)\n    {\n        Vector mTranslation, mScale;\n        m.Decompose(m_rotation, mTranslation, mScale);\n        ASSERT(Math::IsNearEqual(mScale.GetX(), mScale.GetY()) && Math::IsNearEqual(mScale.GetY(),mScale.GetZ()));\n        m_translationScale = Vector::Select(mTranslation, mScale, Vector::Select0001);\n    }\n\n    inline Transform::Transform(const Quaternion& rotation, const Vector& translation, float scale)\n        : m_rotation(rotation)\n        , m_translationScale(Vector::Select(translation, Vector(scale), Vector::Select0001))\n    {\n    }\n\n    inline Transform::Transform(const AxisAngle& rotation)\n        : m_rotation(rotation)\n        , m_translationScale(Vector::UnitW)\n    {\n    }\n\n    inline Matrix Transform::ToMatrix() const\n    {\n        return Matrix(m_rotation, m_translationScale.GetWithW1(), m_translationScale.GetSplatW());\n    }\n\n    inline Matrix Transform::ToMatrixNoScale() const\n    {\n        return Matrix(m_rotation, m_translationScale.GetWithW1(), Vector::One);\n    }\n\n    inline EulerAngles Transform::ToEulerAngles() const\n    {\n        return m_rotation.ToEulerAngles();\n    }\n\n    inline Vector Transform::GetAxisX() const\n    {\n        return m_rotation.RotateVector(Vector::UnitX);\n    }\n\n    inline Vector Transform::GetAxisY() const\n    {\n        return m_rotation.RotateVector(Vector::UnitY);\n    }\n\n    inline Vector Transform::GetAxisZ() const\n    {\n        return m_rotation.RotateVector(Vector::UnitZ);\n    }\n\n    inline Vector Transform::GetRightVector() const\n    {\n        return m_rotation.RotateVector(Vector::WorldRight);\n    }\n\n    inline Vector Transform::GetForwardVector() const\n    {\n        return m_rotation.RotateVector(Vector::WorldForward);\n    }\n\n    inline Vector Transform::GetUpVector() const\n    {\n        return m_rotation.RotateVector(Vector::WorldUp);\n    }\n\n    inline bool Transform::IsIdentity() const\n    {\n        return m_rotation.IsIdentity() && m_translationScale.IsEqual4(Vector::UnitW);\n    }\n\n    inline bool Transform::IsRigidTransform() const\n    {\n        return GetScale() == 1.0f;\n    }\n\n    inline void Transform::MakeRigidTransform()\n    {\n        SetScale(1.0f);\n    }\n\n    inline Transform& Transform::Inverse()\n    {\n        ASSERT(!m_translationScale.IsW0());\n\n        Quaternion const inverseRotation = m_rotation.GetInverse();\n        m_rotation = inverseRotation;\n\n        Vector const inverseScale = GetInverseScaleVector();\n        Vector const inverselyScaledTranslation = inverseScale * m_translationScale.GetWithW0();\n        Vector const inverselyRotatedTranslation = inverseRotation.RotateVector(inverselyScaledTranslation);\n        Vector const inverseTranslation = inverselyRotatedTranslation.GetNegated().SetW0();\n\n        m_translationScale = Vector::Select(inverseTranslation, inverseScale, Vector::Select0001);\n\n        return *this;\n    }\n\n    inline Transform Transform::GetInverse() const\n    {\n        Transform inverse = *this;\n        return inverse.Inverse();\n    }\n\n    inline Transform Transform::GetDeltaToOther(const Transform& targetTransform) const\n    {\n        return Transform::Delta(*this, targetTransform);\n    }\n\n    inline Transform Transform::GetDeltaFromOther(const Transform& startTransform) const\n    {\n        return Transform::Delta(startTransform, *this);\n    }\n\n    inline const Quaternion& Transform::GetRotation() const\n    {\n        return m_rotation;\n    }\n\n    inline void Transform::SetRotation(const Quaternion& rotation)\n    {\n        ASSERT(rotation.IsNormalized());\n        m_rotation = rotation;\n    }\n\n    inline void Transform::AddRotation(const Quaternion& delta)\n    {\n        ASSERT(delta.IsNormalized());\n        m_rotation = delta * m_rotation;\n    }\n\n    inline const Vector& Transform::GetTranslation() const\n    {\n        return m_translationScale;\n    }\n\n    inline const Vector& Transform::GetTranslationAndScale() const\n    {\n        return m_translationScale;\n    }\n\n    inline void Transform::SetTranslation(const Vector& newTranslation)\n    {\n        m_translationScale = Vector::Select(newTranslation, m_translationScale, Vector::Select0001);\n    }\n\n    inline void Transform::SetTranslationAndScale(const Vector& newTranslationScale)\n    {\n        ASSERT(newTranslationScale.GetW() != 0.0f);\n        m_translationScale = newTranslationScale;\n    }\n\n    inline void Transform::AddTranslation(const Vector& translationDelta)\n    {\n        m_translationScale += translationDelta.GetWithW0();\n    }\n\n    inline Vector Transform::GetTranslationAsVector() const\n    {\n        return m_translationScale.GetWithW0();\n    }\n\n    inline Vector Transform::GetTranslationAsPoint() const\n    {\n        return m_translationScale.GetWithW1();\n    }\n\n    inline float Transform::GetScale() const\n    {\n        return m_translationScale.GetW();\n    }\n\n    inline Vector Transform::GetScaleVector() const\n    {\n        return m_translationScale.GetSplatW();\n    }\n\n    inline Vector Transform::GetInverseScaleVector() const\n    {\n        return m_translationScale.GetSplatW().GetInverse();\n    }\n\n    inline void Transform::SetScale(float uniformScale)\n    {\n        m_translationScale.SetW(uniformScale);\n    }\n\n    inline bool Transform::HasScale() const\n    {\n        return m_translationScale.GetW() != 1.0f;\n    }\n\n    inline bool Transform::HasNegativeScale() const\n    {\n        return m_translationScale.GetW() < 0.0f;\n    }\n\n    inline Vector Transform::TranslateVector(const Vector& vector) const\n    {\n        return vector + m_translationScale.GetWithW0();\n    }\n\n    inline Vector Transform::ScaleVector(const Vector& vector) const\n    {\n        return vector * GetScaleVector();\n    }\n\n    inline Vector Transform::TransformPoint(const Vector& point) const\n    {\n        ASSERT(!m_translationScale.IsW0());\n        Vector transformedPoint = point * m_translationScale.GetSplatW();\n        transformedPoint = (m_translationScale + m_rotation.RotateVector(transformedPoint)).GetWithW0();\n        return transformedPoint;\n    }\n\n    inline Vector Transform::TransformPointNoScale(const Vector& point) const\n    {\n        Vector transformedPoint = (m_translationScale + m_rotation.RotateVector(point)).GetWithW0();;\n        return transformedPoint;\n    }\n\n    inline Vector Transform::RotateVector(const Vector& vector) const\n    {\n        return m_rotation.RotateVector(vector);\n    }\n\n    inline Vector Transform::TransformNormal(const Vector& vector) const\n    {\n        return RotateVector(vector);\n    }\n\n    inline Vector Transform::InverseRotateVector(const Vector& vector) const\n    {\n        return m_rotation.RotateVectorInverse(vector);\n    }\n\n    inline Vector Transform::InverseTransformPoint(const Vector& point) const\n    {\n        ASSERT(!m_translationScale.IsW0());\n        Vector const shiftedPoint = point - m_translationScale;\n        Vector const unrotatedShiftedPoint = m_rotation.RotateVectorInverse(shiftedPoint);\n        Vector const inverseScale = GetInverseScaleVector();\n        Vector const result = unrotatedShiftedPoint * inverseScale;\n        return result;\n    }\n\n    inline Vector Transform::InverseTransformPointNoScale(const Vector& point) const\n    {\n        Vector const shiftedPoint = point - m_translationScale;\n        Vector const unrotatedShiftedPoint = m_rotation.RotateVectorInverse(shiftedPoint);\n        return unrotatedShiftedPoint;\n    }\n\n    inline Vector Transform::TransformVector(const Vector& vector) const\n    {\n        ASSERT(!m_translationScale.IsW0());\n        Vector transformedVector = vector * GetScaleVector();\n        transformedVector = m_rotation.RotateVector(transformedVector);\n        return transformedVector;\n    }\n\n    inline Vector Transform::TransformVectorNoScale(const Vector& vector) const\n    {\n        return RotateVector(vector);\n    }\n\n    inline Vector Transform::InverseTransformVector(const Vector& vector) const\n    {\n        ASSERT(!m_translationScale.IsW0());\n        Vector const unrotatedVector = m_rotation.RotateVectorInverse(vector);\n        Vector const inverseScale = GetInverseScaleVector();\n        Vector const result = unrotatedVector * inverseScale;\n        return result;\n    }\n\n    inline Vector Transform::InverseTransformVectorNoScale(const Vector& vector) const\n    {\n        return m_rotation.RotateVectorInverse(vector);\n    }\n\n    inline Transform Transform::operator*(const Transform& rhs) const\n    {\n        Transform transform = *this;\n        transform *= rhs;\n        return transform;\n    }\n\n    inline Transform& Transform::operator*=(const Transform& rhs)\n    {\n        Vector const scale = GetScaleVector();\n        Vector const rhsScale = rhs.GetScaleVector();\n        Vector const minScale = Vector::Min(scale, rhsScale);\n        Vector const finalScale = scale * rhsScale;\n\n        if (minScale.IsAnyLessThan(Vector::Zero))\n        {\n            // Multiply the transforms using matrices to\n            // get the correct rotation and then remove the scale;\n            Matrix const lhsMtx = ToMatrix();\n            Matrix const rhsMtx = rhs.ToMatrix();\n            Matrix resultMtx = lhsMtx * rhsMtx;\n            resultMtx.RemoveScaleFast();\n\n            // Apply back the signs from the final scale\n            Vector const sign = finalScale.GetSign();\n            resultMtx[0] *= sign.GetSplatX();\n            resultMtx[1] *= sign.GetSplatY();\n            resultMtx[2] *= sign.GetSplatZ();\n\n            m_rotation = resultMtx.GetRotation();\n            ASSERT(m_rotation.IsNormalized());\n            m_translationScale = Vector::Select(resultMtx.GetTranslation(), finalScale, Vector::Select0001);\n        }\n        else\n        {\n            // Normal case\n            m_rotation = m_rotation * rhs.m_rotation;\n            m_rotation.Normalize();\n            Vector const translation = rhs.m_rotation.RotateVector(m_translationScale * rhsScale) + rhs.m_translationScale;\n            m_translationScale = Vector::Select(translation, finalScale, Vector::Select0001);\n        }\n\n        return *this;\n    }\n\n    inline bool Transform::IsNearEqual(const Transform& rhs, const Radians angleThreshold, float translationScaleThreshold) const\n    {\n        if (!m_rotation.IsNearEqual(rhs.m_rotation, angleThreshold))\n        {\n            return false;\n        }\n\n        if (!m_translationScale.IsNearEqual4(rhs.m_translationScale, translationScaleThreshold))\n        {\n            return false;\n        }\n\n        return true;\n    }\n\n    inline bool Transform::operator==(const Transform& rhs) const\n    {\n        if (m_translationScale != rhs.m_translationScale)\n        {\n            return false;\n        }\n\n        if (m_rotation != rhs.m_rotation)\n        {\n            return false;\n        }\n\n        return true;\n    }\n\n    inline bool Transform::operator!=(const Transform& rhs) const\n    {\n        return !operator==(rhs);\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Types.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"Types.h\"\n\nInt2 const Int2::Zero = Int2( 0, 0 );\n\nInt4 const Int4::Zero = Int4( 0, 0, 0, 0 );\nInt4 const Int4::MinusOne = Int4( -1, -1, -1, -1 );\n\nFloat2 const Float2::Zero = Float2( 0, 0 );\nFloat2 const Float2::One = Float2( 1, 1 );\nFloat2 const Float2::UnitX = Float2( 1, 0 );\nFloat2 const Float2::UnitY = Float2( 0, 1 );\n\nFloat3 const Float3::Zero = Float3( 0, 0, 0 );\nFloat3 const Float3::One = Float3( 1, 1, 1 );\nFloat3 const Float3::UnitX = Float3( 1, 0, 0 );\nFloat3 const Float3::UnitY = Float3( 0, 1, 0 );\nFloat3 const Float3::UnitZ = Float3( 0, 0, 1 );\n\nFloat3 const Float3::WorldForward = Float3( 0, -1, 0 );\nFloat3 const Float3::WorldUp = Float3( 0, 0, 1 );\nFloat3 const Float3::WorldRight = Float3( -1, 0, 0 );\n\nFloat4 const Float4::Zero = Float4( 0, 0, 0, 0 );\nFloat4 const Float4::One = Float4( 1, 1, 1, 1 );\nFloat4 const Float4::UnitX = Float4( 1, 0, 0, 0 );\nFloat4 const Float4::UnitY = Float4( 0, 1, 0, 0 );\nFloat4 const Float4::UnitZ = Float4( 0, 0, 1, 0 );\nFloat4 const Float4::UnitW = Float4( 0, 0, 0, 1 );\n\nFloat4 const Float4::WorldForward = Float4( 0, -1, 0, 0 );\nFloat4 const Float4::WorldUp = Float4( 0, 0, 1, 0 );\nFloat4 const Float4::WorldRight = Float4( -1, 0, 0, 0 );\n\nRadians const Radians::Pi = Radians( Math::Pi );\nRadians const Radians::TwoPi = Radians( Math::TwoPi );\nRadians const Radians::OneDivPi = Radians( Math::OneDivPi );\nRadians const Radians::OneDivTwoPi = Radians( Math::OneDivTwoPi );\nRadians const Radians::PiDivTwo = Radians( Math::PiDivTwo );\nRadians const Radians::PiDivFour = Radians( Math::PiDivFour );\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Types.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Scalar.h\"\n\nenum NoInit_t { NoInit };\nenum ZeroInit_t { ZeroInit };\nenum IdentityInit_t { IdentityInit };\n\nenum class Axis : uint8_t\n{\n    X = 0,\n    Y,\n    Z,\n    NegX,\n    NegY,\n    NegZ\n};\n\nstruct Float2;\nstruct Float3;\nstruct Float4;\n\nstruct Int2\n{\n    static Int2 const Zero;\n\npublic:\n\n    inline Int2() {}\n    inline Int2( ZeroInit_t ) : m_x( 0 ), m_y( 0 ) {}\n    inline Int2( Float2 const& v );\n    inline explicit Int2( int32_t v ) : m_x( v ), m_y( v ) {}\n    inline explicit Int2( int32_t ix, int32_t iy ) : m_x( ix ), m_y( iy ) {}\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    inline int32_t& operator[]( uint32_t i ) { return ( (int32_t*) this )[i]; }\n    inline int32_t const& operator[]( uint32_t i ) const { return ( (int32_t*) this )[i]; }\n\n    inline bool operator==( Int2 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y; }\n    inline bool operator!=( Int2 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y; }\n\n    inline Int2 operator+( Int2 const& rhs ) const { return Int2( m_x + rhs.m_x, m_y + rhs.m_y ); }\n    inline Int2 operator-( Int2 const& rhs ) const { return Int2( m_x - rhs.m_x, m_y - rhs.m_y ); }\n    inline Int2 operator*( Int2 const& rhs ) const { return Int2( m_x * rhs.m_x, m_y * rhs.m_y ); }\n    inline Int2 operator/( Int2 const& rhs ) const { return Int2( m_x / rhs.m_x, m_y / rhs.m_y ); }\n\n    inline Int2& operator+=( int32_t const& rhs ) { m_x += rhs; m_y += rhs; return *this; }\n    inline Int2& operator-=( int32_t const& rhs ) { m_x -= rhs; m_y -= rhs; return *this; }\n    inline Int2& operator*=( int32_t const& rhs ) { m_x *= rhs; m_y *= rhs; return *this; }\n    inline Int2& operator/=( int32_t const& rhs ) { m_x /= rhs; m_y /= rhs; return *this; }\n\n    // Component wise operation\n    inline Int2 operator+( int32_t const& rhs ) const { return Int2( m_x + rhs, m_y + rhs ); }\n    inline Int2 operator-( int32_t const& rhs ) const { return Int2( m_x - rhs, m_y - rhs ); }\n    inline Int2 operator*( int32_t const& rhs ) const { return Int2( m_x * rhs, m_y * rhs ); }\n    inline Int2 operator/( int32_t const& rhs ) const { return Int2( m_x / rhs, m_y / rhs ); }\n\n    inline Int2& operator+=( Int2 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; return *this; }\n    inline Int2& operator-=( Int2 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; return *this; }\n    inline Int2& operator*=( Int2 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; return *this; }\n    inline Int2& operator/=( Int2 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; return *this; }\n\npublic:\n\n    int32_t m_x, m_y;\n};\n\nstruct Int3\n{\n    static Int3 const Zero;\n\npublic:\n\n    inline Int3() {}\n    inline Int3( ZeroInit_t ) : m_x( 0 ), m_y( 0 ), m_z( 0 ) {}\n    inline Int3( Float3 const& v );\n    inline explicit Int3( int32_t v ) : m_x( v ), m_y( v ), m_z( v ) {}\n    inline explicit Int3( int32_t ix, int32_t iy, int32_t iz ) : m_x( ix ), m_y( iy ), m_z( iz ) {}\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    inline int32_t& operator[]( uint32_t i ) { return ( (int32_t*) this )[i]; }\n    inline int32_t const& operator[]( uint32_t i ) const { return ( (int32_t*) this )[i]; }\n\n    inline bool operator==( Int3 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y && m_z == rhs.m_z; }\n    inline bool operator!=( Int3 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y || m_z != rhs.m_z; }\n\n    inline Int3 operator+( Int3 const& rhs ) const { return Int3( m_x + rhs.m_x, m_y + rhs.m_y, m_z + rhs.m_z ); }\n    inline Int3 operator-( Int3 const& rhs ) const { return Int3( m_x - rhs.m_x, m_y - rhs.m_y, m_z - rhs.m_z ); }\n    inline Int3 operator*( Int3 const& rhs ) const { return Int3( m_x * rhs.m_x, m_y * rhs.m_y, m_z * rhs.m_z ); }\n    inline Int3 operator/( Int3 const& rhs ) const { return Int3( m_x / rhs.m_x, m_y / rhs.m_y, m_z / rhs.m_z ); }\n\n    inline Int3& operator+=( int32_t const& rhs ) { m_x += rhs; m_y += rhs; m_z += rhs; return *this; }\n    inline Int3& operator-=( int32_t const& rhs ) { m_x -= rhs; m_y -= rhs; m_z -= rhs; return *this; }\n    inline Int3& operator*=( int32_t const& rhs ) { m_x *= rhs; m_y *= rhs; m_z *= rhs; return *this; }\n    inline Int3& operator/=( int32_t const& rhs ) { m_x /= rhs; m_y /= rhs; m_z /= rhs; return *this; }\n\n    // Component wise operation\n    inline Int3 operator+( int32_t const& rhs ) const { return Int3( m_x + rhs, m_y + rhs, m_z + rhs ); }\n    inline Int3 operator-( int32_t const& rhs ) const { return Int3( m_x - rhs, m_y - rhs, m_z - rhs ); }\n    inline Int3 operator*( int32_t const& rhs ) const { return Int3( m_x * rhs, m_y * rhs, m_z * rhs ); }\n    inline Int3 operator/( int32_t const& rhs ) const { return Int3( m_x / rhs, m_y / rhs, m_z / rhs ); }\n\n    inline Int3& operator+=( Int3 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; m_z += rhs.m_z; return *this; }\n    inline Int3& operator-=( Int3 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; m_z -= rhs.m_z; return *this; }\n    inline Int3& operator*=( Int3 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; m_z *= rhs.m_z; return *this; }\n    inline Int3& operator/=( Int3 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; m_z /= rhs.m_z; return *this; }\n\npublic:\n\n    int32_t m_x, m_y, m_z;\n};\n\nstruct Int4\n{\n    static Int4 const Zero;\n    static Int4 const MinusOne;\n\npublic:\n\n    inline Int4() {}\n    inline Int4( ZeroInit_t ) : m_x( 0 ), m_y( 0 ), m_z( 0 ), m_w( 0 ) {}\n    inline explicit Int4( int32_t v ) : m_x( v ), m_y( v ), m_z( v ), m_w( v ) {}\n    inline explicit Int4( int32_t ix, int32_t iy, int32_t iz, int32_t iw ) : m_x( ix ), m_y( iy ), m_z( iz ), m_w( iw ) {}\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    inline int32_t& operator[]( uint32_t i ) { return ( (int32_t*) this )[i]; }\n    inline int32_t const& operator[]( uint32_t i ) const { return ( (int32_t*) this )[i]; }\n\n    inline bool operator==( Int4 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y && m_z == rhs.m_z && m_w == rhs.m_w; }\n    inline bool operator!=( Int4 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y || m_z != rhs.m_z || m_w != rhs.m_w; }\n\n    inline Int4 operator+( int32_t const& rhs ) const { return Int4( m_x + rhs, m_y + rhs, m_z + rhs, m_w + rhs ); }\n    inline Int4 operator-( int32_t const& rhs ) const { return Int4( m_x - rhs, m_y - rhs, m_z - rhs, m_w - rhs ); }\n    inline Int4 operator*( int32_t const& rhs ) const { return Int4( m_x * rhs, m_y * rhs, m_z * rhs, m_w * rhs ); }\n    inline Int4 operator/( int32_t const& rhs ) const { return Int4( m_x / rhs, m_y / rhs, m_z / rhs, m_w / rhs ); }\n\n    inline Int4& operator+=( int32_t const& rhs ) { m_x += rhs; m_y += rhs; m_z += rhs; m_w += rhs; return *this; }\n    inline Int4& operator-=( int32_t const& rhs ) { m_x -= rhs; m_y -= rhs; m_z -= rhs; m_w -= rhs; return *this; }\n    inline Int4& operator*=( int32_t const& rhs ) { m_x *= rhs; m_y *= rhs; m_z *= rhs; m_w *= rhs; return *this; }\n    inline Int4& operator/=( int32_t const& rhs ) { m_x /= rhs; m_y /= rhs; m_z /= rhs; m_w /= rhs; return *this; }\n\n    // Component wise operation\n    inline Int4 operator+( Int4 const& rhs ) const { return Int4( m_x + rhs.m_x, m_y + rhs.m_y, m_z + rhs.m_z, m_w + rhs.m_w ); }\n    inline Int4 operator-( Int4 const& rhs ) const { return Int4( m_x - rhs.m_x, m_y - rhs.m_y, m_z - rhs.m_z, m_w - rhs.m_w ); }\n    inline Int4 operator*( Int4 const& rhs ) const { return Int4( m_x * rhs.m_x, m_y * rhs.m_y, m_z * rhs.m_z, m_w * rhs.m_w ); }\n    inline Int4 operator/( Int4 const& rhs ) const { return Int4( m_x / rhs.m_x, m_y / rhs.m_y, m_z / rhs.m_z, m_w / rhs.m_w ); }\n\n    inline Int4& operator+=( Int4 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; m_z += rhs.m_z; m_w += rhs.m_w; return *this; }\n    inline Int4& operator-=( Int4 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; m_z -= rhs.m_z; m_w -= rhs.m_w; return *this; }\n    inline Int4& operator*=( Int4 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; m_z *= rhs.m_z; m_w *= rhs.m_w; return *this; }\n    inline Int4& operator/=( Int4 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; m_z /= rhs.m_z; m_w /= rhs.m_w; return *this; }\n\npublic:\n\n    int32_t m_x, m_y, m_z, m_w;\n};\n\nstruct Float2\n{\n    static Float2 const Zero;\n    static Float2 const One;\n    static Float2 const UnitX;\n    static Float2 const UnitY;\n\npublic:\n\n    inline Float2() {}\n    FORCE_INLINE Float2( ZeroInit_t ) : m_x( 0 ), m_y( 0 ) {}\n    FORCE_INLINE explicit Float2( float v ) : m_x( v ), m_y( v ) {}\n    FORCE_INLINE explicit Float2( float ix, float iy ) : m_x( ix ), m_y( iy ) {}\n    FORCE_INLINE explicit Float2( int32_t ix, int32_t iy ) : m_x( (float) ix ), m_y( (float) iy ) {}\n    inline explicit Float2( Int2 const& v ) : m_x( (float) v.m_x ), m_y( (float) v.m_y ) {}\n    inline explicit Float2( Float3 const& v );\n    inline explicit Float2( Float4 const& v );\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    inline float& operator[]( uint32_t i ) { return ( (float*) this )[i]; }\n    inline float const& operator[]( uint32_t i ) const { return ( (float*) this )[i]; }\n\n    FORCE_INLINE Float2 operator-() const { return Float2( -m_x, -m_y ); }\n\n    inline bool operator==( Float2 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y; }\n    inline bool operator!=( Float2 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y; }\n\n    inline Float2 operator+( Float2 const& rhs ) const { return Float2( m_x + rhs.m_x, m_y + rhs.m_y ); }\n    inline Float2 operator-( Float2 const& rhs ) const { return Float2( m_x - rhs.m_x, m_y - rhs.m_y ); }\n    inline Float2 operator*( Float2 const& rhs ) const { return Float2( m_x * rhs.m_x, m_y * rhs.m_y ); }\n    inline Float2 operator/( Float2 const& rhs ) const { return Float2( m_x / rhs.m_x, m_y / rhs.m_y ); }\n\n    inline Float2 operator+( float const& rhs ) const { return Float2( m_x + rhs, m_y + rhs ); }\n    inline Float2 operator-( float const& rhs ) const { return Float2( m_x - rhs, m_y - rhs ); }\n    inline Float2 operator*( float const& rhs ) const { return Float2( m_x * rhs, m_y * rhs ); }\n    inline Float2 operator/( float const& rhs ) const { return Float2( m_x / rhs, m_y / rhs ); }\n\n    inline Float2& operator+=( Float2 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; return *this; }\n    inline Float2& operator-=( Float2 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; return *this; }\n    inline Float2& operator*=( Float2 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; return *this; }\n    inline Float2& operator/=( Float2 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; return *this; }\n\n    inline Float2& operator+=( float const& rhs ) { m_x += rhs; m_y += rhs; return *this; }\n    inline Float2& operator-=( float const& rhs ) { m_x -= rhs; m_y -= rhs; return *this; }\n    inline Float2& operator*=( float const& rhs ) { m_x *= rhs; m_y *= rhs; return *this; }\n    inline Float2& operator/=( float const& rhs ) { m_x /= rhs; m_y /= rhs; return *this; }\n\n    float m_x, m_y;\n};\n\nstruct Float3\n{\n    static Float3 const Zero;\n    static Float3 const One;\n    static Float3 const UnitX;\n    static Float3 const UnitY;\n    static Float3 const UnitZ;\n\n    static Float3 const WorldForward;\n    static Float3 const WorldUp;\n    static Float3 const WorldRight;\n\npublic:\n\n    inline Float3() {}\n    FORCE_INLINE Float3( ZeroInit_t ) : m_x( 0 ), m_y( 0 ), m_z( 0 ) {}\n    FORCE_INLINE explicit Float3( float v ) : m_x( v ), m_y( v ), m_z( v ) {}\n    FORCE_INLINE explicit Float3( float ix, float iy, float iz ) : m_x( ix ), m_y( iy ), m_z( iz ) {}\n    inline explicit Float3( Float2 const& v, float iz = 0.0f ) : m_x( v.m_x ), m_y( v.m_y ), m_z( iz ) {}\n    inline explicit Float3( Float4 const& v );\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    inline float& operator[]( uint32_t i ) { return ( (float*) this )[i]; }\n    inline float const& operator[]( uint32_t i ) const { return ( (float*) this )[i]; }\n\n    FORCE_INLINE Float3 operator-() const { return Float3( -m_x, -m_y, -m_z ); }\n\n    inline bool operator==( Float3 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y && m_z == rhs.m_z; }\n    inline bool operator!=( Float3 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y || m_z != rhs.m_z; }\n\n    inline operator Float2() const { return Float2( m_x, m_y ); }\n\n    inline Float3 operator+( Float3 const& rhs ) const { return Float3( m_x + rhs.m_x, m_y + rhs.m_y, m_z + rhs.m_z ); }\n    inline Float3 operator-( Float3 const& rhs ) const { return Float3( m_x - rhs.m_x, m_y - rhs.m_y, m_z - rhs.m_z ); }\n    inline Float3 operator*( Float3 const& rhs ) const { return Float3( m_x * rhs.m_x, m_y * rhs.m_y, m_z * rhs.m_z ); }\n    inline Float3 operator/( Float3 const& rhs ) const { return Float3( m_x / rhs.m_x, m_y / rhs.m_y, m_z / rhs.m_z ); }\n\n    inline Float3 operator+( float const& rhs ) const { return Float3( m_x + rhs, m_y + rhs, m_z + rhs ); }\n    inline Float3 operator-( float const& rhs ) const { return Float3( m_x - rhs, m_y - rhs, m_z - rhs ); }\n    inline Float3 operator*( float const& rhs ) const { return Float3( m_x * rhs, m_y * rhs, m_z * rhs ); }\n    inline Float3 operator/( float const& rhs ) const { return Float3( m_x / rhs, m_y / rhs, m_z / rhs ); }\n\n    inline Float3& operator+=( Float3 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; m_z += rhs.m_z; return *this; }\n    inline Float3& operator-=( Float3 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; m_z -= rhs.m_z; return *this; }\n    inline Float3& operator*=( Float3 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; m_z *= rhs.m_z; return *this; }\n    inline Float3& operator/=( Float3 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; m_z /= rhs.m_z; return *this; }\n\n    inline Float3& operator+=( float const& rhs ) { m_x += rhs; m_y += rhs; m_z += rhs; return *this; }\n    inline Float3& operator-=( float const& rhs ) { m_x -= rhs; m_y -= rhs; m_z -= rhs; return *this; }\n    inline Float3& operator*=( float const& rhs ) { m_x *= rhs; m_y *= rhs; m_z *= rhs; return *this; }\n    inline Float3& operator/=( float const& rhs ) { m_x /= rhs; m_y /= rhs; m_z /= rhs; return *this; }\n\n    float m_x, m_y, m_z;\n};\n\nstruct Float4\n{\n    static Float4 const Zero;\n    static Float4 const One;\n    static Float4 const UnitX;\n    static Float4 const UnitY;\n    static Float4 const UnitZ;\n    static Float4 const UnitW;\n\n    static Float4 const WorldForward;\n    static Float4 const WorldUp;\n    static Float4 const WorldRight;\n\npublic:\n\n    Float4() {}\n    FORCE_INLINE Float4( ZeroInit_t ) : m_x( 0 ), m_y( 0 ), m_z( 0 ), m_w( 0 ) {}\n    FORCE_INLINE explicit Float4( float v ) : m_x( v ), m_y( v ), m_z( v ), m_w( v ) {}\n    FORCE_INLINE explicit Float4( float ix, float iy, float iz, float iw ) : m_x( ix ), m_y( iy ), m_z( iz ), m_w( iw ) {}\n    explicit Float4( Float2 const& v, float iz = 0.0f, float iw = 0.0f ) : m_x( v.m_x ), m_y( v.m_y ), m_z( iz ), m_w( iw ) {}\n    explicit Float4( Float3 const& v, float iw = 0.0f ) : m_x( v.m_x ), m_y( v.m_y ), m_z( v.m_z ), m_w( iw ) {}\n\n    inline bool IsZero() const { return *this == Zero; }\n\n    float& operator[]( uint32_t i ) { return ( (float*) this )[i]; }\n    float const& operator[]( uint32_t i ) const { return ( (float*) this )[i]; }\n\n    FORCE_INLINE Float4 operator-() const { return Float4( -m_x, -m_y, -m_z, -m_w ); }\n\n    bool operator==( Float4 const rhs ) const { return m_x == rhs.m_x && m_y == rhs.m_y && m_z == rhs.m_z && m_w == rhs.m_w; }\n    bool operator!=( Float4 const rhs ) const { return m_x != rhs.m_x || m_y != rhs.m_y || m_z != rhs.m_z || m_w != rhs.m_w; }\n\n    inline operator Float2() const { return Float2( m_x, m_y ); }\n    inline operator Float3() const { return Float3( m_x, m_y, m_z ); }\n\n    inline Float4 operator+( Float4 const& rhs ) const { return Float4( m_x + rhs.m_x, m_y + rhs.m_y, m_z + rhs.m_z, m_w + rhs.m_w ); }\n    inline Float4 operator-( Float4 const& rhs ) const { return Float4( m_x - rhs.m_x, m_y - rhs.m_y, m_z - rhs.m_z, m_w - rhs.m_w ); }\n    inline Float4 operator*( Float4 const& rhs ) const { return Float4( m_x * rhs.m_x, m_y * rhs.m_y, m_z * rhs.m_z, m_w * rhs.m_w ); }\n    inline Float4 operator/( Float4 const& rhs ) const { return Float4( m_x / rhs.m_x, m_y / rhs.m_y, m_z / rhs.m_z, m_w / rhs.m_w ); }\n\n    inline Float4 operator+( float const& rhs ) const { return Float4( m_x + rhs, m_y + rhs, m_z + rhs, m_w + rhs ); }\n    inline Float4 operator-( float const& rhs ) const { return Float4( m_x - rhs, m_y - rhs, m_z - rhs, m_w - rhs ); }\n    inline Float4 operator*( float const& rhs ) const { return Float4( m_x * rhs, m_y * rhs, m_z * rhs, m_w * rhs ); }\n    inline Float4 operator/( float const& rhs ) const { return Float4( m_x / rhs, m_y / rhs, m_z / rhs, m_w / rhs ); }\n\n    inline Float4& operator+=( Float4 const& rhs ) { m_x += rhs.m_x; m_y += rhs.m_y; m_z += rhs.m_z; m_w += rhs.m_w; return *this; }\n    inline Float4& operator-=( Float4 const& rhs ) { m_x -= rhs.m_x; m_y -= rhs.m_y; m_z -= rhs.m_z; m_w -= rhs.m_w; return *this; }\n    inline Float4& operator*=( Float4 const& rhs ) { m_x *= rhs.m_x; m_y *= rhs.m_y; m_z *= rhs.m_z; m_w *= rhs.m_w; return *this; }\n    inline Float4& operator/=( Float4 const& rhs ) { m_x /= rhs.m_x; m_y /= rhs.m_y; m_z /= rhs.m_z; m_w /= rhs.m_w; return *this; }\n\n    inline Float4& operator+=( float const& rhs ) { m_x += rhs; m_y += rhs; m_z += rhs; m_w += rhs; return *this; }\n    inline Float4& operator-=( float const& rhs ) { m_x -= rhs; m_y -= rhs; m_z -= rhs; m_w -= rhs; return *this; }\n    inline Float4& operator*=( float const& rhs ) { m_x *= rhs; m_y *= rhs; m_z *= rhs; m_w *= rhs; return *this; }\n    inline Float4& operator/=( float const& rhs ) { m_x /= rhs; m_y /= rhs; m_z /= rhs; m_w /= rhs; return *this; }\n\n    float m_x, m_y, m_z, m_w;\n};\n\ninline Int2::Int2( Float2 const& v )\n    : m_x( (int32_t) v.m_x )\n    , m_y( (int32_t) v.m_y )\n{\n}\n\ninline Int3::Int3( Float3 const& v )\n    : m_x( (int32_t) v.m_x )\n    , m_y( (int32_t) v.m_y )\n    , m_z( (int32_t) v.m_z )\n{\n}\n\ninline Float2::Float2( Float3 const& v )\n    : m_x( v.m_x )\n    , m_y( v.m_y )\n{\n}\n\ninline Float2::Float2( Float4 const& v )\n    : m_x( v.m_x )\n    , m_y( v.m_y )\n{\n}\n\ninline Float3::Float3( Float4 const& v )\n    : m_x( v.m_x )\n    , m_y( v.m_y )\n    , m_z( v.m_z )\n{\n}\n\nstruct Radians;\nstruct Degrees;\n\nstruct Degrees\n{\npublic:\n\n    inline Degrees() = default;\n    inline Degrees( float degrees ) : m_value( degrees ) {}\n    inline explicit Degrees( Radians const& radians );\n\n    FORCE_INLINE explicit operator float() const { return m_value; }\n    FORCE_INLINE operator Radians() const;\n    FORCE_INLINE float ToFloat() const { return m_value; }\n    FORCE_INLINE Radians ToRadians() const;\n\n    inline Degrees operator-() const { return Degrees( -m_value ); }\n\n    inline Degrees operator+( Degrees const& rhs ) const { return Degrees( m_value + rhs.m_value ); }\n    inline Degrees operator-( Degrees const& rhs ) const { return Degrees( m_value - rhs.m_value ); }\n    inline Degrees operator*( Degrees const& rhs ) const { return Degrees( m_value * rhs.m_value ); }\n    inline Degrees operator/( Degrees const& rhs ) const { return Degrees( m_value / rhs.m_value ); }\n\n    inline Degrees& operator+=( Degrees const& rhs ) { m_value += rhs.m_value; return *this; }\n    inline Degrees& operator-=( Degrees const& rhs ) { m_value -= rhs.m_value; return *this; }\n    inline Degrees& operator*=( Degrees const& rhs ) { m_value *= rhs.m_value; return *this; }\n    inline Degrees& operator/=( Degrees const& rhs ) { m_value /= rhs.m_value; return *this; }\n\n    inline Degrees operator+( float const& rhs ) const { return Degrees( m_value + rhs ); }\n    inline Degrees operator-( float const& rhs ) const { return Degrees( m_value - rhs ); }\n    inline Degrees operator*( float const& rhs ) const { return Degrees( m_value * rhs ); }\n    inline Degrees operator/( float const& rhs ) const { return Degrees( m_value / rhs ); }\n\n    inline Degrees& operator+=( float const& rhs ) { m_value += rhs; return *this; }\n    inline Degrees& operator-=( float const& rhs ) { m_value -= rhs; return *this; }\n    inline Degrees& operator*=( float const& rhs ) { m_value *= rhs; return *this; }\n    inline Degrees& operator/=( float const& rhs ) { m_value /= rhs; return *this; }\n\n    inline Degrees operator+( int32_t const& rhs ) const { return Degrees( m_value + rhs ); }\n    inline Degrees operator-( int32_t const& rhs ) const { return Degrees( m_value - rhs ); }\n    inline Degrees operator*( int32_t const& rhs ) const { return Degrees( m_value * rhs ); }\n    inline Degrees operator/( int32_t const& rhs ) const { return Degrees( m_value / rhs ); }\n\n    inline Degrees& operator+=( int32_t const& rhs ) { m_value += rhs; return *this; }\n    inline Degrees& operator-=( int32_t const& rhs ) { m_value -= rhs; return *this; }\n    inline Degrees& operator*=( int32_t const& rhs ) { m_value *= rhs; return *this; }\n    inline Degrees& operator/=( int32_t const& rhs ) { m_value /= rhs; return *this; }\n\n    inline Degrees operator+( uint32_t const& rhs ) const { return Degrees( m_value + rhs ); }\n    inline Degrees operator-( uint32_t const& rhs ) const { return Degrees( m_value - rhs ); }\n    inline Degrees operator*( uint32_t const& rhs ) const { return Degrees( m_value * rhs ); }\n    inline Degrees operator/( uint32_t const& rhs ) const { return Degrees( m_value / rhs ); }\n\n    inline Degrees& operator+=( uint32_t const& rhs ) { m_value += rhs; return *this; }\n    inline Degrees& operator-=( uint32_t const& rhs ) { m_value -= rhs; return *this; }\n    inline Degrees& operator*=( uint32_t const& rhs ) { m_value *= rhs; return *this; }\n    inline Degrees& operator/=( uint32_t const& rhs ) { m_value /= rhs; return *this; }\n\n    inline bool operator>( float const& rhs ) const { return m_value > rhs; };\n    inline bool operator<( float const& rhs ) const { return m_value < rhs; }\n    inline bool operator>=( float const& rhs ) const { return m_value >= rhs; }\n    inline bool operator<=( float const& rhs ) const { return m_value <= rhs; }\n\n    inline bool operator>( Degrees const& rhs ) const { return m_value > rhs.m_value; }\n    inline bool operator<( Degrees const& rhs ) const { return m_value < rhs.m_value; }\n    inline bool operator>=( Degrees const& rhs ) const { return m_value >= rhs.m_value; }\n    inline bool operator<=( Degrees const& rhs ) const { return m_value <= rhs.m_value; }\n\n    inline bool operator>( Radians const& rhs ) const;\n    inline bool operator<( Radians const& rhs ) const;\n    inline bool operator>=( Radians const& rhs ) const;\n    inline bool operator<=( Radians const& rhs ) const;\n\n    inline bool operator==( float const& v ) const { return Math::IsNearEqual( m_value, v ); }\n    inline bool operator!=( float const& v ) const { return !Math::IsNearEqual( m_value, v ); }\n\n    inline bool operator==( Degrees const& rhs ) const  { return m_value == rhs.m_value; }\n    inline bool operator!=( Degrees const& rhs ) const  { return m_value != rhs.m_value; }\n\n    inline bool operator==( Radians const& rhs ) const;\n    inline bool operator!=( Radians const& rhs ) const;\n\n    inline void Clamp( Degrees min, Degrees max )\n    {\n        m_value = Math::Clamp( m_value, min.m_value, max.m_value );\n    }\n\n    // Clamps between -360 and 360\n    inline void Clamp360()\n    {\n        m_value -= ( int32_t( m_value / 360.0f ) * 360.0f );\n    }\n\n    // Clamps between -360 and 360\n    inline Degrees GetClamped360() const\n    {\n        Degrees d( m_value );\n        d.Clamp360();\n        return d;\n    }\n\n    // Clamps to -180 to 180\n    inline void Clamp180()\n    {\n        Clamp360();\n\n        float delta = 180 - Math::Abs( m_value );\n        if ( delta < 0 )\n        {\n            delta += 180;\n            m_value = ( m_value < 0 ) ? delta : -delta;\n        }\n    }\n\n    // Clamps to -180 to 180\n    inline Degrees GetClamped180() const\n    {\n        Degrees r( m_value );\n        r.Clamp180();\n        return r;\n    }\n\n    // Clamps between 0 to 360\n    inline Degrees& ClampPositive360()\n    {\n        Clamp360();\n        if ( m_value < 0 )\n        {\n            m_value += 360;\n        }\n        return *this;\n    }\n\n    // Clamps between 0 to 360\n    inline Degrees GetClampedPositive360() const\n    {\n        Degrees d( m_value );\n        d.ClampPositive360();\n        return d;\n    }\n\nprivate:\n\n    float m_value = 0;\n};\n\nstruct Radians\n{\n    static Radians const Pi;\n    static Radians const TwoPi;\n    static Radians const OneDivPi;\n    static Radians const OneDivTwoPi;\n    static Radians const PiDivTwo;\n    static Radians const PiDivFour;\n\npublic:\n\n    inline Radians() = default;\n    inline Radians( float radians ) : m_value( radians ) {}\n    inline explicit Radians( Degrees const& degrees );\n\n    FORCE_INLINE explicit operator float() const { return m_value; }\n    FORCE_INLINE operator Degrees() const { return ToDegrees(); }\n    FORCE_INLINE float ToFloat() const { return m_value; }\n    FORCE_INLINE Degrees ToDegrees() const { return Degrees( m_value * Math::RadiansToDegrees ); }\n\n    inline Radians operator-() const { return Radians( -m_value ); }\n\n    inline Radians operator+( Radians const& rhs ) const { return Radians( m_value + rhs.m_value ); }\n    inline Radians operator-( Radians const& rhs ) const { return Radians( m_value - rhs.m_value ); }\n    inline Radians operator*( Radians const& rhs ) const { return Radians( m_value * rhs.m_value ); }\n    inline Radians operator/( Radians const& rhs ) const { return Radians( m_value / rhs.m_value ); }\n\n    inline Radians& operator+=( Radians const& rhs ) { m_value += rhs.m_value; return *this; }\n    inline Radians& operator-=( Radians const& rhs ) { m_value -= rhs.m_value; return *this; }\n    inline Radians& operator*=( Radians const& rhs ) { m_value *= rhs.m_value; return *this; }\n    inline Radians& operator/=( Radians const& rhs ) { m_value /= rhs.m_value; return *this; }\n\n    inline Radians operator+( float const& rhs ) const { return Radians( m_value + rhs ); }\n    inline Radians operator-( float const& rhs ) const { return Radians( m_value - rhs ); }\n    inline Radians operator*( float const& rhs ) const { return Radians( m_value * rhs ); }\n    inline Radians operator/( float const& rhs ) const { return Radians( m_value / rhs ); }\n\n    inline Radians& operator+=( float const& rhs ) { m_value += rhs; return *this; }\n    inline Radians& operator-=( float const& rhs ) { m_value -= rhs; return *this; }\n    inline Radians& operator*=( float const& rhs ) { m_value *= rhs; return *this; }\n    inline Radians& operator/=( float const& rhs ) { m_value /= rhs; return *this; }\n\n    inline Radians operator+( int32_t const& rhs ) const { return Radians( m_value + rhs ); }\n    inline Radians operator-( int32_t const& rhs ) const { return Radians( m_value - rhs ); }\n    inline Radians operator*( int32_t const& rhs ) const { return Radians( m_value * rhs ); }\n    inline Radians operator/( int32_t const& rhs ) const { return Radians( m_value / rhs ); }\n\n    inline Radians& operator+=( int32_t const& rhs ) { m_value += rhs; return *this; }\n    inline Radians& operator-=( int32_t const& rhs ) { m_value -= rhs; return *this; }\n    inline Radians& operator*=( int32_t const& rhs ) { m_value *= rhs; return *this; }\n    inline Radians& operator/=( int32_t const& rhs ) { m_value /= rhs; return *this; }\n\n    inline Radians operator+( uint32_t const& rhs ) const { return Radians( m_value + rhs ); }\n    inline Radians operator-( uint32_t const& rhs ) const { return Radians( m_value - rhs ); }\n    inline Radians operator*( uint32_t const& rhs ) const { return Radians( m_value * rhs ); }\n    inline Radians operator/( uint32_t const& rhs ) const { return Radians( m_value / rhs ); }\n\n    inline Radians& operator+=( uint32_t const& rhs ) { m_value += rhs; return *this; }\n    inline Radians& operator-=( uint32_t const& rhs ) { m_value -= rhs; return *this; }\n    inline Radians& operator*=( uint32_t const& rhs ) { m_value *= rhs; return *this; }\n    inline Radians& operator/=( uint32_t const& rhs ) { m_value /= rhs; return *this; }\n\n    inline bool operator>( float const& rhs ) const { return m_value > rhs; };\n    inline bool operator<( float const& rhs ) const { return m_value < rhs; }\n    inline bool operator>=( float const& rhs ) const { return m_value >= rhs; }\n    inline bool operator<=( float const& rhs ) const { return m_value <= rhs; }\n\n    inline bool operator>( Radians const& rhs ) const { return m_value > rhs.m_value; }\n    inline bool operator<( Radians const& rhs ) const { return m_value < rhs.m_value; }\n    inline bool operator>=( Radians const& rhs ) const { return m_value >= rhs.m_value; }\n    inline bool operator<=( Radians const& rhs ) const { return m_value <= rhs.m_value; }\n\n    inline bool operator>( Degrees const& rhs ) const;\n    inline bool operator<( Degrees const& rhs ) const;\n    inline bool operator>=( Degrees const& rhs ) const;\n    inline bool operator<=( Degrees const& rhs ) const;\n\n    inline bool operator==( float const& v ) const { return Math::IsNearEqual( m_value, v ); }\n    inline bool operator!=( float const& v ) const { return !Math::IsNearEqual( m_value, v ); }\n\n    inline bool operator==( Radians const& rhs ) const { return m_value == rhs.m_value; }\n    inline bool operator!=( Radians const& rhs ) const { return m_value != rhs.m_value; }\n\n    inline bool operator==( Degrees const& rhs ) const;\n    inline bool operator!=( Degrees const& rhs ) const;\n\n    inline void Clamp( Radians min, Radians max )\n    {\n        m_value = Math::Clamp( m_value, min.m_value, max.m_value );\n    }\n\n    // Clamps between -2Pi to 2Pi\n    inline void Clamp360()\n    {\n        m_value -= int32_t( m_value / Math::TwoPi ) * Math::TwoPi;\n    }\n\n    // Clamps between -2Pi to 2Pi\n    inline Radians GetClamped360() const\n    {\n        Radians r( m_value );\n        r.Clamp360();\n        return r;\n    }\n\n    // Clamps between 0 to 2Pi\n    inline void ClampPositive360()\n    {\n        Clamp360();\n        if( m_value < 0 )\n        {\n            m_value += Math::TwoPi;\n        }\n    }\n\n    // Clamps between 0 to 2Pi\n    inline Radians GetClampedToPositive360() const\n    {\n        Radians r( m_value );\n        r.ClampPositive360();\n        return r;\n    }\n\n    // Clamps to -Pi to Pi\n    inline void Clamp180()\n    {\n        Clamp360();\n\n        float delta = Math::Pi - Math::Abs( m_value );\n        if ( delta < 0 )\n        {\n            delta += Math::Pi;\n            m_value = ( m_value < 0 ) ? delta : -delta;\n        }\n    }\n\n    // Clamps to -Pi to Pi\n    inline Radians GetClamped180() const\n    {\n        Radians r( m_value );\n        r.Clamp180();\n        return r;\n    }\n\n    // Inverts angle between [0;2Pi] and [-2Pi;0]\n    inline void Invert()\n    {\n        Clamp360();\n        float const delta = Math::TwoPi - Math::Abs( m_value );\n        m_value = ( m_value < 0 ) ? delta : -delta;\n    }\n\n    // Inverts angle between [0;2Pi] and [-2Pi;0]\n    inline Radians GetInverse() const\n    {\n        Radians r( m_value );\n        r.Invert();\n        return r;\n    }\n\n    // Flips the front and rear 180 degree arc i.e. 135 becomes -45, -90 becomes 90, etc.\n    inline void Flip()\n    {\n        Clamp180();\n        float const delta = Math::Pi - Math::Abs( m_value );\n        m_value = ( m_value < 0 ) ? delta : -delta;\n    }\n\n    // Flips the front and rear 180 degree arc i.e. 135 becomes -45, -90 becomes 90, etc.\n    inline Radians GetFlipped() const\n    {\n        Radians r( m_value );\n        r.Flip();\n        return r;\n    }\n\nprivate:\n\n    float m_value = 0;\n};\n\ninline Degrees::Degrees( Radians const& radians )\n    : m_value( radians.ToDegrees() )\n{}\n\ninline Radians Degrees::ToRadians() const\n{\n    return Radians( m_value * Math::DegreesToRadians );\n}\n\ninline Degrees::operator Radians() const\n{\n    return ToRadians();\n}\n\ninline bool Degrees::operator>( Radians const& rhs ) const { return m_value > rhs.ToDegrees().m_value; }\ninline bool Degrees::operator<( Radians const& rhs ) const { return m_value < rhs.ToDegrees().m_value; }\ninline bool Degrees::operator>=( Radians const& rhs ) const { return m_value >= rhs.ToDegrees().m_value; }\ninline bool Degrees::operator<=( Radians const& rhs ) const { return m_value <= rhs.ToDegrees().m_value; }\n\ninline bool Degrees::operator==( Radians const& rhs ) const { return Math::IsNearEqual( m_value, rhs.ToDegrees().m_value ); }\ninline bool Degrees::operator!=( Radians const& rhs ) const { return !Math::IsNearEqual( m_value, rhs.ToDegrees().m_value ); }\n\ninline Radians::Radians( Degrees const& degrees )\n    : m_value( degrees.ToRadians() )\n{}\n\ninline bool Radians::operator>( Degrees const& rhs ) const { return m_value > rhs.ToRadians().m_value; }\ninline bool Radians::operator<( Degrees const& rhs ) const { return m_value < rhs.ToRadians().m_value; }\ninline bool Radians::operator>=( Degrees const& rhs ) const { return m_value >= rhs.ToRadians().m_value; }\ninline bool Radians::operator<=( Degrees const& rhs ) const { return m_value <= rhs.ToRadians().m_value; }\n\ninline bool Radians::operator==( Degrees const& rhs ) const { return Math::IsNearEqual( m_value, rhs.ToRadians().m_value ); }\ninline bool Radians::operator!=( Degrees const& rhs ) const { return !Math::IsNearEqual( m_value, rhs.ToRadians().m_value ); }\n\nstruct EulerAngles\n{\npublic:\n\n    EulerAngles() = default;\n\n    inline explicit EulerAngles( Degrees inX, Degrees inY, Degrees inZ )\n        : m_x( inX )\n        , m_y( inY )\n        , m_z( inZ )\n    {}\n\n    inline explicit EulerAngles( Radians inX, Radians inY, Radians inZ )\n        : m_x( inX )\n        , m_y( inY )\n        , m_z( inZ )\n    {}\n\n    inline explicit EulerAngles( float inDegreesX, float inDegreesY, float inDegreesZ )\n        : m_x( Math::DegreesToRadians * inDegreesX )\n        , m_y( Math::DegreesToRadians * inDegreesY )\n        , m_z( Math::DegreesToRadians * inDegreesZ )\n    {}\n\n    inline EulerAngles( Float3 const& anglesInDegrees )\n        : m_x( Math::DegreesToRadians * anglesInDegrees.m_x )\n        , m_y( Math::DegreesToRadians * anglesInDegrees.m_y )\n        , m_z( Math::DegreesToRadians * anglesInDegrees.m_z )\n    {}\n\n    inline void Clamp()\n    {\n        m_x.Clamp360();\n        m_y.Clamp360();\n        m_z.Clamp360();\n    }\n\n    inline EulerAngles GetClamped() const\n    {\n        EulerAngles clamped = *this;\n        clamped.Clamp();\n        return clamped;\n    }\n\n    inline Radians GetYaw() const { return m_z; }\n    inline Radians GetPitch() const { return m_x; }\n    inline Radians GetRoll() const { return m_y; }\n\n    inline Float3 GetAsRadians() const { return Float3( m_x.ToFloat(), m_y.ToFloat(), m_z.ToFloat() ); }\n    inline Float3 GetAsDegrees() const { return Float3( m_x.ToDegrees().ToFloat(), m_y.ToDegrees().ToFloat(), m_z.ToDegrees().ToFloat() ); }\n\n    inline bool operator==( EulerAngles const& other ) const { return m_x == other.m_x && m_y == other.m_y && m_z == other.m_z; }\n    inline bool operator!=( EulerAngles const& other ) const { return m_x != other.m_x || m_y != other.m_y || m_z != other.m_z; }\n\n    inline Radians& operator[]( uint32_t i ) { return ( (Radians*) this )[i]; }\n    inline Radians const& operator[]( uint32_t i ) const { return ( (Radians*) this )[i]; }\n    // in degrees\n    inline Float3 ToFloat3() const { return Float3( Math::RadiansToDegrees * m_x.ToFloat(), Math::RadiansToDegrees * m_y.ToFloat(), Math::RadiansToDegrees * m_z.ToFloat() ); }\n\npublic:\n\n    Radians m_x = 0.0f;\n    Radians m_y = 0.0f;\n    Radians m_z = 0.0f;\n};\n\nstruct AxisAngle\n{\npublic:\n\n    inline AxisAngle() = default;\n    inline explicit AxisAngle( Float3 axis, Radians angle ) : m_axis( axis ), m_angle( angle ) {}\n    inline explicit AxisAngle( Float3 axis, Degrees angle ) : m_axis( axis ), m_angle( angle.ToRadians() ) {}\n\n    inline bool IsValid() const\n    {\n        float const lengthSq = m_axis.m_x * m_axis.m_x + m_axis.m_y * m_axis.m_y + m_axis.m_z * m_axis.m_z;\n        return Math::Abs( lengthSq - 1.0f ) < Math::Epsilon;\n    }\n\npublic:\n\n    Float3      m_axis = Float3::Zero;\n    Radians     m_angle = Radians( 0.0f );\n};\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Vector.cpp",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#include \"Vector.h\"\n#include \"Quaternion.h\"\n\nnamespace Math\n{\n    Vector const Vector::UnitX = { 1, 0, 0, 0 };\n    Vector const Vector::UnitY = { 0, 1, 0, 0 };\n    Vector const Vector::UnitZ = { 0, 0, 1, 0 };\n    Vector const Vector::UnitW = { 0, 0, 0, 1 };\n\n    Vector const Vector::Origin = { 0, 0, 0, 1 };\n    Vector const Vector::WorldForward = { 0, -1, 0, 0 };\n    Vector const Vector::WorldBackward = { 0, 1, 0, 0 };\n    Vector const Vector::WorldUp = { 0, 0, 1, 0 };\n    Vector const Vector::WorldDown = { 0, 0, -1, 0 };\n    Vector const Vector::WorldLeft = { 1, 0, 0, 0 };\n    Vector const Vector::WorldRight = { -1, 0, 0, 0 };\n\n    Vector const Vector::Infinity = { 0x7F800000, 0x7F800000, 0x7F800000, 0x7F800000 };\n    Vector const Vector::QNaN = { 0x7FC00000, 0x7FC00000, 0x7FC00000, 0x7FC00000 };\n\n    Vector const Vector::NegativeOne(-1.0f);\n    Vector const Vector::Zero(0.0f);\n    Vector const Vector::Half(0.5f);\n    Vector const Vector::One(1.0f);\n\n    Vector const Vector::Epsilon(Math::Epsilon);\n    Vector const Vector::LargeEpsilon(Math::LargeEpsilon);\n    Vector const Vector::OneMinusEpsilon(1.0f - Math::Epsilon);\n    Vector const Vector::EpsilonMinusOne(Math::Epsilon - 1.0f);\n    Vector const Vector::NormalizeCheckThreshold(0.01f); // Squared Error\n\n    Vector const Vector::Pi(Math::Pi);\n    Vector const Vector::PiDivTwo(Math::PiDivTwo);\n    Vector const Vector::TwoPi(Math::TwoPi);\n    Vector const Vector::OneDivTwoPi(Math::OneDivTwoPi);\n\n    Vector const Vector::Select0000(0, 0, 0, 0);\n    Vector const Vector::Select0001(0, 0, 0, 1);\n    Vector const Vector::Select0010(0, 0, 1, 0);\n    Vector const Vector::Select0011(0, 0, 1, 1);\n    Vector const Vector::Select0100(0, 1, 0, 0);\n    Vector const Vector::Select0101(0, 1, 0, 1);\n    Vector const Vector::Select0110(0, 1, 1, 0);\n    Vector const Vector::Select0111(0, 1, 1, 1);\n    Vector const Vector::Select1000(1, 0, 0, 0);\n    Vector const Vector::Select1001(1, 0, 0, 1);\n    Vector const Vector::Select1010(1, 0, 1, 0);\n    Vector const Vector::Select1011(1, 0, 1, 1);\n    Vector const Vector::Select1100(1, 1, 0, 0);\n    Vector const Vector::Select1101(1, 1, 0, 1);\n    Vector const Vector::Select1110(1, 1, 1, 0);\n    Vector const Vector::Select1111(1, 1, 1, 1);\n\n    Vector const Vector::BoxCorners[8] =\n    {\n        { -1.0f, -1.0f,  1.0f, 0.0f },\n        {  1.0f, -1.0f,  1.0f, 0.0f },\n        {  1.0f,  1.0f,  1.0f, 0.0f },\n        { -1.0f,  1.0f,  1.0f, 0.0f },\n        { -1.0f, -1.0f, -1.0f, 0.0f },\n        {  1.0f, -1.0f, -1.0f, 0.0f },\n        {  1.0f,  1.0f, -1.0f, 0.0f },\n        { -1.0f,  1.0f, -1.0f, 0.0f },\n    };\n\n    Vector Vector::SLerp(const Vector& from, const Vector& to, float t)\n    {\n        ASSERT(t >= 0.0f && t <= 1.0f);\n        if (from.LengthSquared3().IsLessThan4(Epsilon) || to.LengthSquared3().IsLessThan4(Epsilon))\n        {\n            return Lerp(from, to, t);\n        }\n\n        // Calculate the final length\n        const Vector fromLength = from.Length3();\n        const Vector toLength = to.Length3();\n        const Vector finalLength = Lerp(fromLength, toLength, t);\n\n        // Normalize vectors\n        const Vector normalizedFrom = from / fromLength;\n        const Vector normalizedTo = to / toLength;\n\n        // Handle parallel vector\n        Vector result;\n        if (normalizedFrom.IsParallelTo(normalizedTo))\n        {\n            result = normalizedFrom;\n        }\n        else\n        {\n            // Interpolate the rotation between the vectors\n            const Vector dot = Dot3(normalizedFrom, normalizedTo);\n            const Vector angle = ACos(dot);\n            const Vector axis = Cross3(normalizedFrom, normalizedTo).Normalize3();\n            const Vector interpolatedAngle = Lerp(Zero, angle, t);\n\n            const Quaternion rotation(axis, Radians(interpolatedAngle.ToFloat()));\n            const Vector finalDirection = rotation.RotateVector(normalizedFrom);\n            result = finalDirection.GetNormalized3() * finalLength;\n        }\n\n        return result;\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Vector.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include \"Compiler.h\"\n\n#include \"Types.h\"\n#include \"Constants.h\"\n#include \"SIMD.h\"\n\nnamespace Math\n{\n    class alignas(16) Vector\n    {\n    public:\n\n        static Vector const UnitX;\n        static Vector const UnitY;\n        static Vector const UnitZ;\n        static Vector const UnitW;\n\n        static Vector const Origin;\n        static Vector const WorldForward;\n        static Vector const WorldBackward;\n        static Vector const WorldUp;\n        static Vector const WorldDown;\n        static Vector const WorldLeft;\n        static Vector const WorldRight;\n\n        static Vector const NegativeOne;\n        static Vector const Zero;\n        static Vector const Half;\n        static Vector const One;\n        static Vector const Epsilon;\n        static Vector const LargeEpsilon;\n        static Vector const OneMinusEpsilon;\n        static Vector const EpsilonMinusOne;\n        static Vector const NormalizeCheckThreshold;\n        static Vector const Pi;\n        static Vector const PiDivTwo;\n        static Vector const TwoPi;\n        static Vector const OneDivTwoPi;\n\n        static Vector const Select0000;\n        static Vector const Select0001;\n        static Vector const Select0010;\n        static Vector const Select0011;\n        static Vector const Select0100;\n        static Vector const Select0101;\n        static Vector const Select0110;\n        static Vector const Select0111;\n        static Vector const Select1000;\n        static Vector const Select1001;\n        static Vector const Select1010;\n        static Vector const Select1011;\n        static Vector const Select1100;\n        static Vector const Select1101;\n        static Vector const Select1110;\n        static Vector const Select1111;\n\n        static Vector const Infinity;\n        static Vector const QNaN;\n\n        static Vector const BoxCorners[8];\n\n        //\n        // Utils\n        //\n\n        static Vector Cross2(const Vector& v0, const Vector& v1);\n        static Vector Cross3(const Vector& v0, const Vector& v1);\n        static Vector Dot2(const Vector& v0, const Vector& v1);\n        static Vector Dot3(const Vector& v0, const Vector& v1);\n        static Vector Dot4(const Vector& v0, const Vector& v1);\n        static Vector Average2(const Vector& v0, const Vector& v1);\n        static Vector Average3(const Vector& v0, const Vector& v1);\n        static Vector Average4(const Vector& v0, const Vector& v1);\n        static Vector Min(const Vector& v0, const Vector& v1);\n        static Vector Max(const Vector& v0, const Vector& v1);\n        static float Min(const Vector& v);\n        static float Max(const Vector& v);\n        static Vector Clamp(const Vector& v, const Vector& min, const Vector& max);\n        static Vector Xor(const Vector& vec0, const Vector& vec1);\n\n        // Add the multiplied results to a vector: ( vec * mul ) + addend\n        static Vector MultiplyAdd(const Vector& vec, const Vector& multiplier, const Vector& addend);\n\n        // Subtract a vector from the multiplied result: (vec * mul ) - subtrahend\n        static Vector MultiplySubtract(const Vector& vec, const Vector& multiplier, const Vector& subtrahend);\n\n        // Subtract the multiplied result from a vector: minuend - (vec * mul )\n        static Vector NegativeMultiplySubtract(const Vector& vec, const Vector& multiplier, const Vector& minuend);\n\n        // Sum up scaled versions of two vectors\n        static Vector LinearCombination(const Vector& v0, const Vector& v1, float scale0, float scale1);\n\n        // Linear interpolation of one vector to another\n        static Vector Lerp(const Vector& from, const Vector& to, float t);\n\n        // Normalized linear interpolation of one vector to another\n        static Vector NLerp(const Vector& from, const Vector& to, float t);\n\n        // Spherical interpolation of one vector to another\n        static Vector SLerp(const Vector& from, const Vector& to, float t);\n\n        // Combine the two vectors based on the control: 0 means select from v0, 1 means select from v1. E.G. To select XY from v0 and ZW from v1, control = Vector( 0, 0, 1, 1 )\n        static Vector Select(const Vector& v0, const Vector& v1, const Vector& control);\n\n        // Get a permutation of two vectors, each template argument represents the element index to select ( v0: 0-3, v1: 4-7 );\n        template<uint32_t PermuteX, uint32_t PermuteY, uint32_t PermuteZ, uint32_t PermuteW>\n        static Vector Permute(const Vector& v0, const Vector& v1);\n\n        //\n        // Trigonometry\n        //\n\n        static Vector Sin(const Vector& vec);\n        static Vector Cos(const Vector& vec);\n        static Vector Tan(const Vector& vec);\n        static Vector ASin(const Vector& vec);\n        static Vector ACos(const Vector& vec);\n        static Vector ATan(const Vector& vec);\n        static Vector ATan2(const Vector& vec0, const Vector& vec1);\n\n        static Vector SinEst(const Vector& vec);\n        static Vector CosEst(const Vector& vec);\n        static Vector TanEst(const Vector& vec);\n        static Vector ASinEst(const Vector& vec);\n        static Vector ACosEst(const Vector& vec);\n        static Vector ATanEst(const Vector& vec);\n        static Vector ATan2Est(const Vector& vec0, const Vector& vec1);\n\n        static void SinCos(Vector& sin, Vector& cos, float angle);\n        static void SinCos(Vector& sin, Vector& cos, const Vector& angle);\n\n        static Vector AngleMod2Pi(const Vector& angles);\n\n    public:\n\n        operator __m128& ();\n        operator const __m128& () const;\n\n        Vector();\n        explicit Vector(Axis axis);\n        explicit Vector(ZeroInit_t);\n        explicit Vector(float v);\n        Vector(__m128 v);\n        Vector(float ix, float iy, float iz, float iw = 1.0f);\n\n        Vector(const Float2& v, float iz = 0.0f, float iw = 0.0f);\n        Vector(const Float3& v, float iw = 1.0f);\n        Vector(const Float4& v);\n        Vector(const float* pValues);\n\n        bool IsValid() const;\n\n        void Store(float* pValues) const;\n        void StoreFloat(float& value) const;\n        void StoreFloat2(Float2& value) const;\n        void StoreFloat3(Float3& value) const;\n        void StoreFloat4(Float4& value) const;\n\n        float ToFloat() const;\n        Float2 ToFloat2() const;\n        Float3 ToFloat3() const;\n        Float4 ToFloat4() const;\n\n        operator Float2() const;\n        operator Float3() const;\n        operator Float4() const;\n\n        //\n        // Element accessors\n        //\n\n        float GetX() const;\n        float GetY() const;\n        float GetZ() const;\n        float GetW() const;\n\n        void SetX(float x);\n        void SetY(float y);\n        void SetZ(float z);\n        void SetW(float w);\n\n        float operator[](uint32_t i) const;\n\n        //\n        // W component operations\n        //\n\n        bool IsW1() const;\n        bool IsW0() const;\n        Vector& SetW0();\n        Vector& SetW1();\n        Vector GetWithW0() const;\n        Vector GetWithW1() const;\n\n        //\n        // Dimensional Getters\n        //\n\n        // Returns only the first two components, z=w=0\n        Vector Get2D() const;\n\n        // Returns only the first three components, w = 0\n        Vector Get3D() const;\n\n        //\n        // Algebraic operators\n        //\n\n        Vector operator+(const Vector& v) const;\n        Vector& operator+=(const Vector& v);\n        Vector operator-(const Vector& v) const;\n        Vector& operator-=(const Vector& v);\n        Vector operator*(const Vector& v) const;\n        Vector& operator*=(const Vector& v);\n        Vector operator/(const Vector& v) const;\n        Vector& operator/=(const Vector& v);\n\n        Vector operator*(float const f) const;\n        Vector& operator*=(float const f);\n        Vector operator/(float const f) const;\n        Vector& operator/=(float const f);\n\n        Vector operator-() const;\n\n        Vector Orthogonal2D() const;\n        Vector Cross2(const Vector& other) const;\n        Vector Cross3(const Vector& other) const;\n        Vector Dot2(const Vector& other) const;\n        Vector Dot3(const Vector& other) const;\n        Vector Dot4(const Vector& other) const;\n        float GetDot2(const Vector& other) const;\n        float GetDot3(const Vector& other) const;\n        float GetDot4(const Vector& other) const;\n\n        Vector ScalarProjection(const Vector& other) const;\n        float GetScalarProjection(const Vector& other) const;\n        Vector VectorProjection(const Vector& other) const;\n\n        //\n        // Transformations\n        //\n\n        Vector& Invert();\n        Vector GetInverse() const;\n        Vector GetReciprocal() const;\n\n        Vector& InvertEst();\n        Vector GetInverseEst() const;\n\n        Vector& Negate();\n        Vector GetNegated() const;\n\n        Vector& Abs();\n        Vector GetAbs() const;\n\n        Vector& Sqrt();\n        Vector GetSqrt();\n\n        Vector& ReciprocalSqrt();\n        Vector GetReciprocalSqrt();\n\n        Vector& EstimatedReciprocalSqrt();\n        Vector GetEstimatedReciprocalSqrt();\n\n        Vector& Normalize2();\n        Vector& Normalize3();\n        Vector& Normalize4();\n\n        Vector GetNormalized2() const;\n        Vector GetNormalized3() const;\n        Vector GetNormalized4() const;\n\n        Vector& Floor();\n        Vector GetFloor() const;\n        Vector& Ceil();\n        Vector GetCeil() const;\n        Vector& Round();\n        Vector GetRound() const;\n\n        Vector GetSign() const;\n\n        //\n        // Permutations\n        //\n\n        Vector GetSplatX() const;\n        Vector GetSplatY() const;\n        Vector GetSplatZ() const;\n        Vector GetSplatW() const;\n\n        // Get a shuffled version of this vector, each argument represents the element index in the original vector\n        template<uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx>\n        Vector Swizzle() const;\n\n        // Get a shuffled version of this vector, each argument represents the element index in the original vector\n        Vector Swizzle(uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx) const;\n\n        // Get a shuffled version of this vector, each argument represents the element index in the original vector\n        Vector Shuffle(uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx) const;\n\n        // Get a shuffled version of this vector, each argument represents the element index in the original vector\n        template<uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx>\n        Vector Shuffle() const;\n\n        //\n        // Queries\n        //\n\n        Vector Length2() const;\n        Vector Length3() const;\n        Vector Length4() const;\n\n        float GetLength2() const;\n        float GetLength3() const;\n        float GetLength4() const;\n\n        Vector InverseLength2() const;\n        Vector InverseLength3() const;\n        Vector InverseLength4() const;\n\n        float GetInverseLength2() const;\n        float GetInverseLength3() const;\n        float GetInverseLength4() const;\n\n        Vector LengthSquared2() const;\n        Vector LengthSquared3() const;\n        Vector LengthSquared4() const;\n\n        float GetLengthSquared2() const;\n        float GetLengthSquared3() const;\n        float GetLengthSquared4() const;\n\n        Vector Distance2(const Vector& to) const;\n        Vector Distance3(const Vector& to) const;\n        Vector Distance4(const Vector& to) const;\n\n        float GetDistance2(const Vector& to) const;\n        float GetDistance3(const Vector& to) const;\n        float GetDistance4(const Vector& to) const;\n\n        Vector DistanceSquared2(const Vector& to) const;\n        Vector DistanceSquared3(const Vector& to) const;\n        Vector DistanceSquared4(const Vector& to) const;\n\n        float GetDistanceSquared2(const Vector& to) const;\n        float GetDistanceSquared3(const Vector& to) const;\n        float GetDistanceSquared4(const Vector& to) const;\n\n        bool IsNormalized2() const;\n        bool IsNormalized3() const;\n        bool IsNormalized4() const;\n\n        // Is this vector within the range [-bounds, bounds]\n        Vector InBounds(const Vector& bounds) const;\n\n        bool IsInBounds2(const Vector& bounds) const;\n        bool IsInBounds3(const Vector& bounds) const;\n        bool IsInBounds4(const Vector& bounds) const;\n\n        Vector Equal(const Vector& v) const;\n\n        bool IsEqual2(const Vector& v) const;\n        bool IsEqual3(const Vector& v) const;\n        bool IsEqual4(const Vector& v) const;\n\n        Vector NearEqual(const Vector& v, const Vector& epsilon) const;\n\n        bool IsNearEqual2(const Vector& v, float epsilon) const;\n        bool IsNearEqual3(const Vector& v, float epsilon) const;\n        bool IsNearEqual4(const Vector& v, float epsilon) const;\n\n        bool IsNearEqual2(const Vector& v, const Vector& epsilon = Vector::Epsilon) const;\n        bool IsNearEqual3(const Vector& v, const Vector& epsilon = Vector::Epsilon) const;\n        bool IsNearEqual4(const Vector& v, const Vector& epsilon = Vector::Epsilon) const;\n\n        Vector GreaterThan(const Vector& v) const;\n        bool IsAnyGreaterThan(const Vector& v) const;\n\n        bool IsGreaterThan2(const Vector& v) const;\n        bool IsGreaterThan3(const Vector& v) const;\n        bool IsGreaterThan4(const Vector& v) const;\n\n        Vector GreaterThanEqual(const Vector& v) const;\n        bool IsAnyGreaterThanEqual(const Vector& v) const;\n\n        bool IsGreaterThanEqual2(const Vector& v) const;\n        bool IsGreaterThanEqual3(const Vector& v) const;\n        bool IsGreaterThanEqual4(const Vector& v) const;\n\n        Vector LessThan(const Vector& v) const;\n        bool IsAnyLessThan(const Vector& v) const;\n\n        bool IsLessThan2(const Vector& v) const;\n        bool IsLessThan3(const Vector& v) const;\n        bool IsLessThan4(const Vector& v) const;\n\n        Vector LessThanEqual(const Vector& v) const;\n        bool IsAnyLessThanEqual(const Vector& v) const;\n\n        bool IsLessThanEqual2(const Vector& v) const;\n        bool IsLessThanEqual3(const Vector& v) const;\n        bool IsLessThanEqual4(const Vector& v) const;\n\n        Vector EqualsZero() const;\n        bool IsAnyEqualToZero2() const;\n        bool IsAnyEqualToZero3() const;\n        bool IsAnyEqualToZero4() const;\n\n        bool IsZero2() const;\n        bool IsZero3() const;\n        bool IsZero4() const;\n\n        Vector NearEqualsZero(float epsilon = Math::Epsilon) const;\n\n        bool IsNearZero2(float epsilon = Math::Epsilon) const;\n        bool IsNearZero3(float epsilon = Math::Epsilon) const;\n        bool IsNearZero4(float epsilon = Math::Epsilon) const;\n\n        Vector EqualsInfinity() const;\n\n        bool IsInfinite2() const;\n        bool IsInfinite3() const;\n        bool IsInfinite4() const;\n\n        Vector EqualsNaN() const;\n\n        bool IsNaN2() const;\n        bool IsNaN3() const;\n        bool IsNaN4() const;\n\n        bool IsParallelTo(const Vector& v) const;\n\n        void ToDirectionAndLength2(Vector& direction, float& length) const;\n        void ToDirectionAndLength3(Vector& direction, float& length) const;\n\n        bool operator==(const Vector& rhs) const;\n        bool operator!=(const Vector& rhs) const;\n\n    public:\n\n        __m128 m_data;\n    };\n\n    static_assert(sizeof(Vector) == 16, \"Vector size must be 16 bytes!\");\n}\n\n#include \"Vector.inl\"\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Math/Vector.inl",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n#include <cstring>\n\n#include \"Vector.h\"\n\nnamespace Math\n{\n    FORCE_INLINE Vector Vector::Cross2(const Vector& v0, const Vector& v1)\n    {\n        return v0.Cross2(v1);\n    }\n\n    FORCE_INLINE Vector Vector::Cross3(const Vector& v0, const Vector& v1)\n    {\n        return v0.Cross3(v1);\n    }\n\n    FORCE_INLINE Vector Vector::Dot2(const Vector& v0, const Vector& v1)\n    {\n        return v0.Dot2(v1);\n    }\n\n    FORCE_INLINE Vector Vector::Dot3(const Vector& v0, const Vector& v1)\n    {\n        return v0.Dot3(v1);\n    }\n\n    FORCE_INLINE Vector Vector::Dot4(const Vector& v0, const Vector& v1)\n    {\n        return v0.Dot4(v1);\n    }\n\n    FORCE_INLINE Vector Vector::Average2(const Vector& v0, const Vector& v1)\n    {\n        auto avg4 = Average4(v0, v1);\n        return Vector::Select(avg4, Vector::Zero, Vector(0, 0, 1, 1));\n    }\n\n    FORCE_INLINE Vector Vector::Average3(const Vector& v0, const Vector& v1)\n    {\n        auto avg4 = Average4(v0, v1);\n        return Vector::Select(avg4, Vector::Zero, Vector(0, 0, 0, 1));\n    }\n\n    FORCE_INLINE Vector Vector::Average4(const Vector& v0, const Vector& v1)\n    {\n        return (v0 + v1) * Vector::Half;\n    }\n\n    FORCE_INLINE Vector Vector::Min(const Vector& v0, const Vector& v1)\n    {\n        Vector result;\n        result = _mm_min_ps(v0, v1);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Max(const Vector& v0, const Vector& v1)\n    {\n        Vector result;\n        result = _mm_max_ps(v0, v1);\n        return result;\n    }\n\n    FORCE_INLINE float Vector::Min(const Vector& v)\n    {\n        __m128 shufReg, sumsReg;\n        shufReg = _mm_movehdup_ps(v);\n        sumsReg = _mm_min_ps(v, shufReg);\n        shufReg = _mm_movehl_ps(shufReg, sumsReg);\n        sumsReg = _mm_min_ss(sumsReg, shufReg);\n        return _mm_cvtss_f32(sumsReg);\n    }\n\n    FORCE_INLINE float Vector::Max(const Vector& v)\n    {\n        __m128 shufReg, sumsReg;\n        shufReg = _mm_movehdup_ps(v);\n        sumsReg = _mm_max_ps(v, shufReg);\n        shufReg = _mm_movehl_ps(shufReg, sumsReg);\n        sumsReg = _mm_max_ss(sumsReg, shufReg);\n        return _mm_cvtss_f32(sumsReg);\n    }\n\n    FORCE_INLINE Vector Vector::Clamp(const Vector& v, const Vector& min, const Vector& max)\n    {\n        Vector result;\n        result = _mm_max_ps(min, v);\n        result = _mm_min_ps(result, max);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Xor(const Vector& v0, const Vector& v1)\n    {\n        __m128i V = _mm_xor_si128(_mm_castps_si128(v0), _mm_castps_si128(v1));\n\n        Vector result;\n        result = _mm_castsi128_ps(V);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::MultiplyAdd(const Vector& v, const Vector& multiplier, const Vector& addend)\n    {\n        // result = addend + ( vec * multiplier )\n        Vector result;\n        result = _mm_mul_ps(v, multiplier);\n        result = _mm_add_ps(result, addend);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::MultiplySubtract(const Vector& vec, const Vector& multiplier, const Vector& subtrahend)\n    {\n        // result = ( vec * multiplier ) - subtrahend\n        auto r = _mm_mul_ps(vec, multiplier);\n        return _mm_sub_ps(r, subtrahend);\n    }\n\n    FORCE_INLINE Vector Vector::NegativeMultiplySubtract(const Vector& vec, const Vector& multiplier, const Vector& minuend)\n    {\n        // result = minuend - ( vec * multiplier )\n        auto r = _mm_mul_ps(vec, multiplier);\n        return _mm_sub_ps(minuend, r);\n    }\n\n    FORCE_INLINE Vector Vector::LinearCombination(const Vector& v0, const Vector& v1, float scale0, float scale1)\n    {\n        return (v0 * scale0) + (v1 * scale1);\n    }\n\n    FORCE_INLINE Vector Vector::Lerp(const Vector& from, const Vector& to, float t)\n    {\n        ASSERT(t >= 0.0f && t <= 1.0f);\n\n        Vector L = _mm_sub_ps(to, from);\n        Vector S = _mm_set_ps1(t);\n\n        Vector result;\n        result = _mm_mul_ps(L, S);\n        result = _mm_add_ps(result, from);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::NLerp(const Vector& from, const Vector& to, float t)\n    {\n        ASSERT(t >= 0.0f && t <= 1.0f);\n\n        // Calculate the final length\n        auto const fromLength = from.Length3();\n        auto const toLength = to.Length3();\n        auto const finalLength = Vector::Lerp(fromLength, toLength, t);\n\n        // Normalize vectors\n        Vector const normalizedFrom = from / fromLength;\n        Vector const normalizedTo = to / toLength;\n\n        // LERP\n        auto const finalDirection = Lerp(normalizedFrom, normalizedTo, t);\n        auto result = finalDirection.GetNormalized3() * finalLength;\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Select(const Vector& v0, const Vector& v1, const Vector& control)\n    {\n        auto const ctrl = _mm_cmpneq_ps(control, Vector::Zero);\n\n        Vector result;\n        auto vTemp1 = _mm_andnot_ps(ctrl, v0);\n        auto vTemp2 = _mm_and_ps(v1, ctrl);\n        result = _mm_or_ps(vTemp1, vTemp2);\n        return result;\n    }\n\n    template<uint32_t PermuteX, uint32_t PermuteY, uint32_t PermuteZ, uint32_t PermuteW>\n    FORCE_INLINE Vector Vector::Permute(const Vector& v0, const Vector& v1)\n    {\n        static_assert(PermuteX <= 7, \"Element index parameter out of range\");\n        static_assert(PermuteY <= 7, \"Element index parameter out of range\");\n        static_assert(PermuteZ <= 7, \"Element index parameter out of range\");\n        static_assert(PermuteW <= 7, \"Element index parameter out of range\");\n\n        uint32_t const shuffle = _MM_SHUFFLE(PermuteW & 3, PermuteZ & 3, PermuteY & 3, PermuteX & 3);\n        bool const whichX = PermuteX > 3;\n        bool const whichY = PermuteY > 3;\n        bool const whichZ = PermuteZ > 3;\n        bool const whichW = PermuteW > 3;\n\n        static SIMD::UIntMask const selectMask = { whichX ? 0xFFFFFFFF : 0, whichY ? 0xFFFFFFFF : 0, whichZ ? 0xFFFFFFFF : 0, whichW ? 0xFFFFFFFF : 0 };\n        __m128 shuffled1 = _mm_shuffle_ps(v0, v0, shuffle);\n        __m128 shuffled2 = _mm_shuffle_ps(v1, v1, shuffle);\n        __m128 masked1 = _mm_andnot_ps(selectMask, shuffled1);\n        __m128 masked2 = _mm_and_ps(selectMask, shuffled2);\n        return _mm_or_ps(masked1, masked2);\n    }\n\n    FORCE_INLINE Vector Vector::Sin(const Vector& vec)\n    {\n        // Force the value within the bounds of pi\n        auto m_x = Vector::AngleMod2Pi(vec);\n\n        // Map in [-pi/2,pi/2] with sin(m_y) = sin(m_x).\n        __m128 sign = _mm_and_ps(m_x, SIMD::g_signMask);\n        __m128 c = _mm_or_ps(Vector::Pi, sign);  // pi when m_x >= 0, -pi when m_x < 0\n        __m128 absx = _mm_andnot_ps(sign, m_x);  // |m_x|\n        __m128 rflx = _mm_sub_ps(c, m_x);\n        __m128 comp = _mm_cmple_ps(absx, Vector::PiDivTwo);\n        __m128 select0 = _mm_and_ps(comp, m_x);\n        __m128 select1 = _mm_andnot_ps(comp, rflx);\n        m_x = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        const auto SC1 = SIMD::g_sinCoefficients1;\n        auto vConstants = _mm_shuffle_ps(SC1, SC1, _MM_SHUFFLE(0, 0, 0, 0));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        const auto SC0 = SIMD::g_sinCoefficients0;\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(3, 3, 3, 3));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, m_x);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::Cos(const Vector& vec)\n    {\n        // Map V to m_x in [-pi,pi].\n        auto m_x = Vector::AngleMod2Pi(vec);\n\n        // Map in [-pi/2,pi/2] with cos(m_y) = sign*cos(m_x).\n        auto sign = _mm_and_ps(m_x, SIMD::g_signMask);\n        __m128 c = _mm_or_ps(Vector::Pi, sign);  // pi when m_x >= 0, -pi when m_x < 0\n        __m128 absx = _mm_andnot_ps(sign, m_x);  // |m_x|\n        __m128 rflx = _mm_sub_ps(c, m_x);\n        __m128 comp = _mm_cmple_ps(absx, Vector::PiDivTwo);\n        __m128 select0 = _mm_and_ps(comp, m_x);\n        __m128 select1 = _mm_andnot_ps(comp, rflx);\n        m_x = _mm_or_ps(select0, select1);\n        select0 = _mm_and_ps(comp, Vector::One);\n        select1 = _mm_andnot_ps(comp, Vector::NegativeOne);\n        sign = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        const auto CC1 = SIMD::g_cosCoefficients1;\n        auto vConstants = _mm_shuffle_ps(CC1, CC1, _MM_SHUFFLE(0, 0, 0, 0));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        const auto CC0 = SIMD::g_cosCoefficients0;\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(3, 3, 3, 3));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, sign);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::Tan(const Vector& vec)\n    {\n        static const Vector tanCoefficients0 = { 1.0f, -4.667168334e-1f, 2.566383229e-2f, -3.118153191e-4f };\n        static const Vector tanCoefficients1 = { 4.981943399e-7f, -1.333835001e-1f, 3.424887824e-3f, -1.786170734e-5f };\n        static const Vector tanConstants = { 1.570796371f, 6.077100628e-11f, 0.000244140625f, 0.63661977228f /*2 / Pi*/ };\n        static const SIMD::UIntMask mask = { 0x1, 0x1, 0x1, 0x1 };\n\n        Vector TwoDivPi = tanConstants.GetSplatW();\n        Vector C0 = tanConstants.GetSplatX();\n        Vector C1 = tanConstants.GetSplatY();\n        Vector vEpsilon = tanConstants.GetSplatZ();\n\n        Vector VA = (vec * TwoDivPi).Round();\n        Vector VC = Vector::NegativeMultiplySubtract(VA, C0, vec);\n        Vector VB = VA.GetAbs();\n        VC = Vector::NegativeMultiplySubtract(VA, C1, VC);\n        reinterpret_cast<__m128i*>(&VB)[0] = _mm_cvttps_epi32(VB);\n\n        Vector VC2 = VC * VC;\n        Vector T7 = tanCoefficients1.GetSplatW();\n        Vector T6 = tanCoefficients1.GetSplatZ();\n        Vector T4 = tanCoefficients1.GetSplatX();\n        Vector T3 = tanCoefficients0.GetSplatW();\n        Vector T5 = tanCoefficients1.GetSplatY();\n        Vector T2 = tanCoefficients0.GetSplatZ();\n        Vector T1 = tanCoefficients0.GetSplatY();\n        Vector T0 = tanCoefficients0.GetSplatX();\n\n        Vector VBIsEven = _mm_and_ps(VB, mask);\n        VBIsEven = _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(VBIsEven), _mm_castps_si128(Vector::Zero)));\n\n        Vector N = Vector::MultiplyAdd(VC2, T7, T6);\n        Vector D = Vector::MultiplyAdd(VC2, T4, T3);\n        N = Vector::MultiplyAdd(VC2, N, T5);\n        D = Vector::MultiplyAdd(VC2, D, T2);\n        N = VC2 * N;\n        D = Vector::MultiplyAdd(VC2, D, T1);\n        N = Vector::MultiplyAdd(VC, N, VC);\n        Vector VCNearZero = VC.InBounds(vEpsilon);\n        D = Vector::MultiplyAdd(VC2, D, T0);\n\n        N = Vector::Select(N, VC, VCNearZero);\n        D = Vector::Select(D, Vector::One, VCNearZero);\n\n        Vector R0 = N.GetNegated();\n        Vector R1 = N / D;\n        R0 = D / R0;\n\n        Vector VIsZero = vec.EqualsZero();\n        Vector Result = Vector::Select(R0, R1, VBIsEven);\n        Result = Vector::Select(Result, Zero, VIsZero);\n\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::ASin(const Vector& vec)\n    {\n        __m128 nonnegative = _mm_cmpge_ps(vec, Vector::Zero);\n        __m128 mvalue = _mm_sub_ps(Vector::Zero, vec);\n        __m128 m_x = _mm_max_ps(vec, mvalue);  // |vec|\n\n        // Compute (1-|vec|), clamp to zero to avoid sqrt of negative number.\n        __m128 oneMValue = _mm_sub_ps(Vector::One, m_x);\n        __m128 clampOneMValue = _mm_max_ps(Vector::Zero, oneMValue);\n        __m128 root = _mm_sqrt_ps(clampOneMValue);  // sqrt(1-|vec|)\n\n        // Compute polynomial approximation\n        const auto AC1 = SIMD::g_arcCoefficients1;\n        auto vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 t0 = _mm_mul_ps(vConstants, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        const auto AC0 = SIMD::g_arcCoefficients0;\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(3, 3, 3, 3));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, root);\n\n        __m128 t1 = _mm_sub_ps(Vector::Pi, t0);\n        t0 = _mm_and_ps(nonnegative, t0);\n        t1 = _mm_andnot_ps(nonnegative, t1);\n        t0 = _mm_or_ps(t0, t1);\n        t0 = _mm_sub_ps(Vector::PiDivTwo, t0);\n        return t0;\n    }\n\n    FORCE_INLINE Vector Vector::ACos(const Vector& vec)\n    {\n        __m128 nonnegative = _mm_cmpge_ps(vec, Vector::Zero);\n        __m128 mvalue = _mm_sub_ps(Vector::Zero, vec);\n        __m128 m_x = _mm_max_ps(vec, mvalue);  // |vec|\n\n        // Compute (1-|vec|), clamp to zero to avoid sqrt of negative number.\n        __m128 oneMValue = _mm_sub_ps(Vector::One, m_x);\n        __m128 clampOneMValue = _mm_max_ps(Vector::Zero, oneMValue);\n        __m128 root = _mm_sqrt_ps(clampOneMValue);  // sqrt(1-|vec|)\n\n        // Compute polynomial approximation\n        const auto AC1 = SIMD::g_arcCoefficients1;\n        auto vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 t0 = _mm_mul_ps(vConstants, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC1, AC1, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        const auto AC0 = SIMD::g_arcCoefficients0;\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(3, 3, 3, 3));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AC0, AC0, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, root);\n\n        __m128 t1 = _mm_sub_ps(Vector::Pi, t0);\n        t0 = _mm_and_ps(nonnegative, t0);\n        t1 = _mm_andnot_ps(nonnegative, t1);\n        t0 = _mm_or_ps(t0, t1);\n        return t0;\n    }\n\n    FORCE_INLINE Vector Vector::ATan(const Vector& vec)\n    {\n        __m128 absV = vec.GetAbs();\n        __m128 invV = _mm_div_ps(Vector::One, vec);\n        __m128 comp = _mm_cmpgt_ps(vec, Vector::One);\n        __m128 select0 = _mm_and_ps(comp, Vector::One);\n        __m128 select1 = _mm_andnot_ps(comp, Vector::NegativeOne);\n        __m128 sign = _mm_or_ps(select0, select1);\n        comp = _mm_cmple_ps(absV, Vector::One);\n        select0 = _mm_and_ps(comp, Vector::Zero);\n        select1 = _mm_andnot_ps(comp, sign);\n        sign = _mm_or_ps(select0, select1);\n        select0 = _mm_and_ps(comp, vec);\n        select1 = _mm_andnot_ps(comp, invV);\n        __m128 m_x = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        Vector const TC1 = SIMD::g_aTanCoefficients1;\n        Vector vConstants = _mm_shuffle_ps(TC1, TC1, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        vConstants = _mm_shuffle_ps(TC1, TC1, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(TC1, TC1, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(TC1, TC1, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        Vector const TC0 = SIMD::g_aTanCoefficients0;\n        vConstants = _mm_shuffle_ps(TC0, TC0, _MM_SHUFFLE(3, 3, 3, 3));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(TC0, TC0, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(TC0, TC0, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(TC0, TC0, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, m_x);\n        __m128 result1 = _mm_mul_ps(sign, Vector::PiDivTwo);\n        result1 = _mm_sub_ps(result1, Result);\n\n        comp = _mm_cmpeq_ps(sign, Vector::Zero);\n        select0 = _mm_and_ps(comp, Result);\n        select1 = _mm_andnot_ps(comp, result1);\n        Result = _mm_or_ps(select0, select1);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::ATan2(const Vector& Y, const Vector& X)\n    {\n        Vector ATanResultValid = Vector(SIMD::g_trueMask);\n\n        Vector vPi = Vector(SIMD::g_aTan2Constants).GetSplatX();\n        Vector vPiOverTwo = Vector(SIMD::g_aTan2Constants).GetSplatY();\n        Vector vPiOverFour = Vector(SIMD::g_aTan2Constants).GetSplatZ();\n        Vector vThreePiOverFour = Vector(SIMD::g_aTan2Constants).GetSplatW();\n\n        Vector YEqualsZero = Y.EqualsZero();\n        Vector XEqualsZero = X.EqualsZero();\n        Vector XIsPositive = _mm_and_ps(X, SIMD::g_signMask);\n        XIsPositive = _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(XIsPositive), _mm_castps_si128(Vector::Zero)));\n        Vector YEqualsInfinity = Y.EqualsInfinity();\n        Vector XEqualsInfinity = X.EqualsInfinity();\n\n        Vector YSign = _mm_and_ps(Y, SIMD::g_signMask);\n        vPi = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPi), _mm_castps_si128(YSign)));\n        vPiOverTwo = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPiOverTwo), _mm_castps_si128(YSign)));\n        vPiOverFour = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPiOverFour), _mm_castps_si128(YSign)));\n        vThreePiOverFour = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vThreePiOverFour), _mm_castps_si128(YSign)));\n\n        Vector R1 = Vector::Select(vPi, YSign, XIsPositive);\n        Vector R2 = Vector::Select(ATanResultValid, vPiOverTwo, XEqualsZero);\n        Vector R3 = Vector::Select(R2, R1, YEqualsZero);\n        Vector R4 = Vector::Select(vThreePiOverFour, vPiOverFour, XIsPositive);\n        Vector R5 = Vector::Select(vPiOverTwo, R4, XEqualsInfinity);\n        Vector Result = Vector::Select(R3, R5, YEqualsInfinity);\n        ATanResultValid = _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(Result), _mm_castps_si128(ATanResultValid)));\n\n        Vector V = Y / X;\n        Vector R0 = Vector::ATan(V);\n        R1 = Vector::Select(vPi, Vector(SIMD::g_signMask), XIsPositive);\n        R2 = R0 + R1;\n\n        return Vector::Select(Result, R2, ATanResultValid);\n    }\n\n    FORCE_INLINE Vector Vector::SinEst(const Vector& vec)\n    {\n        // Force the value within the bounds of pi\n        auto m_x = Vector::AngleMod2Pi(vec);\n\n        // Map in [-pi/2,pi/2] with sin(m_y) = sin(m_x).\n        __m128 sign = _mm_and_ps(m_x, SIMD::g_signMask);\n        __m128 c = _mm_or_ps(Vector::Pi, sign);  // pi when m_x >= 0, -pi when m_x < 0\n        __m128 absx = _mm_andnot_ps(sign, m_x);  // |m_x|\n        __m128 rflx = _mm_sub_ps(c, m_x);\n        __m128 comp = _mm_cmple_ps(absx, Vector::PiDivTwo);\n        __m128 select0 = _mm_and_ps(comp, m_x);\n        __m128 select1 = _mm_andnot_ps(comp, rflx);\n        m_x = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        const auto SEC = SIMD::g_sinCoefficients1;\n        auto vConstants = _mm_shuffle_ps(SEC, SEC, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        vConstants = _mm_shuffle_ps(SEC, SEC, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SEC, SEC, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, m_x);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::CosEst(const Vector& vec)\n    {\n        // Map V to m_x in [-pi,pi].\n        auto m_x = Vector::AngleMod2Pi(vec);\n\n        // Map in [-pi/2,pi/2] with cos(m_y) = sign*cos(m_x).\n        auto sign = _mm_and_ps(m_x, SIMD::g_signMask);\n        __m128 c = _mm_or_ps(Vector::Pi, sign);  // pi when m_x >= 0, -pi when m_x < 0\n        __m128 absx = _mm_andnot_ps(sign, m_x);  // |m_x|\n        __m128 rflx = _mm_sub_ps(c, m_x);\n        __m128 comp = _mm_cmple_ps(absx, Vector::PiDivTwo);\n        __m128 select0 = _mm_and_ps(comp, m_x);\n        __m128 select1 = _mm_andnot_ps(comp, rflx);\n        m_x = _mm_or_ps(select0, select1);\n        select0 = _mm_and_ps(comp, Vector::One);\n        select1 = _mm_andnot_ps(comp, Vector::NegativeOne);\n        sign = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        const auto CEC = SIMD::g_cosCoefficients1;\n        auto vConstants = _mm_shuffle_ps(CEC, CEC, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        vConstants = _mm_shuffle_ps(CEC, CEC, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CEC, CEC, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, sign);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::TanEst(const Vector& vec)\n    {\n        Vector W = Vector(SIMD::g_tanEstCoefficients).GetSplatW();\n        Vector V1 = (vec * W).Round();\n        V1 = Vector::NegativeMultiplySubtract(Vector::Pi, V1, vec);\n\n        Vector const T0 = Vector(SIMD::g_tanEstCoefficients).GetSplatX();\n        Vector const T1 = Vector(SIMD::g_tanEstCoefficients).GetSplatY();\n        Vector const T2 = Vector(SIMD::g_tanEstCoefficients).GetSplatZ();\n\n        auto V2T2 = Vector::NegativeMultiplySubtract(V1, V1, T2);\n        auto V2 = V1 * V1;\n        auto V1T0 = V1 * T0;\n        auto V1T1 = V1 * T1;\n\n        auto N = Vector::MultiplyAdd(V2, V1T1, V1T0);\n        auto D = V2T2.GetInverseEst();\n        return N * D;\n    }\n\n    FORCE_INLINE Vector Vector::ASinEst(const Vector& vec)\n    {\n        __m128 nonnegative = _mm_cmpge_ps(vec, Vector::Zero);\n        __m128 mvalue = _mm_sub_ps(Vector::Zero, vec);\n        __m128 m_x = _mm_max_ps(vec, mvalue);  // |vec|\n\n        // Compute (1-|vec|), clamp to zero to avoid sqrt of negative number.\n        __m128 oneMValue = _mm_sub_ps(Vector::One, m_x);\n        __m128 clampOneMValue = _mm_max_ps(Vector::Zero, oneMValue);\n        __m128 root = _mm_sqrt_ps(clampOneMValue);  // sqrt(1-|vec|)\n\n        // Compute polynomial approximation\n        const auto AEC = SIMD::g_arcEstCoefficients;\n        auto vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 t0 = _mm_mul_ps(vConstants, m_x);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, root);\n\n        __m128 t1 = _mm_sub_ps(Vector::Pi, t0);\n        t0 = _mm_and_ps(nonnegative, t0);\n        t1 = _mm_andnot_ps(nonnegative, t1);\n        t0 = _mm_or_ps(t0, t1);\n        t0 = _mm_sub_ps(Vector::PiDivTwo, t0);\n        return t0;\n    }\n\n    FORCE_INLINE Vector Vector::ACosEst(const Vector& vec)\n    {\n        __m128 nonnegative = _mm_cmpge_ps(vec, Vector::Zero);\n        __m128 mvalue = _mm_sub_ps(Vector::Zero, vec);\n        __m128 m_x = _mm_max_ps(vec, mvalue);  // |vec|\n\n        // Compute (1-|vec|), clamp to zero to avoid sqrt of negative number.\n        __m128 oneMValue = _mm_sub_ps(Vector::One, m_x);\n        __m128 clampOneMValue = _mm_max_ps(Vector::Zero, oneMValue);\n        __m128 root = _mm_sqrt_ps(clampOneMValue);  // sqrt(1-|vec|)\n\n        // Compute polynomial approximation\n        auto vConstants = _mm_shuffle_ps(SIMD::g_arcEstCoefficients, SIMD::g_arcEstCoefficients, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 t0 = _mm_mul_ps(vConstants, m_x);\n\n        vConstants = _mm_shuffle_ps(SIMD::g_arcEstCoefficients, SIMD::g_arcEstCoefficients, _MM_SHUFFLE(2, 2, 2, 2));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(SIMD::g_arcEstCoefficients, SIMD::g_arcEstCoefficients, _MM_SHUFFLE(1, 1, 1, 1));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, m_x);\n\n        vConstants = _mm_shuffle_ps(SIMD::g_arcEstCoefficients, SIMD::g_arcEstCoefficients, _MM_SHUFFLE(0, 0, 0, 0));\n        t0 = _mm_add_ps(t0, vConstants);\n        t0 = _mm_mul_ps(t0, root);\n\n        __m128 t1 = _mm_sub_ps(Vector::Pi, t0);\n        t0 = _mm_and_ps(nonnegative, t0);\n        t1 = _mm_andnot_ps(nonnegative, t1);\n        t0 = _mm_or_ps(t0, t1);\n        return t0;\n    }\n\n    FORCE_INLINE Vector Vector::ATanEst(const Vector& vec)\n    {\n        __m128 absV = vec.GetAbs();\n        __m128 invV = _mm_div_ps(Vector::One, vec);\n        __m128 comp = _mm_cmpgt_ps(vec, Vector::One);\n        __m128 select0 = _mm_and_ps(comp, Vector::One);\n        __m128 select1 = _mm_andnot_ps(comp, Vector::NegativeOne);\n        __m128 sign = _mm_or_ps(select0, select1);\n        comp = _mm_cmple_ps(absV, Vector::One);\n        select0 = _mm_and_ps(comp, Vector::Zero);\n        select1 = _mm_andnot_ps(comp, sign);\n        sign = _mm_or_ps(select0, select1);\n        select0 = _mm_and_ps(comp, vec);\n        select1 = _mm_andnot_ps(comp, invV);\n        __m128 m_x = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation\n        Vector const AEC = SIMD::g_aTanEstCoefficients1;\n        Vector vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(3, 3, 3, 3));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(AEC, AEC, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        // ATanEstCoefficients0 is already splatted\n        Result = _mm_add_ps(Result, SIMD::g_aTanEstCoefficients0);\n        Result = _mm_mul_ps(Result, m_x);\n        __m128 result1 = _mm_mul_ps(sign, Vector::PiDivTwo);\n        result1 = _mm_sub_ps(result1, Result);\n\n        comp = _mm_cmpeq_ps(sign, Vector::Zero);\n        select0 = _mm_and_ps(comp, Result);\n        select1 = _mm_andnot_ps(comp, result1);\n        Result = _mm_or_ps(select0, select1);\n        return Result;\n    }\n\n    FORCE_INLINE Vector Vector::ATan2Est(const Vector& X, const Vector& Y)\n    {\n        Vector ATanResultValid = Vector(SIMD::g_trueMask);\n\n        Vector vPi = Vector(SIMD::g_aTan2Constants).GetSplatX();\n        Vector vPiOverTwo = Vector(SIMD::g_aTan2Constants).GetSplatY();\n        Vector vPiOverFour = Vector(SIMD::g_aTan2Constants).GetSplatZ();\n        Vector vThreePiOverFour = Vector(SIMD::g_aTan2Constants).GetSplatW();\n\n        Vector YEqualsZero = Y.EqualsZero();\n        Vector XEqualsZero = X.EqualsZero();\n        Vector XIsPositive = _mm_and_ps(X, SIMD::g_signMask);\n        XIsPositive = _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(XIsPositive), _mm_castps_si128(Vector::Zero)));\n        Vector YEqualsInfinity = Y.EqualsInfinity();\n        Vector XEqualsInfinity = X.EqualsInfinity();\n\n        Vector YSign = _mm_and_ps(Y, SIMD::g_signMask);\n        vPi = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPi), _mm_castps_si128(YSign)));\n        vPiOverTwo = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPiOverTwo), _mm_castps_si128(YSign)));\n        vPiOverFour = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vPiOverFour), _mm_castps_si128(YSign)));\n        vThreePiOverFour = _mm_castsi128_ps(_mm_or_si128(_mm_castps_si128(vThreePiOverFour), _mm_castps_si128(YSign)));\n\n        Vector R1 = Vector::Select(vPi, YSign, XIsPositive);\n        Vector R2 = Vector::Select(ATanResultValid, vPiOverTwo, XEqualsZero);\n        Vector R3 = Vector::Select(R2, R1, YEqualsZero);\n        Vector R4 = Vector::Select(vThreePiOverFour, vPiOverFour, XIsPositive);\n        Vector R5 = Vector::Select(vPiOverTwo, R4, XEqualsInfinity);\n        Vector Result = Vector::Select(R3, R5, YEqualsInfinity);\n        ATanResultValid = _mm_castsi128_ps(_mm_cmpeq_epi32(_mm_castps_si128(Result), _mm_castps_si128(ATanResultValid)));\n\n        Vector Reciprocal = X.GetInverseEst();\n        Vector V = Y * Reciprocal;\n        Vector R0 = Vector::ATanEst(V);\n\n        R1 = Vector::Select(vPi, Vector(SIMD::g_signMask), XIsPositive);\n        R2 = R0 + R1;\n        Result = Vector::Select(Result, R2, ATanResultValid);\n\n        return Result;\n    }\n\n    FORCE_INLINE void Vector::SinCos(Vector& sin, Vector& cos, float angle)\n    {\n        return SinCos(sin, cos, Vector(angle));\n    }\n\n    FORCE_INLINE void Vector::SinCos(Vector& sin, Vector& cos, const Vector& angle)\n    {\n        // Force the value within the bounds of pi\n        auto m_x = Vector::AngleMod2Pi(angle);\n\n        // Map in [-pi/2,pi/2] with sin(m_y) = sin(m_x), cos(m_y) = sign*cos(m_x).\n        auto sign = _mm_and_ps(m_x, SIMD::g_signMask);\n        __m128 c = _mm_or_ps(Vector::Pi, sign);  // pi when m_x >= 0, -pi when m_x < 0\n        __m128 absx = _mm_andnot_ps(sign, m_x);  // |m_x|\n        __m128 rflx = _mm_sub_ps(c, m_x);\n        __m128 comp = _mm_cmple_ps(absx, Vector::PiDivTwo);\n        __m128 select0 = _mm_and_ps(comp, m_x);\n        __m128 select1 = _mm_andnot_ps(comp, rflx);\n        m_x = _mm_or_ps(select0, select1);\n        select0 = _mm_and_ps(comp, Vector::One);\n        select1 = _mm_andnot_ps(comp, Vector::NegativeOne);\n        sign = _mm_or_ps(select0, select1);\n\n        __m128 x2 = _mm_mul_ps(m_x, m_x);\n\n        // Compute polynomial approximation of sine\n        const auto SC1 = SIMD::g_sinCoefficients1;\n        auto vConstants = _mm_shuffle_ps(SC1, SC1, _MM_SHUFFLE(0, 0, 0, 0));\n        __m128 Result = _mm_mul_ps(vConstants, x2);\n\n        const auto SC0 = SIMD::g_sinCoefficients0;\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(3, 3, 3, 3));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(SC0, SC0, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, m_x);\n        sin = Result;\n\n        // Compute polynomial approximation of cosine\n        const auto CC1 = SIMD::g_cosCoefficients1;\n        vConstants = _mm_shuffle_ps(CC1, CC1, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_mul_ps(vConstants, x2);\n\n        const auto CC0 = SIMD::g_cosCoefficients0;\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(3, 3, 3, 3));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(2, 2, 2, 2));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(1, 1, 1, 1));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n\n        vConstants = _mm_shuffle_ps(CC0, CC0, _MM_SHUFFLE(0, 0, 0, 0));\n        Result = _mm_add_ps(Result, vConstants);\n        Result = _mm_mul_ps(Result, x2);\n        Result = _mm_add_ps(Result, Vector::One);\n        Result = _mm_mul_ps(Result, sign);\n        cos = Result;\n    }\n\n    FORCE_INLINE Vector Vector::AngleMod2Pi(const Vector& angles)\n    {\n        // Modulo the range of the given angles such that -Pi <= Angles < Pi\n        Vector result = _mm_mul_ps(angles, Vector::OneDivTwoPi);\n        result.Round();\n        result = _mm_mul_ps(result, Vector::TwoPi);\n        result = _mm_sub_ps(angles, result);\n        return result;\n    }\n\n    FORCE_INLINE Vector::operator __m128& ()\n    {\n        return m_data;\n    }\n\n    FORCE_INLINE Vector::operator const __m128& () const\n    {\n        return m_data;\n    }\n\n    FORCE_INLINE Vector::Vector()\n    {\n    }\n\n    FORCE_INLINE Vector::Vector(Axis axis)\n    {\n        switch (axis)\n        {\n        case Axis::X: *this = Vector::UnitX; break;\n        case Axis::Y: *this = Vector::UnitY; break;\n        case Axis::Z: *this = Vector::UnitZ; break;\n        default: HALT(); break;\n        }\n    }\n\n    FORCE_INLINE Vector::Vector(ZeroInit_t)\n    {\n        memset(this, 0, sizeof(Vector));\n    }\n\n    FORCE_INLINE Vector::Vector(float v)\n    {\n        m_data = _mm_set1_ps(v);\n    }\n\n    FORCE_INLINE Vector::Vector(__m128 v)\n        : m_data(v)\n    {\n    }\n\n    FORCE_INLINE Vector::Vector(float ix, float iy, float iz, float iw)\n    {\n        m_data = _mm_set_ps(iw, iz, iy, ix);\n    }\n\n    FORCE_INLINE Vector::Vector(const Float2& v, float iz, float iw)\n    {\n        m_data = _mm_set_ps(iw, iz, v.m_y, v.m_x);\n    }\n\n    FORCE_INLINE Vector::Vector(const Float3& v, float iw)\n    {\n        m_data = _mm_set_ps(iw, v.m_z, v.m_y, v.m_x);\n    }\n\n    FORCE_INLINE Vector::Vector(const Float4& v)\n    {\n        m_data = _mm_loadu_ps(&v.m_x);\n    }\n\n    FORCE_INLINE Vector::Vector(const float* pValues)\n    {\n        m_data = _mm_loadu_ps(pValues);\n    }\n\n    FORCE_INLINE bool Vector::IsValid() const\n    {\n        return !IsNaN4() && !IsInfinite4();\n    }\n\n    FORCE_INLINE void Vector::Store(float* pValues) const\n    {\n        _mm_storeu_ps(pValues, m_data);\n    }\n\n    FORCE_INLINE void Vector::StoreFloat(float& value) const\n    {\n        _mm_store_ss(&value, m_data);\n    }\n\n    FORCE_INLINE void Vector::StoreFloat2(Float2& value) const\n    {\n        auto yVec = _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(1, 1, 1, 1));\n        _mm_store_ss(&value.m_x, m_data);\n        _mm_store_ss(&value.m_y, yVec);\n    }\n\n    FORCE_INLINE void Vector::StoreFloat3(Float3& value) const\n    {\n        auto yVec = _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(1, 1, 1, 1));\n        auto zVec = _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(2, 2, 2, 2));\n        _mm_store_ss(&value.m_x, m_data);\n        _mm_store_ss(&value.m_y, yVec);\n        _mm_store_ss(&value.m_z, zVec);\n    }\n\n    FORCE_INLINE void Vector::StoreFloat4(Float4& value) const\n    {\n        _mm_storeu_ps(&value.m_x, m_data);\n    }\n\n    FORCE_INLINE float Vector::ToFloat() const\n    {\n        float v;\n        StoreFloat(v);\n        return v;\n    }\n\n    FORCE_INLINE Float2 Vector::ToFloat2() const\n    {\n        Float2 v;\n        StoreFloat2(v);\n        return v;\n    }\n\n    FORCE_INLINE Float3 Vector::ToFloat3() const\n    {\n        Float3 v;\n        StoreFloat3(v);\n        return v;\n    }\n\n    FORCE_INLINE Float4 Vector::ToFloat4() const\n    {\n        Float4 v;\n        StoreFloat4(v);\n        return v;\n    }\n\n    FORCE_INLINE Vector::operator Float2() const\n    {\n        return ToFloat2();\n    }\n\n    FORCE_INLINE Vector::operator Float3() const\n    {\n        return ToFloat3();\n    }\n\n    FORCE_INLINE Vector::operator Float4() const\n    {\n        return ToFloat4();\n    }\n\n    FORCE_INLINE float Vector::GetX() const\n    {\n        return _mm_cvtss_f32(m_data);\n    }\n\n    FORCE_INLINE float Vector::GetY() const\n    {\n        auto vTemp = GetSplatY();\n        return _mm_cvtss_f32(vTemp);\n    }\n\n    FORCE_INLINE float Vector::GetZ() const\n    {\n        auto vTemp = GetSplatZ();\n        return _mm_cvtss_f32(vTemp);\n    }\n\n    FORCE_INLINE float Vector::GetW() const\n    {\n        auto vTemp = GetSplatW();\n        return _mm_cvtss_f32(vTemp);\n    }\n\n    FORCE_INLINE void Vector::SetX(float x)\n    {\n        m_data = _mm_move_ss(m_data, _mm_set_ss(x));\n    }\n\n    FORCE_INLINE void Vector::SetY(float y)\n    {\n        m_data = _mm_insert_ps(m_data, _mm_set_ss(y), 0x10);\n    }\n\n    FORCE_INLINE void Vector::SetZ(float z)\n    {\n        m_data = _mm_insert_ps(m_data, _mm_set_ss(z), 0x20);\n    }\n\n    FORCE_INLINE void Vector::SetW(float w)\n    {\n        m_data = _mm_insert_ps(m_data, _mm_set_ss(w), 0x30);\n    }\n\n    FORCE_INLINE float Vector::operator[](uint32_t i) const\n    {\n        ASSERT(i < 4);\n\n        switch (i)\n        {\n        case 0: return GetX(); break;\n        case 1: return GetY(); break;\n        case 2: return GetZ(); break;\n        case 3: return GetW(); break;\n        }\n\n        UNREACHABLE_CODE();\n        return 0.0f;\n    }\n\n    FORCE_INLINE bool Vector::IsW1() const\n    {\n        return GetSplatW().IsEqual4(Vector::One);\n    }\n\n    FORCE_INLINE bool Vector::IsW0() const\n    {\n        return GetSplatW().IsZero4();\n    }\n\n    FORCE_INLINE Vector& Vector::SetW0()\n    {\n        SetW(0.0f);\n        return *this;\n    }\n\n    FORCE_INLINE Vector& Vector::SetW1()\n    {\n        SetW(1.0f);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetWithW0() const\n    {\n        Vector v = *this;\n        v.SetW0();\n        return v;\n    }\n\n    FORCE_INLINE Vector Vector::GetWithW1() const\n    {\n        Vector v = *this;\n        v.SetW1();\n        return v;\n    }\n\n    FORCE_INLINE Vector Vector::Get2D() const\n    {\n        return Vector::Select(*this, Vector::Zero, Vector::Select0011);\n    }\n\n    FORCE_INLINE Vector Vector::Get3D() const\n    {\n        return Vector::Select(*this, Vector::Zero, Vector::Select0001);\n    }\n\n    FORCE_INLINE Vector Vector::operator+(const Vector& v) const\n    {\n        return _mm_add_ps(m_data, v);\n    }\n\n    FORCE_INLINE Vector& Vector::operator+=(const Vector& v)\n    {\n        m_data = _mm_add_ps(m_data, v);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::operator-(const Vector& v) const\n    {\n        return _mm_sub_ps(m_data, v);\n    }\n\n    FORCE_INLINE Vector& Vector::operator-=(const Vector& v)\n    {\n        m_data = _mm_sub_ps(m_data, v);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::operator*(const Vector& v) const\n    {\n        return _mm_mul_ps(m_data, v);\n    }\n\n    FORCE_INLINE Vector& Vector::operator*=(const Vector& v)\n    {\n        m_data = _mm_mul_ps(m_data, v);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::operator/(const Vector& v) const\n    {\n        return _mm_div_ps(m_data, v);\n    }\n\n    FORCE_INLINE Vector& Vector::operator/=(const Vector& v)\n    {\n        m_data = _mm_div_ps(m_data, v);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::operator*(float const f) const\n    {\n        return operator*(Vector(f));\n    }\n\n    FORCE_INLINE Vector& Vector::operator*=(float const f)\n    {\n        return operator*=(Vector(f));\n    }\n\n    FORCE_INLINE Vector Vector::operator/(float const f) const\n    {\n        return operator/(Vector(f));\n    }\n\n    FORCE_INLINE Vector& Vector::operator/=(float const f)\n    {\n        return operator/=(Vector(f));\n    }\n\n    FORCE_INLINE Vector Vector::operator-() const\n    {\n        return GetNegated();\n    }\n\n    FORCE_INLINE Vector Vector::Orthogonal2D() const\n    {\n        static Vector const negX(-1.0f, 1.0f, 1.0f, 1.0f);\n\n        Vector result;\n        result = _mm_shuffle_ps(*this, *this, _MM_SHUFFLE(3, 2, 0, 1));\n        result = _mm_mul_ps(result, negX);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Cross2(const Vector& other) const\n    {\n        Vector vResult = _mm_shuffle_ps(other.m_data, other.m_data, _MM_SHUFFLE(0, 1, 0, 1));\n        vResult = _mm_mul_ps(vResult, m_data);\n        Vector vTemp = vResult.GetSplatY();\n        vResult = _mm_sub_ss(vResult, vTemp);\n        vResult = vResult.GetSplatX();\n        return vResult;\n    }\n\n    FORCE_INLINE Vector Vector::Cross3(const Vector& other) const\n    {\n        auto vTemp1 = _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(3, 0, 2, 1));\n        auto vTemp2 = _mm_shuffle_ps(other, other, _MM_SHUFFLE(3, 1, 0, 2));\n        Vector result = _mm_mul_ps(vTemp1, vTemp2);\n        vTemp1 = _mm_shuffle_ps(vTemp1, vTemp1, _MM_SHUFFLE(3, 0, 2, 1));\n        vTemp2 = _mm_shuffle_ps(vTemp2, vTemp2, _MM_SHUFFLE(3, 1, 0, 2));\n        vTemp1 = _mm_mul_ps(vTemp1, vTemp2);\n        result = _mm_sub_ps(result, vTemp1);\n        result = _mm_and_ps(result, SIMD::g_maskXYZ0);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Dot2(const Vector& other) const\n    {\n        // Perform the dot product on m_x and m_y\n        Vector result = _mm_mul_ps(m_data, other);\n        // vTemp has m_y splatted\n        auto vTemp = _mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 1, 1, 1));\n        // m_x+m_y\n        result = _mm_add_ss(result, vTemp);\n        result = _mm_shuffle_ps(result, result, _MM_SHUFFLE(0, 0, 0, 0));\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Dot3(const Vector& vOther) const\n    {\n        // Perform the dot product\n        auto vDot = _mm_mul_ps(m_data, vOther);\n        // m_x=Dot.vector4_f32[1], m_y=Dot.vector4_f32[2]\n        auto vTemp = _mm_shuffle_ps(vDot, vDot, _MM_SHUFFLE(2, 1, 2, 1));\n        // Result.vector4_f32[0] = m_x+m_y\n        vDot = _mm_add_ss(vDot, vTemp);\n        // m_x=Dot.vector4_f32[2]\n        vTemp = _mm_shuffle_ps(vTemp, vTemp, _MM_SHUFFLE(1, 1, 1, 1));\n        // Result.vector4_f32[0] = (m_x+m_y)+m_z\n        vDot = _mm_add_ss(vDot, vTemp);\n        // Splat m_x\n        Vector result = _mm_shuffle_ps(vDot, vDot, _MM_SHUFFLE(0, 0, 0, 0));\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Dot4(const Vector& other) const\n    {\n        auto vTemp2 = other;\n        auto vTemp = _mm_mul_ps(m_data, vTemp2);\n        vTemp2 = _mm_shuffle_ps(vTemp2, vTemp, _MM_SHUFFLE(1, 0, 0, 0)); // Copy X to the Z position and Y to the W position\n        vTemp2 = _mm_add_ps(vTemp2, vTemp); // Add Z = X+Z; W = Y+W;\n        vTemp = _mm_shuffle_ps(vTemp, vTemp2, _MM_SHUFFLE(0, 3, 0, 0));  // Copy W to the Z position\n        vTemp = _mm_add_ps(vTemp, vTemp2); // Add Z and W together\n        return _mm_shuffle_ps(vTemp, vTemp, _MM_SHUFFLE(2, 2, 2, 2)); // Splat Z and return\n    }\n\n    FORCE_INLINE float Vector::GetDot2(const Vector& other) const\n    {\n        return Dot2(other).ToFloat();\n    }\n\n    FORCE_INLINE float Vector::GetDot3(const Vector& other) const\n    {\n        return Dot3(other).ToFloat();\n    }\n\n    FORCE_INLINE float Vector::GetDot4(const Vector& other) const\n    {\n        return Dot4(other).ToFloat();\n    }\n\n    FORCE_INLINE Vector Vector::ScalarProjection(const Vector& other) const\n    {\n        Vector const normalizedThis = GetNormalized3();\n        Vector const projection = other.Dot3(normalizedThis);\n        return projection;\n    }\n\n    FORCE_INLINE float Vector::GetScalarProjection(const Vector& other) const\n    {\n        return ScalarProjection(other).ToFloat();\n    }\n\n    FORCE_INLINE Vector Vector::VectorProjection(const Vector& other) const\n    {\n        Vector const normalizedThis = GetNormalized3();\n        Vector const dotOther = other.Dot3(normalizedThis);\n        Vector const projection = normalizedThis * dotOther;\n        return projection;\n    }\n\n    FORCE_INLINE Vector& Vector::Invert()\n    {\n        m_data = _mm_div_ps(Vector::One, m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetInverse() const\n    {\n        return _mm_div_ps(Vector::One, m_data);\n    }\n\n    FORCE_INLINE Vector Vector::GetReciprocal() const\n    {\n        return GetInverse();\n    }\n\n    FORCE_INLINE Vector& Vector::InvertEst()\n    {\n        m_data = _mm_rcp_ps(m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetInverseEst() const\n    {\n        return _mm_rcp_ps(m_data);\n    }\n\n    FORCE_INLINE Vector& Vector::Negate()\n    {\n        m_data = _mm_sub_ps(Vector::Zero, m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetNegated() const\n    {\n        return _mm_sub_ps(Vector::Zero, m_data);\n    }\n\n    FORCE_INLINE Vector& Vector::Abs()\n    {\n        m_data = _mm_max_ps(_mm_sub_ps(Vector::Zero, m_data), m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetAbs() const\n    {\n        return _mm_max_ps(_mm_sub_ps(Vector::Zero, m_data), m_data);\n    }\n\n    FORCE_INLINE Vector& Vector::Sqrt()\n    {\n        m_data = _mm_sqrt_ps(m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetSqrt()\n    {\n        return _mm_sqrt_ps(m_data);\n    }\n\n    FORCE_INLINE Vector& Vector::ReciprocalSqrt()\n    {\n        m_data = _mm_div_ps(Vector::One, _mm_sqrt_ps(m_data));\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetReciprocalSqrt()\n    {\n        return _mm_div_ps(Vector::One, _mm_sqrt_ps(m_data));\n    }\n\n    FORCE_INLINE Vector& Vector::EstimatedReciprocalSqrt()\n    {\n        m_data = _mm_rsqrt_ps(m_data);\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetEstimatedReciprocalSqrt()\n    {\n        return _mm_rsqrt_ps(m_data);\n    }\n\n    FORCE_INLINE Vector& Vector::Normalize2()\n    {\n        // Perform the dot product on m_x and m_y only\n        auto vLengthSq = _mm_mul_ps(m_data, m_data);\n        auto vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(1, 1, 1, 1));\n        vLengthSq = _mm_add_ss(vLengthSq, vTemp);\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(0, 0, 0, 0));\n        // Prepare for the division\n        auto vResult = _mm_sqrt_ps(vLengthSq);\n        // Create zero with a single instruction\n        auto vZeroMask = _mm_setzero_ps();\n        // Test for a divide by zero (Must be FP to detect -0.0)\n        vZeroMask = _mm_cmpneq_ps(vZeroMask, vResult);\n        // Failsafe on zero (Or epsilon) length planes\n        // If the length is infinity, set the elements to zero\n        vLengthSq = _mm_cmpneq_ps(vLengthSq, Vector::Infinity);\n        // Divide to perform the normalization\n        vResult = _mm_div_ps(m_data, vResult);\n        // Any that are infinity, set to zero\n        vResult = _mm_and_ps(vResult, vZeroMask);\n        // Select qnan or result based on infinite length\n        auto vTemp1 = _mm_andnot_ps(vLengthSq, Vector::QNaN);\n        auto vTemp2 = _mm_and_ps(vResult, vLengthSq);\n        m_data = _mm_or_ps(vTemp1, vTemp2);\n\n        *this = Select(*this, Vector::Zero, Select0011);\n\n        return *this;\n    }\n\n    FORCE_INLINE Vector& Vector::Normalize3()\n    {\n        // Perform the dot product on m_x,m_y and m_z only\n        auto vLengthSq = _mm_mul_ps(m_data, m_data);\n        auto vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(2, 1, 2, 1));\n        vLengthSq = _mm_add_ss(vLengthSq, vTemp);\n        vTemp = _mm_shuffle_ps(vTemp, vTemp, _MM_SHUFFLE(1, 1, 1, 1));\n        vLengthSq = _mm_add_ss(vLengthSq, vTemp);\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(0, 0, 0, 0));\n        // Prepare for the division\n        auto vResult = _mm_sqrt_ps(vLengthSq);\n        // Create zero with a single instruction\n        auto vZeroMask = _mm_setzero_ps();\n        // Test for a divide by zero (Must be FP to detect -0.0)\n        vZeroMask = _mm_cmpneq_ps(vZeroMask, vResult);\n        // Failsafe on zero (Or epsilon) length planes\n        // If the length is infinity, set the elements to zero\n        vLengthSq = _mm_cmpneq_ps(vLengthSq, Vector::Infinity);\n        // Divide to perform the normalization\n        vResult = _mm_div_ps(m_data, vResult);\n        // Any that are infinity, set to zero\n        vResult = _mm_and_ps(vResult, vZeroMask);\n        // Select qnan or result based on infinite length\n        auto vTemp1 = _mm_andnot_ps(vLengthSq, Vector::QNaN);\n        auto vTemp2 = _mm_and_ps(vResult, vLengthSq);\n        m_data = _mm_or_ps(vTemp1, vTemp2);\n\n        *this = Select(*this, Vector::Zero, Select0001);\n\n        return *this;\n    }\n\n    FORCE_INLINE Vector& Vector::Normalize4()\n    {\n        // Perform the dot product on m_x,m_y,m_z and m_w\n        auto vLengthSq = _mm_mul_ps(m_data, m_data);\n        // vTemp has m_z and m_w\n        auto vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(3, 2, 3, 2));\n        // m_x+m_z, m_y+m_w\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        // m_x+m_z,m_x+m_z,m_x+m_z,m_y+m_w\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(1, 0, 0, 0));\n        // ??,??,m_y+m_w,m_y+m_w\n        vTemp = _mm_shuffle_ps(vTemp, vLengthSq, _MM_SHUFFLE(3, 3, 0, 0));\n        // ??,??,m_x+m_z+m_y+m_w,??\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        // Splat the length\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(2, 2, 2, 2));\n        // Prepare for the division\n        auto vResult = _mm_sqrt_ps(vLengthSq);\n        // Create zero with a single instruction\n        auto vZeroMask = _mm_setzero_ps();\n        // Test for a divide by zero (Must be FP to detect -0.0)\n        vZeroMask = _mm_cmpneq_ps(vZeroMask, vResult);\n        // Failsafe on zero (Or epsilon) length planes\n        // If the length is infinity, set the elements to zero\n        vLengthSq = _mm_cmpneq_ps(vLengthSq, Vector::Infinity);\n        // Divide to perform the normalization\n        vResult = _mm_div_ps(m_data, vResult);\n        // Any that are infinity, set to zero\n        vResult = _mm_and_ps(vResult, vZeroMask);\n        // Select qnan or result based on infinite length\n        auto vTemp1 = _mm_andnot_ps(vLengthSq, Vector::QNaN);\n        auto vTemp2 = _mm_and_ps(vResult, vLengthSq);\n        m_data = _mm_or_ps(vTemp1, vTemp2);\n\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetNormalized2() const\n    {\n        Vector v = *this;\n        v.Normalize2();\n        return v;\n    }\n\n    FORCE_INLINE Vector Vector::GetNormalized3() const\n    {\n        Vector v = *this;\n        v.Normalize3();\n        return v;\n    }\n\n    FORCE_INLINE Vector Vector::GetNormalized4() const\n    {\n        Vector v = *this;\n        v.Normalize4();\n        return v;\n    }\n\n    FORCE_INLINE Vector& Vector::Floor()\n    {\n        Vector result;\n\n        // To handle NAN, INF and numbers greater than 8388608, use masking\n        __m128i vTest = _mm_and_si128(_mm_castps_si128(m_data), SIMD::g_absMask);\n        vTest = _mm_cmplt_epi32(vTest, SIMD::g_noFraction);\n        // Truncate\n        __m128i vInt = _mm_cvttps_epi32(m_data);\n        result = _mm_cvtepi32_ps(vInt);\n        __m128 vLarger = _mm_cmpgt_ps(result, m_data);\n        // 0 -> 0, 0xffffffff -> -1.0f\n        vLarger = _mm_cvtepi32_ps(_mm_castps_si128(vLarger));\n        result = _mm_add_ps(result, vLarger);\n        // All numbers less than 8388608 will use the round to int\n        result = _mm_and_ps(result, _mm_castsi128_ps(vTest));\n        // All others, use the ORIGINAL value\n        vTest = _mm_andnot_si128(vTest, _mm_castps_si128(m_data));\n        result = _mm_or_ps(result, _mm_castsi128_ps(vTest));\n\n        m_data = result;\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetFloor() const\n    {\n        Vector v = *this;\n        v.Floor();\n        return v;\n    }\n\n    FORCE_INLINE Vector& Vector::Ceil()\n    {\n        Vector result;\n\n        // To handle NAN, INF and numbers greater than 8388608, use masking\n        __m128i vTest = _mm_and_si128(_mm_castps_si128(m_data), SIMD::g_absMask);\n        vTest = _mm_cmplt_epi32(vTest, SIMD::g_noFraction);\n        // Truncate\n        __m128i vInt = _mm_cvttps_epi32(m_data);\n        result = _mm_cvtepi32_ps(vInt);\n        __m128 vSmaller = _mm_cmplt_ps(result, m_data);\n        // 0 -> 0, 0xffffffff -> -1.0f\n        vSmaller = _mm_cvtepi32_ps(_mm_castps_si128(vSmaller));\n        result = _mm_sub_ps(result, vSmaller);\n        // All numbers less than 8388608 will use the round to int\n        result = _mm_and_ps(result, _mm_castsi128_ps(vTest));\n        // All others, use the ORIGINAL value\n        vTest = _mm_andnot_si128(vTest, _mm_castps_si128(m_data));\n        result = _mm_or_ps(result, _mm_castsi128_ps(vTest));\n\n        m_data = result;\n        return *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetCeil() const\n    {\n        Vector v = *this;\n        v.Ceil();\n        return v;\n    }\n\n    FORCE_INLINE Vector& Vector::Round()\n    {\n        __m128 sign = _mm_and_ps(m_data, SIMD::g_signMask);\n        __m128 sMagic = _mm_or_ps(SIMD::g_noFraction, sign);\n        __m128 R1 = _mm_add_ps(m_data, sMagic);\n        R1 = _mm_sub_ps(R1, sMagic);\n        __m128 R2 = _mm_and_ps(m_data, SIMD::g_absMask);\n        __m128 mask = _mm_cmple_ps(R2, SIMD::g_noFraction);\n        R2 = _mm_andnot_ps(mask, m_data);\n        R1 = _mm_and_ps(R1, mask);\n        m_data = _mm_xor_ps(R1, R2);\n        return  *this;\n    }\n\n    FORCE_INLINE Vector Vector::GetRound() const\n    {\n        Vector v = *this;\n        v.Round();\n        return v;\n    }\n\n    FORCE_INLINE Vector Vector::GetSign() const\n    {\n        Vector const selectMask = GreaterThanEqual(Vector::Zero);\n        return Vector::Select(Vector::NegativeOne, Vector::One, selectMask);\n    }\n\n    FORCE_INLINE Vector Vector::GetSplatX() const\n    {\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(0, 0, 0, 0));\n    }\n\n    FORCE_INLINE Vector Vector::GetSplatY() const\n    {\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(1, 1, 1, 1));\n    }\n\n    FORCE_INLINE Vector Vector::GetSplatZ() const\n    {\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(2, 2, 2, 2));\n    }\n\n    FORCE_INLINE Vector Vector::GetSplatW() const\n    {\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(3, 3, 3, 3));\n    }\n\n    template<uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx>\n    FORCE_INLINE Vector Vector::Swizzle() const\n    {\n        static_assert(xIdx < 4, \"Element index parameter out of range\");\n        static_assert(yIdx < 4, \"Element index parameter out of range\");\n        static_assert(zIdx < 4, \"Element index parameter out of range\");\n        static_assert(wIdx < 4, \"Element index parameter out of range\");\n        return _mm_shuffle_ps(m_data, m_data, _MM_SHUFFLE(wIdx, zIdx, yIdx, xIdx));\n    }\n\n    FORCE_INLINE Vector Vector::Swizzle(uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx) const\n    {\n        ASSERT(xIdx < 4 && yIdx < 4 && zIdx < 4 && wIdx < 4);\n        uint32_t const elem[4] = { xIdx, yIdx, zIdx, wIdx };\n        __m128i vControl = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&elem[0]));\n        return _mm_permutevar_ps(m_data, vControl);\n    }\n\n    FORCE_INLINE Vector Vector::Shuffle(uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx) const\n    {\n        return Swizzle(xIdx, yIdx, zIdx, wIdx);\n    }\n\n    template<uint32_t xIdx, uint32_t yIdx, uint32_t zIdx, uint32_t wIdx>\n    FORCE_INLINE Vector Vector::Shuffle() const\n    {\n        return Swizzle<xIdx, yIdx, zIdx, wIdx>();\n    }\n\n    FORCE_INLINE Vector Vector::Length2() const\n    {\n        Vector result;\n\n        result = _mm_mul_ps(m_data, m_data);\n        auto vTemp = _mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 1, 1, 1));\n        // m_x+m_y\n        result = _mm_add_ss(result, vTemp);\n        result = _mm_shuffle_ps(result, result, _MM_SHUFFLE(0, 0, 0, 0));\n        result = _mm_sqrt_ps(result);\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Length3() const\n    {\n        Vector result;\n\n        // Perform the dot product on m_x,m_y and m_z\n        result = _mm_mul_ps(m_data, m_data);\n        // vTemp has m_z and m_y\n        auto vTemp = _mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 2, 1, 2));\n        // m_x+m_z, m_y\n        result = _mm_add_ss(result, vTemp);\n        // m_y,m_y,m_y,m_y\n        vTemp = _mm_shuffle_ps(vTemp, vTemp, _MM_SHUFFLE(1, 1, 1, 1));\n        // m_x+m_z+m_y,??,??,??\n        result = _mm_add_ss(result, vTemp);\n        // Splat the length squared\n        result = _mm_shuffle_ps(result, result, _MM_SHUFFLE(0, 0, 0, 0));\n        // Get the length\n        result = _mm_sqrt_ps(result);\n\n        return result;\n    }\n\n    FORCE_INLINE Vector Vector::Length4() const\n    {\n        Vector result;\n\n        // Perform the dot product on m_x,m_y,m_z and m_w\n        result = _mm_mul_ps(m_data, m_data);\n        // vTemp has m_z and m_w\n        auto vTemp = _mm_shuffle_ps(result, result, _MM_SHUFFLE(3, 2, 3, 2));\n        // m_x+m_z, m_y+m_w\n        result = _mm_add_ps(result, vTemp);\n        // m_x+m_z,m_x+m_z,m_x+m_z,m_y+m_w\n        result = _mm_shuffle_ps(result, result, _MM_SHUFFLE(1, 0, 0, 0));\n        // ??,??,m_y+m_w,m_y+m_w\n        vTemp = _mm_shuffle_ps(vTemp, result, _MM_SHUFFLE(3, 3, 0, 0));\n        // ??,??,m_x+m_z+m_y+m_w,??\n        result = _mm_add_ps(result, vTemp);\n        // Splat the length\n        result = _mm_shuffle_ps(result, result, _MM_SHUFFLE(2, 2, 2, 2));\n        // Get the length\n        result = _mm_sqrt_ps(result);\n\n        return result;\n    }\n\n    FORCE_INLINE float Vector::GetLength2() const\n    {\n        return Length2().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetLength3() const\n    {\n        return Length3().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetLength4() const\n    {\n        return Length4().GetX();\n    }\n\n    FORCE_INLINE Vector Vector::InverseLength2() const\n    {\n        // Perform the dot product on m_x and m_y\n        auto vLengthSq = _mm_mul_ps(m_data, m_data);\n        // vTemp has m_y splatted\n        auto vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(1, 1, 1, 1));\n        // m_x+m_y\n        vLengthSq = _mm_add_ss(vLengthSq, vTemp);\n        vLengthSq = _mm_sqrt_ss(vLengthSq);\n        vLengthSq = _mm_div_ss(Vector::One, vLengthSq);\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(0, 0, 0, 0));\n        return vLengthSq;\n    }\n\n    FORCE_INLINE Vector Vector::InverseLength3() const\n    {\n        // Perform the dot product\n        auto vDot = _mm_mul_ps(m_data, m_data);\n        // m_x=Dot.m_y, m_y=Dot.m_z\n        auto vTemp = _mm_shuffle_ps(vDot, vDot, _MM_SHUFFLE(2, 1, 2, 1));\n        // Result.m_x = m_x+m_y\n        vDot = _mm_add_ss(vDot, vTemp);\n        // m_x=Dot.m_z\n        vTemp = _mm_shuffle_ps(vTemp, vTemp, _MM_SHUFFLE(1, 1, 1, 1));\n        // Result.m_x = (m_x+m_y)+m_z\n        vDot = _mm_add_ss(vDot, vTemp);\n        // Splat m_x\n        vDot = _mm_shuffle_ps(vDot, vDot, _MM_SHUFFLE(0, 0, 0, 0));\n        // Get the reciprocal\n        vDot = _mm_sqrt_ps(vDot);\n        // Get the reciprocal\n        vDot = _mm_div_ps(Vector::One, vDot);\n        return vDot;\n    }\n\n    FORCE_INLINE Vector Vector::InverseLength4() const\n    {\n        // Perform the dot product on m_x,m_y,m_z and m_w\n        auto vLengthSq = _mm_mul_ps(m_data, m_data);\n        // vTemp has m_z and m_w\n        auto vTemp = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(3, 2, 3, 2));\n        // m_x+m_z, m_y+m_w\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        // m_x+m_z,m_x+m_z,m_x+m_z,m_y+m_w\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(1, 0, 0, 0));\n        // ??,??,m_y+m_w,m_y+m_w\n        vTemp = _mm_shuffle_ps(vTemp, vLengthSq, _MM_SHUFFLE(3, 3, 0, 0));\n        // ??,??,m_x+m_z+m_y+m_w,??\n        vLengthSq = _mm_add_ps(vLengthSq, vTemp);\n        // Splat the length\n        vLengthSq = _mm_shuffle_ps(vLengthSq, vLengthSq, _MM_SHUFFLE(2, 2, 2, 2));\n        // Get the reciprocal\n        vLengthSq = _mm_sqrt_ps(vLengthSq);\n        // Accurate!\n        vLengthSq = _mm_div_ps(Vector::One, vLengthSq);\n        return vLengthSq;\n    }\n\n    FORCE_INLINE float Vector::GetInverseLength2() const\n    {\n        return InverseLength2().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetInverseLength3() const\n    {\n        return InverseLength3().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetInverseLength4() const\n    {\n        return InverseLength4().GetX();\n    }\n\n    FORCE_INLINE Vector Vector::LengthSquared2() const\n    {\n        return Vector::Dot2(m_data, m_data);\n    }\n\n    FORCE_INLINE Vector Vector::LengthSquared3() const\n    {\n        return Vector::Dot3(m_data, m_data);\n    }\n\n    FORCE_INLINE Vector Vector::LengthSquared4() const\n    {\n        return Vector::Dot4(m_data, m_data);\n    }\n\n    FORCE_INLINE float Vector::GetLengthSquared2() const\n    {\n        return LengthSquared2().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetLengthSquared3() const\n    {\n        return LengthSquared3().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetLengthSquared4() const\n    {\n        return LengthSquared4().GetX();\n    }\n\n    FORCE_INLINE Vector Vector::Distance2(const Vector& to) const\n    {\n        return (to - *this).Length2();\n    }\n\n    FORCE_INLINE Vector Vector::Distance3(const Vector& to) const\n    {\n        return (to - *this).Length3();\n    }\n\n    FORCE_INLINE Vector Vector::Distance4(const Vector& to) const\n    {\n        return (to - *this).Length4();\n    }\n\n    FORCE_INLINE float Vector::GetDistance2(const Vector& to) const\n    {\n        return (to - *this).Length2().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetDistance3(const Vector& to) const\n    {\n        return (to - *this).Length3().GetX();\n    }\n\n    FORCE_INLINE float Vector::GetDistance4(const Vector& to) const\n    {\n        return (to - *this).Length4().GetX();\n    }\n\n    FORCE_INLINE Vector Vector::DistanceSquared2(const Vector& to) const\n    {\n        return (to - *this).LengthSquared2();\n    }\n\n    FORCE_INLINE Vector Vector::DistanceSquared3(const Vector& to) const\n    {\n        return (to - *this).LengthSquared3();\n    }\n\n    FORCE_INLINE Vector Vector::DistanceSquared4(const Vector& to) const\n    {\n        return (to - *this).LengthSquared4();\n    }\n\n    FORCE_INLINE float Vector::GetDistanceSquared2(const Vector& to) const\n    {\n        return (to - *this).GetLengthSquared2();\n    }\n\n    FORCE_INLINE float Vector::GetDistanceSquared3(const Vector& to) const\n    {\n        return (to - *this).GetLengthSquared3();\n    }\n\n    FORCE_INLINE float Vector::GetDistanceSquared4(const Vector& to) const\n    {\n        return (to - *this).GetLengthSquared4();\n    }\n\n    FORCE_INLINE bool Vector::IsNormalized2() const\n    {\n        return (LengthSquared2() - Vector::One).Abs().IsLessThanEqual4(Vector::NormalizeCheckThreshold);\n    }\n\n    FORCE_INLINE bool Vector::IsNormalized3() const\n    {\n        return (LengthSquared3() - Vector::One).Abs().IsLessThanEqual4(Vector::NormalizeCheckThreshold);\n    }\n\n    FORCE_INLINE bool Vector::IsNormalized4() const\n    {\n        return (LengthSquared4() - Vector::One).Abs().IsLessThanEqual4(Vector::NormalizeCheckThreshold);\n    }\n\n    FORCE_INLINE Vector Vector::InBounds(const Vector& bounds) const\n    {\n        // Test if less than or equal\n        auto vTemp1 = _mm_cmple_ps(m_data, bounds);\n        // Negate the bounds\n        auto vTemp2 = _mm_mul_ps(bounds, Vector::NegativeOne);\n        // Test if greater or equal (Reversed)\n        vTemp2 = _mm_cmple_ps(vTemp2, m_data);\n        // Blend answers\n        vTemp1 = _mm_and_ps(vTemp1, vTemp2);\n        return vTemp1;\n    }\n\n    FORCE_INLINE bool Vector::IsInBounds2(const Vector& bounds) const\n    {\n        return ((_mm_movemask_ps(InBounds(bounds)) & 0x3) == 0x3) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsInBounds3(const Vector& bounds) const\n    {\n        return ((_mm_movemask_ps(InBounds(bounds)) & 0x7) == 0x7) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsInBounds4(const Vector& bounds) const\n    {\n        return (_mm_movemask_ps(InBounds(bounds)) == 0x0f) != 0;\n    }\n\n    FORCE_INLINE Vector Vector::Equal(const Vector& v) const\n    {\n        return _mm_cmpeq_ps(*this, v);\n    }\n\n    FORCE_INLINE bool Vector::IsEqual2(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(Equal(v)) & 3) == 3) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsEqual3(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(Equal(v)) & 7) == 7) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsEqual4(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(Equal(v)) == 0x0f) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::NearEqual(const Vector& v, const Vector& epsilon) const\n    {\n        // Get the difference\n        auto vDelta = _mm_sub_ps(m_data, v);\n        // Get the absolute value of the difference\n        auto vTemp = _mm_setzero_ps();\n        vTemp = _mm_sub_ps(vTemp, vDelta);\n        vTemp = _mm_max_ps(vTemp, vDelta);\n        vTemp = _mm_cmple_ps(vTemp, epsilon);\n        return vTemp;\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual2(const Vector& v, float epsilon) const\n    {\n        return IsNearEqual2(v, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual3(const Vector& v, float epsilon) const\n    {\n        return IsNearEqual3(v, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual4(const Vector& v, float epsilon) const\n    {\n        return IsNearEqual4(v, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual2(const Vector& v, const Vector& epsilon) const\n    {\n        return (((_mm_movemask_ps(NearEqual(v, epsilon)) & 3) == 0x3) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual3(const Vector& v, const Vector& epsilon) const\n    {\n        return (((_mm_movemask_ps(NearEqual(v, epsilon)) & 7) == 0x7) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsNearEqual4(const Vector& v, const Vector& epsilon) const\n    {\n        return ((_mm_movemask_ps(NearEqual(v, epsilon)) == 0xf) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::GreaterThan(const Vector& v) const\n    {\n        return _mm_cmpgt_ps(m_data, v);\n    }\n\n    FORCE_INLINE bool Vector::IsAnyGreaterThan(const Vector& v) const\n    {\n        return !GreaterThan(v).IsZero4();\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThan2(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(GreaterThan(v)) & 3) == 3) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThan3(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(GreaterThan(v)) & 7) == 7) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThan4(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(GreaterThan(v)) == 0x0f) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::GreaterThanEqual(const Vector& v) const\n    {\n        return _mm_cmpge_ps(m_data, v);\n    }\n\n    FORCE_INLINE bool Vector::IsAnyGreaterThanEqual(const Vector& v) const\n    {\n        return !GreaterThanEqual(v).IsZero4();\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThanEqual2(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(GreaterThanEqual(v)) & 3) == 3) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThanEqual3(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(GreaterThanEqual(v)) & 7) == 7) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsGreaterThanEqual4(const Vector& v) const\n    {\n        return (_mm_movemask_ps(GreaterThanEqual(v)) == 0x0f) != 0;\n    }\n\n    FORCE_INLINE Vector Vector::LessThan(const Vector& v) const\n    {\n        return _mm_cmplt_ps(m_data, v);\n    }\n\n    FORCE_INLINE bool Vector::IsAnyLessThan(const Vector& v) const\n    {\n        return !LessThan(v).IsZero4();\n    }\n\n    FORCE_INLINE bool Vector::IsLessThan2(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(LessThan(v)) & 3) == 3) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsLessThan3(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(LessThan(v)) & 7) == 7) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsLessThan4(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(LessThan(v)) == 0x0f) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::LessThanEqual(const Vector& v) const\n    {\n        return _mm_cmple_ps(m_data, v);\n    }\n\n    FORCE_INLINE bool Vector::IsAnyLessThanEqual(const Vector& v) const\n    {\n        return !LessThanEqual(v).IsZero4();\n    }\n\n    FORCE_INLINE bool Vector::IsLessThanEqual2(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(LessThanEqual(v)) & 3) == 3) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsLessThanEqual3(const Vector& v) const\n    {\n        return (((_mm_movemask_ps(LessThanEqual(v)) & 7) == 7) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsLessThanEqual4(const Vector& v) const\n    {\n        return ((_mm_movemask_ps(LessThanEqual(v)) == 0x0f) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::EqualsZero() const\n    {\n        return Equal(Vector::Zero);\n    }\n\n    FORCE_INLINE bool Vector::IsAnyEqualToZero2() const\n    {\n        return !EqualsZero().IsZero2();\n    }\n\n    FORCE_INLINE bool Vector::IsAnyEqualToZero3() const\n    {\n        return !EqualsZero().IsZero3();\n    }\n\n    FORCE_INLINE bool Vector::IsAnyEqualToZero4() const\n    {\n        return !EqualsZero().IsZero4();\n    }\n\n    FORCE_INLINE bool Vector::IsZero2() const\n    {\n        return IsEqual2(Vector::Zero);\n    }\n\n    FORCE_INLINE bool Vector::IsZero3() const\n    {\n        return IsEqual3(Vector::Zero);\n    }\n\n    FORCE_INLINE bool Vector::IsZero4() const\n    {\n        return IsEqual4(Vector::Zero);\n    }\n\n    FORCE_INLINE Vector Vector::NearEqualsZero(float epsilon) const\n    {\n        return NearEqual(Vector::Zero, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearZero2(float epsilon) const\n    {\n        return IsNearEqual2(Vector::Zero, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearZero3(float epsilon) const\n    {\n        return IsNearEqual3(Vector::Zero, Vector(epsilon));\n    }\n\n    FORCE_INLINE bool Vector::IsNearZero4(float epsilon) const\n    {\n        return IsNearEqual4(Vector::Zero, Vector(epsilon));\n    }\n\n    FORCE_INLINE Vector Vector::EqualsInfinity() const\n    {\n        __m128 vTemp = _mm_and_ps(m_data, SIMD::g_absMask);\n        return _mm_cmpeq_ps(vTemp, Vector::Infinity);\n    }\n\n    FORCE_INLINE bool Vector::IsInfinite2() const\n    {\n        return (_mm_movemask_ps(EqualsInfinity()) & 3) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsInfinite3() const\n    {\n        return (_mm_movemask_ps(EqualsInfinity()) & 7) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsInfinite4() const\n    {\n        return (_mm_movemask_ps(EqualsInfinity()) != 0);\n    }\n\n    FORCE_INLINE Vector Vector::EqualsNaN() const\n    {\n        return _mm_cmpneq_ps(m_data, m_data);\n    }\n\n    FORCE_INLINE bool Vector::IsNaN2() const\n    {\n        return (_mm_movemask_ps(EqualsNaN()) & 3) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsNaN3() const\n    {\n        return (_mm_movemask_ps(EqualsNaN()) & 7) != 0;\n    }\n\n    FORCE_INLINE bool Vector::IsNaN4() const\n    {\n        return (_mm_movemask_ps(EqualsNaN()) != 0);\n    }\n\n    FORCE_INLINE bool Vector::IsParallelTo(const Vector& v) const\n    {\n        Vector const vAbsDot = Vector::Dot3(*this, v).GetAbs();\n        Vector const vAbsDelta = Vector::One - vAbsDot;\n        return vAbsDelta.IsLessThanEqual4(Vector::Epsilon);\n    }\n\n    FORCE_INLINE void Vector::ToDirectionAndLength2(Vector& direction, float& length) const\n    {\n        Vector const vLength = Length2();\n        direction = Vector::Select(*this, Vector::Zero, Select0011);\n        direction /= vLength;\n        length = vLength.ToFloat();\n    }\n\n    FORCE_INLINE void Vector::ToDirectionAndLength3(Vector& direction, float& length) const\n    {\n        Vector const vLength = Length3();\n        direction = Vector::Select(*this, Vector::Zero, Select0001);\n        direction /= vLength;\n        length = vLength.ToFloat();\n    }\n\n    FORCE_INLINE bool Vector::operator==(const Vector& rhs) const\n    {\n        return IsEqual4(rhs);\n    }\n\n    FORCE_INLINE bool Vector::operator!=(const Vector& rhs) const\n    {\n        return !IsEqual4(rhs);\n    }\n}\n"
  },
  {
    "path": "MotionCorrection/src/cpp/Platform.h",
    "content": "/*\n * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: Apache-2.0\n */\n\n#pragma once\n\n// Finds the current platform\n#if defined( __WIN32__ ) || defined( _WIN32 )\n#    define PLATFORM_WIN32\n#else\n#    define PLATFORM_LINUX\n#endif\n\n//\n// Platform Specific Helpers/Functions\n//\n\n// DLL export\n#if defined(PLATFORM_WIN32) // Windows\n#    if defined(COMPILER_MSVC)\n#        if defined(STATIC_LIB)\n#            define API\n#        else\n#            if defined(API)\n#                define API __declspec(dllexport)\n#            else\n#                define API __declspec(dllimport)\n#            endif\n#        endif\n#    else\n#        if defined(STATIC_LIB)\n#            define API\n#        else\n#            if defined(API)\n#                define API __attribute__ ((dllexport))\n#            else\n#                define API __attribute__ ((dllimport))\n#            endif\n#        endif\n#    endif\n#    define DISABLE_OPTIMIZATION __pragma( optimize( \"\", off ) )\n#    define ENABLE_OPTIMIZATION __pragma( optimize( \"\", on ) )\n#    define DEBUG_BREAK() // __debugbreak()\n#else // Linux settings\n#    include <signal.h>\n#    define API __attribute__ ((visibility (\"default\")))\n#    define DISABLE_OPTIMIZATION\n#    define ENABLE_OPTIMIZATION\n#    define DEBUG_BREAK() // raise(SIGTRAP)\n#endif\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img src=\"./assets/banner.png\" alt=\"Banner\" width=\"100%\">\n  <a href=\"LICENSE\"><img src=\"https://img.shields.io/badge/License-Apache%202.0-76B900.svg\" alt=\"License\"></a>\n  <a href=\"https://research.nvidia.com/labs/sil/projects/kimodo/\"><img src=\"https://img.shields.io/badge/Project-Page-blue\" alt=\"Project Page\"></a>\n  <a href=\"https://research.nvidia.com/labs/sil/projects/kimodo/docs/index.html\"><img src=\"https://img.shields.io/badge/docs-online-green.svg\" alt=\"Documentation\"></a>\n</p>\n\n## Overview\n\nKimodo is a **ki**nematic **mo**tion **d**iffusi**o**n model trained on a large-scale (700 hours) commercially-friendly optical motion capture dataset. The model generates high-quality 3D human and robot motions, and is controlled through text prompts and an extensive set of constraints such as full-body pose keyframes, end-effector positions/rotations, 2D paths, and 2D waypoints. Full details of the model architecture and training are available in the [technical report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf).\n\nThis repository provides:\n- **Inference**: code and CLI to generate motions on both human and robot skeletons\n- **Interactive Demo**: easily author motions with a timeline interface of text prompts and kinematic controls\n- **Benchmark**: [test cases](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark) and evaluation code built on the [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) dataset to evaluate motion generation models based on text and constraint-following abilities\n- **Annotations**: fine-grained temporal text descriptions created for the Kimodo project are included in the [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) dataset. For more information on these labels, see our separate [Hugging Face repo](https://huggingface.co/datasets/nvidia/SEED-Timeline-Annotations).\n\n<div align=\"center\">\n  <img src=\"assets/teaser.gif\" width=\"1280\">\n</div>\n\n## News\n\nSee the [full changelog](CHANGELOG.md) for a detailed list of all changes.\n\n- **[2026-05-03]** _FIX_: fixed a bug causing incorrect calculation of averaged metrics for constraint test cases in the benchmark\n- **[2026-04-24]** _NEW_: improved multi-prompt generation and better support for small VRAM GPUs via `TEXT_ENCODER_DEVICE=cpu` env var\n- **[2026-04-10]** Released the [Kimodo Motion Generation Benchmark](#kimodo-motion-generation-benchmark) alongside new v1.1 Kimodo-SOMA models\n- **[2026-03-19]** **Breaking:** Model inputs/outputs now use the SOMA 77-joint skeleton (`somaskel77`).\n- **[2026-03-16]** Initial open-source release of Kimodo with five model variants (SOMA, G1, SMPL-X), CLI, interactive demo, and timeline annotations for BONES-SEED.\n\n\n## Kimodo Models\n\nSeveral variations of Kimodo are available trained on various skeletons and datasets. All models support text-to-motion and kinematic controls.\n\n> Note: models will be downloaded automatically when attempting to generate from the CLI or Interactive Demo, so there is no need to download them manually\n\n| Model | Skeleton | Training Data | Release Date | Hugging Face | License |\n|:-------|:-------------|:------:|:------:|:-------------:|:-------------:|\n| **Kimodo-SOMA-RP-v1.1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | April 10, 2026 | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1.1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-SEED-v1.1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | April 10, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1.1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-RP-v1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026 | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-G1-RP-v1** | [Unitree G1](https://github.com/unitreerobotics/unitree_mujoco/tree/main/unitree_robots/g1) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-G1-RP-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-SEED-v1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-G1-SEED-v1** | [Unitree G1](https://github.com/unitreerobotics/unitree_mujoco/tree/main/unitree_robots/g1) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-G1-SEED-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SMPLX-RP-v1** | [SMPL-X](https://github.com/vchoutas/smplx) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SMPLX-RP-v1) | [NVIDIA R&D Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-internal-scientific-research-and-development-model-license/) |\n\nBy default, we recommend using the models trained on the full Bones Rigplay 1 dataset (700 hours of mocap) for your motion generation needs.\nThe models trained on BONES-SEED use 288 hours of [publicly available mocap data](https://huggingface.co/datasets/bones-studio/seed) so are less capable, but are useful for comparing to other models trained on BONES-SEED. To easily compare motion generation models to Kimodo, check out our [Motion Generation Benchmark](#kimodo-motion-generation-benchmark).\n\n### Changes in v1.1\nThe latest v1.1 Kimodo-SOMA models were released primarily for compatibility with our new [Motion Generation Benchmark](#kimodo-motion-generation-benchmark), but also contain minor quality improvements over v1. For details on these improvements, please see the Hugging Face pages for [Kimodo-SOMA-RP-v1.1](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1.1#changes-in-v11) and [Kimodo-SOMA-SEED-v1.1](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1.1#changes-in-v11).\n\n## Getting Started\n\nPlease see the full documentation for detailed installation instructions, how to use the CLI and Interactive Demo, and other practical tips for generating motions with Kimodo:\n\n**[Full Documentation](https://research.nvidia.com/labs/sil/projects/kimodo/docs)**\n- [Quick Start Guide](https://research.nvidia.com/labs/sil/projects/kimodo/docs/getting_started/quick_start.html)\n- [Installation Instructions](https://research.nvidia.com/labs/sil/projects/kimodo/docs/getting_started/installation.html)\n- [Interactive Motion Authoring Demo](https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html)\n- [Command-Line Interface](https://research.nvidia.com/labs/sil/projects/kimodo/docs/user_guide/cli.html)\n- [Benchmark Instructions](https://research.nvidia.com/labs/sil/projects/kimodo/docs/benchmark/introduction.html)\n- [API Reference](https://research.nvidia.com/labs/sil/projects/kimodo/docs/api_reference/index.html)\n\n**Before getting started** with motion generation, please review the [best practices](https://research.nvidia.com/labs/sil/projects/kimodo/docs/key_concepts/limitations.html) and be aware of [model limitations](https://research.nvidia.com/labs/sil/projects/kimodo/docs/key_concepts/limitations.html#limitations).\n\n\nSome notes on installation environment:\n- Kimodo requires ~17GB of VRAM to generate locally entirely on GPU, primarily due to the text embedding model. If you have a smaller card, set `TEXT_ENCODER_DEVICE=cpu` when running Kimodo commands to force text encoding to the CPU. This is slightly slower but reduces VRAM usage to <3 GB.\n- The model has been most extensively tested on GeForce RTX 3090, GeForce RTX 4090, and NVIDIA A100 GPUs, but should work on other recent cards with sufficient VRAM\n- This repo was developed on Linux, though Windows should work especially if using Docker\n\n## Interactive Motion Authoring Demo\n\n<div align=\"center\">\n  <img src=\"assets/demo_screenshot.png\" width=\"1000\">\n</div>\n\n</br>\n\n**[Demo Documentation and Tutorial](https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html)**\n\nThe web-based interactive demo provides an intuitive interface for generating motions with any of the Kimodo model variations. After installation, the demo can be launched with the `kimodo_demo` command. It runs locally on http://127.0.0.1:7860. Open this URL in your browser to access the interface (or use port forwarding if set up on a server).\n\n### Demo Features\n- **Multiple Characters**: Supports generating with the SOMA, G1, and SMPL-X versions of Kimodo\n- **Text Prompts**: Enter one or more natural language descriptions of desired motions on the timeline\n- **Timeline Editor**: Add and edit keyframes and constrained intervals on multiple constraint tracks\n- **Constraint Types**:\n  - Full-Body: Complete joint position constraints at specific frames\n  - 2D Root: Define waypoints or full paths to follow on the ground plane\n  - End-Effectors: Control hands and feet positions/rotations\n- **Constraint Editing**: Editing mode allows for re-posing of constraints or adjusting waypoints\n- **3D Visualization**: Real-time rendering of generated motions with skeleton and skinned mesh options\n- **Playback Controls**: Preview generated motions with adjustable playback speed\n- **Multiple Samples**: Generate and compare multiple motion variations\n- **Examples**: Load pre-existing examples to better understand Kimodo's capabilities\n- **Export**: Save constraints and generated motions for later use\n\n## Command-Line Interface\n\n**[CLI Documentation and Examples](https://research.nvidia.com/labs/sil/projects/kimodo/docs/user_guide/cli.html)**\n\nMotions can also be generated directly from the command line with the `kimodo_gen` command or by running `python -m kimodo.scripts.generate` directly.\n\n**Key Arguments:**\n- `prompt`: A single text description or sequence of texts for the desired motion (required)\n- `--model`: Which Kimodo model to use for generation\n- `--duration`: Motion duration in seconds\n- `--num_samples`: Number of motion variations to generate\n- `--constraints`: Constraint file to control the generated motion (e.g., saved from the web demo)\n- `--diffusion_steps`: Number of denoising steps\n- `--cfg_type` / `--cfg_weight`: Classifier-free guidance (`nocfg`, `regular` with one weight, or `separated` with two weights for text vs. constraints); see the [CLI docs](https://research.nvidia.com/labs/sil/projects/kimodo/docs/user_guide/cli.html#classifier-free-guidance-cfg)\n- `--no-postprocess`: Flag to disable foot skate and constraint cleanup post-processing\n- `--seed`: Random seed for reproducible results\n\nThe script supports different output formats depending on which skeleton is used. By default, a custom NPZ format is saved that is compatible with the web demo.\nFor Kimodo-G1 models, the motion can be saved in the standard MuJoCo qpos CSV format.\nFor Kimodo-SMPLX, motion can be saved in the standard AMASS npz format for compability with existing pipelines.\n\n### Default NPZ Output Format\nGenerated motions are saved as NPZ files containing:\n- `posed_joints`: Global joint positions `[T, J, 3]`\n- `global_rot_mats`: Global joint rotation matrices `[T, J, 3, 3]`\n- `local_rot_mats`: Local (parent-relative) joint rotation matrices `[T, J, 3, 3]`\n- `foot_contacts`: Foot contact labels [left heel, left toe, right heel, right toes] `[T, 4]`\n- `smooth_root_pos`: Smoothed root representations outputted from the model `[T, 3]`\n- `root_positions`: The (non-smoothed) trajectory of the actual root joint (e.g., pelvis) `[T, 3]`\n- `global_root_heading`: The heading direction output from the model `[T, 2]`\n\n`T` the number of frames and `J` the number of joints.\n\n## Low-Level Python API\n\n**[Model API Documentation](https://research.nvidia.com/labs/sil/projects/kimodo/docs/api_reference/model.html#kimodo.model.kimodo_model.Kimodo.__call__)**\n\nFor maximum flexibility, the low-level model inference API can be called directly, rather than going through our high-level CLI.\nThis allows for advanced model configuration including classifier-free guidance weights and parameters related to transitions in multi-prompt sequences.\n\n## Downstream Robotics Applications of Kimodo\n\n### Visualizing G1 Motions with MuJoCo\n\n<div align=\"center\">\n  <img src=\"assets/mujoco_result.gif\" width=\"800\">\n</div>\n\nAfter generating motions on the G1 robot skeleton and saving to the MuJoCo qpos CSV file format, they can be easily used and visualized within MuJoCo.\nA minimal visualization script is available with:\n```\npython -m kimodo.scripts.mujoco_load\n```\nMake sure to edit the script to correctly point to your CSV file and install Mujoco before running this.\n\n### Tracking Generated Motions with ProtoMotions\n\n<div align=\"center\">\n  <img src=\"assets/protomotions_results.gif\" width=\"1280\">\n</div>\n\n[ProtoMotions](https://github.com/NVlabs/ProtoMotions) is a GPU-accelerated simulation and learning framework for training physically simulated digital humans and humanoid robots. The Kimodo NPZ and CSV output formats are both compatible with ProtoMotions making it easy to train physics-based policies with generated motions from Kimodo. ProtoMotions supports outputs on both the SOMA skeleton and Unitree G1\n\nAfter generating motions with Kimodo, head over to the [ProtoMotions docs](https://github.com/NVlabs/ProtoMotions?tab=readme-ov-file#-motion-authoring-with-kimodo) to see how to import them.\n\n### Retargeting Motions to Other Robots with GMR\n\n<div align=\"center\">\n  <img src=\"assets/gmr_results.gif\" width=\"1280\">\n</div>\n\nMotions generated by Kimodo-SMPLX can be retargeted to other robots using [General Motion Retargeting (GMR)](https://github.com/YanjieZe/GMR).\nGMR supports the AMASS NPZ format out of the box, so simply generate motions with Kimodo and use `--output` to save; the AMASS NPZ is written to `stem_amass.npz` (single sample) or in the output folder (multiple samples). Then, use the [SMPL-X to Robot script](https://github.com/YanjieZe/GMR?tab=readme-ov-file#retargeting-from-smpl-x-amass-omomo-to-robot) in GMR to retarget to any supported robot. For example:\n```\n# run within GMR codebase\npython scripts/smplx_to_robot.py --smplx_file /path/to/saved/amass_format.npz --robot booster_t1\n```\n\n### Combining Kimodo with GEAR-SONIC\n\n<div align=\"center\">\n  <img src=\"assets/sonic_kimodo_demo.gif\" width=\"800\">\n</div>\n\nAs a proof of concept, we have also incorporated Kimodo into the [interactive GEAR-SONIC demo](https://nvlabs.github.io/GEAR-SONIC/demo.html). In the demo, Kimodo can be used to generate a kinematic motion on the G1 robot skeleton, then GEAR-SONIC tracks the motion in simulation.\n\n## Kimodo Motion Generation Benchmark\n\n[**[Benchmark Documentation](https://research.nvidia.com/labs/sil/projects/kimodo/docs/benchmark/introduction.html)**]\n[**[Test Suite on Hugging Face](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark)**]\n\nAlongside the Kimodo models, we provide a benchmark designed to standardize evaluation for motion generation models with a comprehensive set of test cases. This includes:\n\n* **Evaluation Data**: A suite of test cases [available on Hugging Face](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark) is used in concert with the [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) dataset to construct the full benchmark. \n* **Diverse Test Cases**: Test cases cover a wide range of text-conditioned and constraint-conditioned motion generation.\n* **Evaluation Pipeline**: Code for the full evaluation pipeline including benchmark construction, motion generation, and evaluation.\n* **Metrics**: Several metrics to evaluate generated motions that cover motion quality, constraint following, and text alignment. Our [TMR-SOMA-RP-v1](https://huggingface.co/nvidia/TMR-SOMA-RP-v1) model trained on all 700 hours of the Bones Rigplay dataset is a powerful embedding model to compute common metrics like R-precision and FID.\n\nTo facilitate future research, we [report benchmark results](https://research.nvidia.com/labs/sil/projects/kimodo/docs/benchmark/results.html) for Kimodo-SOMA-v1.1 models, which are reproducible and easily comparable to other methods trained on the BONES-SEED data. \n\n## Timeline Annotations for BONES-SEED\n\nAs detailed in the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf), Kimodo is trained using fine-grained temporal text annotations of mocap clips.\nWhile the full [Rigplay 1](https://bones.studio/datasets#rp01) dataset is proprietary, we have released the temporal segmentations for the public [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) subset.\nThese annotations are already included in the BONES-SEED dataset, but the standalone labels and additional information about them is [available on HuggingFace](https://huggingface.co/datasets/nvidia/SEED-Timeline-Annotations).\n\n\n## Related Humanoid Work at NVIDIA\nKimodo is part of a larger effort to enable humanoid motion data for robotics, physical AI, and other applications.\n\nCheck out these related works:\n* [SOMA Body Model](https://github.com/NVlabs/SOMA-X) - a unified parameteric human body model\n* [BONES-SEED Dataset](https://huggingface.co/datasets/bones-studio/seed) - a large scale human(oid) motion capture dataset in SOMA and G1 format\n* [ProtoMotions](https://github.com/NVlabs/ProtoMotions) - simulation and learning framework for training physically simulated human(oid)s\n* [SOMA Retargeter](https://github.com/NVIDIA/soma-retargeter) - SOMA to G1 retargeting tool\n* [GEM](https://github.com/NVlabs/GEM-X) - human motion reconstruction from video\n* [GEAR SONIC](https://github.com/NVlabs/GR00T-WholeBodyControl) - humanoid behavior foundation model for physical robots\n\n## Citation\n\nIf you use this code in your research, please cite:\n\n```bibtex\n@article{Kimodo2026,\n  title={Kimodo: Scaling Controllable Human Motion Generation},\n  author={Rempe, Davis and Petrovich, Mathis and Yuan, Ye and Zhang, Haotian and Peng, Xue Bin and Jiang, Yifeng and Wang, Tingwu and Iqbal, Umar and Minor, David and de Ruyter, Michael and Li, Jiefeng and Tessler, Chen and Lim, Edy and Jeong, Eugene and Wu, Sam and Hassani, Ehsan and Huang, Michael and Yu, Jin-Bey and Chung, Chaeyeon and Song, Lina and Dionne, Olivier and Kautz, Jan and Yuen, Simon and Fidler, Sanja},\n  journal={arXiv:2603.15546},\n  year={2026}\n}\n```\n\n## License\n\nThis codebase is licensed under [Apache-2.0](LICENSE). Note that model checkpoints and data are licensed separately as indicated on the HuggingFace download pages.\n\nThis project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.\n\n## Acknowledgments\n\nThis project builds upon excellent open-source projects:\n- [Viser](https://github.com/nerfstudio-project/viser) for 3D motion authoring demo\n- [LLM2Vec](https://github.com/McGill-NLP/llm2vec) for text encoding\n\n## Contact\n\nFor questions or issues, please open an issue on this repository or reach out directly to the authors.\n\n---\n"
  },
  {
    "path": "benchmark/create_benchmark.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nStep (1) of evaluation pipeline.\n\nThis script builds the benchmark test suites from BVH motions in the Bones-SEED dataset using \nthe benchmark metadata. Currently it is only set up for the SOMA skeleton.\n\"\"\"\n\nimport argparse\nfrom functools import partial\nfrom multiprocessing import Pool\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom kimodo.geometry import matrix_to_axis_angle\nfrom kimodo.motion_rep import KimodoMotionRep\nfrom kimodo.skeleton import SOMASkeleton77\nfrom kimodo.skeleton.bvh import parse_bvh_motion\nfrom kimodo.tools import load_json, save_json, to_numpy, to_torch\n\nFPS = 30\nBENCHMARK_REPO_ID = \"nvidia/Kimodo-Motion-Gen-Benchmark\"\n\n\ndef download_benchmark(dest: Path) -> Path:\n    \"\"\"Download the benchmark testsuite from HuggingFace to *dest*.\"\"\"\n    from huggingface_hub import snapshot_download\n\n    print(f\"Downloading benchmark testsuite from {BENCHMARK_REPO_ID} to {dest} ...\")\n    snapshot_dir = snapshot_download(\n        repo_id=BENCHMARK_REPO_ID,\n        repo_type=\"dataset\",\n        local_dir=str(dest),\n    )\n    return Path(snapshot_dir)\n\n\ndef discover_seed_motion_folders(root: Path) -> list[Path]:\n    \"\"\"Find all directories under root that contain seed_motion.json; return sorted list of those\n    dirs.\"\"\"\n    root = root.resolve()\n    if not root.is_dir():\n        raise FileNotFoundError(f\"Folder does not exist: {root}\")\n    out: list[Path] = []\n    for meta_path in root.rglob(\"seed_motion.json\"):\n        src_dir = meta_path.parent\n        out.append(src_dir)\n    return sorted(out)\n\n\ndef constraints_and_motion_from_seed(folder: str, dataset_folder: str, fps=FPS):\n    \"\"\"Load seed_motion.json and BVH from folder; subsample to fps, convert to SOMA gt_motion.npz\n    and constraints.\"\"\"\n    folder = Path(folder)\n    dataset_folder = Path(dataset_folder)\n    out_path = folder / \"gt_motion.npz\"\n\n    seed_motion = load_json(folder / \"seed_motion.json\")\n\n    start = seed_motion[\"crop_start_frame_index\"]\n    end = seed_motion[\"crop_end_frame_index\"]\n\n    bvh_path = dataset_folder / seed_motion[\"bvh_path\"].replace(\"BVH/\", \"bvh/\")\n\n    local_rot_mats, root_trans, bvh_fps = parse_bvh_motion(bvh_path)\n    step = round(bvh_fps / fps)\n\n    # Subsample fps\n    root_trans = root_trans[::step]\n    local_rot_mats = local_rot_mats[::step]\n\n    skeleton = SOMASkeleton77()\n    # Changing t_pose: essential step\n    local_rot_mats, global_rot_mats = skeleton.to_standard_tpose(local_rot_mats)\n\n    # Use the motion rep to canonicalize the motion (start z+ at 0,0)\n    # and get other components (smooth root, foot contacts etc)\n    motion_rep = KimodoMotionRep(skeleton, fps)\n    feats = motion_rep(local_rot_mats, root_trans, to_normalize=False)\n\n    # Crop the features and canonicalizing them\n    feats = feats[start:end]\n    can_feats = motion_rep.canonicalize(feats)\n    # Get back the motion\n    motion = motion_rep.inverse(can_feats, is_normalized=False)\n    motion = to_numpy(to_torch(motion, dtype=torch.float32))\n\n    np.savez(out_path, **motion)\n\n    seed_constraints_path = folder / \"seed_constraints.json\"\n    if seed_constraints_path.exists():\n        seed_constraints_lst = load_json(seed_constraints_path)\n\n        constraints_lst = []\n        for seed_cons in seed_constraints_lst:\n            cons = seed_cons.copy()\n            frame_indices = cons[\"frame_indices\"]\n\n            cons[\"smooth_root_2d\"] = motion[\"smooth_root_pos\"][frame_indices][..., [0, 2]].tolist()\n\n            if cons[\"type\"] == \"root2d\":\n                if cons.get(\"use_global_orient\", False):\n                    cons[\"global_root_heading\"] = motion[\"global_root_heading\"][  # noqa\n                        frame_indices\n                    ].tolist()\n            elif cons[\"type\"] in [\"fullbody\"] or cons[\"type\"] in [\n                \"left-hand\",\n                \"right-hand\",\n                \"left-foot\",\n                \"right-foot\",\n                \"end-effector\",\n            ]:\n                cons[\"local_joints_rot\"] = matrix_to_axis_angle(\n                    to_torch(motion[\"local_rot_mats\"][frame_indices])\n                ).tolist()\n                cons[\"root_positions\"] = motion[\"root_positions\"][frame_indices].tolist()\n            else:\n                raise TypeError(f\"This constraint type is not recognized: {cons['type']}\")\n\n            constraints_lst.append(cons)\n\n        # check that it is close to old_constraints_lst\n        save_json(folder / \"constraints.json\", constraints_lst)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Recursively find test case to fill with motions and constraints.\",\n    )\n    parser.add_argument(\n        \"benchmark\",\n        type=Path,\n        help=\"Root folder to search recursively or seed_motion.json for to download the benchmark testsuite from HuggingFace to.\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=Path,\n        default=\"datasets/bones-seed/soma_uniform\",\n        help=\"SEED dataset folder\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Redo the process even if gt_motion.npz already exists\",\n    )\n    parser.add_argument(\n        \"--workers\",\n        type=int,\n        default=1,\n        help=\"Number of parallel worker processes (default: 1, sequential)\",\n    )\n    args = parser.parse_args()\n\n    folder = args.benchmark.resolve()\n    if not folder.is_dir():\n        print(f\"Benchmark folder not found at {folder}, downloading from HuggingFace...\")\n        download_benchmark(folder)\n\n    dirs = discover_seed_motion_folders(folder)\n    if not dirs:\n        raise SystemExit(f\"No directories with seed_motion.json found under {folder}\")\n    print(f\"Discovered {len(dirs)} motion to populate.\")\n\n    skipped = 0\n    to_process = []\n    for d in dirs:\n        if not args.overwrite and (d / \"gt_motion.npz\").is_file():\n            skipped += 1\n        else:\n            to_process.append(d)\n\n    fn = partial(constraints_and_motion_from_seed, dataset_folder=args.dataset)\n    with Pool(args.workers) as pool:\n        list(tqdm(pool.imap_unordered(fn, to_process), total=len(to_process), desc=\"Extracting GT motions\"))\n\n    if skipped:\n        print(f\"Processed {len(dirs) - skipped} folders, skipped {skipped} (already present).\")\n    else:\n        print(\"Saved gt_motion.npz and constraints.json from the seed files.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/embed_folder.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nStep (3) of evaluation pipeline.\n\nThis script recursively embeds generated motions, ground-truth motions, and text prompts from a test suite folder tree with the pre-trained TMR model.\n\"\"\"\n\nimport argparse\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom kimodo.meta import parse_prompts_from_meta\nfrom kimodo.model.load_model import load_model\nfrom kimodo.tools import load_json\n\n\ndef discover_motion_folders(root: Path) -> list[Path]:\n    root = root.resolve()\n    if not root.is_dir():\n        raise FileNotFoundError(f\"Folder does not exist: {root}\")\n    out: list[Path] = []\n    for meta_path in root.rglob(\"meta.json\"):\n        src_dir = meta_path.parent\n        if (src_dir / \"motion.npz\").is_file() or (src_dir / \"gt_motion.npz\").is_file():\n            out.append(src_dir)\n    return sorted(out)\n\n\ndef _load_posed_joints(npz_path: Path, device: str) -> torch.Tensor:\n    data = np.load(npz_path)\n    if \"posed_joints\" not in data:\n        raise SystemExit(f\"NPZ must contain 'posed_joints': {npz_path}\")\n    posed_joints = data[\"posed_joints\"]\n    if posed_joints.ndim == 4:\n        if posed_joints.shape[0] != 1:\n            raise SystemExit(f\"Expected batch size 1 for posed_joints, got {posed_joints.shape[0]} in {npz_path}\")\n        posed_joints = posed_joints[0]\n    if posed_joints.ndim != 3:\n        raise SystemExit(f\"Expected posed_joints shape [T, J, 3], got {posed_joints.shape} in {npz_path}\")\n    return torch.from_numpy(posed_joints).float().to(device)\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Recursively embed motion, gt_motion, and text; save motion_embedding.npy, gt_motion_embedding.npy, and text_embedding.npy when present.\",\n    )\n    parser.add_argument(\n        \"folder\",\n        type=Path,\n        help=\"Root folder to search recursively for meta.json and motion.npz and/or gt_motion.npz\",\n    )\n    parser.add_argument(\n        \"--model\",\n        default=\"tmr-soma-rp\",\n        help=\"Model for encoding (e.g. TMR-SOMA-RP-v1, tmr-soma-rp). Default: tmr-soma-rp\",\n    )\n    parser.add_argument(\n        \"--device\",\n        default=None,\n        help=\"Device (default: cuda if available else cpu)\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Re-embed even if embedding files already exist\",\n    )\n    parser.add_argument(\n        \"--text_encoder_fp32\",\n        action=\"store_true\",\n        help=\"Uses fp32 for the text encoder rather than default bfloat16.\",\n    )\n    args = parser.parse_args()\n\n    folder = args.folder.resolve()\n    if not folder.is_dir():\n        raise SystemExit(f\"Folder does not exist or is not a directory: {folder}\")\n\n    device = args.device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    model = load_model(modelname=args.model, device=device, default_family=\"TMR\", text_encoder_fp32=args.text_encoder_fp32)\n\n    dirs = discover_motion_folders(folder)\n    if not dirs:\n        raise SystemExit(f\"No directories with meta.json and (motion.npz or gt_motion.npz) found under {folder}\")\n    print(f\"Discovered {len(dirs)} motion folders.\")\n\n    skipped_motion = 0\n    skipped_gt = 0\n    skipped_text = 0\n    for sample_dir in tqdm(dirs, desc=\"Embedding\"):\n        meta_path = sample_dir / \"meta.json\"\n        meta = load_json(meta_path)\n        texts, _ = parse_prompts_from_meta(meta)\n        if len(texts) != 1:\n            raise SystemExit(f\"Expected exactly one text per motion; got {len(texts)} in {meta_path}\")\n        text = texts[0]\n\n        # Embed motion.npz -> motion_embedding.npy\n        if (sample_dir / \"motion.npz\").is_file():\n            if not args.overwrite and (sample_dir / \"motion_embedding.npy\").is_file():\n                skipped_motion += 1\n            else:\n                npz_path = sample_dir / \"motion.npz\"\n                posed_joints = _load_posed_joints(npz_path, device)\n                with torch.inference_mode():\n                    motion_emb = model.encode_motion(posed_joints, unit_vector=True)\n                np.save(sample_dir / \"motion_embedding.npy\", motion_emb.cpu().numpy())\n\n        # Embed gt_motion.npz -> gt_motion_embedding.npy\n        if (sample_dir / \"gt_motion.npz\").is_file():\n            if not args.overwrite and (sample_dir / \"gt_motion_embedding.npy\").is_file():\n                skipped_gt += 1\n            else:\n                npz_path = sample_dir / \"gt_motion.npz\"\n                posed_joints = _load_posed_joints(npz_path, device)\n                with torch.inference_mode():\n                    gt_motion_emb = model.encode_motion(posed_joints, unit_vector=True)\n                np.save(sample_dir / \"gt_motion_embedding.npy\", gt_motion_emb.cpu().numpy())\n\n        # Embed text -> text_embedding.npy\n        if not args.overwrite and (sample_dir / \"text_embedding.npy\").is_file():\n            skipped_text += 1\n        else:\n            with torch.inference_mode():\n                text_emb = model.encode_raw_text([text], unit_vector=True)\n            np.save(sample_dir / \"text_embedding.npy\", text_emb.cpu().numpy())\n\n    total_skipped = skipped_motion + skipped_gt + skipped_text\n    if total_skipped:\n        print(f\"Embedded {len(dirs)} folders; skipped some existing files (use --overwrite to re-embed).\")\n    else:\n        print(f\"Saved motion_embedding.npy, gt_motion_embedding.npy, and text_embedding.npy in {len(dirs)} folders.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/evaluate_folder.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nStep (4) of evaluation pipeline.\n\nThis script recursively computes metrics for generated and ground-truth motions within a test suite folder tree. \nSaves metrics json files per test case and per group of test cases in the folder tree.\n\"\"\"\n\nimport argparse\nimport json\nfrom itertools import groupby\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom kimodo.constraints import load_constraints_lst\nfrom kimodo.meta import parse_prompts_from_meta\nfrom kimodo.metrics import (\n    ContraintFollow,\n    FootContactConsistency,\n    FootSkateFromContacts,\n    FootSkateFromHeight,\n    FootSkateRatio,\n    TMR_EmbeddingMetric,\n    aggregate_metrics,\n    clear_metrics,\n    compute_metrics,\n    compute_tmr_per_sample_retrieval,\n)\nfrom kimodo.skeleton import build_skeleton\nfrom kimodo.skeleton.definitions import SOMASkeleton30\nfrom kimodo.tools import load_json, to_torch\n\nDEFAULT_FPS = 30.0\n\n\ndef discover_motion_folders(root: Path) -> list[tuple[Path, Path]]:\n    root = root.resolve()\n    if not root.is_dir():\n        raise FileNotFoundError(f\"Folder does not exist: {root}\")\n    out: list[tuple[Path, Path]] = []\n    for meta_path in root.rglob(\"meta.json\"):\n        sample_dir = meta_path.parent\n        if (sample_dir / \"motion.npz\").is_file() and (sample_dir / \"gt_motion.npz\").is_file():\n            rel = sample_dir.relative_to(root)\n            out.append((sample_dir, rel))\n    return sorted(out, key=lambda x: str(x[1]))\n\n\ndef group_by_parent(examples: list[tuple[Path, Path]]) -> list[list[tuple[Path, Path]]]:\n    def parent_key(item: tuple[Path, Path]) -> Path:\n        return item[1].parent if len(item[1].parts) > 1 else Path(\".\")\n\n    sorted_examples = sorted(examples, key=parent_key)\n    groups: list[list[tuple[Path, Path]]] = []\n    for _key, group in groupby(sorted_examples, key=parent_key):\n        groups.append(list(group))\n    return groups\n\n\ndef _to_scalar(t: torch.Tensor) -> float:\n    return float(t.mean().item()) if t.numel() > 0 else float(t.item())\n\n\ndef _to_p95(t: torch.Tensor) -> float:\n    if t.numel() == 0:\n        return float(\"nan\")\n    return float(torch.nanquantile(t, torch.tensor(0.95, device=t.device), dim=0).item())\n\n\ndef _per_sample_metrics_from_saved(metrics_list: list, n: int) -> list[dict[str, float]]:\n    per_sample: list[dict[str, float]] = [{} for _ in range(n)]\n    for metric in metrics_list:\n        for key, lst in metric.saved_metrics.items():\n            for i, t in enumerate(lst):\n                if i >= n:\n                    break\n                per_sample[i][key] = _to_scalar(t)\n    return per_sample\n\n\ndef _load_pair_embeddings(\n    sample_dir: Path,\n) -> tuple[np.ndarray, np.ndarray, np.ndarray | None] | None:\n    motion_emb_path = sample_dir / \"motion_embedding.npy\"\n    text_emb_path = sample_dir / \"text_embedding.npy\"\n    gt_motion_emb_path = sample_dir / \"gt_motion_embedding.npy\"\n    if not (motion_emb_path.is_file() and text_emb_path.is_file()):\n        return None\n\n    motion_emb = np.load(motion_emb_path)\n    text_emb = np.load(text_emb_path)\n    if motion_emb.ndim == 3 and motion_emb.shape[0] == 1:\n        motion_emb = motion_emb[0]\n    if text_emb.ndim == 3 and text_emb.shape[0] == 1:\n        text_emb = text_emb[0]\n\n    gt_motion_emb = None\n    if gt_motion_emb_path.is_file():\n        gt_motion_emb = np.load(gt_motion_emb_path)\n        if gt_motion_emb.ndim == 3 and gt_motion_emb.shape[0] == 1:\n            gt_motion_emb = gt_motion_emb[0]\n\n    return motion_emb, text_emb, gt_motion_emb\n\n\ndef _load_npz_motion(\n    npz_path: Path,\n    device: str,\n    soma30_skel: SOMASkeleton30 | None = None,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Load posed_joints and foot_contacts from an NPZ, upscaling SOMA30 to SOMA77 if needed.\"\"\"\n    data = np.load(npz_path)\n    posed_joints = to_torch(data[\"posed_joints\"], device=device)\n    foot_contacts = to_torch(data[\"foot_contacts\"], device=device)\n\n    if posed_joints.shape[-2] == 30 and soma30_skel is not None:\n        local_rot_mats = to_torch(data[\"local_rot_mats\"], device=device)\n        root_positions = to_torch(data[\"root_positions\"], device=device)\n        out77 = soma30_skel.output_to_SOMASkeleton77(\n            {\"local_rot_mats\": local_rot_mats, \"root_positions\": root_positions, \"foot_contacts\": foot_contacts}\n        )\n        posed_joints = out77[\"posed_joints\"]\n        foot_contacts = out77[\"foot_contacts\"]\n\n    return posed_joints, foot_contacts\n\n\ndef _run_eval_on_group(\n    group: list[tuple[Path, Path]],\n    skeleton: torch.nn.Module,\n    metrics_list: list,\n    device: str,\n    group_name: str = \"\",\n    soma30_skel: SOMASkeleton30 | None = None,\n) -> tuple[\n    list[dict[str, float]],\n    list[dict[str, float]],\n    dict[str, float],\n    dict[str, float],\n    dict[str, float],\n    list[dict[str, Any]],\n]:\n    \"\"\"Run two passes: gen (motion.npz + embeddings) and GT (gt_motion.npz only). Return\n    per_sample_gen, per_sample_gt, aggregated_gen, aggregated_gt, tmr_metrics, tmr_per_sample.\n    \"\"\"\n    n = len(group)\n    sample_ids: list[str] = []\n    texts: list[str] = []\n    motion_embs: list[np.ndarray] = []\n    text_embs: list[np.ndarray] = []\n\n    # ----- Pass 1: generation (motion.npz + all embeddings) -----\n    clear_metrics(metrics_list)\n    desc = f\"Samples ({group_name})\" if group_name else \"Samples\"\n    for sample_dir, rel_path in tqdm(group, desc=desc, unit=\"motion\"):\n        stem = rel_path.name\n        sample_ids.append(stem)\n        meta_path = sample_dir / \"meta.json\"\n        meta = load_json(meta_path)\n        texts_parsed, _ = parse_prompts_from_meta(meta)\n        texts.append(texts_parsed[0] if texts_parsed else \"\")\n\n        posed_joints, foot_contacts = _load_npz_motion(sample_dir / \"motion.npz\", device, soma30_skel)\n        nframes = posed_joints.shape[0]\n        lengths = torch.tensor(nframes, dtype=torch.long, device=device)\n        constraints_path = sample_dir / \"constraints.json\"\n        constraints_lst = (\n            load_constraints_lst(str(constraints_path), skeleton=skeleton) if constraints_path.is_file() else []\n        )\n        metrics_in: dict[str, Any] = {\n            \"posed_joints\": posed_joints,\n            \"foot_contacts\": foot_contacts,\n            \"lengths\": lengths,\n            \"constraints_lst\": constraints_lst,\n        }\n        text_this = texts_parsed[0] if texts_parsed else \"\"\n        embs = _load_pair_embeddings(sample_dir)\n        if (text_this or \"\").strip() and embs is not None:\n            motion_emb, text_emb, gt_motion_emb = embs\n            metrics_in[\"motion_emb\"] = motion_emb\n            metrics_in[\"text_emb\"] = text_emb\n            if gt_motion_emb is not None:\n                metrics_in[\"gt_motion_emb\"] = gt_motion_emb\n            motion_embs.append(motion_emb)\n            text_embs.append(text_emb)\n\n        compute_metrics(metrics_list, metrics_in)\n\n    per_sample_gen = _per_sample_metrics_from_saved(metrics_list, n)\n    raw_aggregated_gen = aggregate_metrics(metrics_list)\n    aggregated_gen = {}\n    tmr_metrics: dict[str, float] = {}\n    has_text = len(motion_embs) == n and len(text_embs) == n\n    for key, v in raw_aggregated_gen.items():\n        val = _to_scalar(v)\n        if key.startswith(\"TMR/\"):\n            if has_text:\n                tmr_metrics[key] = val\n        else:\n            aggregated_gen[key] = val\n    if \"constraint_root2d_err\" in raw_aggregated_gen:\n        aggregated_gen[\"constraint_root2d_err_p95\"] = _to_p95(raw_aggregated_gen[\"constraint_root2d_err\"])\n\n    tmr_per_sample: list[dict[str, Any]] = []\n    if has_text and motion_embs and text_embs and len(motion_embs) == n and len(text_embs) == n:\n        motion_emb_stack = np.stack(motion_embs, axis=0)\n        text_emb_stack = np.stack(text_embs, axis=0)\n        tmr_per_sample = compute_tmr_per_sample_retrieval(motion_emb_stack, text_emb_stack, sample_ids, texts, top_k=5)\n\n    # ----- Pass 2: GT (gt_motion.npz only, no embeddings) -----\n    clear_metrics(metrics_list)\n    for sample_dir, rel_path in tqdm(group, desc=f\"GT ({group_name})\" if group_name else \"GT\", unit=\"motion\"):\n        posed_joints, foot_contacts = _load_npz_motion(sample_dir / \"gt_motion.npz\", device, soma30_skel)\n        nframes = posed_joints.shape[0]\n        lengths = torch.tensor(nframes, dtype=torch.long, device=device)\n        constraints_path = sample_dir / \"constraints.json\"\n        constraints_lst = (\n            load_constraints_lst(str(constraints_path), skeleton=skeleton) if constraints_path.is_file() else []\n        )\n        metrics_in = {\n            \"posed_joints\": posed_joints,\n            \"foot_contacts\": foot_contacts,\n            \"lengths\": lengths,\n            \"constraints_lst\": constraints_lst,\n        }\n        compute_metrics(metrics_list, metrics_in)\n\n    per_sample_gt = _per_sample_metrics_from_saved(metrics_list, n)\n    raw_aggregated_gt = aggregate_metrics(metrics_list)\n    aggregated_gt = {}\n    for key, v in raw_aggregated_gt.items():\n        if key.startswith(\"TMR/\"):\n            continue\n        aggregated_gt[key] = _to_scalar(v)\n    if \"constraint_root2d_err\" in raw_aggregated_gt:\n        aggregated_gt[\"constraint_root2d_err_p95\"] = _to_p95(raw_aggregated_gt[\"constraint_root2d_err\"])\n\n    return (\n        per_sample_gen,\n        per_sample_gt,\n        aggregated_gen,\n        aggregated_gt,\n        tmr_metrics,\n        tmr_per_sample,\n    )\n\n\ndef _write_json(path: Path, payload: dict[str, Any]) -> None:\n    path.parent.mkdir(parents=True, exist_ok=True)\n    path.write_text(json.dumps(payload, indent=2) + \"\\n\", encoding=\"utf-8\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Recursively evaluate generated motions; write metrics.json per folder and <name>.json per parent.\",\n    )\n    parser.add_argument(\n        \"folder\",\n        type=Path,\n        help=\"Root folder to search recursively for meta.json + motion.npz + gt_motion.npz\",\n    )\n    parser.add_argument(\"--device\", default=None, help=\"cuda/cpu. Default: auto\")\n    args = parser.parse_args()\n\n    folder = args.folder.resolve()\n    if not folder.is_dir():\n        raise SystemExit(f\"Folder does not exist: {folder}\")\n\n    device = args.device or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    examples = discover_motion_folders(folder)\n    if not examples:\n        raise SystemExit(f\"No directories with meta.json, motion.npz, and gt_motion.npz found under {folder}\")\n    print(f\"Discovered {len(examples)} motion folders.\")\n\n    first_posed = np.load(examples[0][0] / \"motion.npz\")[\"posed_joints\"]\n    num_joints = first_posed.shape[-2]\n\n    # SOMA models could generate 30-joint output; upscale to 77 for evaluation\n    soma30_skel: SOMASkeleton30 | None = None\n    if num_joints == 30:\n        soma30_skel = SOMASkeleton30().to(device)\n        _ = soma30_skel.somaskel77  # trigger lazy init\n        soma30_skel.somaskel77.to(device)\n        skeleton = soma30_skel.somaskel77\n        print(\"Detected SOMA30 motions; will upscale to SOMA77 for evaluation.\")\n    else:\n        skeleton = build_skeleton(num_joints).to(device)\n\n    fps = DEFAULT_FPS\n    kwargs = {\"skeleton\": skeleton, \"fps\": fps}\n    metrics_list = [\n        FootSkateFromHeight(**kwargs),\n        FootSkateFromContacts(**kwargs),\n        FootContactConsistency(**kwargs),\n        FootSkateRatio(**kwargs),\n        ContraintFollow(**kwargs),\n        TMR_EmbeddingMetric(**kwargs),\n    ]\n\n    groups = group_by_parent(examples)\n    for group in tqdm(groups, desc=\"Evaluating folders\"):\n        sample_dirs = [g[0] for g in group]\n        folder_for_group = sample_dirs[0].parent\n        folder_name = folder_for_group.name\n\n        (\n            per_sample_gen,\n            per_sample_gt,\n            aggregated_gen,\n            aggregated_gt,\n            tmr_metrics,\n            tmr_per_sample,\n        ) = _run_eval_on_group(group, skeleton, metrics_list, device, group_name=folder_name, soma30_skel=soma30_skel)\n\n        texts = []\n        for sample_dir, _ in group:\n            meta = load_json(sample_dir / \"meta.json\")\n            texts_parsed, _ = parse_prompts_from_meta(meta)\n            texts.append(texts_parsed[0] if texts_parsed else \"\")\n\n        for i, (sample_dir, _) in enumerate(group):\n            metrics_path = sample_dir / \"metrics.json\"\n            out = {\n                \"num_motions\": 1,\n                \"folder\": str(sample_dir),\n                \"per_motion_mean_gen\": per_sample_gen[i] if i < len(per_sample_gen) else {},\n                \"per_motion_mean_gt\": per_sample_gt[i] if i < len(per_sample_gt) else {},\n            }\n            if i < len(tmr_per_sample):\n                out[\"tmr\"] = {\n                    \"t2m_rank\": tmr_per_sample[i][\"rank\"],\n                    \"text\": texts[i] if i < len(texts) else \"\",\n                    \"top5_retrieved\": tmr_per_sample[i][\"top_k\"],\n                }\n            _write_json(metrics_path, out)\n\n        parent_json_path = folder_for_group.parent / f\"{folder_name}.json\"\n        full_metrics = {\n            \"num_motions\": len(group),\n            \"folder\": str(folder_for_group),\n            \"per_motion_mean_gen\": aggregated_gen,\n            \"per_motion_mean_gt\": aggregated_gt,\n        }\n        if tmr_metrics:\n            full_metrics[\"tmr\"] = tmr_metrics\n        _write_json(parent_json_path, full_metrics)\n\n    print(f\"Wrote metrics.json in each of {len(examples)} folders and folder-level JSONs for {len(groups)} groups.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/generate_eval.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nStep (2) of evaluation pipeline.\n\nThis script recursively generates motions using Kimodo from a test suite folder tree.\n\"\"\"\n\nimport argparse\nimport shutil\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\nfrom tqdm.auto import tqdm\n\nfrom kimodo.constraints import load_constraints_lst\nfrom kimodo.meta import parse_prompts_from_meta\nfrom kimodo.model import DEFAULT_MODEL, load_model\nfrom kimodo.tools import load_json, seed_everything\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Recursively generate motions from a testsuite folder tree\")\n    parser.add_argument(\n        \"--benchmark\",\n        type=str,\n        default=\"testsuite\",\n        help=\"Root folder containing subfolders with meta.json (default: testsuite)\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        default=None,\n        help=\"Output root; directory hierarchy is mirrored here. If omitted, motions are generated in-place inside the testsuite folder.\",\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=32,\n        help=\"Batch size for generating motions (default: 32)\",\n    )\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=4,\n        help=\"DataLoader workers for loading meta/constraints paths (default: 4)\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=DEFAULT_MODEL,\n        help=\"Name of the model (e.g. Kimodo-SOMA-RP-v1.1, kimodo-soma-rp, or SOMA).\",\n    )\n    parser.add_argument(\n        \"--diffusion_steps\",\n        type=int,\n        default=100,\n        help=\"Number of diffusion steps (default: 100); overridden by meta.json if present\",\n    )\n    parser.add_argument(\n        \"--postprocess\",\n        action=\"store_true\",\n        help=\"Apply motion post-processing to reduce foot skating\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Regenerate outputs even if motion.npz already exists\",\n    )\n    parser.add_argument(\n        \"--text_encoder_fp32\",\n        action=\"store_true\",\n        help=\"Uses fp32 for instantiating the text encoder (if API is not already running) rather than default bfloat16.\",\n    )\n    return parser.parse_args()\n\n\ndef discover_example_folders(root: Path) -> list[tuple[Path, Path]]:\n    \"\"\"Discover leaf directories that contain meta.json.\n\n    Returns list of (src_dir, rel_path).\n    \"\"\"\n    root = root.resolve()\n    if not root.is_dir():\n        raise FileNotFoundError(f\"Testsuite folder does not exist: {root}\")\n    out: list[tuple[Path, Path]] = []\n    for meta_path in root.rglob(\"meta.json\"):\n        src_dir = meta_path.parent\n        rel = src_dir.relative_to(root)\n        out.append((src_dir, rel))\n    return sorted(out, key=lambda x: str(x[1]))\n\n\ndef copy_source_files(src_dir: Path, out_dir: Path) -> None:\n    \"\"\"Copy meta.json, constraints.json, and gt_motion.npz (if present) from src_dir to out_dir.\"\"\"\n    out_dir.mkdir(parents=True, exist_ok=True)\n    for name in (\"meta.json\", \"constraints.json\", \"gt_motion.npz\"):\n        src_file = src_dir / name\n        if src_file.is_file():\n            shutil.copy2(src_file, out_dir / name)\n\n\nclass EvalExampleDataset(Dataset):\n    \"\"\"Dataset of example folders: yields text, num_frame, constraints_path (and paths, meta).\n    No torch/skeleton in workers so num_workers > 0 is safe with CUDA.\n    \"\"\"\n\n    def __init__(\n        self,\n        examples: list[tuple[Path, Path]],\n        testsuite_root: Path,\n        generated_root: Path,\n        fps: float,\n    ):\n        self.examples = examples\n        self.testsuite_root = testsuite_root\n        self.generated_root = generated_root\n        self.fps = fps\n\n    def __len__(self) -> int:\n        return len(self.examples)\n\n    def __getitem__(self, idx: int) -> dict[str, Any]:\n        src_dir, rel_path = self.examples[idx]\n        out_dir = self.generated_root / rel_path\n        meta_path = src_dir / \"meta.json\"\n        meta = load_json(str(meta_path))\n        assert meta.get(\"num_samples\", 1) == 1, \"Expected num_samples to be absent or 1 in meta.json\"\n        texts, durations_sec = parse_prompts_from_meta(meta)\n        assert len(texts) == 1, \"Expected exactly one prompt (len(texts)==1) per example\"\n        num_frames = [int(float(d) * self.fps) for d in durations_sec]\n        assert len(num_frames) == 1, \"Expected exactly one duration per example\"\n        constraints_path = src_dir / \"constraints.json\"\n        cpath = str(constraints_path) if constraints_path.is_file() else None\n        return {\n            \"rel_path\": rel_path,\n            \"src_dir\": str(src_dir),\n            \"out_dir\": str(out_dir),\n            \"meta\": meta,\n            \"text\": texts[0],\n            \"num_frame\": num_frames[0],\n            \"constraints_path\": cpath,\n        }\n\n\ndef collate_examples(batch: list[dict]) -> dict[str, Any]:\n    \"\"\"Collate list of example dicts; keep list fields as lists (no stacking).\"\"\"\n    if not batch:\n        return {}\n    keys = batch[0].keys()\n    out: dict[str, Any] = {}\n    for k in keys:\n        vals = [b[k] for b in batch]\n        out[k] = vals\n    return out\n\n\ndef group_by_parent(\n    examples: list[tuple[Path, Path]],\n) -> list[list[tuple[Path, Path]]]:\n    \"\"\"Group (src_dir, rel_path) by parent directory of rel_path for folder-by-folder processing.\"\"\"\n    from itertools import groupby\n\n    def parent_key(item: tuple[Path, Path]) -> Path:\n        rel = item[1]\n        return rel.parent if len(rel.parts) > 1 else Path(\".\")\n\n    sorted_examples = sorted(examples, key=parent_key)\n    groups: list[list[tuple[Path, Path]]] = []\n    for _key, group in groupby(sorted_examples, key=parent_key):\n        groups.append(list(group))\n    return groups\n\n\ndef _slice_output_at(output: dict[str, Any], index: int) -> dict[str, Any]:\n    \"\"\"Slice a (possibly nested) output dict at batch index for one sample.\"\"\"\n    out: dict[str, Any] = {}\n    for k, v in output.items():\n        if isinstance(v, dict):\n            out[k] = _slice_output_at(v, index)\n        elif isinstance(v, np.ndarray) and v.ndim > 0:\n            out[k] = v[index]\n        else:\n            out[k] = v\n    return out\n\n\ndef _crop_output(output: dict[str, Any], num_frames: int) -> dict[str, Any]:\n    \"\"\"Crop a single-sample output dict along the time dimension (axis 0).\"\"\"\n    out: dict[str, Any] = {}\n    for k, v in output.items():\n        if isinstance(v, dict):\n            out[k] = _crop_output(v, num_frames)\n        elif isinstance(v, np.ndarray) and v.ndim >= 1:\n            out[k] = v[:num_frames]\n        else:\n            out[k] = v\n    return out\n\n\ndef main():\n    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n    print(f\"Using device: {device}\")\n\n    args = parse_args()\n    testsuite_root = Path(args.benchmark).resolve()\n    if args.output is not None:\n        generated_root = Path(args.output).resolve()\n    else:\n        generated_root = testsuite_root\n    in_place = generated_root == testsuite_root\n\n    examples = discover_example_folders(testsuite_root)\n    if not examples:\n        raise SystemExit(f\"No folders with meta.json found under {testsuite_root}\")\n    print(f\"Discovered {len(examples)} example folders.\")\n\n    model, resolved_name = load_model(\n        args.model,\n        device=device,\n        default_family=\"Kimodo\",\n        return_resolved_name=True,\n        text_encoder_fp32=args.text_encoder_fp32,\n    )\n    # v1.1 models are meant to be used for benchmark evaluation\n    _deprecated_for_benchmark = {\n        \"kimodo-soma-rp-v1\": \"Kimodo-SOMA-RP-v1 was not trained to be compatible with the benchmark evaluation.\",\n        \"kimodo-soma-seed-v1\": \"Kimodo-SOMA-SEED-v1 is not the latest model for benchmark evaluation.\",\n    }\n    if resolved_name in _deprecated_for_benchmark:\n        import warnings\n\n        warnings.warn(\n            f\"Model '{args.model}' resolved to {resolved_name}: \"\n            f\"{_deprecated_for_benchmark[resolved_name]} Consider using v1.1.\",\n            stacklevel=1,\n        )\n    print(f\"Generating with model: {resolved_name}\")\n    fps = model.fps\n    default_diffusion_steps = args.diffusion_steps\n\n    groups = group_by_parent(examples)\n    total_generated = 0\n    total_skipped = 0\n\n    total_examples = len(examples)\n    for group in groups:\n        rel_path_0 = group[0][1]\n        if rel_path_0.parent != Path(\".\"):\n            folder_label = str(rel_path_0.parent)\n        else:\n            # Direct children of testsuite root: show root name (e.g. inbetweening)\n            folder_label = testsuite_root.name\n        num_in_folder = len(group)\n        print(f\"Generating folder: {folder_label} ({num_in_folder} motions)\")\n\n        dataset = EvalExampleDataset(\n            group,\n            testsuite_root,\n            generated_root,\n            fps=fps,\n        )\n        loader = DataLoader(\n            dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            num_workers=args.num_workers,\n            collate_fn=collate_examples,\n        )\n\n        folder_generated = 0\n        folder_skipped = 0\n        for batch_idx, batch in enumerate(loader):\n            rel_paths = batch[\"rel_path\"]\n            src_dirs = batch[\"src_dir\"]\n            out_dirs = batch[\"out_dir\"]\n            metas = batch[\"meta\"]\n            batch_texts = batch[\"text\"]\n            batch_num_frames = batch[\"num_frame\"]\n            constraints_paths = batch[\"constraints_path\"]\n\n            # Filter out samples that are already generated (unless --overwrite).\n            if args.overwrite:\n                selected_indices = list(range(len(rel_paths)))\n            else:\n                selected_indices = []\n                for i, out_dir_str in enumerate(out_dirs):\n                    motion_path = Path(out_dir_str) / \"motion.npz\"\n                    if motion_path.is_file():\n                        folder_skipped += 1\n                        total_skipped += 1\n                        continue\n                    selected_indices.append(i)\n\n            if not selected_indices:\n                print(\n                    f\"\\r  Generated {folder_generated} / {num_in_folder} (skipped: {folder_skipped}) \"\n                    f\"(total: {total_generated + total_skipped} / {total_examples})\",\n                    end=\"\",\n                    flush=True,\n                )\n                continue\n\n            rel_paths = [rel_paths[i] for i in selected_indices]\n            src_dirs = [src_dirs[i] for i in selected_indices]\n            out_dirs = [out_dirs[i] for i in selected_indices]\n            metas = [metas[i] for i in selected_indices]\n            batch_texts = [batch_texts[i] for i in selected_indices]\n            batch_num_frames = [batch_num_frames[i] for i in selected_indices]\n            constraints_paths = [constraints_paths[i] for i in selected_indices]\n\n            # Load constraints in main process on model device (no torch in workers)\n            device_t = torch.device(device)\n            batch_constraints_lst = [\n                load_constraints_lst(cpath, model.skeleton, device=device_t) if cpath else []\n                for cpath in constraints_paths\n            ]\n\n            if not in_place:\n                for i in range(len(rel_paths)):\n                    copy_source_files(Path(src_dirs[i]), Path(out_dirs[i]))\n\n            # Use first example's diffusion_steps and seed for the whole batch\n            diffusion_steps = metas[0].get(\"diffusion_steps\", default_diffusion_steps)\n            seed = metas[0].get(\"seed\", None)\n            if seed is not None:\n                seed_everything(seed)\n            else:\n                print(\"Warning: No seed found in meta.json, not seeding this batch.\")\n\n            # Single model call for the entire batch (count in bar title, bar clears when done)\n            bar_desc = (\n                f\"  Generated {folder_generated} / {num_in_folder} \"\n                f\"(skipped: {folder_skipped}) (total: {total_generated + total_skipped} / {total_examples})\"\n            )\n            output = model(\n                batch_texts,\n                batch_num_frames,\n                constraint_lst=batch_constraints_lst,\n                num_denoising_steps=diffusion_steps,\n                multi_prompt=False,\n                post_processing=args.postprocess,\n                return_numpy=True,\n                progress_bar=lambda x: tqdm(x, leave=False, desc=bar_desc),\n            )\n\n            # Save each sample to its output dir\n            B = len(batch_texts)\n            for b in range(B):\n                out_dir = Path(out_dirs[b])\n                sample_output = _slice_output_at(output, b)\n                sample_output = _crop_output(sample_output, batch_num_frames[b])\n                motion_path = out_dir / \"motion.npz\"\n                np.savez(motion_path, **sample_output)\n                total_generated += 1\n                folder_generated += 1\n\n            print(\n                f\"\\r  Generated {folder_generated} / {num_in_folder} (skipped: {folder_skipped}) \"\n                f\"(total: {total_generated + total_skipped} / {total_examples})\",\n                end=\"\",\n                flush=True,\n            )\n\n        print()\n        print(\n            f\"  Finished folder {folder_label} ({num_in_folder} motions, \"\n            f\"generated: {folder_generated}, skipped: {folder_skipped}).\"\n        )\n\n    if in_place:\n        print(f\"Generated {total_generated} motions in-place under {testsuite_root}.\")\n    else:\n        print(f\"Generated {total_generated} motions under {generated_root}.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmark/parse_folder.py",
    "content": "#!/usr/bin/env python3\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nStep (5) of evaluation pipeline.\n\nValidate testcase result JSONs and aggregate benchmark rows.\n\nExpected testsuite layout (aligned with evaluate_folder output):\n\n    <root>/\n    ├── <split>/                    # e.g. content, repetition\n    │   ├── text2motion/            # text-following eval\n    │   │   ├── overview/           # or timeline_single, timeline_multi\n    │   │   │   └── <testcase>.json\n    │   │   └── ...\n    │   └── <category>/             # constraints_withtext, constraints_notext\n    │       └── .../                 # optional subdirs, e.g. root, fullbody\n    │           └── <testcase>/\n    │           └── <testcase>.json\n\nSamples are discovered via rglob('meta.json') with motion.npz and gt_motion.npz in the same dir.\nTestcase dir = parent of a sample dir. Result file = testcase_dir.parent / f\"{testcase_dir.name}.json\".\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport json\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Any\n\nSPLITS = (\"content\", \"repetition\")\nTEXT_FOLLOWING_CATEGORIES = (\"overview\", \"timeline_single\", \"timeline_multi\")\nCONSTRAINTS_CATEGORIES = (\"constraints_withtext\", \"constraints_notext\")\nROW_CATEGORIES = TEXT_FOLLOWING_CATEGORIES + CONSTRAINTS_CATEGORIES\n\n\ndef _discover_sample_dirs(root: Path) -> list[Path]:\n    sample_dirs: list[Path] = []\n    for meta_path in root.rglob(\"meta.json\"):\n        sample_dir = meta_path.parent\n        if (sample_dir / \"motion.npz\").is_file() and (sample_dir / \"gt_motion.npz\").is_file():\n            sample_dirs.append(sample_dir)\n    return sorted(set(sample_dirs))\n\n\ndef _discover_testcase_dirs(root: Path) -> list[Path]:\n    sample_dirs = _discover_sample_dirs(root)\n    return sorted({sample_dir.parent for sample_dir in sample_dirs})\n\n\ndef _expected_result_path(testcase_dir: Path) -> Path:\n    return testcase_dir.parent / f\"{testcase_dir.name}.json\"\n\n\ndef _parse_testcase_key(root: Path, testcase_dir: Path) -> tuple[str, str]:\n    rel_parts = testcase_dir.relative_to(root).parts\n    if len(rel_parts) < 2:\n        raise ValueError(f\"Unexpected testcase path shape: {testcase_dir} (relative: {'/'.join(rel_parts)})\")\n    split = rel_parts[0]\n    if split not in SPLITS:\n        raise ValueError(f\"Unknown split '{split}' for testcase {testcase_dir}\")\n    if len(rel_parts) >= 3 and rel_parts[1] == \"text2motion\":\n        category = rel_parts[2]\n        if category not in TEXT_FOLLOWING_CATEGORIES:\n            raise ValueError(f\"Unknown text-following category '{category}' for testcase {testcase_dir}\")\n    else:\n        category = rel_parts[1]\n        if category not in CONSTRAINTS_CATEGORIES:\n            raise ValueError(f\"Unknown category '{category}' for testcase {testcase_dir}\")\n    return split, category\n\n\ndef _accumulate_weighted(\n    sum_acc: dict[str, float],\n    weight_acc: dict[str, float],\n    metric_dict: dict[str, Any],\n    weight: float,\n) -> None:\n    for metric_name, value in metric_dict.items():\n        if isinstance(value, (int, float)):\n            sum_acc[metric_name] = sum_acc.get(metric_name, 0.0) + float(value) * weight\n            weight_acc[metric_name] = weight_acc.get(metric_name, 0.0) + weight\n\n\ndef _to_averages(\n    weighted_sum: dict[str, float], weight: dict[str, float]\n) -> dict[str, float]:\n    return {\n        k: v / weight[k]\n        for k, v in sorted(weighted_sum.items())\n        if weight.get(k, 0.0) > 0\n    }\n\n\ndef _load_result_row(\n    result_path: Path,\n) -> tuple[float, dict[str, Any], dict[str, Any], dict[str, Any]]:\n    payload = json.loads(result_path.read_text(encoding=\"utf-8\"))\n    num_motions = float(payload.get(\"num_motions\", 1))\n    per_motion_mean_gen = payload.get(\"per_motion_mean_gen\") or payload.get(\"per_motion_mean\", {})\n    per_motion_mean_gt = payload.get(\"per_motion_mean_gt\") or {}\n    tmr = payload.get(\"tmr\") or {}\n    if not isinstance(per_motion_mean_gen, dict):\n        raise ValueError(f\"'per_motion_mean_gen' / 'per_motion_mean' is not a dict in {result_path}\")\n    if not isinstance(per_motion_mean_gt, dict):\n        raise ValueError(f\"'per_motion_mean_gt' is not a dict in {result_path}\")\n    if not isinstance(tmr, dict):\n        raise ValueError(f\"'tmr' is not a dict in {result_path}\")\n    return num_motions, per_motion_mean_gen, per_motion_mean_gt, tmr\n\n\n# Display labels for table rows (paper-style).\nTEXT_FOLLOWING_ROW_LABELS = {\n    \"overview\": \"Overview\",\n    \"timeline_single\": \"Timeline single\",\n    \"timeline_multi\": \"Timeline multi\",\n}\nCONSTRAINTS_ROW_LABELS = {\n    \"constraints_withtext\": \"Constraints with text\",\n    \"constraints_notext\": \"Constraints without text\",\n}\n\n# Meters to cm for constraint position metrics.\nM_TO_CM = 100.0\n\n\ndef _table_value(val: float | None) -> float | str | None:\n    \"\"\"Return value for JSON table; use None for missing (omit or serialize as null).\"\"\"\n    if val is None:\n        return None\n    if isinstance(val, (int, float)) and (val != val or val == float(\"inf\")):  # nan or inf\n        return None\n    return val\n\n\ndef _build_tables(\n    row_acc: dict[tuple[str, str], dict[str, Any]],\n) -> dict[str, dict[str, list[dict[str, Any]]]]:\n    \"\"\"Build text_following and constraints tables per split for paper-style output.\"\"\"\n    tables: dict[str, dict[str, list[dict[str, Any]]]] = {}\n    for split in SPLITS:\n        tables[split] = {\"text_following\": [], \"constraints\": []}\n\n        # Text-following table: Overview, Timeline single, Timeline multi.\n        for category in TEXT_FOLLOWING_CATEGORIES:\n            acc = row_acc[(split, category)]\n            per_motion_gen = _to_averages(acc[\"per_motion_mean_weighted_sum\"], acc[\"per_motion_mean_weight\"])\n            per_motion_gt = _to_averages(acc[\"per_motion_mean_gt_weighted_sum\"], acc[\"per_motion_mean_gt_weight\"])\n            tmr_avg = _to_averages(acc[\"tmr_weighted_sum\"], acc[\"tmr_weight\"]) if acc[\"tmr_weight\"] else {}\n            r03_gen = tmr_avg.get(\"TMR/t2m_R/R03\")\n            r03_gt = tmr_avg.get(\"TMR/t2m_gt_R/R03\")\n            fid_gen_text = tmr_avg.get(\"TMR/FID/gen_text\")\n            fid_gt_text = tmr_avg.get(\"TMR/FID/gt_text\")\n            fid_gen_gt = tmr_avg.get(\"TMR/FID/gen_gt\")\n            # Skate is velocity in m/s; convert to cm/s for display.\n            skate_gen = per_motion_gen.get(\"foot_skate_from_pred_contacts\")\n            skate_gt = per_motion_gt.get(\"foot_skate_from_pred_contacts\")\n            contact_gen = per_motion_gen.get(\"foot_contact_consistency\")\n            contact_gt = per_motion_gt.get(\"foot_contact_consistency\")\n            row_label = TEXT_FOLLOWING_ROW_LABELS[category]\n            tables[split][\"text_following\"].append(\n                {\n                    \"row\": row_label,\n                    \"R@3 (gen)\": _table_value(r03_gen),\n                    \"R@3 (GT)\": _table_value(r03_gt),\n                    \"FID gen-text\": _table_value(fid_gen_text),\n                    \"FID GT-text\": _table_value(fid_gt_text),\n                    \"FID gen-GT\": _table_value(fid_gen_gt),\n                    \"Skate (gen, cm/s)\": _table_value(skate_gen * 100.0 if skate_gen is not None else None),\n                    \"Skate (GT, cm/s)\": _table_value(skate_gt * 100.0 if skate_gt is not None else None),\n                    \"Contact (gen)\": _table_value(contact_gen),\n                    \"Contact (GT)\": _table_value(contact_gt),\n                }\n            )\n\n        # Constraints table: Constraints with text, Constraints without text.\n        for category in CONSTRAINTS_CATEGORIES:\n            acc = row_acc[(split, category)]\n            per_motion_gen = _to_averages(acc[\"per_motion_mean_weighted_sum\"], acc[\"per_motion_mean_weight\"])\n            per_motion_gt = _to_averages(acc[\"per_motion_mean_gt_weighted_sum\"], acc[\"per_motion_mean_gt_weight\"])\n            row_label = CONSTRAINTS_ROW_LABELS[category]\n            row_dict: dict[str, Any] = {\n                \"row\": row_label,\n                \"Full-Body Pos (gen, cm)\": _table_value(\n                    per_motion_gen.get(\"constraint_fullbody_keyframe\") * M_TO_CM\n                    if per_motion_gen.get(\"constraint_fullbody_keyframe\") is not None\n                    else None\n                ),\n                \"Full-Body Pos (GT, cm)\": _table_value(\n                    per_motion_gt.get(\"constraint_fullbody_keyframe\") * M_TO_CM\n                    if per_motion_gt.get(\"constraint_fullbody_keyframe\") is not None\n                    else None\n                ),\n                \"End-Effector Pos (gen, cm)\": _table_value(\n                    per_motion_gen.get(\"constraint_end_effector\") * M_TO_CM\n                    if per_motion_gen.get(\"constraint_end_effector\") is not None\n                    else None\n                ),\n                \"End-Effector Pos (GT, cm)\": _table_value(\n                    per_motion_gt.get(\"constraint_end_effector\") * M_TO_CM\n                    if per_motion_gt.get(\"constraint_end_effector\") is not None\n                    else None\n                ),\n                \"End-Effector Rot (deg)\": None,  # Not implemented in metrics.\n                \"2D Root Pos (gen, cm)\": _table_value(\n                    per_motion_gen.get(\"constraint_root2d_err\") * M_TO_CM\n                    if per_motion_gen.get(\"constraint_root2d_err\") is not None\n                    else None\n                ),\n                \"2D Root Pos (GT, cm)\": _table_value(\n                    per_motion_gt.get(\"constraint_root2d_err\") * M_TO_CM\n                    if per_motion_gt.get(\"constraint_root2d_err\") is not None\n                    else None\n                ),\n                \"2D Pelvis Pos@95% (gen, cm)\": _table_value(\n                    per_motion_gen.get(\"constraint_root2d_err_p95\") * M_TO_CM\n                    if per_motion_gen.get(\"constraint_root2d_err_p95\") is not None\n                    else None\n                ),\n                \"2D Pelvis Pos@95% (GT, cm)\": _table_value(\n                    per_motion_gt.get(\"constraint_root2d_err_p95\") * M_TO_CM\n                    if per_motion_gt.get(\"constraint_root2d_err_p95\") is not None\n                    else None\n                ),\n            }\n            tables[split][\"constraints\"].append(row_dict)\n\n    return tables\n\n\ndef _fmt_md(val: float | None, decimals: int) -> str:\n    \"\"\"Format a numeric value for a markdown cell, or '-' for None/NaN.\"\"\"\n    if val is None:\n        return \"-\"\n    if isinstance(val, float) and (val != val or val == float(\"inf\")):\n        return \"-\"\n    return f\"{val:.{decimals}f}\"\n\n\ndef _print_tf_formatted_md(\n    splits_data: list[tuple[str, list[dict[str, Any]]]],\n    title: str,\n) -> None:\n    \"\"\"Print text-following table in markdown, mirroring the terminal layout.\"\"\"\n    groups = [\"Overview\", \"Timeline single\", \"Timeline multi\"]\n    specs: list[tuple[str, int]] = [\n        (\"R@3\\u2191\", 2),\n        (\"FID\\u2193\", 3),\n        (\"Skate\\u2193\", 3),\n        (\"Contact\\u2191\", 3),\n    ]\n    gt_keys = [\"R@3 (GT)\", None, \"Skate (GT, cm/s)\", \"Contact (GT)\"]\n    gen_keys = [\"R@3 (gen)\", \"FID gen-GT\", \"Skate (gen, cm/s)\", \"Contact (gen)\"]\n    gt_defaults: list[float | None] = [None, 0.0, None, None]\n\n    headers = [\"\"]\n    for g in groups:\n        for hdr, _ in specs:\n            headers.append(f\"{g} {hdr}\")\n\n    print(f\"\\n### {title}\\n\")\n    print(\"| \" + \" | \".join(headers) + \" |\")\n    print(\"| \" + \" | \".join(\"---\" for _ in headers) + \" |\")\n\n    for split_label, rows in splits_data:\n        for row_type, keys, defaults in [\n            (\"Ground Truth\", gt_keys, gt_defaults),\n            (\"Method\", gen_keys, [None] * len(specs)),\n        ]:\n            cells = [f\"**{split_label}** {row_type}\"]\n            for row in rows:\n                for j, (_, dec) in enumerate(specs):\n                    key = keys[j]\n                    val = defaults[j] if key is None else row.get(key)\n                    cells.append(_fmt_md(val, dec))\n            print(\"| \" + \" | \".join(cells) + \" |\")\n\n    print()\n\n\ndef _print_c_formatted_md(\n    splits_data: list[tuple[str, list[dict[str, Any]]]],\n    title: str,\n) -> None:\n    \"\"\"Print constraints table in markdown, mirroring the terminal layout.\"\"\"\n    groups = [\"With text\", \"Without text\"]\n    specs: list[tuple[str, int]] = [\n        (\"FB Pos\\u2193\", 3),\n        (\"EE Pos\\u2193\", 3),\n        (\"EE Rot\\u2193\", 3),\n        (\"2D Root\\u2193\", 3),\n        (\"Pelvis@95%\", 2),\n    ]\n    gt_keys = [\n        \"Full-Body Pos (GT, cm)\",\n        \"End-Effector Pos (GT, cm)\",\n        \"End-Effector Rot (deg)\",\n        \"2D Root Pos (GT, cm)\",\n        \"2D Pelvis Pos@95% (GT, cm)\",\n    ]\n    gen_keys = [\n        \"Full-Body Pos (gen, cm)\",\n        \"End-Effector Pos (gen, cm)\",\n        \"End-Effector Rot (deg)\",\n        \"2D Root Pos (gen, cm)\",\n        \"2D Pelvis Pos@95% (gen, cm)\",\n    ]\n\n    headers = [\"\"]\n    for g in groups:\n        for hdr, _ in specs:\n            headers.append(f\"{g} {hdr}\")\n\n    print(f\"\\n### {title}\\n\")\n    print(\"| \" + \" | \".join(headers) + \" |\")\n    print(\"| \" + \" | \".join(\"---\" for _ in headers) + \" |\")\n\n    for split_label, rows in splits_data:\n        for row_type, keys in [(\"Ground Truth\", gt_keys), (\"Method\", gen_keys)]:\n            cells = [f\"**{split_label}** {row_type}\"]\n            for row in rows:\n                for j, (_, dec) in enumerate(specs):\n                    cells.append(_fmt_md(row.get(keys[j]), dec))\n            print(\"| \" + \" | \".join(cells) + \" |\")\n\n    print()\n\n\ndef _print_formatted_gt_method_md(\n    tables: dict[str, dict[str, list[dict[str, Any]]]],\n) -> None:\n    \"\"\"Print combined tables in markdown format, mirroring the terminal layout.\"\"\"\n    tf_splits: list[tuple[str, list[dict[str, Any]]]] = []\n    c_splits: list[tuple[str, list[dict[str, Any]]]] = []\n    for split in SPLITS:\n        split_tables = tables.get(split, {})\n        tf_rows = split_tables.get(\"text_following\", [])\n        c_rows = split_tables.get(\"constraints\", [])\n        if tf_rows and len(tf_rows) == 3:\n            tf_splits.append((split.capitalize(), tf_rows))\n        if c_rows and len(c_rows) == 2:\n            c_splits.append((split.capitalize(), c_rows))\n\n    if tf_splits:\n        _print_tf_formatted_md(tf_splits, \"Text-Following Evaluation\")\n    if c_splits:\n        _print_c_formatted_md(c_splits, \"Constrained Evaluation\")\n\n\ndef _fmt(val: float | None, decimals: int, width: int) -> str:\n    \"\"\"Format a numeric value right-aligned to *width*, or '-' for None.\"\"\"\n    if val is None:\n        return f\"{'-':>{width}}\"\n    return f\"{val:>{width}.{decimals}f}\"\n\n\ndef _print_grouped_rows(\n    label: str,\n    rows: list[dict[str, Any]],\n    specs: list[tuple[str, int, int]],\n    keys: list[str],\n    mw: int,\n    sep: str,\n) -> None:\n    \"\"\"Print one data row across all column groups.\"\"\"\n    parts = [f\"{label:<{mw}}\"]\n    for i, row in enumerate(rows):\n        if i:\n            parts.append(sep)\n        for j, (_, dec, w) in enumerate(specs):\n            parts.append(_fmt(row.get(keys[j]), dec, w))\n    print(\"\".join(parts))\n\n\ndef _print_tf_formatted(\n    splits_data: list[tuple[str, list[dict[str, Any]]]],\n    title: str,\n) -> None:\n    \"\"\"Print text-following table with Overview / Timeline single / Timeline multi groups.\n\n    *splits_data* is a list of ``(split_label, category_rows)`` tuples so\n    that content and repetition splits appear as separate row-pairs inside\n    one table.\n    \"\"\"\n    groups = [\"Overview\", \"Timeline single\", \"Timeline multi\"]\n    specs: list[tuple[str, int, int]] = [\n        (\"R@3\\u2191\", 2, 7),\n        (\"FID\\u2193\", 3, 7),\n        (\"Skate\\u2193\", 3, 9),\n        (\"Contact\\u2191\", 3, 10),\n    ]\n    gt_keys = [\"R@3 (GT)\", None, \"Skate (GT, cm/s)\", \"Contact (GT)\"]\n    gen_keys = [\"R@3 (gen)\", \"FID gen-GT\", \"Skate (gen, cm/s)\", \"Contact (gen)\"]\n    gt_defaults: list[float | None] = [None, 0.0, None, None]\n\n    mw = 16\n    gw = sum(s[2] for s in specs)\n    sep = \" | \"\n    total_w = mw + len(groups) * gw + (len(groups) - 1) * len(sep)\n\n    print(f\"\\n{title:^{total_w}}\")\n    print(\"=\" * total_w)\n\n    parts: list[str] = [\" \" * mw]\n    for i, g in enumerate(groups):\n        if i:\n            parts.append(sep)\n        parts.append(g.center(gw))\n    print(\"\".join(parts))\n\n    parts = [f\"{'':<{mw}}\"]\n    for i in range(len(groups)):\n        if i:\n            parts.append(sep)\n        for hdr, _, w in specs:\n            parts.append(f\"{hdr:>{w}}\")\n    print(\"\".join(parts))\n\n    parts = [\"\\u2500\" * mw]\n    for i in range(len(groups)):\n        if i:\n            parts.append(\"\\u2500\\u253c\\u2500\")\n        parts.append(\"\\u2500\" * gw)\n    print(\"\".join(parts))\n\n    for si, (split_label, rows) in enumerate(splits_data):\n        tag = f\"\\u2500\\u2500 {split_label} \"\n        print(tag + \"\\u2500\" * (total_w - len(tag)))\n\n        parts = [f\"{'Ground Truth':<{mw}}\"]\n        for i, row in enumerate(rows):\n            if i:\n                parts.append(sep)\n            for j, (_, dec, w) in enumerate(specs):\n                key = gt_keys[j]\n                val = gt_defaults[j] if key is None else row.get(key)\n                parts.append(_fmt(val, dec, w))\n        print(\"\".join(parts))\n\n        _print_grouped_rows(\"Method\", rows, specs, gen_keys, mw, sep)\n\n    print()\n\n\ndef _print_c_formatted(\n    splits_data: list[tuple[str, list[dict[str, Any]]]],\n    title: str,\n) -> None:\n    \"\"\"Print constraints table with With text / Without text groups.\n\n    *splits_data* is a list of ``(split_label, category_rows)`` tuples.\n    \"\"\"\n    groups = [\"With text\", \"Without text\"]\n    specs: list[tuple[str, int, int]] = [\n        (\"FB Pos\\u2193\", 3, 10),\n        (\"EE Pos\\u2193\", 3, 10),\n        (\"EE Rot\\u2193\", 3, 10),\n        (\"2D Root\\u2193\", 3, 11),\n        (\"Pelvis@95%\", 2, 12),\n    ]\n    gt_keys = [\n        \"Full-Body Pos (GT, cm)\",\n        \"End-Effector Pos (GT, cm)\",\n        \"End-Effector Rot (deg)\",\n        \"2D Root Pos (GT, cm)\",\n        \"2D Pelvis Pos@95% (GT, cm)\",\n    ]\n    gen_keys = [\n        \"Full-Body Pos (gen, cm)\",\n        \"End-Effector Pos (gen, cm)\",\n        \"End-Effector Rot (deg)\",\n        \"2D Root Pos (gen, cm)\",\n        \"2D Pelvis Pos@95% (gen, cm)\",\n    ]\n\n    mw = 16\n    gw = sum(s[2] for s in specs)\n    sep = \" | \"\n    total_w = mw + len(groups) * gw + (len(groups) - 1) * len(sep)\n\n    print(f\"\\n{title:^{total_w}}\")\n    print(\"=\" * total_w)\n\n    parts: list[str] = [\" \" * mw]\n    for i, g in enumerate(groups):\n        if i:\n            parts.append(sep)\n        parts.append(g.center(gw))\n    print(\"\".join(parts))\n\n    parts = [f\"{'':<{mw}}\"]\n    for i in range(len(groups)):\n        if i:\n            parts.append(sep)\n        for hdr, _, w in specs:\n            parts.append(f\"{hdr:>{w}}\")\n    print(\"\".join(parts))\n\n    parts = [\"\\u2500\" * mw]\n    for i in range(len(groups)):\n        if i:\n            parts.append(\"\\u2500\\u253c\\u2500\")\n        parts.append(\"\\u2500\" * gw)\n    print(\"\".join(parts))\n\n    for si, (split_label, rows) in enumerate(splits_data):\n        tag = f\"\\u2500\\u2500 {split_label} \"\n        print(tag + \"\\u2500\" * (total_w - len(tag)))\n\n        _print_grouped_rows(\"Ground Truth\", rows, specs, gt_keys, mw, sep)\n        _print_grouped_rows(\"Method\", rows, specs, gen_keys, mw, sep)\n\n    print()\n\n\ndef _print_formatted_gt_method(\n    tables: dict[str, dict[str, list[dict[str, Any]]]],\n) -> None:\n    \"\"\"Print combined tables with column groups separated by vertical bars.\n\n    Content and repetition splits are shown as separate row-pairs inside one text-following table\n    and one constraints table.\n    \"\"\"\n    tf_splits: list[tuple[str, list[dict[str, Any]]]] = []\n    c_splits: list[tuple[str, list[dict[str, Any]]]] = []\n    for split in SPLITS:\n        split_tables = tables.get(split, {})\n        tf_rows = split_tables.get(\"text_following\", [])\n        c_rows = split_tables.get(\"constraints\", [])\n        if tf_rows and len(tf_rows) == 3:\n            tf_splits.append((split.capitalize(), tf_rows))\n        if c_rows and len(c_rows) == 2:\n            c_splits.append((split.capitalize(), c_rows))\n\n    if tf_splits:\n        _print_tf_formatted(tf_splits, \"Text-Following Evaluation\")\n    if c_splits:\n        _print_c_formatted(c_splits, \"Constrained Evaluation\")\n\n\ndef _build_summary(root: Path) -> dict[str, Any]:\n    testcase_dirs = _discover_testcase_dirs(root)\n    if not testcase_dirs:\n        raise SystemExit(\n            f\"No testcase folders found under {root} (expected folders containing meta.json + motion.npz + gt_motion.npz samples).\"\n        )\n\n    missing_results: list[Path] = []\n    for testcase_dir in testcase_dirs:\n        result_path = _expected_result_path(testcase_dir)\n        if not result_path.is_file():\n            missing_results.append(result_path)\n\n    if missing_results:\n        missing_text = \"\\n\".join(str(path) for path in missing_results)\n        raise SystemExit(f\"Missing {len(missing_results)} testcase result JSON files:\\n{missing_text}\")\n\n    row_acc: dict[tuple[str, str], dict[str, Any]] = {}\n    for split in SPLITS:\n        for category in ROW_CATEGORIES:\n            row_acc[(split, category)] = {\n                \"num_testcases\": 0,\n                \"num_motions\": 0.0,\n                \"per_motion_mean_weighted_sum\": {},\n                \"per_motion_mean_weight\": {},\n                \"per_motion_mean_gt_weighted_sum\": {},\n                \"per_motion_mean_gt_weight\": {},\n                \"tmr_weighted_sum\": {},\n                \"tmr_weight\": {},\n            }\n\n    for testcase_dir in testcase_dirs:\n        split, category = _parse_testcase_key(root, testcase_dir)\n        result_path = _expected_result_path(testcase_dir)\n        num_motions, per_motion_mean_gen, per_motion_mean_gt, tmr = _load_result_row(result_path)\n\n        acc = row_acc[(split, category)]\n        acc[\"num_testcases\"] += 1\n        acc[\"num_motions\"] += num_motions\n        _accumulate_weighted(\n            acc[\"per_motion_mean_weighted_sum\"],\n            acc[\"per_motion_mean_weight\"],\n            per_motion_mean_gen,\n            num_motions,\n        )\n        if per_motion_mean_gt:\n            _accumulate_weighted(\n                acc[\"per_motion_mean_gt_weighted_sum\"],\n                acc[\"per_motion_mean_gt_weight\"],\n                per_motion_mean_gt,\n                num_motions,\n            )\n        if tmr:\n            _accumulate_weighted(\n                acc[\"tmr_weighted_sum\"],\n                acc[\"tmr_weight\"],\n                tmr,\n                num_motions,\n            )\n\n    rows: list[dict[str, Any]] = []\n    for split in SPLITS:\n        for category in ROW_CATEGORIES:\n            acc = row_acc[(split, category)]\n            tmr_avg = _to_averages(acc[\"tmr_weighted_sum\"], acc[\"tmr_weight\"]) if acc[\"tmr_weight\"] else {}\n            per_motion_gt_avg = _to_averages(acc[\"per_motion_mean_gt_weighted_sum\"], acc[\"per_motion_mean_gt_weight\"])\n            row_dict: dict[str, Any] = {\n                \"split\": split,\n                \"category\": category,\n                \"num_testcases\": acc[\"num_testcases\"],\n                \"num_motions\": int(acc[\"num_motions\"]),\n                \"per_motion_mean\": _to_averages(acc[\"per_motion_mean_weighted_sum\"], acc[\"per_motion_mean_weight\"]),\n                \"tmr\": tmr_avg,\n            }\n            if per_motion_gt_avg:\n                row_dict[\"per_motion_mean_gt\"] = per_motion_gt_avg\n            rows.append(row_dict)\n\n        # Combined constraints row for this split.\n        withtext = row_acc[(split, \"constraints_withtext\")]\n        notext = row_acc[(split, \"constraints_notext\")]\n\n        combined_per_motion = defaultdict(float)\n        combined_per_motion_weight = defaultdict(float)\n        combined_per_motion_gt = defaultdict(float)\n        combined_per_motion_gt_weight = defaultdict(float)\n        combined_tmr = defaultdict(float)\n        combined_tmr_weight = defaultdict(float)\n        for sum_key, weight_key, sum_acc, weight_acc in (\n            (\"per_motion_mean_weighted_sum\", \"per_motion_mean_weight\", combined_per_motion, combined_per_motion_weight),\n            (\"per_motion_mean_gt_weighted_sum\", \"per_motion_mean_gt_weight\", combined_per_motion_gt, combined_per_motion_gt_weight),\n            (\"tmr_weighted_sum\", \"tmr_weight\", combined_tmr, combined_tmr_weight),\n        ):\n            for src in (withtext, notext):\n                for k, v in src[sum_key].items():\n                    sum_acc[k] += v\n                for k, w in src[weight_key].items():\n                    weight_acc[k] += w\n\n        combined_tmr_avg = _to_averages(dict(combined_tmr), dict(combined_tmr_weight)) if combined_tmr_weight else {}\n        combined_gt_avg = _to_averages(dict(combined_per_motion_gt), dict(combined_per_motion_gt_weight))\n        combined_row: dict[str, Any] = {\n            \"split\": split,\n            \"category\": \"constraints\",\n            \"num_testcases\": withtext[\"num_testcases\"] + notext[\"num_testcases\"],\n            \"num_motions\": int(withtext[\"num_motions\"] + notext[\"num_motions\"]),\n            \"per_motion_mean\": _to_averages(dict(combined_per_motion), dict(combined_per_motion_weight)),\n            \"tmr\": combined_tmr_avg,\n        }\n        if combined_gt_avg:\n            combined_row[\"per_motion_mean_gt\"] = combined_gt_avg\n        rows.append(combined_row)\n\n    tables = _build_tables(row_acc)\n    return {\n        \"folder\": str(root),\n        \"num_testcases\": len(testcase_dirs),\n        \"rows\": rows,\n        \"tables\": tables,\n    }\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(\n        description=(\"Validate testcase XXX.json result files and aggregate averages by split/category.\")\n    )\n    parser.add_argument(\n        \"folder\",\n        type=Path,\n        help=\"Testsuite root folder (contains content/ and repetition/).\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=Path,\n        default=None,\n        help=\"Optional output JSON path. Default: <folder>/summary_rows.json\",\n    )\n    parser.add_argument(\n        \"--format\",\n        choices=[\"terminal\", \"md\"],\n        default=\"terminal\",\n        dest=\"table_format\",\n        help=\"Table output format: 'terminal' (default) for fixed-width tables, 'md' for markdown.\",\n    )\n    args = parser.parse_args()\n\n    folder = args.folder.resolve()\n    if not folder.is_dir():\n        raise SystemExit(f\"Folder does not exist: {folder}\")\n\n    summary = _build_summary(folder)\n\n    out_path = args.output.resolve() if args.output else folder / \"summary_rows.json\"\n    out_path.write_text(json.dumps(summary, indent=2) + \"\\n\", encoding=\"utf-8\")\n    print(f\"Wrote aggregated summary: {out_path}\")\n    print(f\"Rows: {len(summary['rows'])}, testcases: {summary['num_testcases']}\")\n    if args.table_format == \"md\":\n        _print_formatted_gt_method_md(summary[\"tables\"])\n    else:\n        _print_formatted_gt_method(summary[\"tables\"])\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "docker-compose.yaml",
    "content": "services:\n  text-encoder:\n    build:\n      context: .\n      dockerfile: Dockerfile\n    image: kimodo:1.0\n    container_name: text-encoder\n    working_dir: /workspace\n    command: python -m kimodo.scripts.run_text_encoder_server\n    volumes:\n      - ./:/workspace\n      # Cache HF downloads in host \"system-wide\" Hugging Face cache.\n      - ${HOME}/.cache/huggingface:/workspace/.cache/huggingface\n      # Mount the host HF auth token at the standard cache location in-container.\n      - ${HOME}/.cache/huggingface/token:/workspace/.cache/huggingface/token:ro\n    # expose to your host browser\n    ports:\n      - \"9550:9550\"\n    environment:\n      # Make Gradio reachable from other containers\n      # - GRADIO_SERVER_NAME=0.0.0.0\n      # - GRADIO_SERVER_PORT=9550\n      - HF_HOME=/workspace/.cache/huggingface\n      # Host user mapping (for non-root ownership + proper shell prompt)\n      - HOST_USER=${USER:-user}\n\n      # GPU\n      - NVIDIA_VISIBLE_DEVICES=all\n      - NVIDIA_DRIVER_CAPABILITIES=compute,utility\n\n    shm_size: \"16gb\"\n    ipc: host\n\n    # Wait until Gradio responds on HTTP\n    healthcheck:\n      test:\n        [\"CMD\", \"bash\", \"-lc\", \"curl -fsS http://localhost:9550/ > /dev/null\"]\n      interval: 3s\n      timeout: 2s\n      retries: 40\n\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: all\n              capabilities: [gpu]\n\n  demo:\n    build:\n      context: .\n      dockerfile: Dockerfile\n    image: kimodo:1.0\n    container_name: demo\n    working_dir: /workspace\n    command: python -m kimodo.demo\n    volumes:\n      - ./:/workspace\n      - ${HOME}/.cache/huggingface:/workspace/.cache/huggingface\n      - ${HOME}/.cache/huggingface/token:/workspace/.cache/huggingface/token:ro\n      # Explicit checkpoint mount (avoids surprises if the repo bind mount isn't what you expect).\n      - ./checkpoints:/workspace/checkpoints:ro\n    ports:\n      - \"${SERVER_PORT:-7860}:${SERVER_PORT:-7860}\"\n    environment:\n      # Point the model at the text-encoder service.\n      - TEXT_ENCODER_URL=http://text-encoder:9550/\n      # Make checkpoint paths robust (Hydra config reads this).\n      - SERVER_PORT=${SERVER_PORT:-7860}\n      - HF_HOME=/workspace/.cache/huggingface\n      # Host user mapping (for non-root ownership + proper shell prompt)\n      - HOST_USER=${USER:-user}\n\n      # GPU\n      - NVIDIA_VISIBLE_DEVICES=all\n      - NVIDIA_DRIVER_CAPABILITIES=compute,utility\n\n    shm_size: \"16gb\"\n    ipc: host\n\n    depends_on:\n      text-encoder:\n        condition: service_healthy\n\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - driver: nvidia\n              count: all\n              capabilities: [gpu]\n"
  },
  {
    "path": "docker_requirements.in",
    "content": "#\n# Human-maintained direct dependencies (top-level).\n# Use `uv` to compile this into a fully pinned `requirements.txt` lockfile.\n#\n# IMPORTANT:\n# - We intentionally do NOT list `torch` here because the Docker image base\n#   (`nvcr.io/nvidia/pytorch`) already provides it. Installing torch via pip\n#   during image build is slow and can lead to ABI/CUDA mismatches.\n# - If you are NOT using Docker, install an appropriate PyTorch build separately.\n#\n\n# Config / wiring\nhydra-core>=1.3\nomegaconf>=2.3\n\n# Core numerics\nnumpy>=1.23,<2\nscipy>=1.10,<2\n\n# Model / embeddings\n# NOTE: `kimodo/model/llm2vec` is has only been tested with transformers==5.1.0\ntransformers==5.1.0\nurllib3>=2.6.3\nboto3\npeft>=0.12\neinops>=0.7\n\n# Misc\ntqdm>=4.0\npackaging>=21.0\npydantic>=2.0\n\n# UI / client\nfilelock>=3.20.3\ngradio>=6.8.0\ngradio_client>=1.0\n\n# Visualization\ntrimesh>=3.21.7\nscenepic>=1.1.0\npillow>=9.0\nav>=16.1.0\n\npy-soma-x @ git+https://github.com/NVlabs/SOMA-X.git\n\n# Local packages (editable installs for viser and kimodo; MotionCorrection non-editable)\n./MotionCorrection\n-e .\n-e ./kimodo-viser\n"
  },
  {
    "path": "docker_requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n# NOTE: `torch` (and its CUDA wheels) are intentionally omitted from this lockfile.\n# The Docker base image (nvcr.io/nvidia/pytorch) already provides a tested PyTorch build.\n#\n#    uv pip compile docker_requirements.in -o docker_requirements.txt --python-version 3.10 --python-platform x86_64-manylinux2014\n-e .\n    # via -r docker_requirements.in\n-e ./kimodo-viser\n    # via -r docker_requirements.in\npy-soma-x @ git+https://github.com/NVlabs/SOMA-X.git\n    # via -r docker_requirements.in\naccelerate==1.13.0\n    # via peft\naiofiles==24.1.0\n    # via gradio\nannotated-doc==0.0.4\n    # via\n    #   fastapi\n    #   typer\nannotated-types==0.7.0\n    # via pydantic\nantlr4-python3-runtime==4.9.3\n    # via\n    #   hydra-core\n    #   omegaconf\nanyio==4.12.1\n    # via\n    #   gradio\n    #   httpx\n    #   starlette\nattrs==25.4.0\n    # via\n    #   jsonschema\n    #   referencing\nav==16.1.0\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\nboto3==1.42.66\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\nbotocore==1.42.66\n    # via\n    #   boto3\n    #   s3transfer\nbrotli==1.2.0\n    # via gradio\ncertifi==2026.2.25\n    # via\n    #   httpcore\n    #   httpx\n    #   requests\ncharset-normalizer==3.4.5\n    # via\n    #   requests\n    #   trimesh\nclick==8.3.1\n    # via\n    #   typer\n    #   uvicorn\ncolorlog==6.10.1\n    # via trimesh\neinops==0.8.2\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\nembreex==2.17.7.post7\n    # via trimesh\nexceptiongroup==1.3.1\n    # via anyio\nfastapi==0.135.1\n    # via gradio\nffmpy==1.0.0\n    # via gradio\nfilelock==3.25.2\n    # via\n    #   -r docker_requirements.in\n    #   huggingface-hub\n    #   kimodo\n    #   torch\nfsspec==2026.2.0\n    # via\n    #   gradio-client\n    #   huggingface-hub\n    #   torch\ngradio==6.9.0\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\ngradio-client==2.3.0\n    # via\n    #   -r docker_requirements.in\n    #   gradio\n    #   kimodo\ngroovy==0.1.2\n    # via gradio\nh11==0.16.0\n    # via\n    #   httpcore\n    #   uvicorn\nhf-xet==1.4.0\n    # via huggingface-hub\nhttpcore==1.0.9\n    # via httpx\nhttpx==0.28.1\n    # via\n    #   gradio\n    #   gradio-client\n    #   huggingface-hub\n    #   safehttpx\n    #   trimesh\nhuggingface-hub==1.6.0\n    # via\n    #   accelerate\n    #   gradio\n    #   gradio-client\n    #   peft\n    #   tokenizers\n    #   transformers\nhydra-core==1.3.2\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\nidna==3.11\n    # via\n    #   anyio\n    #   httpx\n    #   requests\nimageio==2.37.3\n    # via viser\njinja2==3.1.6\n    # via\n    #   gradio\n    #   torch\njmespath==1.1.0\n    # via\n    #   boto3\n    #   botocore\njsonschema==4.26.0\n    # via trimesh\njsonschema-specifications==2025.9.1\n    # via jsonschema\nlxml==6.0.2\n    # via\n    #   trimesh\n    #   yourdfpy\nmanifold3d==3.4.0\n    # via trimesh\nmapbox-earcut==2.0.0\n    # via trimesh\nmarkdown-it-py==4.0.0\n    # via rich\nmarkupsafe==3.0.3\n    # via\n    #   gradio\n    #   jinja2\nmdurl==0.1.2\n    # via markdown-it-py\n./MotionCorrection\n    # via -r docker_requirements.in\nmsgspec==0.20.0\n    # via viser\nnodeenv==1.10.0\n    # via viser\nnumpy==1.26.4\n    # via\n    #   -r docker_requirements.in\n    #   accelerate\n    #   embreex\n    #   gradio\n    #   imageio\n    #   kimodo\n    #   manifold3d\n    #   mapbox-earcut\n    #   motion-correction\n    #   pandas\n    #   peft\n    #   pycollada\n    #   scenepic\n    #   scipy\n    #   shapely\n    #   transformers\n    #   trimesh\n    #   vhacdx\n    #   viser\n    #   yourdfpy\nomegaconf==2.3.0\n    # via\n    #   -r docker_requirements.in\n    #   hydra-core\n    #   kimodo\norjson==3.11.7\n    # via gradio\npackaging==26.0\n    # via\n    #   -r docker_requirements.in\n    #   accelerate\n    #   gradio\n    #   gradio-client\n    #   huggingface-hub\n    #   hydra-core\n    #   kimodo\n    #   peft\n    #   transformers\npandas==2.3.3\n    # via gradio\npeft==0.18.1\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\npillow==12.1.1\n    # via\n    #   -r docker_requirements.in\n    #   gradio\n    #   imageio\n    #   kimodo\n    #   scenepic\n    #   trimesh\npsutil==7.2.2\n    # via\n    #   accelerate\n    #   peft\npycollada==0.9.3\n    # via trimesh\npydantic==2.12.5\n    # via\n    #   -r docker_requirements.in\n    #   fastapi\n    #   gradio\n    #   kimodo\npydantic-core==2.41.5\n    # via pydantic\npydub==0.25.1\n    # via gradio\npygments==2.19.2\n    # via rich\npython-dateutil==2.9.0.post0\n    # via\n    #   botocore\n    #   pandas\n    #   pycollada\npython-multipart==0.0.22\n    # via gradio\npytz==2026.1.post1\n    # via\n    #   gradio\n    #   pandas\npyyaml==6.0.3\n    # via\n    #   accelerate\n    #   gradio\n    #   huggingface-hub\n    #   omegaconf\n    #   peft\n    #   transformers\nreferencing==0.37.0\n    # via\n    #   jsonschema\n    #   jsonschema-specifications\nregex==2026.2.28\n    # via transformers\nrequests==2.32.5\n    # via viser\nrich==14.3.3\n    # via\n    #   typer\n    #   viser\nrpds-py==0.30.0\n    # via\n    #   jsonschema\n    #   referencing\nrtree==1.4.1\n    # via trimesh\ns3transfer==0.16.0\n    # via boto3\nsafehttpx==0.1.7\n    # via gradio\nsafetensors==0.7.0\n    # via\n    #   accelerate\n    #   peft\n    #   transformers\nscenepic==1.1.2\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\nscipy==1.15.3\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\n    #   scenepic\n    #   trimesh\nsemantic-version==2.10.0\n    # via gradio\nshapely==2.1.2\n    # via trimesh\nshellingham==1.5.4\n    # via typer\nsix==1.17.0\n    # via\n    #   python-dateutil\n    #   yourdfpy\nstarlette==0.52.1\n    # via\n    #   fastapi\n    #   gradio\nsvg-path==7.0\n    # via trimesh\ntokenizers==0.22.2\n    # via transformers\ntomlkit==0.13.3\n    # via gradio\ntqdm==4.67.3\n    # via\n    #   -r docker_requirements.in\n    #   huggingface-hub\n    #   kimodo\n    #   peft\n    #   transformers\n    #   viser\ntransformers==5.1.0\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\n    #   peft\ntrimesh==4.11.3\n    # via\n    #   -r docker_requirements.in\n    #   kimodo\n    #   viser\n    #   yourdfpy\ntyper==0.24.1\n    # via\n    #   gradio\n    #   huggingface-hub\n    #   typer-slim\ntyper-slim==0.24.0\n    # via transformers\ntyping-extensions==4.15.0\n    # via\n    #   anyio\n    #   exceptiongroup\n    #   fastapi\n    #   gradio\n    #   gradio-client\n    #   huggingface-hub\n    #   pydantic\n    #   pydantic-core\n    #   referencing\n    #   starlette\n    #   torch\n    #   typing-inspection\n    #   uvicorn\n    #   viser\ntyping-inspection==0.4.2\n    # via\n    #   fastapi\n    #   pydantic\ntzdata==2025.3\n    # via pandas\nurllib3==2.6.3\n    # via\n    #   -r docker_requirements.in\n    #   botocore\n    #   kimodo\n    #   requests\nuvicorn==0.41.0\n    # via gradio\nvhacdx==0.0.10\n    # via trimesh\nwebsockets==15.0.1\n    # via viser\nxxhash==3.6.0\n    # via trimesh\nyourdfpy==0.0.60\n    # via viser\n"
  },
  {
    "path": "docs/.gitattributes",
    "content": "source/_static/quick_tour.mp4 filter=lfs diff=lfs merge=lfs -text\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n.PHONY: help Makefile apidoc\n\n# Catch-all target: route all unknown targets to Sphinx\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS)\n\napidoc:\n\t@$(SPHINXBUILD) -b html -q \"$(SOURCEDIR)\" \"$(BUILDDIR)\" >/dev/null 2>&1 || true\n\t@sphinx-apidoc -o \"$(SOURCEDIR)/api_reference/_generated\" -t \"$(SOURCEDIR)/_templates/apidoc\" ../kimodo ../kimodo/**/tests* ../kimodo/**/test* -f\n"
  },
  {
    "path": "docs/README.md",
    "content": "# Documentation\n\n## Local build\n\nInstall doc dependencies:\n\n```bash\npip install -r docs/requirements.txt\n```\n\nBuild HTML:\n\n```bash\ncd docs\nmake html\n```\n\nOpen the output at `docs/build/html/index.html`.\n\n## API reference generation\n\nGenerate API stubs from the Python packages:\n\n```bash\ncd docs\nmake apidoc\nmake html\n```\n\nNote: generated stubs are written to `docs/source/api_reference/_generated` and are not\nincluded in the default navigation. Add them to a toctree if you want to expose them.\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nset SPHINXOPTS=\nset SPHINXBUILD=sphinx-build\nset SOURCEDIR=source\nset BUILDDIR=build\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx>=7.0,<9.0\nnvidia-sphinx-theme\nsphinx-copybutton\nmyst-parser\nsphinx-design\n"
  },
  {
    "path": "docs/source/_static/custom.css",
    "content": ".hero {\n  padding: 2.5rem 2rem;\n  border-radius: 12px;\n  background: linear-gradient(135deg, #0f1a0c 0%, #1c2b16 55%, #76b900 100%);\n  color: #f8f9fb;\n  margin: 1.5rem 0 2rem 0;\n}\n\n.hero-title {\n  font-size: 2.2rem;\n  margin: 0 0 0.6rem 0;\n}\n\n.hero-subtitle {\n  font-size: 1.1rem;\n  margin: 0 0 1.2rem 0;\n  opacity: 0.9;\n}\n\n.hero-actions a {\n  display: inline-block;\n  margin-right: 0.8rem;\n  padding: 0.5rem 0.9rem;\n  border-radius: 6px;\n  background: #76b900;\n  color: #0f1a0c;\n  text-decoration: none;\n  font-weight: 600;\n}\n\n.hero-actions a.secondary {\n  background: transparent;\n  color: #f8f9fb;\n  border: 1px solid #f8f9fb;\n}\n\n.card-grid {\n  display: grid;\n  gap: 1rem;\n  grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));\n  margin: 1.5rem 0 2rem 0;\n}\n\n.card {\n  border: 1px solid rgba(0, 0, 0, 0.08);\n  border-radius: 10px;\n  padding: 1rem 1.2rem;\n  background: #ffffff;\n}\n\n.card h3 {\n  margin-top: 0;\n  margin-bottom: 0.4rem;\n}\n\n.card p {\n  margin: 0;\n  color: #3c4758;\n}\n\n.quick-links {\n  display: flex;\n  flex-wrap: wrap;\n  gap: 0.8rem;\n  margin: 1rem 0 2rem 0;\n}\n\n.quick-links a {\n  display: inline-block;\n  padding: 0.4rem 0.8rem;\n  border-radius: 999px;\n  background: #edf2f7;\n  color: #1a202c;\n  text-decoration: none;\n  font-weight: 600;\n}\n"
  },
  {
    "path": "docs/source/_templates/apidoc/module.rst.jinja",
    "content": "{%- if show_headings %}\n{{- [basename, \"module\"] | join(' ') | e | heading }}\n\n{% endif -%}\n.. automodule:: {{ qualname }}\n{%- set preferred_order = ['members', 'undoc-members', 'show-inheritance'] %}\n{%- for option in preferred_order %}\n{%- if option in automodule_options %}\n   :{{ option }}:\n{%- endif %}\n{%- endfor %}\n{%- for option in automodule_options %}\n{%- if option not in preferred_order %}\n   :{{ option }}:\n{%- endif %}\n{%- endfor %}\n"
  },
  {
    "path": "docs/source/_templates/apidoc/package.rst.jinja",
    "content": "{%- set preferred_order = ['members', 'undoc-members', 'show-inheritance'] %}\n{%- macro automodule(modname, options) -%}\n.. automodule:: {{ modname }}\n{%- for option in preferred_order %}\n{%- if option in options %}\n   :{{ option }}:\n{%- endif %}\n{%- endfor %}\n{%- for option in options %}\n{%- if option not in preferred_order %}\n   :{{ option }}:\n{%- endif %}\n{%- endfor %}\n{%- endmacro %}\n\n{%- macro toctree(docnames) -%}\n.. toctree::\n   :maxdepth: {{ maxdepth }}\n{% for docname in docnames %}\n   {{ docname }}\n{%- endfor %}\n{%- endmacro %}\n\n{%- if is_namespace %}\n{{- [pkgname, \"namespace\"] | join(\" \") | e | heading }}\n{% else %}\n{{- [pkgname, \"package\"] | join(\" \") | e | heading }}\n{% endif %}\n\n{%- if is_namespace %}\n.. py:module:: {{ pkgname }}\n{% endif %}\n\n{%- if modulefirst and not is_namespace %}\n{{ automodule(pkgname, automodule_options) }}\n{% endif %}\n\n{%- if subpackages %}\nSubpackages\n-----------\n\n{{ toctree(subpackages) }}\n{% endif %}\n\n{%- if submodules %}\nSubmodules\n----------\n{% if separatemodules %}\n{{ toctree(submodules) }}\n{% else %}\n{%- for submodule in submodules %}\n{% if show_headings %}\n{{- [submodule, \"module\"] | join(\" \") | e | heading(2) }}\n{% endif %}\n{{ automodule(submodule, automodule_options) }}\n{% endfor %}\n{%- endif %}\n{%- endif %}\n\n{%- if not modulefirst and not is_namespace %}\nModule contents\n---------------\n\n{{ automodule(pkgname, automodule_options) }}\n{% endif %}\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.demo.rst",
    "content": "kimodo.demo package\n===================\n\nSubmodules\n----------\n\nkimodo.demo.app module\n----------------------\n\n.. automodule:: kimodo.demo.app\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.config module\n-------------------------\n\n.. automodule:: kimodo.demo.config\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.embedding\\_cache module\n-----------------------------------\n\n.. automodule:: kimodo.demo.embedding_cache\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.generation module\n-----------------------------\n\n.. automodule:: kimodo.demo.generation\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.queue\\_manager module\n---------------------------------\n\n.. automodule:: kimodo.demo.queue_manager\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.state module\n------------------------\n\n.. automodule:: kimodo.demo.state\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.demo.ui module\n---------------------\n\n.. automodule:: kimodo.demo.ui\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.demo\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.exports.rst",
    "content": "kimodo.exports package\n======================\n\nSubmodules\n----------\n\nkimodo.exports.bvh module\n-------------------------\n\n.. automodule:: kimodo.exports.bvh\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.exports.motion\\_convert\\_lib module\n------------------------------------------\n\n.. automodule:: kimodo.exports.motion_convert_lib\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.exports.motion\\_formats module\n-------------------------------------\n\n.. automodule:: kimodo.exports.motion_formats\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.exports.motion\\_io module\n--------------------------------\n\n.. automodule:: kimodo.exports.motion_io\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.exports.mujoco module\n----------------------------\n\n.. automodule:: kimodo.exports.mujoco\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.exports.smplx module\n---------------------------\n\n.. automodule:: kimodo.exports.smplx\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.exports\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.metrics.rst",
    "content": "kimodo.metrics package\n======================\n\nSubmodules\n----------\n\nkimodo.metrics.base module\n--------------------------\n\n.. automodule:: kimodo.metrics.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.metrics.constraints module\n---------------------------------\n\n.. automodule:: kimodo.metrics.constraints\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.metrics.foot\\_skate module\n---------------------------------\n\n.. automodule:: kimodo.metrics.foot_skate\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.metrics.tmr module\n-------------------------\n\n.. automodule:: kimodo.metrics.tmr\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.metrics\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.model.llm2vec.models.rst",
    "content": "kimodo.model.llm2vec.models package\n===================================\n\nSubmodules\n----------\n\nkimodo.model.llm2vec.models.attn\\_mask\\_utils module\n----------------------------------------------------\n\n.. automodule:: kimodo.model.llm2vec.models.attn_mask_utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.llm2vec.models.bidirectional\\_llama module\n-------------------------------------------------------\n\n.. automodule:: kimodo.model.llm2vec.models.bidirectional_llama\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.llm2vec.models.utils module\n----------------------------------------\n\n.. automodule:: kimodo.model.llm2vec.models.utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.model.llm2vec.models\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.model.llm2vec.rst",
    "content": "kimodo.model.llm2vec package\n============================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 4\n\n   kimodo.model.llm2vec.models\n\nSubmodules\n----------\n\nkimodo.model.llm2vec.llm2vec module\n-----------------------------------\n\n.. automodule:: kimodo.model.llm2vec.llm2vec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.llm2vec.llm2vec\\_wrapper module\n--------------------------------------------\n\n.. automodule:: kimodo.model.llm2vec.llm2vec_wrapper\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.model.llm2vec\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.model.rst",
    "content": "kimodo.model package\n====================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 4\n\n   kimodo.model.llm2vec\n\nSubmodules\n----------\n\nkimodo.model.backbone module\n----------------------------\n\n.. automodule:: kimodo.model.backbone\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.cfg module\n-----------------------\n\n.. automodule:: kimodo.model.cfg\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.common module\n--------------------------\n\n.. automodule:: kimodo.model.common\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.diffusion module\n-----------------------------\n\n.. automodule:: kimodo.model.diffusion\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.kimodo\\_model module\n---------------------------------\n\n.. automodule:: kimodo.model.kimodo_model\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.load\\_model module\n-------------------------------\n\n.. automodule:: kimodo.model.load_model\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.loading module\n---------------------------\n\n.. automodule:: kimodo.model.loading\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.registry module\n----------------------------\n\n.. automodule:: kimodo.model.registry\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.text\\_encoder\\_api module\n--------------------------------------\n\n.. automodule:: kimodo.model.text_encoder_api\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.tmr module\n-----------------------\n\n.. automodule:: kimodo.model.tmr\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.model.twostage\\_denoiser module\n--------------------------------------\n\n.. automodule:: kimodo.model.twostage_denoiser\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.model\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.motion_rep.reps.rst",
    "content": "kimodo.motion\\_rep.reps package\n===============================\n\nSubmodules\n----------\n\nkimodo.motion\\_rep.reps.base module\n-----------------------------------\n\n.. automodule:: kimodo.motion_rep.reps.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.reps.kimodo\\_motionrep module\n------------------------------------------------\n\n.. automodule:: kimodo.motion_rep.reps.kimodo_motionrep\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.reps.tmr\\_motionrep module\n---------------------------------------------\n\n.. automodule:: kimodo.motion_rep.reps.tmr_motionrep\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.motion_rep.reps\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.motion_rep.rst",
    "content": "kimodo.motion\\_rep package\n==========================\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 4\n\n   kimodo.motion_rep.reps\n\nSubmodules\n----------\n\nkimodo.motion\\_rep.conditioning module\n--------------------------------------\n\n.. automodule:: kimodo.motion_rep.conditioning\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.feature\\_utils module\n----------------------------------------\n\n.. automodule:: kimodo.motion_rep.feature_utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.feet module\n------------------------------\n\n.. automodule:: kimodo.motion_rep.feet\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.smooth\\_root module\n--------------------------------------\n\n.. automodule:: kimodo.motion_rep.smooth_root\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.motion\\_rep.stats module\n-------------------------------\n\n.. automodule:: kimodo.motion_rep.stats\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.motion_rep\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.rst",
    "content": "kimodo package\n==============\n\nSubpackages\n-----------\n\n.. toctree::\n   :maxdepth: 4\n\n   kimodo.demo\n   kimodo.exports\n   kimodo.metrics\n   kimodo.model\n   kimodo.motion_rep\n   kimodo.scripts\n   kimodo.skeleton\n   kimodo.viz\n\nSubmodules\n----------\n\nkimodo.assets module\n--------------------\n\n.. automodule:: kimodo.assets\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.constraints module\n-------------------------\n\n.. automodule:: kimodo.constraints\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.geometry module\n----------------------\n\n.. automodule:: kimodo.geometry\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.meta module\n------------------\n\n.. automodule:: kimodo.meta\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.postprocess module\n-------------------------\n\n.. automodule:: kimodo.postprocess\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.sanitize module\n----------------------\n\n.. automodule:: kimodo.sanitize\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.tools module\n-------------------\n\n.. automodule:: kimodo.tools\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.scripts.rst",
    "content": "kimodo.scripts package\n======================\n\nSubmodules\n----------\n\nkimodo.scripts.generate module\n------------------------------\n\n.. automodule:: kimodo.scripts.generate\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.scripts.gradio\\_theme module\n-----------------------------------\n\n.. automodule:: kimodo.scripts.gradio_theme\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.scripts.lock\\_requirements module\n----------------------------------------\n\n.. automodule:: kimodo.scripts.lock_requirements\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.scripts.motion\\_convert module\n-------------------------------------\n\n.. automodule:: kimodo.scripts.motion_convert\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.scripts.mujoco\\_load module\n----------------------------------\n\n.. automodule:: kimodo.scripts.mujoco_load\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.scripts.run\\_text\\_encoder\\_server module\n------------------------------------------------\n\n.. automodule:: kimodo.scripts.run_text_encoder_server\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.scripts\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.skeleton.rst",
    "content": "kimodo.skeleton package\n=======================\n\nSubmodules\n----------\n\nkimodo.skeleton.base module\n---------------------------\n\n.. automodule:: kimodo.skeleton.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.skeleton.bvh module\n--------------------------\n\n.. automodule:: kimodo.skeleton.bvh\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.skeleton.definitions module\n----------------------------------\n\n.. automodule:: kimodo.skeleton.definitions\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.skeleton.kinematics module\n---------------------------------\n\n.. automodule:: kimodo.skeleton.kinematics\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.skeleton.registry module\n-------------------------------\n\n.. automodule:: kimodo.skeleton.registry\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.skeleton.transforms module\n---------------------------------\n\n.. automodule:: kimodo.skeleton.transforms\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.skeleton\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/kimodo.viz.rst",
    "content": "kimodo.viz package\n==================\n\nSubmodules\n----------\n\nkimodo.viz.constraint\\_ui module\n--------------------------------\n\n.. automodule:: kimodo.viz.constraint_ui\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.coords module\n------------------------\n\n.. automodule:: kimodo.viz.coords\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.g1\\_rig module\n-------------------------\n\n.. automodule:: kimodo.viz.g1_rig\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.gui module\n---------------------\n\n.. automodule:: kimodo.viz.gui\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.playback module\n--------------------------\n\n.. automodule:: kimodo.viz.playback\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.scene module\n-----------------------\n\n.. automodule:: kimodo.viz.scene\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.smplx\\_skin module\n-----------------------------\n\n.. automodule:: kimodo.viz.smplx_skin\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.soma\\_layer\\_skin module\n-----------------------------------\n\n.. automodule:: kimodo.viz.soma_layer_skin\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.soma\\_skin module\n----------------------------\n\n.. automodule:: kimodo.viz.soma_skin\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nkimodo.viz.viser\\_utils module\n------------------------------\n\n.. automodule:: kimodo.viz.viser_utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nModule contents\n---------------\n\n.. automodule:: kimodo.viz\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/_generated/modules.rst",
    "content": "kimodo\n======\n\n.. toctree::\n   :maxdepth: 4\n\n   kimodo\n"
  },
  {
    "path": "docs/source/api_reference/constraints.rst",
    "content": "Constraints\n===========\n\nConstraint definitions and utilities.\n\n.. automodule:: kimodo.constraints\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/exports.rst",
    "content": "Exports\n=======\n\nExport utilities for common formats.\n\n.. automodule:: kimodo.exports.bvh\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.exports.mujoco\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.exports.smplx\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/index.rst",
    "content": "API Reference\n=============\n\nThis section contains the API documentation for Kimodo, organized by domain.\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Core Modules\n\n   model\n   motion_rep\n   constraints\n   exports\n   viz\n   utilities\n   post-processing\n"
  },
  {
    "path": "docs/source/api_reference/model.rst",
    "content": "Model\n=====\n\nCore model architecture, diffusion logic, and text encoders.\n\n\nKimodo Model\n------------\n\n.. automodule:: kimodo.model.kimodo_model\n   :members:\n   :undoc-members:\n   :special-members: __call__\n\nDenoiser and Backbone\n---------------------\n.. automodule:: kimodo.model.twostage_denoiser\n   :members:\n   :undoc-members:\n\n.. automodule:: kimodo.model.backbone\n   :members:\n   :undoc-members:\n\nClassifier-Free Guidance\n------------------------\n\n.. automodule:: kimodo.model.cfg\n   :members:\n   :undoc-members:\n\nModel Loading\n-------------\n\n.. automodule:: kimodo.model.load_model\n   :members:\n   :undoc-members:\n\nText Encoder\n------------\n\n.. automodule:: kimodo.model.text_encoder_api\n   :members:\n   :undoc-members:\n   :special-members: __call__\n"
  },
  {
    "path": "docs/source/api_reference/motion_rep.rst",
    "content": "Motion Representation\n=====================\n\nMotion representation utilities and kinematics helpers.\n\nSkeleton\n--------\n\n.. automodule:: kimodo.skeleton\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nForward Kinematics\n------------------\n\n.. automodule:: kimodo.skeleton.kinematics\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nMotion Representations\n----------------------\n\n.. automodule:: kimodo.motion_rep.reps.base\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.motion_rep.reps.kimodo_motionrep\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.motion_rep.reps.tmr_motionrep\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\nUtilities\n---------\n\n.. automodule:: kimodo.motion_rep.feet\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.motion_rep.stats\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.motion_rep.smooth_root\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/post-processing.rst",
    "content": "Post-Processing Bindings\n========================\n\n.. automodule:: kimodo.postprocess\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/utilities.rst",
    "content": "Utilities\n=========\n\nGeneral utilities used across the codebase.\n\n.. automodule:: kimodo.tools\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.geometry\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.sanitize\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/api_reference/viz.rst",
    "content": "Visualization\n=============\n\nVisualization helpers for rendering skeletons and meshes.\n\n.. automodule:: kimodo.viz.g1_rig\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.viz.smplx_skin\n   :members:\n   :undoc-members:\n   :show-inheritance:\n\n.. automodule:: kimodo.viz.viser_utils\n   :members:\n   :undoc-members:\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/benchmark/introduction.md",
    "content": "# Benchmark Introduction\n\nWe provide a benchmark to evaluate text-to-motion and constrained motion generation on a shared test suite.\nFor reproducibility, all test content is stored on disk as folders and files, so anyone can run exactly the same cases.\nThe benchmark test suite is available to download from HuggingFace at [`nvidia/Kimodo-Motion-Gen-Benchmark`](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark) and is currently set up for use with models trained on the [SOMA](https://github.com/NVlabs/SOMA-X) body skeleton.\n\nThe benchmark contains text prompts, durations, and constraint configurations for a variety of test cases, but **not** the ground-truth motion data itself. The ground-truth motions are derived from the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed), which has its own license you should consider. So to construct the full benchmark motions, you must download the BONES-SEED dataset separately and run our `create_benchmark` script to populate the test suite with ground-truth motions. \n\nConstructing the benchmark with `create_benchmark` is the first step in the full [Evaluation Pipeline](pipeline.md), which is described in detail on the next page. In addition to the benchmark test cases, we provide code to run generation with Kimodo and compute a variety of [metrics](metrics.md) measuring motion quality, text alignment, and constraint following. While this open-sourced public test suite is not the exact same used in the [Kimodo tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf) (Sec. 6.1), the evaluation metrics are the same and evaluation methodology is similar.\n\nOn this page, we describe the overall structure of the test suite and details of the different test cases. Then in subsequent pages, we describe how to run the full [evaluation pipeline](pipeline.md), detail the [metrics](metrics.md), and finally provide the [results](results.md) of Kimodo-SOMA-RP and Kimodo-SOMA-SEED on the benchmark.\n\n## Dataset Splits\nTo evaluate a model on the benchmark, it should be trained with the [provided splits](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark/tree/main/splits) for the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed).\n\nThe different splits are defined in:\n\n- `train_split_paths.txt` - filenames of training data\n- `test_content_split_paths.txt` - filenames for test split containing new semantic \"content\". This split contains motions with `content_name` (from the BONES-SEED metadata) that are not seen in the training split. This tests model generalization to new semantic motion types, e.g. for text-to-motion generalization.\n- `test_repetition_split_paths.txt` - filenames for test split containing new motions from content that was seen in training. This split contains motions where the `content_name` is contained in the training split, but the exact motion itself was not seen. This tests a model's ability to generalize to novel performances of a familiar motion type, e.g., for constraint-following generalization.\n\nThe training split should be used for training, while the two test splits (`content` and `repetition`) are used in the test suite, as described below. Note that the test cases in the benchmark do not cover the entire content and repetition test splits, instead we strategically sample a subset that maximizes content diversity.\n\n## Test Suite Structure\n\nThe full test suite contains 22,474 test cases spanning text and constraint-conditioned motion generation. \nThe suite is organized hierarchically to logically group together test cases, so the evaluation pipeline can be run on a subset of the benchmark instead of the full thing, if desired.\n\nAfter the benchmark has been constructed and motions generated for the model to evaluate, a **test case** is a single folder containing:\n\n- `meta.json` (**required**): text prompt(s) and duration(s),\n- `constraints.json` (**optional**): constraints for controlled generation, using the [constraints format](../user_guide/constraints.md),\n- `gt_motion.npz` (**optional**): ground-truth/reference motion, using the [NPZ output format](../user_guide/output_formats.md),\n- `motion.npz` (**optional**): output of the model given the `meta.json` prompt/duration and optional `constraints.json`, using the same [NPZ output format](../user_guide/output_formats.md).\n\nIn addition to being used in the evaluation pipeline, each test case can be:\n\n- loaded in the interactive demo through **Load Example** for visualization,\n- loaded in `kimodo_gen` with `--input_folder` for generation from folder-defined inputs.\n\n### Benchmark Folder Hierarchy\n\nThe full suite is organized as follows:\n\n```text\ntestsuite\n├── content\n│   ├── constraints_notext\n│   │   ├── end-effectors\n│   │   ├── fullbody\n│   │   ├── mixture\n│   │   └── root\n│   ├── constraints_withtext\n│   │   ├── end-effectors\n│   │   ├── fullbody\n│   │   ├── mixture\n│   │   └── root\n│   └── text2motion\n│       ├── overview\n│       ├── timeline_multi\n│       └── timeline_single\n└── repetition\n    ├── constraints_notext\n    │   ├── end-effectors\n    │   ├── fullbody\n    │   ├── mixture\n    │   └── root\n    ├── constraints_withtext\n    │   ├── end-effectors\n    │   ├── fullbody\n    │   ├── mixture\n    │   └── root\n    └── text2motion\n        ├── overview\n        ├── timeline_multi\n        └── timeline_single\n```\n\nAt the highest level, the test suite is organized by the test split used. As discussed previously, `content` refers to the test split with held out semantic categories of motion, while `repetition` refers to held out motions from semantic categories seen during training. \n\nWithin each test split, test cases are organized into:\n\n* `text2motion`: test cases with only text prompts as input (no constraints)\n* `constraints_notext`:  test cases with only constraints as input (no text prompt)\n* `constraints_withtext`: test cases with both prompt and constraints as input\n\n### Text2Motion Test Cases\n\nThese test cases are pure text-to-motion with no constraints as input. `text2motion` test cases exclusively use prompts derived from our [SEED timeline annotations](https://huggingface.co/datasets/nvidia/SEED-Timeline-Annotations). It contains three types of test cases:\n\n* `overview`: medium-detail prompt that describes a full motion. Corresponds to `overview_description` in the [NVIDIA SEED timelines](https://huggingface.co/datasets/nvidia/SEED-Timeline-Annotations) or equivalently `content_natural_desc_4` in the [BONES SEED](https://huggingface.co/datasets/bones-studio/seed) metadata.\n* `timeline_single`: fine-grained prompt describing a single segment of a timeline annotation. Corresponds to a single event in a SEED timeline.\n* `timeline_multi`: fine-grained prompt describing multiple subsequent segments of a timeline annotation. Corresponds to multiple contiguous events in a SEED timeline, which have been concatenated with an LLM to get a single natural text description.\n\n### Constrained Test Cases\n\nConstrained test cases provide a constraint input either without a text prompt (i.e., `constraints_notext`) or with an `overview` text prompt (i.e., `constraints_withtext`). The different types of constraint categories mirror the [constraint types support by Kimodo](../key_concepts/constraints.md) and include:\n\n* `fullbody`: constrains all joint positions in the skeleton at specific frames\n* `end-effectors`: constraints the position and rotations of hand and/or feet joints at specific frames\n* `root`: constraints the 2D root position and optionally heading on a path or at specific frames\n* `mixture`: evaluates compositional control when multiple constraint families are combined\n\nWithin each constraint type in the hierarchy are multiple subtypes that vary the constraint sparsity patterns (either in time or in space). So the hierarchy of a `constraint` folder is:\n\n```text\nconstraints_XX\n├── end-effectors\n│   ├── feet_posrot          # feet only constraints\n│   ├── hands_feet_posrot    # hands + feet constraints\n│   └── hands_posrot         # hands only constraints\n├── fullbody\n│   ├── inbetweening         # constraints at start and end only\n│   └── random               # constraints at random frames\n├── mixture\n│   ├── root_ee_hands_feet_posrot_fullbody    # mix of (1) root trajectory, (2) hand + foot, and (3) full-body \n│   ├── root_ee_hands_posrot                  # mix of (1) root keyframe, and (2) hands\n│   ├── root_ee_hands_posrot_fullbody         # mix of (1) root keyframe, (2) hands, and (3) full-body\n│   └── root_path_fullbody                    # mix of (1) root trajectory, and (2) full-body\n└── root\n    ├── path_2dpos             # root trajectory position\n    ├── path_2dposrot          # root trajecotry position + heading\n    ├── waypoint_2dpos         # root waypoint position\n    └── waypoint_2dposrot      # root waypoint position + heading\n```\n\n### Indexed Test Cases in Leaf Folders\n\nEach leaf folder contains indexed test cases (`0000`, `0001`, `0002`, ...).\nFor example:\n\n```text\nend-effectors/feet_posrot/\n├── 0000/\n├── 0001/\n├── 0002/\n...\n└── 0255/\n```\n\nEach index folder is one standalone test case with its own `meta.json`, optional `constraints.json`, optional `gt_motion.npz`, and optional `motion.npz`.\n"
  },
  {
    "path": "docs/source/benchmark/metrics.md",
    "content": "# Metrics\n\nThe benchmark evaluates generated motion along three axes:\n\n- **Motion quality** -- foot-skate and contact-consistency metrics,\n- **Constraint following** -- position error for root, end-effector, and full-body constraints,\n- **Text alignment** -- TMR retrieval and distributional metrics.\n\nMetrics are implemented in `kimodo/metrics/` and orchestrated by `benchmark/evaluate_folder.py`.\nThe protocol is aligned with the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf) (Sec. 6.1, \"Evaluation Metrics\").\n\n## Evaluation Protocol\n\nThe evaluation pipeline runs two passes over each group of test cases:\n\n1. **Generated pass** -- evaluates `motion.npz` with all metrics (foot skate, contact consistency, constraint following) and, when TMR embeddings are available, computes retrieval and FID scores.\n2. **Ground-truth pass** -- evaluates `gt_motion.npz` with the same motion-quality and constraint metrics. TMR retrieval metrics are not recomputed in this pass.\n\nRunning both passes enables side-by-side comparison: the GT row serves as an empirical upper bound for motion quality, and deviations between GT and generated metrics highlight where the model can improve. See [Evaluation pipeline](pipeline.md) for the full workflow.\n\n## Metrics Reference\n\nThe table below lists every key written to `metrics.json`. Detailed descriptions follow in subsequent sections.\n\n| Key | Category | Unit | Direction |\n| --- | --- | --- | --- |\n| `foot_skate_from_height` | Motion quality | m/s | Lower is better |\n| `foot_skate_from_pred_contacts` | Motion quality | m/s | Lower is better |\n| `foot_skate_max_vel` | Motion quality | m/s | Lower is better |\n| `foot_skate_ratio` | Motion quality | ratio (0--1) | Lower is better |\n| `foot_contact_consistency` | Motion quality | ratio (0--1) | Higher is better |\n| `constraint_root2d_err` | Constraint follow | m | Lower is better |\n| `constraint_root2d_err_p95` | Constraint follow | m | Lower is better |\n| `constraint_root2d_acc` | Constraint follow | ratio (0--1) | Higher is better |\n| `constraint_fullbody_keyframe` | Constraint follow | m | Lower is better |\n| `constraint_end_effector` | Constraint follow | m | Lower is better |\n| `TMR/t2m_sim` | Text alignment | score (0--1) | Higher is better |\n| `TMR/t2m_R/R01` ... `R10` | Text alignment | % | Higher is better |\n| `TMR/t2m_R/MedR` | Text alignment | rank | Lower is better |\n| `TMR/FID/gen_text` | Text alignment | distance | Lower is better |\n| `TMR/FID/gen_gt` | Text alignment | distance | Lower is better |\n| `TMR/FID/gt_text` | Text alignment | distance | Lower is better |\n| `TMR/m2m_sim` | Text alignment | score (0--1) | Higher is better |\n| `TMR/t2m_gt_sim` | Text alignment | score (0--1) | Higher is better |\n| `TMR/m2m_R/R01` ... `R10` | Text alignment | % | Higher is better |\n| `TMR/t2m_gt_R/R01` ... `R10` | Text alignment | % | Higher is better |\n\n:::{note}\nRaw metric values are stored in SI units (meters for positions, m/s for velocities).\nThe summary tables printed by `benchmark/parse_folder.py` convert constraint position errors to **cm** and foot-skate velocities to **cm/s** for readability.\n:::\n\n### Foot Skating Metrics\n\nFoot skating measures how much a foot slides along the ground when it should be in static contact with the ground. Four complementary metrics capture different aspects of this artifact.\n\n- **`foot_skate_from_height`** (m/s, lower is better):\n  Mean velocity of the **toe joints** (left toe, right toe) on frames where the toe height is below a floor threshold (`height_thresh = 0.05 m`).\n  This metric does not rely on predicted contact labels -- it uses a geometric criterion (Y-coordinate < threshold) to identify ground-contact frames.\n\n- **`foot_skate_from_pred_contacts`** (m/s, lower is better):\n  Mean velocity of all **four foot joints** (left/right heel and toe) on frames where the model predicts contact via the `foot_contacts` output.\n  Unlike `foot_skate_from_height`, this metric trusts the model's own contact predictions and measures all four foot joints rather than toes only.\n\n- **`foot_skate_max_vel`** (m/s, lower is better):\n  Maximum velocity across all four foot joints and all time steps where predicted contact is active.\n  This captures worst-case slip spikes that mean-based metrics can hide.\n\n- **`foot_skate_ratio`** (ratio 0--1, lower is better):\n  Fraction of ground-contact frames where toe velocity exceeds a threshold (`vel_thresh = 0.2 m/s`). A frame counts as ground contact when the toe is below `height_thresh = 0.05 m` on both the current and the next frame. Inspired by the [GMD](https://github.com/korrawe/guided-motion-diffusion) skating metric.\n\n### Contact Consistency Metric\n\n- **`foot_contact_consistency`** (ratio 0--1, higher is better):\n  Agreement between the model's predicted foot contacts and a heuristic contact detector based on joint height and velocity (`vel_thresh = 0.15 m/s`, `height_thresh = 0.10 m`).\n  Computed as accuracy (`1 - incorrect_ratio`) over all time steps and four contact channels.\n  A score of 1.0 means perfect agreement between predicted and heuristic contacts.\n  As noted in the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf), this metric provides important context for interpreting the contact-based foot-skate metrics above: if contact consistency is low, `foot_skate_from_pred_contacts` may be unreliable.\n\n### Constraint-Following Metrics\n\nConstraint metrics are computed only when the test case includes a `constraints.json` file. The `ContraintFollow` metric class dispatches by [constraint type](../key_concepts/constraints.md):\n\n- **`constraint_end_effector`** (m, lower is better):\n  Mean Euclidean distance between target end-effector positions and generated joint positions at the constrained frames.\n  Only position-constrained joints are evaluated (rotation targets are not measured by this metric).\n\n- **`constraint_fullbody_keyframe`** (m, lower is better):\n  Mean per-joint Euclidean distance between target and generated full-body joint positions at keyframes.\n  The error is averaged over all joints and all keyframe frames.\n\n- **`constraint_root2d_err`** (m, lower is better):\n  Mean 2D Euclidean distance (in the XZ ground plane) between target and generated root positions at constrained frames.\n\n- **`constraint_root2d_err_p95`** (m, lower is better):\n  95th percentile of the per-frame root 2D error across all samples in a group.\n  Computed during aggregation by `evaluate_folder.py` to capture tail-end failures that the mean can mask.\n\n- **`constraint_root2d_acc`** (ratio 0--1, higher is better):\n  Fraction of constrained root frames where the 2D position error is within a distance threshold (`root_threshold = 0.10 m`).\n\n### TMR-Based Metrics\n\nText alignment is evaluated using [TMR](https://mathis.petrovich.fr/tmr/) (Text-to-Motion Retrieval), a separate encoder model that maps both text and motion into a shared embedding space. TMR is not used for generation -- it is loaded only for evaluation (see `kimodo/model/tmr.py`).\n\nWe release a version of TMR retrained on the full Rigplay dataset as [`TMR-SOMA-RP-v1`](https://huggingface.co/nvidia/TMR-SOMA-RP-v1). The original TMR was trained on HumanML3D; our retrained variant uses the same architecture but is trained on the Rigplay motion dataset, SOMA skeleton, and with [LLM2Vec](https://github.com/McGill-NLP/llm2vec) text embeddings.\n\n#### Similarity Scores\n\nTMR encodes each text prompt and each motion clip into a unit-length embedding vector. Cosine similarity between text and motion embeddings is rescaled to a [0, 1] range:\n\n```\nscore = cosine_similarity / 2 + 0.5\n```\n\nThree per-test-case similarity scores are recorded:\n\n- **`TMR/t2m_sim`** (0--1, higher is better): similarity between the text prompt and the generated motion.\n- **`TMR/m2m_sim`** (0--1, higher is better): similarity between the generated and ground-truth motions (only when GT is available).\n- **`TMR/t2m_gt_sim`** (0--1, higher is better): similarity between the text prompt and the GT motion (only when GT is available).\n\n#### R-precision (Retrieval Accuracy)\n\nR-precision measures whether the correct motion can be retrieved from a pool given its corresponding text query.\nFor each text query in the evaluation group, all motions are ranked by TMR similarity.\nR@k is the percentage of queries where the correct motion appears in the top k results.\n\nReported keys: `TMR/t2m_R/R01`, `R02`, `R03`, `R05`, `R10` (%), and `TMR/t2m_R/MedR` (median rank, lower is better) correspond to retrieval accuracy when using generated motions.\n\nWhen ground-truth motions are available, analogous retrieval metrics are computed for motion-to-GT-motion (`TMR/m2m_R/...`) and text-to-GT-motion (`TMR/t2m_gt_R/...`).\n\n:::{note}\nNear-duplicate text prompts can artificially penalize retrieval ranking. The evaluation handles this by grouping prompts whose text-text similarity exceeds a threshold of 0.99 and treating any motion in that group as a valid match.\n:::\n\n#### FID (Frechet Inception Distance)\n\nFID measures distributional distance between two sets of TMR embeddings by fitting a multivariate Gaussian to each set and computing the Frechet distance. Three FID variants are reported:\n\n- **`TMR/FID/gen_gt`**: distance between generated-motion and GT-motion embeddings (only when GT is available). This is the FID metric that is typically reported in the motion generation literature.\n- **`TMR/FID/gen_text`**: distance between generated-motion embeddings and text embeddings. \n- **`TMR/FID/gt_text`**: distance between GT-motion and text embeddings (only when GT is available).\n\nLower values indicate that the two distributions are more similar. FID requires at least 2 samples; groups with fewer samples report `NaN`.\n\n#### Per-Test-Case Retrieval\n\nIn addition to the aggregate metrics above, each test case's `metrics.json` includes a `tmr` block with single motion retrieval results:\n\n- `t2m_rank`: the rank of the correct motion when retrieving with this test case's text query.\n- `top5_retrieved`: the top-5 retrieved motions (sample IDs and text prompts) for inspection.\n\n## JSON Output Format\n\nBelow is a representative `metrics.json` written by `evaluate_folder.py` for a single test case with mixed constraints (root + end-effector + full-body) and TMR embeddings:\n\n```json\n{\n  \"num_motions\": 1,\n  \"folder\": \"...\",\n  \"per_motion_mean_gen\": {\n    \"foot_skate_from_height\": 0.3144,\n    \"foot_skate_from_pred_contacts\": 0.0672,\n    \"foot_skate_max_vel\": 0.2109,\n    \"foot_contact_consistency\": 0.9522,\n    \"foot_skate_ratio\": 0.2182,\n    \"constraint_end_effector\": 0.0286,\n    \"constraint_root2d_err\": 0.0534,\n    \"constraint_root2d_acc\": 1.0,\n    \"constraint_fullbody_keyframe\": 0.0324,\n    \"TMR/t2m_sim\": 0.8209\n  },\n  \"per_motion_mean_gt\": {\n    \"foot_skate_from_height\": 0.2361,\n    \"foot_skate_from_pred_contacts\": 0.0269,\n    \"foot_skate_max_vel\": 0.1459,\n    \"foot_contact_consistency\": 1.0,\n    \"foot_skate_ratio\": 0.1402,\n    \"constraint_end_effector\": 9.82e-07,\n    \"constraint_root2d_err\": 0.0407,\n    \"constraint_root2d_acc\": 1.0,\n    \"constraint_fullbody_keyframe\": 8.73e-07\n  },\n  \"tmr\": {\n    \"t2m_rank\": 2,\n    \"text\": \"A person is swiftly performing a dance move by moving their hands and legs.\",\n    \"top5_retrieved\": [\n      {\n        \"id\": \"0231\",\n        \"text\": \"A person is performing dance steps while stepping back and forward...\"\n      },\n      {\n        \"id\": \"0029\",\n        \"text\": \"A person is swiftly performing a dance move by moving their hands and legs.\"\n      }\n    ]\n  }\n}\n```\n\nGroup-level aggregate JSONs (`<group_name>.json`) have the same structure but with `num_motions > 1`, averaged per-motion metrics, additional keys like `constraint_root2d_err_p95`, and a `tmr` block containing the aggregate retrieval and FID scores:\n\n```json\n{\n  \"num_motions\": 256,\n  \"folder\": \"...\",\n  \"per_motion_mean_gen\": {\n    \"foot_skate_from_height\": 0.1742,\n    \"foot_skate_from_pred_contacts\": 0.0611,\n    \"foot_skate_max_vel\": 0.3747,\n    \"foot_contact_consistency\": 0.9483,\n    \"foot_skate_ratio\": 0.1499,\n    \"constraint_end_effector\": 0.0367,\n    \"constraint_root2d_err\": 0.0495,\n    \"constraint_root2d_acc\": 0.9212,\n    \"constraint_fullbody_keyframe\": 0.0324,\n    \"constraint_root2d_err_p95\": 0.1115\n  },\n  \"per_motion_mean_gt\": {\n    \"foot_skate_from_height\": 0.1617,\n    \"foot_skate_from_pred_contacts\": 0.0235,\n    \"foot_skate_max_vel\": 0.1185,\n    \"foot_contact_consistency\": 1.0,\n    \"foot_skate_ratio\": 0.1214,\n    \"constraint_end_effector\": 1.48e-06,\n    \"constraint_root2d_err\": 0.0376,\n    \"constraint_root2d_acc\": 1.0,\n    \"constraint_fullbody_keyframe\": 1.16e-06,\n    \"constraint_root2d_err_p95\": 0.0602\n  },\n  \"tmr\": {\n    \"TMR/t2m_sim\": 0.8742,\n    \"TMR/t2m_R/R01\": 75.39,\n    \"TMR/t2m_R/R02\": 85.55,\n    \"TMR/t2m_R/R03\": 88.28,\n    \"TMR/t2m_R/R05\": 90.23,\n    \"TMR/t2m_R/R10\": 93.36,\n    \"TMR/t2m_R/MedR\": 1.0,\n    \"TMR/t2m_R/len\": 256.0,\n    \"TMR/FID/gen_text\": 0.1442,\n    \"TMR/m2m_R/R01\": 94.53,\n    \"TMR/m2m_R/R02\": 97.66,\n    \"TMR/m2m_R/R03\": 98.05,\n    \"TMR/m2m_R/R05\": 98.83,\n    \"TMR/m2m_R/R10\": 99.22,\n    \"TMR/m2m_R/MedR\": 1.0,\n    \"TMR/m2m_R/len\": 256.0,\n    \"TMR/t2m_gt_R/R01\": 80.47,\n    \"TMR/t2m_gt_R/R02\": 88.28,\n    \"TMR/t2m_gt_R/R03\": 91.02,\n    \"TMR/t2m_gt_R/R05\": 92.58,\n    \"TMR/t2m_gt_R/R10\": 94.53,\n    \"TMR/t2m_gt_R/MedR\": 1.0,\n    \"TMR/t2m_gt_R/len\": 256.0,\n    \"TMR/FID/gen_gt\": 0.0387,\n    \"TMR/FID/gt_text\": 0.1349\n  }\n}\n```\n"
  },
  {
    "path": "docs/source/benchmark/pipeline.md",
    "content": "# Evaluation Pipeline\n\nThis page describes the full benchmark workflow, which uses scripts in the `benchmark` directory:\n\n1. Build full test suite using ground-truth motions from BONES-SEED BVH data and benchmark metadata (`create_benchmark.py`),\n2. Generate motions with a model for all or part of the test suite (`generate_eval.py`),\n3. Compute text/motion embeddings with pre-trained TMR model (`embed_folder.py `),\n4. Evaluate metrics over all generated samples (`evaluate_folder.py`),\n5. Aggregate and summarize results (`parse_folder.py`).\n\nThis pipeline works off-the-shelf for Kimodo models. To evaluate your own model, step (2) will need to be modified to generate with your custom model and output in the expected npz format.\n\n## Prerequisite: Download Motion Data and Metadata\nThe benchmark is constructed from motions in the BONES-SEED dataset and our released metadata. Make sure you have downloaded the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed) along with the metadata for the test suite from HuggingFace at [`nvidia/Kimodo-Motion-Gen-Benchmark`](https://huggingface.co/datasets/nvidia/Kimodo-Motion-Gen-Benchmark). \n\nThe `testsuite` folder from the downloaded metadata contains the directory structure described in the [benchmark introduction](introduction.md) with `meta.json`, `seed_motion.json`, and `seed_constraints.json` metadata files in the leaf folders. These metadata files contain info about the text prompts, durations, and constraint definitions for each test case. The first two steps of the evaluation pipeline will create the following in the leaf folders to prepare for computing metrics:\n\n- **Ground-Truth Motion** (`gt_motion.npz`): produced by `create_benchmark.py` from SEED BVH + metadata.\n- **Constraints Configuration** (`constraints.json`): for test cases with constraint inputs, this file is created by `create_benchmark.py` from SEED BVH + metadata.\n- **Generated Motion** (`motion.npz`): produced by the generation step from the model to evaluate (e.g. `generate_eval.py`).\n\nTo perform the full evaluation, including metrics for both ground-truth and generated motions (steps 3--5), each leaf folder must contain both `gt_motion.npz` and `motion.npz`.\n\n> Note: all of the following steps will work with a _subset_ of the full test suite, if desired. Anywhere the `testsuite` directory is passed in, it can be replaced with a specific subset such as `testsuite/content/text2motion` to only run this subset of the benchmark.\n\n## 1. Build Full Benchmark (`create_benchmark.py`)\n\n The `create_benchmark.py` script bridges the ground truth motions and metadata: it downloads the testsuite structure (if not already present locally), then reads the referenced BVH files from a local copy of BONES-SEED and writes `gt_motion.npz` and `constraints.json` into each sample folder.\n\n```bash\npython benchmark/create_benchmark.py path/to/testsuite --dataset datasets/bones-seed/soma_uniform\n```\n\nBy default, this construction can take several hours and the resulting folder is about **26 GB**. \n\nTo run faster, you can increase the number of parallel workers for processing:\n```bash\nOMP_NUM_THREADS=2 python benchmark/create_benchmark.py path/to/testsuite --dataset datasets/bones-seed/soma_uniform --workers 16\n```\nThis example runs well with a 32-core system, but you may need to adjust the number of threads-per-worker and total workers for your system. Generally, a lower number of threads-per-worker with larger number of workers (up to your available CPU capacity) runs fastest.\n\nOptions:\n\n- `--dataset`: path to the local SEED dataset folder (default: `datasets/bones-seed/soma_uniform`).\n- `--workers`: number of parallel workers to use for benchmark construction (default: 1, sequential)\n- `--overwrite`: rebuild `gt_motion.npz` even if it already exists.\n\nFor each test case, the script:\n\n1. parses the BVH file into local rotation matrices and root translation,\n2. subsamples to 30 FPS,\n3. converts to the standard T-pose via `SOMASkeleton77.to_standard_tpose`,\n4. computes Kimodo motion features and canonicalizes the motion,\n5. writes the resulting motion dictionary as `gt_motion.npz`.\n\nFor a detailed walkthrough of steps 1--4, see [Loading BONES-SEED BVH data](../user_guide/seed_dataset.md).\n\n## 2. Generate Motions (`generate_eval.py`)\n\nThe next step is to generate a motion for each test case.\nThe script `benchmark/generate_eval.py` recursively generates one motion with Kimodo per test case from either the full `testsuite` or a  desired subset. \n\n```bash\npython benchmark/generate_eval.py \\\n  --benchmark path/to/testsuite \\\n  --output generated_folder \\\n  --model kimodo-soma-rp \\\n  --batch_size 32 \\\n  --num_workers 4\n```\n\nThe batch size and number of data workers should be adjusted for your system. The script is intended to be run with the latest Kimodo-SOMA models (right now v1.1) which are compatible with the benchmark.\n\n> Note: each test cases has a seed in `meta.json` that is  loaded and used for generation to enable reproducibility. However, by default, the generation script uses the first seed in a batch to seed the whole batch, so to make results completely repeatable, you must set the batch size to 1 or always use the same batch size when running generation.\n\nUseful options:\n\n- `--model`: Kimodo model to use for generation. See [available models](../getting_started/quick_start.md#overview-kimodo-models) for the full list. \n- `--output`: output root directory. The testsuite hierarchy is mirrored here. If omitted, motions are generated **in-place** inside the testsuite folder.\n- `--overwrite`: regenerate even if `motion.npz` already exists.\n- `--diffusion_steps`: default denoising steps (can be overridden by each sample `meta.json`).\n- `--postprocess`: enable post-processing. For fair evaluation, it is recommended to **not** use post-processing so that metrics reflect the raw model output.\n- `--text_encoder_fp32`: will instantiate the text encoder (if needed) with float32 precision instead of bfloat16. The Kimodo v1.1 models are trained with float32 text encodings, so this slightly improves accuracy but requires extra VRAM.\n\nAfter generation, the output tree mirrors the `testsuite` hierarchy and includes generated motions (`motion.npz`). If the testsuite was built with `create_benchmark.py`, each leaf already has `gt_motion.npz`; the generation step adds `motion.npz` per sample.\n\n```text\ngenerated_folder/\n└── .../0000/\n    ├── meta.json\n    ├── constraints.json                # present if available in testsuite\n    ├── gt_motion.npz                   # if built with create_benchmark\n    └── motion.npz                      # generated\n```\n\n### Using Custom Models\n\nThe `generate_eval` script is set up to work with Kimodo models, but it can be easily adapted or replaced by generation with a custom model. The only requirement to be able to compute all metrics is to output the `motion.npz` file for each test case that minimally contains: (1) `posed_joints` field with global joint positions on the SOMA 77-joint skeleton and (2) `foot_contacts` field with binary foot contact predictions. Please see the [output formats docs](../user_guide/output_formats.md) for more details on the `NPZ` format.\n\n## 3. Embed with Pre-Trained TMR (`embed_folder.py`)\n\nSeveral evaluation metrics such as R-precision, FID, and latent similarity rely on latent embeddings of both motion and text. For this purpose, we use a [Text-Motion-Retrieval (TMR)](https://mathis.petrovich.fr/tmr/) model trained on the full Bones Rigplay dataset. See [Metrics](metrics.md) for details on the TMR evaluation protocol and metrics. \n\nThe next step in the eval pipeline is using this TMR model with the `benchmark/embed_folder.py` script to recursively embed each generated motion (`motion.npz`), GT motion (`gt_motion.npz`) when present, and the text prompt from `meta.json`:\n\n```bash\npython benchmark/embed_folder.py generated_folder --model tmr-soma-rp\n```\n\nThe default TMR model (`tmr-soma-rp`) trained on the full Rigplay dataset is released as [`TMR-SOMA-RP-v1`](https://huggingface.co/nvidia/TMR-SOMA-RP-v1). It is automatically downloaded from HuggingFace on first use of the embedding script. \n\nOptions:\n\n- `--model`: TMR model to use for encoding (default: `tmr-soma-rp`).\n- `--device`: compute device (`cuda` or `cpu`). Defaults to `cuda` if available, otherwise `cpu`.\n- `--overwrite`: re-embed even if embedding files already exist.\n- `--text_encoder_fp32`: will instantiate the text encoder (if needed) with float32 precision instead of bfloat16. The TMR model is trained with float32 text encodings, so this slightly improves accuracy but requires extra VRAM.\n\nRunning this script saves the embeddings to each test case folder that has the corresponding motion file(s) and `meta.json`:\n\n- `motion_embedding.npy` (when `motion.npz` exists)\n- `gt_motion_embedding.npy` (when `gt_motion.npz` exists)\n- `text_embedding.npy`\n\n> Note: this script can take over 1 hour to run for the full test suite, depending on your GPU.\n\n## 4. Compute Evaluation Metrics (`evaluate_folder.py`)\n\nNext, use `benchmark/evaluate_folder.py` to compute per-test-case and aggregated metrics across the test suite (or a specific subset folder). Each leaf folder must contain both `motion.npz` and `gt_motion.npz` to compute the metrics.\n\n```bash\npython benchmark/evaluate_folder.py generated_folder\n```\n\nOptions:\n\n- `--device`: compute device (`cuda` or `cpu`). Defaults to `cuda` if available, otherwise `cpu`.\n\nThe script runs two evaluation passes: one on the generated motion (`motion.npz`) and one on the ground-truth motion (`gt_motion.npz`). It outputs:\n\n- per test case results: `metrics.json` inside each test case (leaf) folder with metrics summarized for that single test case\n- per group results: `<group_name>.json` one level above each group of test-case folders that aggregates metrics over all contained test cases\n\nPlease see the [Metrics](metrics.md) page for a detailed explanation of these json formats.\n\nAfter embedding and evaluation, the folder structure should look like:\n\n```text\ngenerated_folder/\n├── .../0000/\n│   ├── motion.npz\n│   ├── gt_motion.npz\n│   ├── motion_embedding.npy\n│   ├── gt_motion_embedding.npy\n│   ├── text_embedding.npy\n│   └── metrics.json              # single test-case metrics\n└── .../<group_name>.json         # folder-level aggregate summary of all contained test cases\n```\n\n## 5. Summarize Results of Full Benchmark (`parse_folder.py`)\n\nIf you have computed metrics for the _entire_ test suite (both `content` and `repetition` splits), use `benchmark/parse_folder.py` to validate all per-test-case result JSONs and aggregate metrics into summary tables. Unlike the previous steps, this script expects the user to pass in the root `testsuite` and for the test suite to follow the standard split/category hierarchy (see [Introduction](introduction.md)):\n\n- **Splits**: `content`, `repetition`\n- **Categories**: `overview`, `timeline_single`, `timeline_multi` (text-following), `constraints_withtext`, `constraints_notext` (constrained generation)\n\n```bash\npython benchmark/parse_folder.py generated_folder\n```\n\nOptions:\n\n- `--output`: path for the output JSON (default: `<folder>/summary_rows.json`).\n- `--format`: table output format. `terminal` (default) for fixed-width tables, `md` for markdown tables suitable for copy-pasting into documentation.\n\nThe script:\n\n1. discovers all grouped test case directories (folders containing single test cases with `meta.json`, `motion.npz`, and `gt_motion.npz`),\n2. loads each group's `<group_name>.json` result files written by `evaluate_folder`,\n3. computes weighted averages of all metrics by split and category,\n4. writes `summary_rows.json` with per-row and per-table aggregated results,\n5. prints formatted benchmark tables to the terminal (text-following and constraints, with GT and method rows side by side).\n\nMetric values in the tables are converted to user-friendly units (e.g., constraint position errors in cm, foot skating in cm/s). See [Metrics](metrics.md) for definitions of individual metrics.\n"
  },
  {
    "path": "docs/source/benchmark/results.md",
    "content": "# Kimodo Results\n\nOn this page, we report the results for the latest Kimodo models on the benchmark test suite. These results are reproducible with the [evaluation pipeline](pipeline.md) and should be used when comparing against other models. Note that the reported numbers differ from the numbers in the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf) (Sec. 6) due to differences in skeleton, test suite composition, and evaluation details.\n\nTo reproduce these results or evaluate your own model, follow the [evaluation pipeline](pipeline.md) and use `parse_folder --format md` to generate summary tables in markdown format.\n\n**Note on reproducibility**: to exactly reproduce the results in the tables below, use batch size 1 when generating with Kimodo (i.e., when running `generate_eval.py`). This way, every test case is individually seeded according to `meta.json`. The reported results were computed using LLM2Vec in the default `bfloat16` precision. However, the Kimodo-SOMA-v1.1 and TMR models were actually trained with `float32` embeddings, so if you want to get the best possible performance (and you have enough VRAM), you can include `--text_encoder_fp32` when running the generation and embedding steps, even though the results will not match the tables here.\n\nResults are reported on the two splits described in [the introduction](introduction.md#dataset-splits):\n\n- **Content**: test cases with novel semantic content not present in training (e.g. unseen action categories).\n- **Repetition**: content categories seen during training, but specific motion clips are held out and unseen. Note that due to the annotations in Bones Rigplay and SEED datasets, the text prompts in this test split have already been seen during training.\n\nFor each split, we also report metrics for the ground truth motion. These rows serve as an empirical upper bound for motion quality, and deviations between ground truth and generated metrics highlight where the model can improve.\n\nWe split results for each model into two tables corresponding to different test cases in the test suite:\n\n- **Text-Following**: `overview`, `timeline_single`, and `timeline_multi`\n- **Constrained**: `constraints_withtext`, `constraints_notext`\n\n<!-- \n## Kimodo-SOMA-SEED-v1.1\nThese results are for the Kimodo model trained on the public [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) dataset. The results are comparable to any model trained on SEED that uses our recommended splits [described in the introduction](introduction.md#dataset-splits).\n\n### Text-Following Evaluation\n\n|  | Overview R@3↑ | Overview FID↓ | Overview Skate↓ | Overview Contact↑ | Timeline single R@3↑ | Timeline single FID↓ | Timeline single Skate↓ | Timeline single Contact↑ | Timeline multi R@3↑ | Timeline multi FID↓ | Timeline multi Skate↓ | Timeline multi Contact↑ |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 89.09 | 0.000 | 1.849 | 1.000 | 86.26 | 0.000 | 1.789 | 1.000 | 88.47 | 0.000 | 1.711 | 1.000 |\n| **Content** Kimodo | 81.13 | 0.035 | 4.077 | 0.977 | 73.17 | 0.028 | 3.873 | 0.980 | 80.10 | 0.032 | 3.685 | 0.981 |\n| **Repetition** Ground Truth | 93.91 | 0.000 | 2.106 | 1.000 | 90.13 | 0.000 | 2.037 | 1.000 | 94.49 | 0.000 | 1.931 | 1.000 |\n| **Repetition** Kimodo | 90.92 | 0.004 | 4.573 | 0.972 | 80.38 | 0.007 | 4.442 | 0.976 | 92.58 | 0.006 | 4.199 | 0.974 |\n\n\n### Constrained Evaluation\n\n|  | With text FB Pos↓ | With text EE Pos↓ | With text EE Rot↓ | With text 2D Root↓ | With text Pelvis@95% | Without text FB Pos↓ | Without text EE Pos↓ | Without text EE Rot↓ | Without text 2D Root↓ | Without text Pelvis@95% |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 0.000 | 0.000 | - | 3.837 | 5.36 | 0.000 | 0.000 | - | 3.913 | 5.41 |\n| **Content** Kimodo | 3.421 | 3.817 | - | 4.979 | 9.14 | 3.320 | 3.664 | - | 4.797 | 9.03 |\n| **Repetition** Ground Truth | 0.000 | 0.000 | - | 3.607 | 5.44 | 0.000 | 0.000 | - | 3.567 | 5.42 |\n| **Repetition** Kimodo | 3.187 | 3.852 | - | 4.734 | 9.19 | 3.120 | 3.510 | - | 4.264 | 7.89 |\n\n\n\n## Kimodo-SOMA-RP-v1.1\nThese results are for the Kimodo model trained on the full (proprietary) Bones Rigplay dataset which is a superset of BONES-SEED. Though the training split is larger, the model is not trained on the SEED test splits to ensure a fair comparison.\n\n### Text-Following Evaluation\n\n|  | Overview R@3↑ | Overview FID↓ | Overview Skate↓ | Overview Contact↑ | Timeline single R@3↑ | Timeline single FID↓ | Timeline single Skate↓ | Timeline single Contact↑ | Timeline multi R@3↑ | Timeline multi FID↓ | Timeline multi Skate↓ | Timeline multi Contact↑ |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 89.09 | 0.000 | 1.849 | 1.000 | 86.26 | 0.000 | 1.789 | 1.000 | 88.47 | 0.000 | 1.711 | 1.000 |\n| **Content** Kimodo | 83.32 | 0.025 | 3.641 | 0.982 | 78.08 | 0.026 | 3.523 | 0.984 | 84.79 | 0.028 | 3.278 | 0.985 |\n| **Repetition** Ground Truth | 93.91 | 0.000 | 2.106 | 1.000 | 90.13 | 0.000 | 2.037 | 1.000 | 94.49 | 0.000 | 1.931 | 1.000 |\n| **Repetition** Kimodo | 87.90 | 0.008 | 4.103 | 0.977 | 77.02 | 0.011 | 3.938 | 0.981 | 88.59 | 0.009 | 3.727 | 0.980 |\n\n\n### Constrained Evaluation\n\n|  | With text FB Pos↓ | With text EE Pos↓ | With text EE Rot↓ | With text 2D Root↓ | With text Pelvis@95% | Without text FB Pos↓ | Without text EE Pos↓ | Without text EE Rot↓ | Without text 2D Root↓ | Without text Pelvis@95% |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 0.000 | 0.000 | - | 3.837 | 5.36 | 0.000 | 0.000 | - | 3.913 | 5.41 |\n| **Content** Kimodo | 2.929 | 3.029 | - | 4.581 | 7.77 | 2.935 | 2.994 | - | 4.411 | 7.37 |\n| **Repetition** Ground Truth | 0.000 | 0.000 | - | 3.607 | 5.44 | 0.000 | 0.000 | - | 3.567 | 5.42 |\n| **Repetition** Kimodo | 2.804 | 2.983 | - | 4.260 | 7.63 | 2.829 | 2.969 | - | 4.027 | 7.21 |\n-->\n\n\n## Quantitative Results\n\nResults are reported for two models:\n\n- **Kimodo-SOMA-SEED-v1.1**:  trained on the public [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) dataset. The results are comparable to any model trained on SEED that uses our recommended splits [described in the introduction](introduction.md#dataset-splits).\n- **Kimodo-SOMA-RP-v1.1**: trained on the full (proprietary) Bones Rigplay dataset which is a superset of BONES-SEED. Though the training split is larger, the model is not trained on the SEED test splits to ensure a fair comparison.\n\n### Text-Following Evaluation\n\n|  | Overview R@3↑ | Overview FID↓ | Overview Skate↓ | Overview Contact↑ | Timeline single R@3↑ | Timeline single FID↓ | Timeline single Skate↓ | Timeline single Contact↑ | Timeline multi R@3↑ | Timeline multi FID↓ | Timeline multi Skate↓ | Timeline multi Contact↑ |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 89.09 | 0.000 | 1.849 | 1.000 | 86.26 | 0.000 | 1.789 | 1.000 | 88.47 | 0.000 | 1.711 | 1.000 |\n| **Content** Kimodo-SOMA-SEED-v1.1 | 81.13 | 0.035 | 4.077 | 0.977 | 73.17 | 0.028 | 3.873 | 0.980 | 80.10 | 0.032 | 3.685 | 0.981 |\n| **Content** Kimodo-SOMA-RP-v1.1 | 83.32 | 0.025 | 3.641 | 0.982 | 78.08 | 0.026 | 3.523 | 0.984 | 84.79 | 0.028 | 3.278 | 0.985 |\n| **Repetition** Ground Truth | 93.91 | 0.000 | 2.106 | 1.000 | 90.13 | 0.000 | 2.037 | 1.000 | 94.49 | 0.000 | 1.931 | 1.000 |\n| **Repetition** Kimodo-SOMA-SEED-v1.1 | 90.92 | 0.004 | 4.573 | 0.972 | 80.38 | 0.007 | 4.442 | 0.976 | 92.58 | 0.006 | 4.199 | 0.974 |\n| **Repetition** Kimodo-SOMA-RP-v1.1 | 87.90 | 0.008 | 4.103 | 0.977 | 77.02 | 0.011 | 3.938 | 0.981 | 88.59 | 0.009 | 3.727 | 0.980 |\n\n### Constrained Evaluation\n\n|  | With text FB Pos↓ | With text EE Pos↓ | With text EE Rot↓ | With text 2D Root↓ | With text Pelvis@95% | Without text FB Pos↓ | Without text EE Pos↓ | Without text EE Rot↓ | Without text 2D Root↓ | Without text Pelvis@95% |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| **Content** Ground Truth | 0.000 | 0.000 | - | 3.837 | 5.36 | 0.000 | 0.000 | - | 3.913 | 5.41 |\n| **Content** Kimodo-SOMA-SEED-v1.1 | 3.421 | 3.817 | - | 4.979 | 9.14 | 3.320 | 3.664 | - | 4.797 | 9.03 |\n| **Content** Kimodo-SOMA-RP-v1.1 | 2.929 | 3.029 | - | 4.581 | 7.77 | 2.935 | 2.994 | - | 4.411 | 7.37 |\n| **Repetition** Ground Truth | 0.000 | 0.000 | - | 3.607 | 5.44 | 0.000 | 0.000 | - | 3.567 | 5.42 |\n| **Repetition** Kimodo-SOMA-SEED-v1.1 | 3.187 | 3.852 | - | 4.734 | 9.19 | 3.120 | 3.510 | - | 4.264 | 7.89 |\n| **Repetition** Kimodo-SOMA-RP-v1.1 | 2.804 | 2.983 | - | 4.260 | 7.63 | 2.829 | 2.969 | - | 4.027 | 7.21 |"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport sys\n\n# -- Path setup --------------------------------------------------------------\nsys.path.insert(0, os.path.abspath(\"../..\"))\n\n# -- Project information -----------------------------------------------------\n\nproject = \"Kimodo\"\ncopyright = \"2026, NVIDIA\"\nauthor = \"NVIDIA\"\n\nversion = \"\"\nrelease = \"\"\n\n# -- General configuration ---------------------------------------------------\n\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.githubpages\",\n    \"sphinx_copybutton\",\n    \"myst_parser\",\n    \"sphinx_design\",\n]\n\nnapoleon_google_docstring = True\nnapoleon_numpy_docstring = False\nnapoleon_include_init_with_doc = True\nnapoleon_use_param = True\nnapoleon_use_rtype = True\n\nautodoc_default_options = {\n    \"members\": True,\n    \"member-order\": \"bysource\",\n    \"special-members\": \"__init__\",\n    \"undoc-members\": True,\n    \"exclude-members\": \"__weakref__\",\n    \"show-inheritance\": False,\n}\nautodoc_typehints = \"none\"\n\nautosummary_generate = True\n\n# Avoid initialization issues for optional native libs\nos.environ.setdefault(\"MUJOCO_GL\", \"osmesa\")\nos.environ.setdefault(\"PYOPENGL_PLATFORM\", \"osmesa\")\n\n\nclass Mock:\n    \"\"\"Mock class for imports that can't be satisfied.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        pass\n\n    def __call__(self, *args, **kwargs):\n        return Mock()\n\n    def __getattr__(self, name):\n        if name in (\"__file__\", \"__path__\"):\n            return \"/dev/null\"\n        if name == \"__version__\":\n            # Some libraries (e.g. safetensors) parse torch.__version__ with\n            # packaging.version.Version, so this must be a valid PEP 440 string.\n            return \"0.0.0\"\n        if name == \"__signature__\":\n            return None\n        if name == \"__mro_entries__\":\n            return lambda bases: ()\n        return Mock()\n\n    def __getitem__(self, name):\n        return Mock()\n\n    def __iter__(self):\n        return iter([])\n\n    def __or__(self, other):\n        return Mock()\n\n    def __ror__(self, other):\n        return Mock()\n\n\nmock_modules = [\n    \"torch\",\n    \"torch.nn\",\n    \"torch.nn.functional\",\n    \"torch.optim\",\n    \"torch.distributed\",\n    \"torch.cuda\",\n    \"torch.utils\",\n    \"torch.utils.data\",\n    \"lightning\",\n    \"lightning.fabric\",\n    \"lightning_fabric\",\n    \"pytorch_lightning\",\n    \"tensordict\",\n    \"pydantic\",\n    \"pydantic.dataclasses\",\n    \"pydantic_core\",\n    \"mujoco\",\n    \"isaacgym\",\n    \"isaacgymenvs\",\n    \"genesis\",\n    \"omni\",\n    \"wandb\",\n    \"hydra\",\n    \"omegaconf\",\n    \"tqdm\",\n    \"trimesh\",\n    \"pyvista\",\n    \"smplx\",\n    \"smpl\",\n    \"scipy\",\n    \"scipy.spatial\",\n    \"scipy.spatial.transform\",\n    \"peft\",\n    \"transformers\",\n    \"safetensors\",\n    \"safetensors.torch\",\n    \"sklearn\",\n    \"PIL\",\n    \"cv2\",\n    \"rich\",\n    \"rich.progress\",\n    \"skimage\",\n    \"imageio\",\n    \"openmesh\",\n    \"gym\",\n    \"easydict\",\n    \"dm_control\",\n    \"dm_control.mjcf\",\n    \"dm_control.mujoco\",\n    \"matplotlib\",\n    \"matplotlib.pyplot\",\n]\n\nfor mod in mock_modules:\n    sys.modules[mod] = Mock()\n\nautodoc_mock_imports = mock_modules\n\ntemplates_path = [\"_templates\"]\nexclude_patterns = [\"api_reference/_generated/**\"]\n\nlanguage = \"en\"\n\nsource_suffix = {\n    \".rst\": \"restructuredtext\",\n    \".md\": \"markdown\",\n}\n\nmaster_doc = \"index\"\n\n# -- Options for HTML output -------------------------------------------------\n\nhtml_theme = \"nvidia_sphinx_theme\"\nhtml_static_path = [\"_static\"]\nhtml_css_files = [\"custom.css\"]\nhtml_logo = \"_static/logo-placeholder.svg\"\nhtml_show_sourcelink = False\n\nhtml_theme_options = {\n    \"collapse_navigation\": False,\n    \"navigation_depth\": 4,\n}\n\ntoc_object_entries_show_parents = \"hide\"\n\nhtmlhelp_basename = \"Kimododoc\"\n\n# -- Options for intersphinx -------------------------------------------------\n\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/3\", None),\n    \"torch\": (\"https://pytorch.org/docs/stable/\", None),\n    \"numpy\": (\"https://numpy.org/doc/stable/\", None),\n}\n\ncopybutton_prompt_text = r\">>> |\\.\\.\\. |\\$ |In \\[\\d*\\]: | {2,5}\\.\\.\\.: | {5,8}: \"\ncopybutton_prompt_is_regexp = True\n\n# Generate heading anchors so cross-doc links like path.md#fragment resolve (local ids).\nmyst_heading_anchors = 4\n\n# Required so `:::{dropdown}` and other fenced directives in .md files are parsed (not shown as plain text).\nmyst_enable_extensions = [\"colon_fence\"]\n\n\ndef setup(app):\n    app.add_css_file(\"custom.css\")\n"
  },
  {
    "path": "docs/source/getting_started/installation.md",
    "content": "# Installation\n\n> Note: This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.\n\n> Note: This repo was developed and primarily tested on Linux\n\nThere are two ways to install Kimodo: (1) as a package, or (2) download the source code and install.\nBoth require setting up a Hugging Face token to use the text encoder at generation time.\n\n## Set Up Hugging Face Token\n\nThe Kimodo text encoder relies on the **gated** `meta-llama/Meta-Llama-3-8B-Instruct` model, which requires:\n- Your HF account has been granted access to the [model page](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct).\n- You provide a HF token for runtime\n\nAfter receiving access to the Llama repo, please create an access token [here](https://huggingface.co/settings/tokens/new?tokenType=read). Then use it to log in on your command line:\n\n```bash\nhf auth login\n```\nor alternatively, paste the token in this file ``~/.cache/huggingface/token``. If you don't have `hf` installed, you will first need to run `pip install --upgrade huggingface_hub`.\n\n## Kimodo Install Option 1: Package Install\n\nThe easiest way to get started is simply installing Kimodo as a package without needing to clone the codebase. This will allow you to generate motions and run the demo as a black box.\n\nWe suggest creating a new Python environment for the install, for example with `venv` or conda:\n```bash\nconda create -n kimodo python=3.10\nconda activate kimodo\n```\n\nTo ensure you have a version of [PyTorch](https://pytorch.org/get-started/locally/) that is compatible with your system and CUDA version, it is recommended to manually install the best version of PyTorch for you before installing Kimodo. Anything over PyTorch 2.0 is sufficient. We strongly suggest using a GPU-capable version of PyTorch to generate motions in a reasonable amount of time.\n\nInstalling the base Kimodo package will allow you generate motions with the command line:\n```bash\npip install git+https://github.com/nv-tlabs/kimodo.git\n```\n\nIf you want to be able to run the interactive demo as well, use this command which installs additional dependencies:\n```bash\npip install \"kimodo[all] @ git+https://github.com/nv-tlabs/kimodo.git\"\n```\n\nNow should be ready to use Kimodo. Check out the [quick start guide](quick_start.md) to see how to generate motions.\n\nIf you experience issues with package or system compatibility using the above install strategy, we recommend downloading the codebase and using the Docker install detailed below.\n\n## Kimodo Install Option 2: Source Code Install\n\nIf you plan to build on Kimodo or dig into the codebase, you'll want to clone and install the repo.\n\n### Clone Kimodo Repository\n\n```bash\ngit clone https://github.com/nv-tlabs/kimodo.git\ncd kimodo\n```\n\n### Choose Your Installation Route\nKimodo can be installed by building and running through a virtual environment (e.g., `conda`) or within a Docker container.\n\n```{toctree}\n:maxdepth: 1\n\ninstallation_virtual_env\ninstallation_docker\ninstallation_smpl\n```\n"
  },
  {
    "path": "docs/source/getting_started/installation_docker.md",
    "content": "# Installation With Docker\n\n> Note: the first time building and running with Docker can take several minutes, please be patient.\n\n## Clone Modified Viser Library\nThe interactive demo relies on [a fork of Viser](https://github.com/nv-tlabs/kimodo-viser) that implements a timeline interface and more. Clone it within the `kimodo` directory before building with Docker using:\n```bash\ngit clone https://github.com/nv-tlabs/kimodo-viser.git\n```\n\n## Quick Install\n\nBefore running Docker, make sure your Hugging Face token is available at\n`~/.cache/huggingface/token` on the host, for example by running\n`hf auth login` once outside the container (see the [Installation](installation.md) instructions).\n\nThe easiest way to build and immediately run the interactive demo webapp (with the text-encoder service) in one command is:\n\n```bash\ndocker compose up -d --build\n```\n\n## Step-by-Step Installation\n\nAlternatively, you can first build with:\n\n```bash\ndocker compose build\n```\n\nThis builds `text-encoder` and `demo` containers corresponding to the text encoding service and the interactive motion authoring webapp, respectively. Please see the [quick start guide](quick_start.md) for more information on these.\n\n<details>\n\n<summary>Advanced Configuration of Dependencies</summary>\n\nThis repo uses:\n- `docker_requirements.in`: human-maintained, top-level dependencies\n- `docker_requirements.txt`: pinned lockfile (automatically generated)\n\nNotes:\n- We keep a lockfile for **reproducible Docker builds** (so a rebuild next week pulls the same deps).\n- The lockfile intentionally **omits `torch`/CUDA wheels** because the Docker base image\n  (`nvcr.io/nvidia/pytorch`) already provides a tested PyTorch build (avoids slow installs and CUDA mismatches).\n\n</details>\n<br>\n\nAfter building, you will need to manually start the text-encoder service before doing any motion generation:\n```bash\ndocker compose up text-encoder\n```\nNote, the first time running this command will take a long time as the Llama-based text encoder is downloaded.\n\nFinally, to start the interactive demo:\n```bash\ndocker compose up demo\n```\n\nFor more information on using the Docker setup, see the [Quick Start](quick_start.md) guide next.\n"
  },
  {
    "path": "docs/source/getting_started/installation_smpl.md",
    "content": "# Using Kimodo-SMPLX Model\n\nUsing the [Kimodo-SMPLX-RP-v1](https://huggingface.co/nvidia/Kimodo-SMPLX-RP-v1) model requires a few extra installation steps.\n\n## Request Model Access\n\nThe SMPL-X version of Kimodo is gated, so before trying to generate motions with it in the CLI or demo, go to the [Hugging Face model page](https://huggingface.co/nvidia/Kimodo-SMPLX-RP-v1) and request access. As described in the [installation](./installation.md) process, make sure your HF token is properly set up so your access to the model can be authenticated.\n\n## Download SMPL-X Body Model\nIf you want to visualize generated SMPL-X motions in the demo, you will need to download the SMPL-X body model.\nGo to the [SMPL-X](https://smpl-x.is.tue.mpg.de/) webpage and then sign in or create an account and go to the \"Download\" page.\nClick \"Download SMPL-X with removed head bun (NPZ)\" and then copy the `SMPLX_NEUTRAL.npz` file to the Kimodo codebase to be at `kimodo/kimodo/assets/skeletons/smplx22/SMPLX_NEUTRAL.npz`.\n\nNote that if you installed Kimodo as a package without downloading the codebase, you'll need to find where the assets directory is located by running:\n```bash\npython -c \"from kimodo.assets import skeleton_asset_path; print(skeleton_asset_path('smplx22'))\"\n```\n"
  },
  {
    "path": "docs/source/getting_started/installation_virtual_env.md",
    "content": "# Installation With Virtual Environment\n\n> Note: the repo was tested with Python 3.10+ and PyTorch 2.0+.\n\n## Create Enviroment\nWe recommend setting up a separate virtual environment for Kimodo to avoid dependency conflicts.\n\n### Using venv\n```bash\npython -m venv venv\nsource venv/bin/activate\n```\n\n### Using Conda\n```bash\nconda create -n kimodo python=3.10\nconda activate kimodo\n```\n\n## Install Dependencies\n\n### Install PyTorch\nFirst, make sure to install a version of [PyTorch](https://pytorch.org/get-started/locally/) that works with your system and CUDA version. We suggest anything over PyTorch 2.0. We strongly suggest using a GPU-capable version of PyTorch to generate motions in a reasonable amount of time.\n\n### (Optional) Clone Modified Viser Library\nThe interactive demo relies on [a fork of Viser](https://github.com/nv-tlabs/kimodo-viser) that implements a timeline interface and more. If you want to have an editable install of this version of Viser (i.e., you expect to modify it), clone and install it within the `kimodo` directory using:\n```bash\ngit clone https://github.com/nv-tlabs/kimodo-viser.git\npip install -e kimodo-viser\n```\n\n### Install Kimodo\nNext, install Kimodo run this command from the base of repo:\n```bash\npip install -e .\n```\nThis results in a single editable install for Kimodo and the MotionCorrection package.\n\nIf you plan to use the demo, you can instead run:\n```bash\npip install -e \".[all]\"\n```\nThis will install our [Viser fork](https://github.com/nv-tlabs/kimodo-viser) (if not already installed in the previous step) and the [SOMA body model](https://github.com/NVlabs/SOMA-X).\n\nNext, head over to the [Quick Start](quick_start.md) page to test out your installation by generating some motions.\n"
  },
  {
    "path": "docs/source/getting_started/quick_start.md",
    "content": "# Quick Start\n\nThis page provides a quick introduction to motion generation with Kimodo. For detailed explanations, we recommend reviewing the full documentation pages linked in each section.\n\nBefore running these commands, follow the [installation guide](installation.md) to install Kimodo in a virtual environment or using Docker.\n\n## Overview: Kimodo Models\nMotion generation can be performed with several trained Kimodo models that vary by skeleton and training dataset.\n\n> Note: models will be downloaded automatically when attempting to generate from the CLI or Interactive Demo, so there is no need to download them manually\n\n| Model | Skeleton | Training Data | Release Date | Hugging Face | License |\n|-------|------|------|-------------|-------------|----|\n| **Kimodo-SOMA-RP-v1.1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | April 10, 2026 | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1.1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-SEED-v1.1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | April 10, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1.1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-RP-v1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026 | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-RP-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-G1-RP-v1** | [Unitree G1](https://github.com/unitreerobotics/unitree_mujoco/tree/main/unitree_robots/g1) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-G1-RP-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SOMA-SEED-v1** | [SOMA](https://github.com/NVlabs/SOMA-X) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SOMA-SEED-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-G1-SEED-v1** | [Unitree G1](https://github.com/unitreerobotics/unitree_mujoco/tree/main/unitree_robots/g1) | [BONES-SEED](https://huggingface.co/datasets/bones-studio/seed) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-G1-SEED-v1) | [NVIDIA Open Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) |\n| **Kimodo-SMPLX-RP-v1** | [SMPL-X](https://github.com/vchoutas/smplx) | [Bones Rigplay 1](https://bones.studio/datasets#rp01) | March 16, 2026  | [Link](https://huggingface.co/nvidia/Kimodo-SMPLX-RP-v1) | [NVIDIA R&D Model](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-internal-scientific-research-and-development-model-license/) |\n\nBy default, we recommend using the models trained on the full Bones Rigplay dataset (700 hours of mocap) for your motion generation needs.\nThe models trained on BONES-SEED use 288 hours of [publicly available mocap data](https://huggingface.co/datasets/bones-studio/seed) so are less capable, but are useful for comparing your own trained models on the same dataset. See the [benchmark](../benchmark/introduction.md) for a standardized evaluation suite on BONES-SEED.\n\n### Recommended Hardware\nKimodo requires  ~17GB of VRAM to generate locally entirely on GPU, due primarily to the size of the text embedding model. If you have a smaller card, set `TEXT_ENCODER_DEVICE=cpu` when running Kimodo commands to force text encoding to the CPU. This is slightly slower but reduces VRAM usage to <3 GB.\n\nThe model has been most extensively tested on GeForce RTX 3090, GeForce RTX 4090, and NVIDIA A100 GPUs, but it should work on other recent cards with sufficient VRAM.\n\n## Run Text-Encoder Service\nMotion generation relies on embedding the input text prompt, which becomes the input to Kimodo. Although it is fine to run the CLI commands and demo on their own, it may be preferred to start the _text encoder service_ in the background, which can be shared across all motion generation requests. This is much more efficient when making many consecutive CLI calls, as it avoids needing to instantiate the large text encoder every time.\n\nTo start the text encoder service:\n```bash\nkimodo_textencoder\n```\n\nThe first run of the service will take a while as it downloads the embedding model. We recommend running this in the background or in a separate terminal where it will stay open and usable by other scripts.\n\nIf you are using the Docker set up, the service can alternatively be started in the container with:\n```bash\ndocker compose up text-encoder\n```\n\n> Note: when the text encoder is initialized, the transformers library will report several unexpected and missing layers for LLM2Vec. These are expected and can be safely ignored.\n\nIf you are running on a GPU with <16 GB VRAM, you can force the text encoder to the CPU, for example:\n```bash\nTEXT_ENCODER_DEVICE=cpu kimodo_textencoder\n```\n\n## Command-Line Text-to-Motion Generation\n**[CLI Documentation](../user_guide/cli.md)**\n\nYou can generate motions from the command line using the generate script:\n\n```bash\nkimodo_gen \"A person walks forward.\" \\\n    --model Kimodo-SOMA-RP-v1 \\\n    --duration 5.0 \\\n    --output output\n```\n\nThe `--model` command corresponds to the model name in the table above. The output motion will be saved using the stem name given by `--output` in the Kimodo [output format](../user_guide/output_formats.md). For a detailed description of all generation arguments, including how to generate motion with constraints, see the full [CLI documentation](../user_guide/cli.md).\n\nIf you set up Kimodo with Docker, you can instead run generation inside the Docker container, replacing `kimodo_gen XXX` with `docker compose run --rm demo kimodo_gen XXX`. If you will be running generation multiple times, it is better to start the `demo` container (e.g., in another terminal or in the background), and then run commands inside it with `docker compose exec demo kimodo_gen XXX`.\n\n\n## Interactive Motion Authoring Demo\n**[Demo Documentation](../interactive_demo/index.md)**\n\nThe demo allows easily generating motions with an intuitive control interface for text prompting and constraints.\n\nThe demo can be started with:\n```bash\nkimodo_demo\n```\n\nThe demo is a webapp that will run on [http://localhost:7860](http://localhost:7860). Open this URL in your browser to access the interface.\n\nIf you are using Docker, the demo can be launched with:\n```bash\ndocker compose up demo\n```\nor if you want to start the demo and text encoder service (explained below) at the same time, use:\n```bash\ndocker compose up\n```\n\n<details>\n<summary>Additional Tips for Docker</summary>\n\nYou may find the following commands useful if running Kimodo within the Docker containers. In the example commands below, you can also replace `demo` by `text-encoder`:\n\n**Check logs:**\n\n```bash\ndocker compose logs demo\n```\n\n**Stop service:**\n\n```bash\ndocker compose stop demo\n```\n\n**Restart service:**\n\n```bash\ndocker compose restart demo\n```\n\n**Stop and remove everything:**\n\n```bash\ndocker compose down\n```\n\n</details>\n"
  },
  {
    "path": "docs/source/index.md",
    "content": "# Kimodo Documentation\n\n<div class=\"hero\">\n  <div class=\"hero-title\">Kimodo</div>\n  <div class=\"hero-subtitle\">\n    Scaling controllable human motion generation\n  </div>\n  <div class=\"hero-actions\">\n    <a href=\"getting_started/installation.html\">Get Started</a>\n    <a class=\"secondary\" href=\"interactive_demo/index.html\">Interactive Demo</a>\n    <a class=\"secondary\" href=\"https://github.com/nv-tlabs/kimodo\">GitHub</a>\n    <a class=\"secondary\" href=\"https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf\">Tech Report</a>\n  </div>\n</div>\n\n\n## Overview\n\nKimodo is a **ki**nematic **mo**tion **d**iffusi**o**n model trained on a large-scale (700 hours) commercially-friendly optical motion capture dataset. The model generates high-quality 3D human and robot motions, and is controlled through text prompts and an extensive set of constraints such as full-body pose keyframes, end-effector positions/rotations, 2D paths, and 2D waypoints. See the [project page](https://research.nvidia.com/labs/sil/projects/kimodo/) for details.\n\n## Highlights\n\n<div class=\"card-grid\">\n  <div class=\"card\">\n    <h3>Controlled Generation</h3>\n    <p>Text prompts combined with full-body, root, and end-effector constraints.</p>\n  </div>\n  <div class=\"card\">\n    <h3>Human(oid) Support</h3>\n    <p>Model variations for both digital humans and humanoid robots.</p>\n  </div>\n  <div class=\"card\">\n    <h3>Interactive Demo</h3>\n    <p>Timeline editing, real-time 3D visualization, and example presets.</p>\n  </div>\n</div>\n\n## Quick links\n\n- [Installation](getting_started/installation.md)\n- [Quick Start](getting_started/quick_start.md)\n- [Command Line Interface](user_guide/cli.md)\n- [Interactive Demo](interactive_demo/index.md)\n- [Project Structure](project_structure.md)\n\n```{toctree}\n:maxdepth: 3\n:caption: Getting Started\n:hidden:\n\ngetting_started/installation\ngetting_started/quick_start\n```\n\n```{toctree}\n:maxdepth: 2\n:caption: User Guide\n:hidden:\n\ninteractive_demo/index\nuser_guide/cli\nuser_guide/constraints\nuser_guide/output_formats\nuser_guide/motion_convert\nuser_guide/seed_dataset\nuser_guide/configuration\n```\n\n```{toctree}\n:maxdepth: 2\n:caption: Key Concepts\n:hidden:\n\nkey_concepts/model\nkey_concepts/limitations\nkey_concepts/motion_representation\nkey_concepts/constraints\nkey_concepts/skeleton\n```\n\n```{toctree}\n:maxdepth: 2\n:caption: Benchmark\n:hidden:\n\nbenchmark/introduction\nbenchmark/pipeline\nbenchmark/metrics\nbenchmark/results\n```\n\n```{toctree}\n:maxdepth: 2\n:caption: Reference\n:hidden:\n\nproject_structure\nproject_info\napi_reference/index\n```\n"
  },
  {
    "path": "docs/source/interactive_demo/constraints.md",
    "content": "# Constraints\n\nConstraints guide the motion at specific frames or intervals. To learn about the types of constraints details of each, see the [constraints concepts](../key_concepts/constraints.md) and [constraints format](../user_guide/constraints.md) pages.\n\n![Constraints panel](../_static/demo/constraints_panel.png)\n![Editing mode](../_static/demo/editing_mode.png)\n\nThe constraint panel allows you to configure constraints and editing:\n\n- **Enter Editing Mode**: enable FK pose editing in the viewer. Gizmos will be displayed on joints that can be edited. If there is already a constraint on the timeline for the current frame, any pose editing will adjust that constraint, otherwise you need to add a constraint on the timeline after adjusting the pose.\n- **Gizmo space**: whether to display the rotation gizmos in local or global joint space while editing\n- **Snap to Constraint**: will snap the current frame of motion to the constraint at that frame. This can be useful if a generated pose does not exactly meet the constraint and you want to continue editing the constraint.\n- **Reset Constraint**: does the opposite by snapping the pose back to the original generated motion from the constrained pose.\n- **Root 2D Options > Make Smooth Path**: if you have laid down root waypoint constraints, checking this box will turn the waypoints into a smoothed dense path constraint. If there is not a waypoint at the first and last frames of the motion, they will be automatically added since Kimodo is only trained on full-sequence paths.\n- **Clear All Constraints**: clears all current constraints from the viewer and timeline.\n"
  },
  {
    "path": "docs/source/interactive_demo/examples.md",
    "content": "# Examples\n\nThe Examples Tab within the settings panel contains several examples that highlight the key capabilities and potential workflows with Kimodo.\nExamples are included for the `Kimodo-SOMA-RP` and `Kimodo-G1-RP` models.\n\n![Skeleton overview](../_static/demo/examples_panel.png)\n\nAfter choosing an example from the dropdown menu, click \"Load Example\" to load the example configuration into the viewer.\n\nThe viewer will display the pre-generated motion along with the prompts and constraints on the timeline that were used to generate it. All settings used to generate the model are also loaded with the example (e.g., seed, classifier-free guidance settings), so you should be able to click \"Generate\" in the panel to recover the same result.\n\nExample cover a variety of ways to use one or more text prompts along with kinematic constraints for generation.\n\n**Saving New Examples**: after you've generated a motion, you can save a new Example under the \"Load/Save\" tab of the Settings panel. You should immediately see the Examples dropdown update with your new saved example so it can be loaded in later.\n\nThis section walks through common workflows and how to use the webapp. Each\nworkflow has its own section and an accompanying video.\n"
  },
  {
    "path": "docs/source/interactive_demo/export_results.md",
    "content": "# Saving/Loading\n\nThe Load/Save and Exports panels allow saving generated results and load in previously generated results\n\n![Export panel](../_static/demo/exports_panel.png)\n\n- **Load/Save**\n    - **Motion**: save the current motion in the [NPZ format](../user_guide/output_formats.md#kimodo-npz-format) to a specific path. Motion NPZs can also be loaded into the viewer from this panel. This is useful to load in motions generated with the CLI.\n    - **Constraints**: save the current constraints in the [JSON format](../user_guide/constraints.md) to a specific path. Constraint JSON files can also be load into the viewer.\n    - **Example**: allows saving a new example that encompasses the current motion, constraints, and all settings. This is useful for reloading previous work. If examples are saved to the demo examples directory, they will be loadable from the Examples dropdown menu, otherwise you can load them through file path in this menu.\n\n- **Exports**\n    - **Screenshot**: save current canvas as an image that can be downloaded through your browser\n    - **Video**: record the current motion to a video that can be download through your browser\n    - **Motion**: save the current motion to a format of your choice depending on the loaded skeleton:\n      - SOMA: `NPZ` or `BVH`\n      - G1: `NPZ` or `CSV`\n      - SMPL-X: `NPZ` or `AMASS NPZ`\n      These formats are described in [output formats](../user_guide/output_formats.md).\n"
  },
  {
    "path": "docs/source/interactive_demo/generation.md",
    "content": "# Generation\n\nThe most important panel is the \"Generate\" which allows you to call Kimodo to generate one or more motions based on the prompts, constraints, and settings provided.\n\n![Generate panel](../_static/demo/generate_panel.png)\n\n- **Num Samples**: the number of motions to generate based on the current settings. When multiple samples are generated, you _must_ choose a single sample by clicking the character in the viewer before editing constraints or generating new motion.\n- **SOMA Layer**: if using a `Kimodo-SOMA` model, this option will appear. It allows you to use the SOMA body layer to skin the character instead of using the SOMA rig. For details on the difference between the two, see the [Skeletons page](../key_concepts/skeleton.md#soma-default).\n- **Seed**: random seed for repeatable generation\n- **Denoising steps**: number of steps to use with DDIM\n- **CFG Text/Constraint Weight**: the weights to use for classifier-free guidance\n- **Post-Processing**: whether to use foot skate cleanup and constraint post-optimization to improve motion after generation\n    - **Root Margin**: if the skeleton root deviates more than this margin from a constraint, the post-processing will fix it\n"
  },
  {
    "path": "docs/source/interactive_demo/index.md",
    "content": "# Interactive Demo\n\nThe web-based interactive demo provides an intuitive interface for generating motions with any of the Kimodo model variations.\n\n![Demo Interface](../_static/overview.png)\n*Interactive demo interface build with [Viser](https://github.com/viser-project/viser)*\n\n```{note}\nTo see the demo in action, follow the [setup instructions](launching.md) below and launch it locally. After launching, open the demo in a web browser at http://127.0.0.1:7860 or use port forwarding if running on a server.\n```\n\nThe demo provides a timeline-based interface for composing text prompts and\nconstraints, with real-time 3D visualization. Here are some key features:\n\n- **Multiple Characters**: Supports generating with the SOMA, G1, and SMPL-X versions of Kimodo\n- **Text Prompts**: Enter one or more natural language descriptions of desired motions on the timeline\n- **Timeline Editor**: Add and edit keyframes and constrained intervals on multiple constraint tracks\n- **Constraint Types**:\n  - Full-Body: Complete joint position constraints at specific frames\n  - 2D Root: Define waypoints or full paths to follow on the ground plane\n  - End-Effectors: Control hands and feet positions/rotations\n- **Constraint Editing**: Editing mode allows for re-posing of constraints or adjusting waypoints\n- **3D Visualization**: Real-time rendering of generated motions with skeleton and skinned mesh options\n- **Playback Controls**: Preview generated motions with adjustable playback speed\n- **Multiple Samples**: Generate and compare multiple motion variations\n- **Examples**: Load pre-existing examples to better understand Kimodo's capabilities\n- **Export**: Save constraints and generated motions for later use\n\n\n## Quick Links\n\n- [Starting the Demo](launching.md)\n- [UI Overview](ui_overview.md)\n- [Examples](examples.md)\n\n\n```{toctree}\n:maxdepth: 2\n:hidden:\n\nlaunching\nui_overview\nmodel_selection\nexamples\ngeneration\nconstraints\nexport_results\n```\n"
  },
  {
    "path": "docs/source/interactive_demo/launching.md",
    "content": "# Running the Demo\n\nAfter following the installation [instructions](../getting_started/installation.md), the demo can be launched with the commands below. The demo runs in the web browser at [http://localhost:7860](http://localhost:7860).\n\n</details>\n\n<details>\n<summary>If you run the demo on a server, you can use port forwarding to access it.</summary>\n\nTo access the demo's web interface when running on a remote server, set up SSH port forwarding so your web browser can reach `http://localhost:7860` as if it was local.\n\n**Option 1: Add LocalForward to your SSH config**\n\nEdit (or create) your SSH config file (typically `~/.ssh/config`):\n\n```\nHost your-server-name\n    HostName your.server.address\n    User username\n    LocalForward 7860 localhost:7860\n```\nThen connect with:\n```\nssh your-server-name\n```\n\n**Option 2: Use the SSH command-line directly**\n\nFrom your local machine, run:\n```\nssh -N -L 7860:localhost:7860 username@your.server.address\n```\nThis will forward your local port 7860 to the server's port 7860.\nAfter connecting, open [`http://localhost:7860`](http://localhost:7860) in your web browser.\n\nReplace `username` and `your.server.address` with your actual user and server info.\n\n</details>\n</br>\n\nIf you will be restarting the demo frequently, we recommend first starting the text encoder service in the background, as detailed in the [quick start guide](../getting_started/quick_start.md#run-text-encoder-service). If the text encoder service is not running, the demo will automatically load the text encoder model.\n\nThe demo will also automatically download the Kimodo model checkpoint on launch and whenever needed when the model preference is changed in the UI.\n\n## Launch from Command Line\nIf you installed Kimodo as a package or from source, the demo can be started with:\n```bash\nkimodo_demo\n```\n\n## Launch with Docker\nIf you installed with Docker, you can start the demo with:\n```bash\ndocker compose up demo\n```\n\n<details>\n<summary>Additional Tips for Docker</summary>\n\nYou may find the following commands useful if running Kimodo within the Docker containers. In the example commands below, you can also replace `demo` by `text-encoder`:\n\n**Check logs:**\n\n```bash\ndocker compose logs demo\n```\n\n**Stop service:**\n\n```bash\ndocker compose stop demo\n```\n\n**Restart service:**\n\n```bash\ndocker compose restart demo\n```\n\n**Stop and remove everything:**\n\n```bash\ndocker compose down\n```\n"
  },
  {
    "path": "docs/source/interactive_demo/model_selection.md",
    "content": "# Model Selection\n\nModel selection allows choosing between the Kimodo models detailed in the [quick start guide](../getting_started/quick_start.md#overview-kimodo-models).\n\nThe models determine which character is loaded in the scene and the possible export options.\n\n- **SOMA**: default human skeleton\n- **G1**: MuJoCo-compatible exports\n- **SMPL-X**: SMPL-X compatible outputs\n\nFor details on each skeleton, see [Skeletons](../key_concepts/skeleton.md).\n\n<img src=\"../_static/demo/model_selection.png\" alt=\"Model selection UI\" width=\"60%\">\n\n\n![Skeleton overview](../_static/skeletons/skeletons.png)\n"
  },
  {
    "path": "docs/source/interactive_demo/ui_overview.md",
    "content": "# UI Overview\n\nThis page gives an overview of each of the main elements of the demo UI and how to use them.\n\n![Demo Interface](../_static/overview.png)\n*An example scene within the demo webapp*\n\n## Viewer\n![Viewer](../_static/demo/viewer.png)\n\nThe 3D viewer shows the currently generated motion. It supports skeleton or skinned mesh rendering, which is configurable in the \"Visualize\" panel.\n\n### Camera\n- **Left-drag**: rotate\n- **Right-drag**: pan\n- **Scroll**: zoom\n\n### Playback\n- **Space** to play/pause\n- **←/→** to step frames, or click the frame number.\n\n## Timeline\n\n![Timeline](../_static/demo/timeline.png)\n\nThe timeline is where you:\n\n- add, edit, and delete **prompt segments**\n- add and delete **constraints** at frames or intervals and adjust timing\n\n### Timeline Navigation\n- **Scroll Up/Down** in the timeline: move left/right\n- **Shift + Scroll** in the timeline: zoom in/out\n\n### Prompts\n- **Double-Click** a text prompt to edit the text\n- **Click and Drag** the right edge of a prompt box to extend/shorten it (2-10 sec)\n- **Click Empty Space** to add a prompt\n- **Right-Click** a prompt to delete it\n\n### Constraints\nConstraints can be added after generating for the first time when there is an active motion in the viewer:\n- **Click** in the timeline tracks (Full-Body / 2D root etc) to add a constraint of that type using the pose at that frame\n- **Ctrl/Cmd + Click + Drag** to add an interval constraint, or expand a keyframe into an interval\n- **Click + Drag** an existing constraint to move it to a different frame\n- **Right-Click** on a constraint to delete it\n- To **edit** a constraint:\n    - Move playback to the target frame\n    - Click **Enter Editing Mode** in the Constraints tab of the Settings Panel. Note you must exit editing mode before generating again.\n\n\n## Settings Panel\n![Panel](../_static/demo/panel.png)\n\nThe settings panel includes:\n- model selection\n- loading examples\n- model parameter selection for generation and post-processing\n- parameters for constraint editing\n- motion loading and saving\n- visualization options\n\nImportant settings panels are individually explained on subsequent pages.\n"
  },
  {
    "path": "docs/source/key_concepts/constraints.md",
    "content": "# Constraints\n\nConstraints are time-localized signals that steer the generated motion toward\nspecific spatial goals while keeping the rest of the motion free for the model\nto resolve. You can combine constraints with text prompts to control trajectory,\npose, and end-effectors. Constraints are most easily defined in the [interactive demo](../interactive_demo/constraints.md) and can be saved to the [JSON format](../user_guide/constraints.md).\n\n![Overview diagram of constraint types on a timeline](../_static/constraints.png)\n\n## Why Constraints?\n\nConstraints allow you to:\n\n- pin the character to a target pose or keyframe\n- guide a path on the ground while preserving natural motion\n- fix hands or feet at specific times (for example, touch or contact events)\n\n## Constraint Types\n\nKimodo is trained to excel at specific types of constraints.\n\n**Sparse root 2D waypoint**: ground-plane 2D waypoints that guide the global translation of the character. This constrains the 2D components of the smoothed root representation generated by Kimodo.\n\n**Dense root 2D path**: dense 2D path constraints that guide a continuous trajectory. This constrains the 2D components of the smoothed root representation generated by Kimodo.\n\n**Sparse full-body keyframe**: full-body pose targets at specific frames. Within the Kimodo motion representation, this constrains the smoothed root position and all body joint positions at a specific frame.\n\n**Sparse end-effector constraint**: hands or feet targets while leaving the rest of the body flexible. This constrains the smoothed root position along with the specified end-effectors. For hands, this will constrain the wrist position and rotation along with the hand end position. For feet, it constraints the heel position and rotation along with the toe position. Kimodo is trained to support arbitrary subsets of end-effectors.\n\n**Foot contacts**: toe/heel contact patterns. While the model is trained to support this, it is not currently implemented in the demo UI or Python API.\n\n\n```{note}\nFor SOMA models, constraints may be authored or displayed on the full `somaskel77` skeleton, but Kimodo converts them to the reduced `somaskel30` representation before passing them to the model. See the [skeleton](./skeleton.md) section for more details.\n```\n\n## Coordinate Space\n\nAll constraint values are in a **Y-up** coordinate system with units in **meters**. The model expects constraints relative to a canonical origin where the root starts at XZ = (0, 0) at frame 0. The initial heading can be set via the `first_heading_angle` generation parameter (defaults to 0, facing +Z). See the [constraints JSON format](../user_guide/constraints.md#coordinate-space-and-units) for full details on each field.\n\n## Time and Scope\n\nIn our CLI and demo, constraints can be defined at:\n\n- **Single frames**: keyframe-style constraints\n- **Intervals**: guidance across a range of frames\n\nHowever, as described above, the model is trained to excel mostly at sparse keyframes, with dense keyframes usually only seen for root paths. See [best practices](./limitations.md) for more details.\n\n## Post-Processing\n\nSince it is very challenging for a neural network to strictly adhere to constraints, the demo and CLI support motion post-processing to ensure motion _exactly_ hits constraints. This is done through a lightweight optimization that smoothly adjusts joints while minimizing changes in acceleration and velocity.\n"
  },
  {
    "path": "docs/source/key_concepts/limitations.md",
    "content": "# Best Practices\n\nOn this page, we summarize the best approaches to maximize Kimodo's capabilities in terms of prompting and constraints, and also summarize known limitations and failure cases. For additional context, please see the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf).\n\n## Text Prompting\n- For best results, begin each prompt with \"A person...\" (e.g., \"A person walks forward\" or \"A person jumps and waves\"). This phrasing helps clarify the subject and intent of the motion, and is more closely aligned with the style of prompts used in the training data. The subject can also be stylized to better describe the motion such as \"An old person...\" or \"A drunk person...\"\n- Keep each prompt focused one or at most two behaviors. For long sequences of action, split them into multiple prompts and generate in sequence.\n- It's best to use a medium level of detail when describing a motion. Prompts like \"A person walks.\" are too short and vague, while very long prompts describing detailed motion of each body part will be too much for the model to handle. Most training data is a middleground between these two. We recommend looking at the prompts in the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed) to get an idea of prompt granularity.\n- Kimodo is trained on a specific set of human behaviors. The training data tends to cover locomotion, gestures, everyday activities, common object interactions, videogame combat, dancing, and various styles including tired, angry, happy, sad, scared, drunk, injured, stealthy, old, and childlike. Prompts for actions outside of these categories will likely give bad results. For example \"A baseball player walks up to the plate and swings a bat\" is not good, becuase Kimodo has not trained on baseball data.\n- When using multiple prompts (e.g., in the timeline UI), make sure each prompt has enough information on its own. For example, if prompt 1 is \"A person is walking while carrying an object\", then prompt 2 could be \"A person walking carrying an object comes to a stop\". If prompt 2 were instead \"Then the person stops\", the model will not have enough context for what happened previously and may generat poor quality motions.\n\n## Constraints\n- Avoid using constraints that contradict the given text prompt or other types of constraints. If you are having trouble with a tradeoff between constraint and text accuracy, try adjusting the [classifier-free guidance weights](../user_guide/configuration.md).\n- Except for dense 2d root paths, Kimodo is mainly trained to handle sparse temporal constraints. Kimodo will perform best when the number of constraints per constraint type is less than 20 keyframes.\n- When foot contact accuracy and hitting constraints is high priority, make sure to enable [post-processing](./constraints.md#post-processing).\n\n## Limitations\n- **Motion length:** Maximum generated motion duration is 10 sec per prompt\n- **Number of constraints:** The number of constrained frames per constraint type should be less than 20 (excluding the root path constraint)\n- **Overly long or complex prompts** can blur motion intent, especially when many distinct actions are packed into a single prompt.\n- **Conflicting constraints:** can lead to artifacts or constraints that are ignored\n- **Multi-prompt sequences**: When generating motions with a sequence of prompts, each motion is generated one at a time. The second motion is conditioned on the last frames of the first, so the transition between prompts actually happens at the start of the second motion. This means the second prompt must devote some of its duration to performing a smooth transition, which may reduce the time available to realize the new prompt content fully.\n- **Post-processing**: The model by itself can generate foot skating and will not exactly hit constraints. Post-processing helps with this, but currently does not work well for the G1 robot skeleton.\n"
  },
  {
    "path": "docs/source/key_concepts/model.md",
    "content": "# Model Overview\n\nAt a glance:\n- Input: text prompt + optional constraints.\n- Output: full-body motion sequence\n- Core Idea: denoise motion features with a two-stage transformer at each step.\n\nKimodo is an explicit motion diffusion model that generates 3D human motion by denoising a sequence of skeleton poses. The model operates on a carefully designed motion representation that enables precise control over generated motion while minimizing common artifacts, such as floating and foot skating. The motion representation features a smoothed root that emulates paths drawn in practical animation tools, along with global joint rotations and positions amenable to sparse keyframe constraints.\n\nFor full details, see the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf)\n\n![Kimodo model architecture](../_static/arch.png)\n\n## Diffusion Process\n\nAt each step of the denoising process, the model takes in an embedding of the text prompt, a set of kinematic constraints, and the current noisy motion. Constraints are specified using the same motion representation as the input motion, and are used to overwrite the corresponding values in the noisy motion. Additionally, a mask indicating which elements are constrained is concatentated to the input motion. The goal is to predict a clean version of the input motion.\n\n## Two-Stage Transformer Denoiser\n\nGiven these inputs, the two-stage transformer denoiser predicts a clean motion that aligns with the text and constraints. The two-stage denoiser decomposes root and body motion prediction: the root denoiser first predicts global root motion, which is transformed into a local representation as input to the body denoiser. The final output is the concatenation of the two stages.\n\n## Training Dataset\n\nA key component to effectively train Kimodo is the [Bones Rigplay](https://bones.studio/ai-datasets/) dataset, a large studio mocap dataset containing over 700 hours of production-quality human motion with corresponding text descriptions. The data covers locomotion, gestures, everyday activities, common object interactions, videogame combat, dancing, and various styles including tired, angry, happy, sad, scared, drunk, injured, stealthy, old, and childlike.\n"
  },
  {
    "path": "docs/source/key_concepts/motion_representation.md",
    "content": "# Motion Representation\n\nKimodo uses a motion representation that combines a smoothed root representation with global joint positions, rotations, and various auxiliary features.\nFor full details, please refer to the [tech report](https://research.nvidia.com/labs/sil/projects/kimodo/assets/kimodo_tech_report.pdf).\n\nThe representation is implemented in `kimodo/motion_rep/reps/kimodo_motionrep.py` and allows easily going to and from this feature representation.\n\n## Coordinate System\n\nAll motion features use a right-handed coordinate system with:\n\n- **Y up**\n- **+Z forward**\n\n## Smoothed Root Representation\n\nWe use a smoothed root trajectory for the global root position to make\npath-following constraints more natural and controllable. Smoothing removes\nhigh-frequency pelvis jitter while preserving overall motion direction, so\n2D waypoints or paths drawn by users remain clean and easy to match during\ngeneration, while the pelvis can still move naturally around the smoothed\ncurve.\n\n![Comparison of smoothed root rep](../_static/smoothed_root.png)\n\n## Pose Feature\n\nAt each frame, the pose feature vector is the concatenation of:\n\n- **Smooth root position** (`smooth_root_pos`, 3): Smoothed pelvis/root position.\n  The x/z components track ground-plane motion and y stores height.\n- **Global root heading** (`global_root_heading`, 2): `[cos(theta), sin(theta)]`\n  heading direction of the root.\n- **Local joint positions** (`local_joints_positions`, `J x 3`): Joint positions\n  in a pelvis-relative space with the smoothed root x/z offset applied.\n- **Global joint rotations** (`global_rot_data`, `J x 6`): 6D rotation\n  representation of each joint's global orientation.\n- **Joint velocities** (`velocities`, `J x 3`): Global joint velocities.\n- **Foot contacts** (`foot_contacts`, 4): Binary contact indicators for the\n  left/right foot contact points.\n"
  },
  {
    "path": "docs/source/key_concepts/skeleton.md",
    "content": "# Skeletons\n\nDifferent versions of Kimodo support different skeletons (character). A separate model is trained for each skeleton, with the\ncurrently available options being [SOMA](https://github.com/NVlabs/SOMA-X), [G1](https://github.com/unitreerobotics/unitree_mujoco/tree/main/unitree_robots/g1), and [SMPL-X](https://github.com/vchoutas/smplx).\n\nThe skeletons discussed on this page are defined in `kimodo/skeleton/definitions.py`.\n\n![Skeleton overview](../_static/skeletons/skeletons.png)\n\n## SOMA (default)\n\nSOMA is the default skeleton used for Kimodo. It it based on the [SOMA body model](https://github.com/NVlabs/SOMA-X), which is also used in the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed).\nKimodo uses two closely related SOMA skeleton definitions:\n\n- **`somaskel30`**: the reduced 30-joint skeleton used internally by the model and by the core SOMA constraint formulation. It removes most finger and hand detail.\n- **`somaskel77`**: the full 77-joint SOMA skeleton used for public-facing visualization and SOMA motion exports.\n\nIn practice, Kimodo predicts SOMA motions on `somaskel30` and converts them to `somaskel77` when returning or visualizing results in the demo. Older assets and examples may still be stored on `somaskel30`, and the tooling keeps backward compatibility with those files.\n\nNote that all training data for Kimodo is on a uniform skeleton proportion corresponding to one single set of identity parameters for the SOMA body model.\n\n![\"SOMA skeletons\"](../_static/skeletons/soma_skels.png)\n\nOutputs on the SOMA skeleton can be visualized in two ways. The first is by articulating a fixed SOMA rig and doing traditional skinning (corresponds to `kimodo/viz/soma_skin.py` in the codebase).\nAlternatively, we can take generated joint rotations and feed them through the SOMA layer with the set of identity parameters that correspond to the body shape of our uniform skeleton. An example of this in the codebase at `kimodo/viz/soma_layer_skin.py`, which uses the identity parameters defined from `kimodo/assets/skeletons/somaskel30/soma_base_fit_mhr_params.npz` (the same ones from BONES-SEED data).\n\nDue to peculiarities with data processing, using the SOMA rig and SOMA layer give very slightly different results in visualization, with the SOMA rig better reflecting the data that Kimodo was trained on.\n\n## Unitree G1\n\nThe G1 skeleton targets MuJoCo-compatible exports and robotics workflows.\nThe version that Kimodo uses is a 34-joint skeleton, with extra joints added for the toes to ease learning. When generated motions are exported to the MuJoCo `qpos` CSV format, these joints are removed to be compatible with downstream applications.\n\n<img src=\"../_static/skeletons/g1.png\" alt=\"G1 skeleton\" width=\"60%\">\n\n## SMPL-X\n\nThis aligns with the SMPL-X model and supports AMASS-style exports. It uses 22 joints corresponding to only the body joints. This option is useful for compatibility with SMPL-X pipelines or downstream tools expecting AMASS parameters, but it is **not** the recommended Kimodo model to use since generated motions may display particularly severe retargeting artifacts.\n\n<img src=\"../_static/skeletons/smplx.png\" alt=\"SMPL-X skeleton\" width=\"60%\">\n"
  },
  {
    "path": "docs/source/project_info.md",
    "content": "# Project Information\n\n## Citation\n\nIf you use this code in your research, please cite:\n\n```bibtex\n@article{Kimodo2026,\n  title={Kimodo: Scaling Controllable Human Motion Generation},\n  author={Rempe, Davis and Petrovich, Mathis and Yuan, Ye and Zhang, Haotian and Peng, Xue Bin and Jiang, Yifeng and Wang, Tingwu and Iqbal, Umar and Minor, David and de Ruyter, Michael and Li, Jiefeng and Tessler, Chen and Lim, Edy and Jeong, Eugene and Wu, Sam and Hassani, Ehsan and Huang, Michael and Yu, Jin-Bey and Chung, Chaeyeon and Song, Lina and Dionne, Olivier and Kautz, Jan and Yuen, Simon and Fidler, Sanja},\n  journal={arXiv:2603.15546},\n  year={2026}\n}\n```\n\n## License\n\nThe codebase is licensed under Apache-2.0. Please see the codebase for full license text. Note that model checkpoints are licensed separately as indicated on the HuggingFace download pages.\n\n## Acknowledgments\n\nThis project builds upon several excellent open-source projects:\n\n- [Viser](https://github.com/nerfstudio-project/viser) for 3D visualization\n- [LLM2Vec](https://github.com/McGill-NLP/llm2vec) for text encoding\n\n## Contact\n\nFor questions or issues, plese open an issue on this repository or reach out directly to the authors.\n"
  },
  {
    "path": "docs/source/project_structure.md",
    "content": "# Project Structure\n\n```text\nkimodo/\n├── kimodo/                       # Main Python package\n│   ├── model/                    # Model architecture and loading\n│   │   ├── kimodo_model.py       # Kimodo diffusion model wrapper\n│   │   ├── twostage_denoiser.py  # Two-stage denoising architecture\n│   │   ├── backbone.py           # Transformer encoder backbone\n│   │   ├── diffusion.py          # Diffusion process\n│   │   ├── cfg.py                # Classifier-free guidance\n│   │   ├── common.py             # Shared model utilities\n│   │   ├── load_model.py         # Model loading and registry lookup\n│   │   ├── loading.py            # Checkpoint loading utilities\n│   │   ├── registry.py           # Model registry (skeleton, checkpoint URLs)\n│   │   ├── text_encoder_api.py   # Text encoder API client\n│   │   ├── tmr.py                # TMR compatibility\n│   │   └── llm2vec/              # LLM-based text encoder\n│   ├── motion_rep/               # Motion representation\n│   │   ├── reps/                 # Skeleton-specific motion reps\n│   │   │   ├── base.py           # Base motion rep types\n│   │   │   ├── kimodo_motionrep.py\n│   │   │   └── tmr_motionrep.py\n│   │   ├── conditioning.py       # Conditioning (text, constraints)\n│   │   ├── feature_utils.py      # Feature extraction\n│   │   ├── feet.py               # Foot contact / smoothing\n│   │   ├── smooth_root.py        # Smooth root representation\n│   │   └── stats.py              # Normalization statistics\n│   ├── skeleton/                 # Skeleton definitions and kinematics\n│   │   ├── definitions.py        # Skeleton topology (joints, chains)\n│   │   ├── registry.py           # Skeleton registry\n│   │   ├── base.py               # Base skeleton types\n│   │   ├── kinematics.py         # Forward kinematics\n│   │   ├── transforms.py         # Rotation/transform utilities\n│   │   └── bvh.py                # BVH I/O\n│   ├── viz/                      # Visualization\n│   │   ├── scene.py              # 3D scene setup\n│   │   ├── playback.py           # Timeline / motion playback\n│   │   ├── viser_utils.py        # Viser 3D helpers\n│   │   ├── gui.py                # Demo GUI components\n│   │   ├── constraint_ui.py      # Constraint editing UI\n│   │   ├── coords.py             # Coordinate frames\n│   │   ├── soma_skin.py          # SOMA character skinning\n│   │   ├── soma_layer_skin.py    # SOMA layer-based skinning\n│   │   ├── smplx_skin.py         # SMPL-X skinning\n│   │   └── g1_rig.py             # G1 robot rig\n│   ├── demo/                     # Interactive web demo\n│   │   ├── app.py                # Demo entry (Gradio / Viser)\n│   │   ├── config.py             # Demo configuration\n│   │   ├── state.py              # Application state\n│   │   ├── ui.py                 # UI layout and callbacks\n│   │   ├── generation.py         # Generation pipeline for demo\n│   │   ├── embedding_cache.py    # Cached text embeddings\n│   │   ├── queue_manager.py      # Request queue for demo\n│   │   └── __main__.py           # Demo run as module\n│   ├── exports/                  # Motion I/O and format conversion\n│   │   ├── motion_io.py          # Kimodo motion dict helpers (load, save, resample)\n│   │   ├── motion_convert_lib.py # Library API for format conversion\n│   │   ├── motion_formats.py     # Format detection and FPS resolution\n│   │   ├── bvh.py                # SOMA BVH read/write\n│   │   ├── mujoco.py             # G1 MuJoCo qpos conversion\n│   │   └── smplx.py              # AMASS / SMPL-X conversion\n│   ├── metrics/                  # Evaluation metrics\n│   │   ├── base.py               # Metric base classes\n│   │   ├── foot_skate.py         # Foot skate metrics\n│   │   ├── constraints.py        # Constraint metrics\n│   │   └── tmr.py                # TMR-based metrics\n│   ├── scripts/                  # CLI and helper scripts\n│   │   ├── generate.py           # CLI for motion synthesis (kimodo_gen)\n│   │   ├── motion_convert.py     # CLI for format conversion (kimodo_convert)\n│   │   ├── run_text_encoder_server.py  # Text encoder server (kimodo_textencoder)\n│   │   ├── gradio_theme.py       # Gradio theme for demo\n│   │   ├── lock_requirements.py  # Dependency locking\n│   │   └── mujoco_load.py        # MuJoCo g1 csv loading\n│   ├── assets/                   # Package data (shipped with package)\n│   │   ├── demo/                 # Demo examples and config\n│   │   └── skeletons/            # Skeleton assets\n│   ├── constraints.py            # Constraint definitions and handling\n│   ├── geometry.py               # Geometric utilities\n│   ├── postprocess.py            # Post-processing (e.g. MotionCorrection)\n│   ├── meta.py                   # Motion metadata\n│   ├── sanitize.py               # Input sanitization\n│   ├── assets.py                 # Asset path resolution\n│   └── tools.py                  # General utilities\n├── benchmark/                    # Evaluation pipeline scripts\n│   ├── create_benchmark.py       # Step 1: Build test suite from SEED + metadata\n│   ├── generate_eval.py          # Step 2: Generate motions for test suite\n│   ├── embed_folder.py           # Step 3: Embed motions and text with TMR\n│   ├── evaluate_folder.py        # Step 4: Compute metrics for test cases\n│   └── parse_folder.py           # Step 5: Aggregate and display results\n├── MotionCorrection/             # Optional C++/Python post-processing\n│   ├── python/motion_correction/ # Python bindings\n│   └── src/cpp/                  # C++ implementation\n├── docs/                         # Documentation (Sphinx)\n│   └── source/                   # RST/MD sources\n├── assets/                       # Repo-level assets (banner, screenshots)\n├── pyproject.toml                # Package config and entry points\n├── setup.py                      # Setuptools entry (if needed)\n├── Dockerfile                    # Container image for demo\n├── docker-compose.yaml           # Docker Compose for demo + text encoder\n└── README.md\n```\n\nEntry points (from `pyproject.toml`):\n\n- **`kimodo_gen`** — command-line motion synthesis (`kimodo.scripts.generate:main`)\n- **`kimodo_demo`** — interactive web demo (`kimodo.demo:main`)\n- **`kimodo_convert`** — motion format conversion (`kimodo.scripts.motion_convert:main`)\n- **`kimodo_textencoder`** — text encoder server (`kimodo.scripts.run_text_encoder_server:main`)\n"
  },
  {
    "path": "docs/source/user_guide/cli.md",
    "content": "# Command-Line Interface\n\nThe primary CLI entrypoint is the `kimodo_gen` command. This corresponds to the script located in `kimodo/scripts/generate.py`, therefore you can equivalently use `python -m kimodo.scripts.generate`.\n\n**Docker Usage**: If you set up Kimodo with Docker, you can instead run generation inside the Docker container, replacing `kimodo_gen XXX` with `docker compose run --rm demo kimodo_gen XXX`. If you will be running generation multiple times, it is better to start the `demo` container (e.g., in another terminal or in the background), and then run commands inside it with `docker compose exec demo kimodo_gen XXX`.\n\n**Single Prompt Generation:**\n\n```bash\nkimodo_gen \"A person walks forward.\" \\\n    --model Kimodo-SOMA-RP-v1 \\\n    --duration 5.0 \\\n    --output output\n```\n\nThe `--model` command corresponds to the Kimodo model checkpoint to generate with. By default, the `Kimodo-SOMA-RP-v1` is used if not provided. The output motion will be saved using the stem name given by `--output` in the Kimodo [output format](../user_guide/output_formats.md). If generating with a G1 or SMPL-X model, you can also save to other output formats like MuJoCo qpos CSV file and AMASS NPZ format.\n\nFor **offline conversion** between Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV after generating, use `kimodo_convert` (see [Motion format conversion](motion_convert.md)).\n\n**Multi-Prompt Generation:**\n\nGenerating from a sequence of text prompts can be achieved by using multiple sentences separated by periods with corresponding durations:\n\n```bash\nkimodo_gen \"A person walks forward. A person is walking backwards.\" \\\n    --duration \"5.0 4.0\" \\\n```\n\nThis command will use Kimodo to generate each prompt in sequence, with constraints added to the beginning of the second prompt to ensure continuity with the first generated motion. You can control how many frames are used to blend consecutive motions with the `--num_transition_frames` argument.\n\n**Single Prompt with Constraints:**\n\nGeneration can be constrained by providing a constraints JSON file (see the [Constraints Format Definition](constraints.md)).\n\n```bash\nkimodo_gen \"A person walks forward and picks something up from the ground.\" \\\n    --model Kimodo-SOMA-RP-v1 \\\n    --duration 5.0 \\\n    --constraints kimodo/assets/demo/examples/kimodo-soma-rp/03_full_body_keyframes/constraints.json\n```\n\nConstraint files can be created and saved from the interactive demo or manually defined following\nthe [constraints format guide](constraints.md).\n\n## Output Formats\n\nFor full details on output formats, see [this page](output_formats.md).\n\nTo convert between these formats offline, see [Motion format conversion](motion_convert.md) (`kimodo_convert`).\n\nCLI generation uses a single **output stem** (`--output`) for all formats (NPZ, AMASS NPZ, CSV, and BVH). It can write either **one file** or **a folder of files**, depending on the number of samples:\n\n- **One sample** (`--num_samples 1`): writes a single file per format at the stem (e.g. `--output test` → `test.npz`, `test.csv`). No folder is created. For SMPLX, AMASS is written to `test_amass.npz`.\n- **Multiple samples**: creates a folder with that stem and writes one file per sample with suffixes `_00`, `_01`, etc. (e.g. `--output test` → `test/test_00.npz`, ...).\n\nUse the `--bvh` flag to also export BVH (SOMA only) to the same stem.\n\n### Output Rest Pose\n\nFor SOMA-based Kimodo models, motions can be exported with respect to two different rest poses. The default rest pose, that is always used by the `NPZ` format, is a standard T-pose consistent with the canonical T-pose of the SOMA model. For `BVH` outputs, the default rest pose is a non-standard pose, but it is consistent with the BVH format of the [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed). To output a `BVH` file with the standard T-pose as the rest pose, you can use the `--bvh_standard_tpose` option.\n\nThe standard T-pose used by Kimodo is available as a BVH file in the [repo assets](https://github.com/nv-tlabs/kimodo/tree/main/kimodo/assets/skeletons/somaskel77).\n\n## Visualizing Generated Motions\n\nMotions generated with the CLI can be visualized in the demo UI. To do this, under \"Load/Save\" > \"Motion\", type in the path of the generated output npz file, then click \"Load Motion\" to load it into the viewer. If you used constraints when generating, those can also be loaded in in a similar way.\n\n## Arguments\n\nTo see all available flags, run `kimodo_gen --help`. They are:\n\n- `prompt`: Text description of the desired motion (required)\n- `--model`: Model name to use (default: `Kimodo-SOMA-RP-v1`; options are the models in [this table](../getting_started/quick_start.md#overview-kimodo-models))\n- `--duration`: Motion duration in seconds (default: `5.0`). For multiple prompts,\n  pass space-separated durations in a string.\n- `--diffusion_steps`: Number of denoising steps (default: `100`)\n- `--num_samples`: Number of motion variations to generate (default: `1`)\n- `--num_transition_frames`: Frames used to blend between prompts (default: `5`)\n- `--constraints`: Path to a JSON file containing constraints\n- `--output`: Output stem name (default: `output`). Used for all formats (NPZ, AMASS NPZ, CSV, BVH). With one sample, writes a single file per format (e.g. `test.npz`, `test.csv`). With multiple samples, creates a folder and writes `test_00.npz`, `test_01.npz`, … inside it. For SMPLX with one sample, AMASS is written to `stem_amass.npz` so it does not overwrite the main NPZ.\n- `--save_example_dir`: If given, saves outputs to an \"example\" directory structure that can be loaded in the Kimodo demo.\n- `--bvh`: Optional flag. When set, also export BVH (SOMA models only) using the same stem as `--output`.\n- `--bvh_standard_tpose`: If exporting BVH, export with the rest pose being the standard T-pose rather than the rest pose consistent with the BONES-SEED dataset.\n- `--seed`: Seed for reproducible results\n- `--no-postprocess`: Disable post-processing (includes foot skate cleanup and constraint optimization)\n- `--input_folder`: Folder containing meta.json and optional constraints.json. If set, generation settings are loaded from meta.json. These are found in demo example folders.\n- `--cfg_type`: Classifier-free guidance mode: `nocfg`, `regular`, or `separated` (the custom mode with independent text and constraint scales). See {ref}`Classifier-free guidance (details) <classifier-free-guidance-cfg>` below.\n- `--cfg_weight`: One float for `regular` CFG, or two floats `[text_weight, constraint_weight]` for `separated` CFG. If you pass only weights (no `--cfg_type`), one value implies `regular` and two imply `separated`. Not used with `nocfg`.\n\n:::{dropdown} Classifier-free guidance (CFG)\n:name: classifier-free-guidance-cfg\n\nThe CLI mirrors the Python API in [Generation parameters](configuration.md): Kimodo supports standard CFG (`regular`) and a **separated** variant with two scales—text vs. constraints—which is the usual setting in this project.\n\n**Rules:**\n\n- `nocfg`: no weights; do not pass `--cfg_weight`.\n- `regular`: pass exactly one value after `--cfg_weight`.\n- `separated`: pass exactly two values after `--cfg_weight`.\n\nIf you pass **`--cfg_type` or `--cfg_weight` on the command line**, those values override any `cfg` block in `meta.json` when using `--input_folder`. If you omit both flags, `meta.json` may still supply CFG via `cfg.enabled`, `cfg.text_weight`, and `cfg.constraint_weight` (same shape as the interactive demo examples). If there is no CLI CFG and no `cfg` in meta, the model uses its built-in defaults.\n\nExamples:\n\n```bash\n# No classifier-free guidance\nkimodo_gen \"A person walks.\" --cfg_type nocfg\n\n# Standard CFG (single scale)\nkimodo_gen \"A person walks.\" --cfg_type regular --cfg_weight 2.5\n\n# Separated CFG (text scale, then constraint scale)\nkimodo_gen \"A person walks.\" --cfg_type separated --cfg_weight 2.0 1.5\n\n# Infer mode from arity: one float -> regular; two floats -> separated\nkimodo_gen \"A person walks.\" --cfg_weight 2.0 2.0\n```\n\n:::\n\n## Python API\nThe `kimodo/scripts/generate.py` script is a good place to start to familiarize yourself with the Python API of Kimodo if you'd like to use this directly. The full model API is detailed in the [API documentation](../api_reference/index.rst).\n\nIf you want to use kimodo in another project, you can interact with it like this:\n\n```python\nfrom kimodo import load_model\n\nmodel = load_model(\"kimodo-soma-rp\", device=\"cuda\")\noutput = model(\n    prompt=\"A person jumps\",\n    num_frames=150,\n    num_denoising_steps=100,\n)\n```\n"
  },
  {
    "path": "docs/source/user_guide/configuration.md",
    "content": "# Generation Parameters\n\nIn the demo UI, command-line tool (`kimodo_gen` / `python -m kimodo.scripts.generate`), and low-level Python API, Kimodo allows some advanced configuration for motion generation.\n\n## Classifier-Free Guidance\n\nControl the strength of text and constraint guidance:\n\n```python\noutput = model(\n    prompt=\"A person jumps\",\n    num_frames=150,\n    cfg_weight=[2.0, 2.0],  # [text_weight, constraint_weight]\n    cfg_type=\"separated\",  # Options: \"nocfg\", \"regular\", \"separated\"\n    num_denoising_steps=100,\n)\n```\n\nThese are helpful when there is a tradeoff between following the prompt and hitting constraints.\n\nThe CFG options are:\n- `cfg_type=\"nocfg\"`: No guidance (faster, less controllable)\n- `cfg_type=\"regular\"`: \"Standard\" classifier-free guidance\n    - Equation: `out_uncond + w * (out_text_and_constraint - out_uncond)`\n- `cfg_type=\"separated\"`: Separate weights for text and constraints\n    - Equation: `out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)`\n\n### CLI\n\nThe same options are available from the command line as `--cfg_type` and `--cfg_weight`. See the {ref}`CLI user guide (CFG) <classifier-free-guidance-cfg>` for examples, validation rules, and how `meta.json` interacts with explicit flags when using `--input_folder`.\n\n## Denoising Steps\nThe number of denoising steps used in DDIM sampling can be used to control the speed vs. quality trade-off:\n- Fewer steps (50-100): Faster inference, slightly lower quality\n- More steps (100-200): Higher quality, slower inference\n"
  },
  {
    "path": "docs/source/user_guide/constraints.md",
    "content": "# Constraints JSON Format\n\nThe `--constraints` flag in the CLI expects a JSON file containing a list of constraint objects.\nIt is easiest to look at the examples provided with the demo to see how these are formatted. These can be seen for various model types in `kimodo/assets/demo/examples`.\n\n> Tip: the easiest way to get a valid constraints file is to create constraints in the interactive demo and to click on `Save Constraints`.\n\n## High-Level Structure\n\n- The file is a JSON array: `[{...}, {...}, ...]`\n- Each element is an object with at least:\n  - `type` (string)\n    - `root2d`, `fullbody`, `left-hand`, `right-hand`, `left-foot`, `right-foot`, `end-effector`\n  - `frame_indices` (array of integers): 0-based frame indices within the generated clip.\n\n\n```{note}\nFor SOMA models, constraints may be authored or displayed on the full `somaskel77` skeleton, but Kimodo converts them to the reduced `somaskel30` representation before passing them to the model. See the [skeleton](../key_concepts/skeleton.md) section for more details.\n```\n\n## Coordinate Space and Units\n\nAll spatial values in constraints use the same coordinate system as Kimodo's internal motion representation:\n\n- **Axes**: **Y-up**, with locomotion on the **XZ ground plane**. The Y axis points up, X and Z span the horizontal ground plane.\n- **Units**: **Meters**. Joint positions, root translations, and 2D root coordinates are all in meters.\n\n### Canonicalization\n\nDuring training, every motion is *canonicalized* so that the (smoothed) root starts at the XZ origin `(0, 0)` at frame 0.\nThe initial body heading (facing direction) is randomly rotated and passed to the model as an explicit input (`first_heading_angle`), so the model is robust to arbitrary initial orientations.\n\nAt inference, constraints should be authored **relative to this canonical origin**:\n- `smooth_root_2d` values at frame 0 should be at `(0, 0)`, with subsequent frames expressing displacement from there.\n- `root_positions` XZ components follow the same convention; Y is the **absolute hip height above the ground** (typically ~0.9 m for a standing pose, lower for crouching/sitting).\n- `first_heading_angle` (a generation parameter, not part of the constraints JSON) defaults to `0.0` radians (facing +Z) but can be set to any value to change the initial facing direction.\n\n### Field-specific notes\n\n| Field | Space | Notes |\n|-------|-------|-------|\n| `smooth_root_2d` | `[x, z]` ground plane (meters) | Relative to the canonical origin. |\n| `root_positions` | `[x, y, z]` (meters) | Y is absolute hip height above ground. XZ relative to canonical origin. |\n| `global_root_heading` | `[cos(θ), sin(θ)]` | **Not** a raw radian value — must be a 2-element cosine/sine pair per frame (i.e. the heading direction vector). |\n| `local_joints_rot` | axis-angle (radians) | Local joint rotations in the skeleton's rest-pose frame. |\n\n### Constraints not at frame 0\n\nAdding a constraint at frame 0 is **not** required. If the first constrained frame is later in the sequence (e.g. frame 45), Kimodo generates the initial frames freely from its learned distribution, starting near XZ = (0, 0) with the heading set by `first_heading_angle`. The constraint just needs to be reachable from that starting configuration given the text prompt and motion duration.\n\n## Constraint Types\nDepending on `type`, additional fields are required or optional. All numeric arrays are plain nested JSON lists. In the following definitions `T` is the number of constrainted frames (i.e., number of `frame_indices`) and `J` is the number of skeleton joints.\n\n\n### `root2d`\nThis captures 2D root waypoints and 2D root paths. It requires:\n\n- `smooth_root_2d` (array shapes `[T, 2]`): Smoothed root positions `[x, z]` on the ground plane at the given `frame_indices`.\n\nand optionally:\n- `global_root_heading` (array shapes `[T, 2]`): Global root heading direction `[cos, sin]` at the given `frame_indices`.\n\n### `fullbody`\nThis captures full-body keyframe constraints on joint positions. It includes:\n\n- `local_joints_rot` (array shaped `[T, J, 3]`): Per-frame per-joint **axis-angle** local rotations (radians). Constraint joint positions will be derived from these.\n- `root_positions` (array shaped `[T, 3]`): Root (hips) translation `[x, y, z]`.\n- `smooth_root_2d` (optional; array of `[T, 2]`): Smoothed root positions `[x, z]`. If omitted, it is taken as the `[x, z]` components of `root_positions`.\n\nNote the `local_joint_rot` will not explicitly be constrained, the constraint will be on the joint positions that results from FK with the given joint rotations.\n\n### `left-hand` / `right-hand` / `left-foot` / `right-foot`\nCaptures end-effector constraints on the hand/feet joint positions and global rotations.\n\nThese use the same fields as `fullbody`. However, under the hood these will only affect the corresponding end-effectors and hips. Each of these types is a shorthand for `end-effector` with pre-set joint names.\n\n### `end-effector`\nA general end-effector constraint that requires an additional field:\n\n- `joint_names` (array of strings): Which end-effectors to constrain (e.g. `[\"left_hand\"]`, `[\"right_foot\", \"left_foot\"]`). Available names depend on the skeleton; see the skeleton's `expand_joint_names()` for the full mapping.\n\nOtherwise uses the same fields as `fullbody` (`local_joints_rot`, `root_positions`, optional `smooth_root_2d`).\n\n## Examples\n\n### Root 2D waypoints\n\n```json\n[\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [0, 30, 60],\n    \"smooth_root_2d\": [[0.0, 0.0], [0.5, 0.0], [1.0, 0.1]]\n  }\n]\n```\n\n### Full-body keyframe\n\n```json\n[\n  {\n    \"type\": \"fullbody\",\n    \"frame_indices\": [60],\n    \"root_positions\": [[0.0, 0.96, 1.5]],\n    \"local_joints_rot\": [[[0.0, 0.0, 0.0], \"... one [3] per joint ...\"]]\n  }\n]\n```\n\nHere `root_positions` places the hips at x=0, y=0.96 m (standing height), z=1.5 m forward from the origin. `local_joints_rot` is a `[T, J, 3]` array of axis-angle rotations for every joint in the skeleton.\n"
  },
  {
    "path": "docs/source/user_guide/motion_convert.md",
    "content": "# Motion Format Conversion\n\nThe `kimodo_convert` command converts between the formats described in [Output formats](output_formats.md): **Kimodo NPZ**, **AMASS NPZ** (SMPL-X), **SOMA BVH**, and **G1 MuJoCo CSV**.\n\n## Frame rate (30 Hz Kimodo NPZ)\n\nAny conversion **to Kimodo NPZ** (from AMASS, SOMA BVH, or G1 CSV) **writes motion at 30 Hz**, matching Kimodo’s common generation rate. If the detected source rate differs, the tool **resamples** along time, then derived channels (contacts, smooth root, heading) are recomputed via forward kinematics.\n\nIf resampling is required, a **warning** is emitted with the assumed source rate, input/output frame counts, and a reminder that `--source-fps` sets the **source** rate if autodetection is wrong. When the source is already ~30 Hz with the same frame count, no warning is shown (motion is only re-derived via FK for consistency).\n\n<details>\n<summary>Resampling strategy details</summary>\n\nThe resampler picks one of two strategies based on the ratio `source_fps / target_fps`:\n\n- **Integer-ratio fast path** — When the ratio is close to an integer ≥ 2 (within a tolerance of 0.05), the resampler simply takes every *step*-th frame (`frames[::step]`). For example, 120 Hz → 30 Hz has ratio 4, so every 4th frame is kept. This is exact and very fast.\n- **Interpolation fallback** — Otherwise, the output timeline is linearly spaced over the input range. Root positions are linearly interpolated, and local joint rotations are interpolated via quaternion slerp. This handles arbitrary rate conversions (e.g. 50 Hz → 30 Hz).\n\nIn both cases, `complete_motion_dict` is re-run at the target rate so that all derived channels (velocities, foot contacts, heading, smooth root) stay consistent with the new frame spacing.\n\n</details>\n\n## Usage\n\n```bash\nkimodo_convert INPUT OUTPUT [options]\n```\n\nFormats are inferred from file extensions and (for `.npz`) from file contents. You can override with `--from` and `--to`.\n\n### Supported conversions\n\n| From | To | Notes |\n|------|-----|--------|\n| AMASS `.npz` | Kimodo `.npz` | SMPL-X, 22 joints. Uses `--z-up` by default (same as Kimodo’s AMASS export). |\n| Kimodo `.npz` | AMASS `.npz` | Requires `local_rot_mats` with 22 joints (SMPL-X). |\n| SOMA `.bvh` | Kimodo `.npz` | Expects a **Kimodo-exported** SOMA BVH (same hierarchy as `save_motion_bvh`). If the BVH uses the standard T-pose as rest pose, pass in `--bvh_standard_tpose`. |\n| Kimodo `.npz` | SOMA `.bvh` | Accepts 77 joints (SOMA full) or 30 joints (somaskel30, auto-expanded to 77 with relaxed-hand rest poses). If you want the output BVH to use the standard T-pose as rest pose, pass in `--bvh_standard_tpose`. |\n| G1 `.csv` | Kimodo `.npz` | Rows of shape `(36,)` = root xyz + root quat + 29 joint angles (see [output_formats](output_formats.md#csv-format-for-kimodo-g1)). |\n| Kimodo `.npz` | G1 `.csv` | Requires 34 joints (G1). |\n\n### Common options\n\n- **`--source-fps`**: Source motion frame rate in Hz (used before resampling to 30 Hz for Kimodo NPZ). If omitted, the tool auto-detects from `mocap_frame_rate` (AMASS), `Frame Time` (BVH), or defaults to **30** Hz. The legacy `--fps` alias is still accepted for backward compatibility.\n- **`--no-z-up`**: For AMASS, disable the Y-up ↔ Z-up transform (treat data as already in Kimodo Y-up, +Z forward).\n- **`--mujoco-rest-zero`**: For G1 CSV, match the `mujoco_rest_zero` flag used when the CSV was written (see `MujocoQposConverter.dict_to_qpos`).\n- **`--bvh_standard_tpose`**: If input or output is BVH: the BVH file uses the standard T-pose as its rest pose instead of the BONES-SEED rest pose.\n\n### Examples\n\n```bash\n# AMASS → Kimodo NPZ\nkimodo_convert motion_amass.npz motion_kimodo.npz\n\n# Kimodo NPZ → AMASS\nkimodo_convert motion_kimodo.npz motion_out_amass.npz\n\n# Kimodo SOMA NPZ → BVH\nkimodo_convert motion_kimodo.npz motion.bvh\n\n# BVH → Kimodo NPZ\nkimodo_convert motion.bvh motion_kimodo.npz\n\n# G1 CSV → Kimodo NPZ\nkimodo_convert motion.csv motion_kimodo.npz\n\n# Kimodo G1 NPZ → CSV\nkimodo_convert motion_kimodo.npz motion.csv\n```\n\nWhen both input and output are `.npz`, the tool assumes **AMASS → Kimodo** if the input is AMASS, and **Kimodo → AMASS** if the input is already a Kimodo NPZ. Use `--from` / `--to` if you need to disambiguate.\n\n## Limitations\n\n- **BVH import** is intended for BVHs produced by Kimodo (`Root` wrapper + SOMA77 joint names) and is also compatible with the BONES-SEED dataset, which uses the same skeleton hierarchy. Arbitrary BVH files with different joint names or hierarchies may not work.\n- **G1 CSV** encodes only the degrees of freedom exposed in MuJoCo; the inverse path reconstructs local rotations from those angles (same convention as `to_qpos`).\n"
  },
  {
    "path": "docs/source/user_guide/output_formats.md",
    "content": "# Output Formats\n\n## Converting Between Formats\n\nTo convert between the formats described below, see [Motion format conversion](motion_convert.md) (`kimodo_convert`).\n\n## Kimodo NPZ Format\n\nGenerated motions are stored as NPZ files (one file per sample, e.g. `motion_00.npz`) containing:\n\n- `posed_joints`: Global joint positions `[T, J, 3]`\n- `global_rot_mats`: Global joint rotation matrices `[T, J, 3, 3]`\n- `local_rot_mats`: Local (parent-relative) joint rotation matrices `[T, J, 3, 3]`\n- `foot_contacts`: Foot contact labels [left heel, left toe, right heel, right toes] `[T, 4]`\n- `smooth_root_pos`: Smoothed root representations outputted from the model `[T, 3]`\n- `root_positions`: The (non-smoothed) trajectory of the actual root joint (e.g., pelvis) `[T, 3]`\n- `global_root_heading`: The heading direction output from the model `[T, 2]`\n\nWhere:\n\n- `T`: number of frames\n- `J`: number of joints in the exported skeleton representation (`77` for SOMA NPZ exports, `34` for G1, `22` for SMPL-X)\n\nIf multiple samples are generated, files are saved with suffixes like `_00`, `_01`, etc.\n\nFor SOMA models, the exported NPZ uses the full **`somaskel77`** skeleton even though the model itself operates internally on the reduced **`somaskel30`** skeleton. This means the saved `posed_joints`, `global_rot_mats`, and `local_rot_mats` arrays are written in the 77-joint SOMA layout. Older 30-joint SOMA NPZ files may still exist and remain loadable for backward compatibility.\n\nAlso for SOMA models, the output motion is saved such that the rest pose (i.e. zero pose) is the standard T-pose that Kimodo uses internally. This differs from the default behavior of BVH export (see below), which uses a rest pose consistent with the BONES-SEED dataset format. The standard T-pose as a BVH file is also available [in the assets of the repo](https://github.com/nv-tlabs/kimodo/tree/main/kimodo/assets/skeletons/somaskel77).\n\n## BVH Format for Kimodo-SOMA\n\nWhen using a SOMA model and passing the `--bvh` flag to CLI generation, Kimodo also writes a BVH file alongside the NPZ output.\n\n- BVH export is supported for **SOMA models only**\n- the exported hierarchy uses the full **`somaskel77`** skeleton\n- if the motion is still in internal `somaskel30` form, Kimodo converts it to `somaskel77` before writing the BVH\n- the file stores root translation plus per-joint local rotations for the clip at the generated frame rate\n- by default, the rest pose (i.e., zero pose) of the saved BVH file is consistent with the BONES-SEED dataset format. If you prefer a standard T-pose as the rest pose, pass in `--bvh_standard_tpose` when generating.\n\nThe exporter writes a standard plain-text BVH file and scales joint offsets and root motion from meters to centimeters (same format as the SEED dataset release). If multiple samples are generated, files are saved with suffixes like `_00`, `_01`, etc.\n\n## CSV Format for Kimodo-G1\n\nWhen using `Kimodo-G1` models and providing `--output` to CLI generation, the exporter writes MuJoCo `qpos`\ndata to a CSV file. Each row corresponds to a pose in the motion and contains 36 values:\n\n- Root translation `[x, y, z]`\n- Root rotation quaternion `[w, x, y, z]`\n- 29 joint 1-DoF values (in G1 joint order)\n\nThe CSV uses the MuJoCo coordinate system (z-up, +x forward). If multiple samples are generated, files are saved with suffixes like `_00`, `_01`, etc.\n\n\n## AMASS NPZ Format for Kimodo-SMPLX\n\nWhen using the `Kimodo-SMPLX-RP` model and `--output` is specified to CLI generation, the exporter writes an\nAMASS-style SMPL-X `.npz` file. Keys include:\n\n- `trans`: Root translation `[T, 3]`\n- `root_orient`: Root orientation axis-angle `[T, 3]`\n- `pose_body`: Body pose axis-angle `[T, 63]` (21 joints x 3)\n- `pose_hand`: Hand pose axis-angle `[T, 90]` (15 joints x 2 hands x 3)\n- `pose_jaw`: Jaw pose axis-angle `[T, 3]`\n- `pose_eye`: Eye pose axis-angle `[T, 6]`\n- `betas`: Shape coefficients\n- `num_betas`: Number of shape coefficients\n- `gender`: `neutral`\n- `surface_model_type`: `smplx`\n- `mocap_frame_rate`: Frame rate (fps)\n- `mocap_time_length`: Motion duration in seconds\n\nThe exporter converts from the Kimodo coordinate system (y-up, +z forward)\nto AMASS coordinates (z-up, +y forward). If multiple samples are generated, files are saved with suffixes like `_00`, `_01`, etc.\n"
  },
  {
    "path": "docs/source/user_guide/seed_dataset.md",
    "content": "# Loading BONES-SEED BVH data\n\nThe [BONES-SEED dataset](https://huggingface.co/datasets/bones-studio/seed) is a publicly available optical motion-capture dataset distributed as BVH files with the [SOMA 77-joint skeleton](../key_concepts/skeleton.md). This page walks through the steps to parse a SEED BVH file and convert it into Kimodo's internal motion representation.\n\nThis is a similar pipeline used by the benchmark to extract ground-truth motions from SEED data (see the [benchmark pipeline](../benchmark/pipeline.md)).\n\n## Step-by-Step Conversion\n\n### 1. Parse the BVH file\n\n`parse_bvh_motion` reads a BVH file and returns local joint rotation matrices, root translation (in meters), and the source frame rate.\n\n```python\nfrom kimodo.skeleton.bvh import parse_bvh_motion\n\nlocal_rot_mats, root_trans, bvh_fps = parse_bvh_motion(bvh_path)\n```\n\n### 2. Subsample to 30 FPS\n\nKimodo operates at 30 Hz. If the source BVH has a different frame rate (120 FPS for BONES-SEED), subsample by striding:\n\n```python\nfps = 30\nstep = round(bvh_fps / fps)\nroot_trans = root_trans[::step]\nlocal_rot_mats = local_rot_mats[::step]\n```\n\n### 3. Convert to the standard T-pose\n\nThe SEED BVH rest pose differs from Kimodo's canonical T-pose. The `to_standard_tpose` function remaps the local rotations accordingly and returns both local and global rotation matrices:\n\n```python\nfrom kimodo.skeleton import SOMASkeleton77\n\nskeleton = SOMASkeleton77()\nlocal_rot_mats, global_rot_mats = skeleton.to_standard_tpose(local_rot_mats)\n```\n\n### 4. Compute Kimodo motion features\n\nBuild the motion feature tensor used by the model. The feature layout is described in [Motion representation](../key_concepts/motion_representation.md).\n\n```python\nfrom kimodo.motion_rep import KimodoMotionRep\n\nmotion_rep = KimodoMotionRep(skeleton, fps)\nfeats = motion_rep(local_rot_mats, root_trans, to_normalize=False)\n```\n\n### 5. Canonicalize (optionally) and recover the motion dictionary\n\nCanonicalize so that the motion starts at the origin facing +Z, then invert the features back into a full motion dictionary:\n\n```python\ncan_feats = motion_rep.canonicalize(feats)\nmotion_dict = motion_rep.inverse(can_feats, is_normalized=False)\n```\n\n`motion_dict` is a dictionary with keys such as `local_rot_mats`, `global_rot_mats`, `posed_joints`, `root_positions`, `smooth_root_pos`, `foot_contacts`, etc. See [Output formats](output_formats.md) for details on the Kimodo NPZ layout.\n\n## Full script\n\n```python\nfrom kimodo.motion_rep import KimodoMotionRep\nfrom kimodo.skeleton import SOMASkeleton77\nfrom kimodo.skeleton.bvh import parse_bvh_motion\n\n# 1. Parse BVH\nlocal_rot_mats, root_trans, bvh_fps = parse_bvh_motion(bvh_path)\n\n# 2. Subsample to 30 fps\nfps = 30\nstep = round(bvh_fps / fps)\nroot_trans = root_trans[::step]\nlocal_rot_mats = local_rot_mats[::step]\n\n# 3. Convert to standard T-pose\nskeleton = SOMASkeleton77()\nlocal_rot_mats, global_rot_mats = skeleton.to_standard_tpose(local_rot_mats)\n\n# 4. Compute motion features\nmotion_rep = KimodoMotionRep(skeleton, fps)\nfeats = motion_rep(local_rot_mats, root_trans, to_normalize=False)\n\n# 5. Canonicalize and get the full motion dictionary\ncan_feats = motion_rep.canonicalize(feats)\nmotion_dict = motion_rep.inverse(can_feats, is_normalized=False)\n```\n"
  },
  {
    "path": "kimodo/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Kimodo: text-driven and constrained motion generation model.\"\"\"\n\nfrom .model.load_model import AVAILABLE_MODELS, DEFAULT_MODEL, load_model\n\n__all__ = [\n    \"AVAILABLE_MODELS\",\n    \"DEFAULT_MODEL\",\n    \"load_model\",\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/01_single_text_prompt/meta.json",
    "content": "{\n  \"text\": \"A person walking forward quickly stumbles but maintains their balance\",\n  \"duration\": 5.0,\n  \"num_samples\": 1,\n  \"seed\": 43,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/02_multi_text_ee_constraint/constraints.json",
    "content": "[\n  {\n    \"type\": \"left-hand\",\n    \"frame_indices\": [\n      40,\n      155\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          -0.20672118663787842,\n          0.0004979433142580092,\n          0.010066316463053226\n        ],\n        [\n          0.0789145976305008,\n          0.0008333905134350061,\n          -5.267082087812014e-05\n        ],\n        [\n          -0.1686924546957016,\n          -0.0027884345036000013,\n          0.0520743690431118\n        ],\n        [\n          0.000989485066384077,\n          0.1385614573955536,\n          0.0005803265958093107\n        ],\n        [\n          1.0274103879928589,\n          -0.0004089517460670322,\n          0.0007986496202647686\n        ],\n        [\n          -0.39034226536750793,\n          -0.001306047779507935,\n          -4.922552761854604e-05\n        ],\n        [\n          0.0023066187277436256,\n          -0.0007853881106711924,\n          -0.0062883589416742325\n        ],\n        [\n          4.49517356173601e-05,\n          0.0033443598076701164,\n          -0.0014551420463249087\n        ],\n        [\n          0.07268467545509338,\n          -0.0011258760932832956,\n          -3.953919076593593e-05\n        ],\n        [\n          -0.1719113141298294,\n          0.018712127581238747,\n          0.06082615628838539\n        ],\n        [\n          0.0011432868195697665,\n          0.02744375728070736,\n          0.0025501118507236242\n        ],\n        [\n          0.41685307025909424,\n          -0.002692570211365819,\n          -0.0006283970433287323\n        ],\n        [\n          -0.1283608227968216,\n          0.0030534265097230673,\n          0.00016949126438703388\n        ],\n        [\n          -0.005590266548097134,\n          0.0014076301595196128,\n          -0.038615260273218155\n        ],\n        [\n          -0.00013014793512411416,\n          0.001360177993774414,\n          6.41088408883661e-05\n        ],\n        [\n          0.00010043015936389565,\n          -0.01370090153068304,\n          -0.00014910128084011376\n        ],\n        [\n          0.00023336269077844918,\n          0.0025421029422432184,\n          0.04833226650953293\n        ],\n        [\n          0.056574925780296326,\n          0.0006874562823213637,\n          0.0004548647266346961\n        ],\n        [\n          -0.37481847405433655,\n          -0.054357241839170456,\n          0.2803272306919098\n        ],\n        [\n          0.0013725318713113666,\n          0.009074348025023937,\n          -0.0021504403557628393\n        ],\n        [\n          -0.0012184121878817677,\n          -0.4267229437828064,\n          0.011203057132661343\n        ],\n        [\n          1.255251407623291,\n          0.0009449978824704885,\n          0.0010158077348023653\n        ],\n        [\n          -0.003570390399545431,\n          -0.003947308287024498,\n          -0.5030224323272705\n        ],\n        [\n          0.1881941556930542,\n          -0.000495748536195606,\n          0.0016725400928407907\n        ],\n        [\n          -0.002223622752353549,\n          0.11821465194225311,\n          0.007546884939074516\n        ],\n        [\n          -0.00137770373839885,\n          -0.0031452146358788013,\n          -0.0015015294775366783\n        ],\n        [\n          -0.3751647472381592,\n          0.05314668267965317,\n          -0.28086331486701965\n        ],\n        [\n          -0.007756246719509363,\n          -0.016310883685946465,\n          -0.02847120724618435\n        ],\n        [\n          -0.0002517815155442804,\n          0.427451491355896,\n          3.640262002591044e-05\n        ],\n        [\n          1.2455408573150635,\n          -0.0014789876295253634,\n          0.0008519256953150034\n        ],\n        [\n          0.004311776254326105,\n          0.009671058505773544,\n          0.5968337655067444\n        ],\n        [\n          0.1335560381412506,\n          0.0011528844479471445,\n          -0.0008361327927559614\n        ],\n        [\n          0.001167859067209065,\n          -0.1551152616739273,\n          0.00019725598394870758\n        ],\n        [\n          -0.0014258474111557007,\n          0.0034801543224602938,\n          0.0009809854673221707\n        ]\n      ],\n      [\n        [\n          -0.047659896314144135,\n          -0.11130385845899582,\n          -0.0020901868119835854\n        ],\n        [\n          -1.5705475807189941,\n          -0.0014125468442216516,\n          -0.0008221857133321464\n        ],\n        [\n          -0.16147980093955994,\n          0.014729475602507591,\n          0.4458121657371521\n        ],\n        [\n          -0.00045561062870547175,\n          -0.1160486489534378,\n          -0.006125911604613066\n        ],\n        [\n          2.811251401901245,\n          0.0016747766640037298,\n          -0.005349006038159132\n        ],\n        [\n          -0.8591147065162659,\n          0.0037903853226453066,\n          0.00048354381578974426\n        ],\n        [\n          0.006445891689509153,\n          -0.0036706889513880014,\n          -0.03472399711608887\n        ],\n        [\n          -0.001481462037190795,\n          0.0015367366140708327,\n          -0.0015593112912029028\n        ],\n        [\n          -1.5751848220825195,\n          0.001112997648306191,\n          0.0009848373010754585\n        ],\n        [\n          -0.16862420737743378,\n          -0.016877643764019012,\n          -0.26229384541511536\n        ],\n        [\n          -9.055795817403123e-05,\n          0.09453120082616806,\n          -0.0134742371737957\n        ],\n        [\n          2.811314344406128,\n          0.003919574897736311,\n          0.005575981922447681\n        ],\n        [\n          -0.8299098014831543,\n          -0.003791244002059102,\n          0.0012802339624613523\n        ],\n        [\n          0.005852710455656052,\n          0.005849692039191723,\n          0.1632416546344757\n        ],\n        [\n          -0.0015579514438286424,\n          9.288851288147271e-05,\n          0.001196552417241037\n        ],\n        [\n          0.00043879495933651924,\n          0.04429133981466293,\n          0.0002551022043917328\n        ],\n        [\n          -0.0019886596128344536,\n          0.008745947852730751,\n          -0.00962099153548479\n        ],\n        [\n          0.5197923183441162,\n          -0.0010678194230422378,\n          0.0002590256044641137\n        ],\n        [\n          -0.9051622152328491,\n          -0.12138096988201141,\n          0.25749173760414124\n        ],\n        [\n          0.010689850896596909,\n          -0.01072163600474596,\n          0.20382197201251984\n        ],\n        [\n          -0.0009684870601631701,\n          -0.5894762873649597,\n          0.0032688004430383444\n        ],\n        [\n          1.30536949634552,\n          -0.002206705743446946,\n          -0.0020471925381571054\n        ],\n        [\n          0.0067055909894406796,\n          -0.015674468129873276,\n          -0.9086763262748718\n        ],\n        [\n          -0.26612186431884766,\n          -0.00016191616305150092,\n          0.002851327648386359\n        ],\n        [\n          0.003539646975696087,\n          0.20451955497264862,\n          -0.02575569413602352\n        ],\n        [\n          0.003367731347680092,\n          0.0018452388467267156,\n          -0.00026573429931886494\n        ],\n        [\n          -0.9464634656906128,\n          0.12737642228603363,\n          -0.2577688992023468\n        ],\n        [\n          0.00046661958913318813,\n          -0.008693858049809933,\n          -0.19606870412826538\n        ],\n        [\n          -0.0058177076280117035,\n          0.6349377036094666,\n          -0.0003108184027951211\n        ],\n        [\n          1.4694209098815918,\n          0.0046353572979569435,\n          0.002392316237092018\n        ],\n        [\n          0.022281549870967865,\n          0.006433307193219662,\n          1.1441218852996826\n        ],\n        [\n          -0.16217999160289764,\n          -0.0005673008854500949,\n          -0.0028868752997368574\n        ],\n        [\n          0.0011142585426568985,\n          0.036793302744627,\n          0.06873425096273422\n        ],\n        [\n          0.001964340452104807,\n          -0.004202086944133043,\n          0.0034294212237000465\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.014979152008891106,\n        0.7896444201469421,\n        0.8725281357765198\n      ],\n      [\n        0.12546521425247192,\n        0.30551770329475403,\n        2.3331315517425537\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.014979152008891106,\n        0.8725281357765198\n      ],\n      [\n        0.12546521425247192,\n        2.3331315517425537\n      ]\n    ]\n  },\n  {\n    \"type\": \"right-hand\",\n    \"frame_indices\": [\n      40,\n      155\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          -0.20672118663787842,\n          0.0004979433142580092,\n          0.010066316463053226\n        ],\n        [\n          0.0789145976305008,\n          0.0008333905134350061,\n          -5.267082087812014e-05\n        ],\n        [\n          -0.1686924546957016,\n          -0.0027884345036000013,\n          0.0520743690431118\n        ],\n        [\n          0.000989485066384077,\n          0.1385614573955536,\n          0.0005803265958093107\n        ],\n        [\n          1.0274103879928589,\n          -0.0004089517460670322,\n          0.0007986496202647686\n        ],\n        [\n          -0.39034226536750793,\n          -0.001306047779507935,\n          -4.922552761854604e-05\n        ],\n        [\n          0.0023066187277436256,\n          -0.0007853881106711924,\n          -0.0062883589416742325\n        ],\n        [\n          4.49517356173601e-05,\n          0.0033443598076701164,\n          -0.0014551420463249087\n        ],\n        [\n          0.07268467545509338,\n          -0.0011258760932832956,\n          -3.953919076593593e-05\n        ],\n        [\n          -0.1719113141298294,\n          0.018712127581238747,\n          0.06082615628838539\n        ],\n        [\n          0.0011432868195697665,\n          0.02744375728070736,\n          0.0025501118507236242\n        ],\n        [\n          0.41685307025909424,\n          -0.002692570211365819,\n          -0.0006283970433287323\n        ],\n        [\n          -0.1283608227968216,\n          0.0030534265097230673,\n          0.00016949126438703388\n        ],\n        [\n          -0.005590266548097134,\n          0.0014076301595196128,\n          -0.038615260273218155\n        ],\n        [\n          -0.00013014793512411416,\n          0.001360177993774414,\n          6.41088408883661e-05\n        ],\n        [\n          0.00010043015936389565,\n          -0.01370090153068304,\n          -0.00014910128084011376\n        ],\n        [\n          0.00023336269077844918,\n          0.0025421029422432184,\n          0.04833226650953293\n        ],\n        [\n          0.056574925780296326,\n          0.0006874562823213637,\n          0.0004548647266346961\n        ],\n        [\n          -0.37481847405433655,\n          -0.054357241839170456,\n          0.2803272306919098\n        ],\n        [\n          0.0013725318713113666,\n          0.009074348025023937,\n          -0.0021504403557628393\n        ],\n        [\n          -0.0012184121878817677,\n          -0.4267229437828064,\n          0.011203057132661343\n        ],\n        [\n          1.255251407623291,\n          0.0009449978824704885,\n          0.0010158077348023653\n        ],\n        [\n          -0.003570390399545431,\n          -0.003947308287024498,\n          -0.5030224323272705\n        ],\n        [\n          0.1881941556930542,\n          -0.000495748536195606,\n          0.0016725400928407907\n        ],\n        [\n          -0.002223622752353549,\n          0.11821465194225311,\n          0.007546884939074516\n        ],\n        [\n          -0.00137770373839885,\n          -0.0031452146358788013,\n          -0.0015015294775366783\n        ],\n        [\n          -0.3751647472381592,\n          0.05314668267965317,\n          -0.28086331486701965\n        ],\n        [\n          -0.007756246719509363,\n          -0.016310883685946465,\n          -0.02847120724618435\n        ],\n        [\n          -0.0002517815155442804,\n          0.427451491355896,\n          3.640262002591044e-05\n        ],\n        [\n          1.2455408573150635,\n          -0.0014789876295253634,\n          0.0008519256953150034\n        ],\n        [\n          0.004311776254326105,\n          0.009671058505773544,\n          0.5968337655067444\n        ],\n        [\n          0.1335560381412506,\n          0.0011528844479471445,\n          -0.0008361327927559614\n        ],\n        [\n          0.001167859067209065,\n          -0.1551152616739273,\n          0.00019725598394870758\n        ],\n        [\n          -0.0014258474111557007,\n          0.0034801543224602938,\n          0.0009809854673221707\n        ]\n      ],\n      [\n        [\n          -0.047659896314144135,\n          -0.11130385845899582,\n          -0.0020901868119835854\n        ],\n        [\n          -1.5705475807189941,\n          -0.0014125468442216516,\n          -0.0008221857133321464\n        ],\n        [\n          -0.16147980093955994,\n          0.014729475602507591,\n          0.4458121657371521\n        ],\n        [\n          -0.00045561062870547175,\n          -0.1160486489534378,\n          -0.006125911604613066\n        ],\n        [\n          2.811251401901245,\n          0.0016747766640037298,\n          -0.005349006038159132\n        ],\n        [\n          -0.8591147065162659,\n          0.0037903853226453066,\n          0.00048354381578974426\n        ],\n        [\n          0.006445891689509153,\n          -0.0036706889513880014,\n          -0.03472399711608887\n        ],\n        [\n          -0.001481462037190795,\n          0.0015367366140708327,\n          -0.0015593112912029028\n        ],\n        [\n          -1.5751848220825195,\n          0.001112997648306191,\n          0.0009848373010754585\n        ],\n        [\n          -0.16862420737743378,\n          -0.016877643764019012,\n          -0.26229384541511536\n        ],\n        [\n          -9.055795817403123e-05,\n          0.09453120082616806,\n          -0.0134742371737957\n        ],\n        [\n          2.811314344406128,\n          0.003919574897736311,\n          0.005575981922447681\n        ],\n        [\n          -0.8299098014831543,\n          -0.003791244002059102,\n          0.0012802339624613523\n        ],\n        [\n          0.005852710455656052,\n          0.005849692039191723,\n          0.1632416546344757\n        ],\n        [\n          -0.0015579514438286424,\n          9.288851288147271e-05,\n          0.001196552417241037\n        ],\n        [\n          0.00043879495933651924,\n          0.04429133981466293,\n          0.0002551022043917328\n        ],\n        [\n          -0.0019886596128344536,\n          0.008745947852730751,\n          -0.00962099153548479\n        ],\n        [\n          0.5197923183441162,\n          -0.0010678194230422378,\n          0.0002590256044641137\n        ],\n        [\n          -0.9051622152328491,\n          -0.12138096988201141,\n          0.25749173760414124\n        ],\n        [\n          0.010689850896596909,\n          -0.01072163600474596,\n          0.20382197201251984\n        ],\n        [\n          -0.0009684870601631701,\n          -0.5894762873649597,\n          0.0032688004430383444\n        ],\n        [\n          1.30536949634552,\n          -0.002206705743446946,\n          -0.0020471925381571054\n        ],\n        [\n          0.0067055909894406796,\n          -0.015674468129873276,\n          -0.9086763262748718\n        ],\n        [\n          -0.26612186431884766,\n          -0.00016191616305150092,\n          0.002851327648386359\n        ],\n        [\n          0.003539646975696087,\n          0.20451955497264862,\n          -0.02575569413602352\n        ],\n        [\n          0.003367731347680092,\n          0.0018452388467267156,\n          -0.00026573429931886494\n        ],\n        [\n          -0.9464634656906128,\n          0.12737642228603363,\n          -0.2577688992023468\n        ],\n        [\n          0.00046661958913318813,\n          -0.008693858049809933,\n          -0.19606870412826538\n        ],\n        [\n          -0.0058177076280117035,\n          0.6349377036094666,\n          -0.0003108184027951211\n        ],\n        [\n          1.4694209098815918,\n          0.0046353572979569435,\n          0.002392316237092018\n        ],\n        [\n          0.022281549870967865,\n          0.006433307193219662,\n          1.1441218852996826\n        ],\n        [\n          -0.16217999160289764,\n          -0.0005673008854500949,\n          -0.0028868752997368574\n        ],\n        [\n          0.0011142585426568985,\n          0.036793302744627,\n          0.06873425096273422\n        ],\n        [\n          0.001964340452104807,\n          -0.004202086944133043,\n          0.0034294212237000465\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.014979152008891106,\n        0.7896444201469421,\n        0.8725281357765198\n      ],\n      [\n        0.12546521425247192,\n        0.30551770329475403,\n        2.3331315517425537\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.014979152008891106,\n        0.8725281357765198\n      ],\n      [\n        0.12546521425247192,\n        2.3331315517425537\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/02_multi_text_ee_constraint/meta.json",
    "content": "{\n  \"texts\": [\n    \"A person walks forward while carrying a box\",\n    \"A person sets a box down onto the ground\"\n  ],\n  \"durations\": [\n    3.533333333333333,\n    4.066666666666666\n  ],\n  \"num_samples\": 1,\n  \"seed\": 60,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 1.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/03_full_body_keyframes/constraints.json",
    "content": "[\n  {\n    \"type\": \"fullbody\",\n    \"frame_indices\": [\n      59,\n      106,\n      148\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.42420727014541626,\n          0.058721136301755905,\n          -0.1945635825395584\n        ],\n        [\n          -0.5268475413322449,\n          -0.0005157420528121293,\n          0.0004701620200648904\n        ],\n        [\n          -0.17267920076847076,\n          0.027239520102739334,\n          0.36560261249542236\n        ],\n        [\n          0.004160718061029911,\n          -0.22976335883140564,\n          0.010524176992475986\n        ],\n        [\n          1.5825881958007812,\n          -0.01814083196222782,\n          -0.00019598894868977368\n        ],\n        [\n          -0.8827329277992249,\n          0.009902671910822392,\n          -0.00021610780095215887\n        ],\n        [\n          0.0067768096923828125,\n          -0.013547217473387718,\n          -0.16673408448696136\n        ],\n        [\n          0.0006806282908655703,\n          0.004601094871759415,\n          -0.0043960982002317905\n        ],\n        [\n          -1.4894901514053345,\n          -0.003371267579495907,\n          -0.001970127457752824\n        ],\n        [\n          -0.17904962599277496,\n          0.004051337484270334,\n          0.19225701689720154\n        ],\n        [\n          -0.0033012183848768473,\n          -0.29656991362571716,\n          0.004984850063920021\n        ],\n        [\n          1.5931552648544312,\n          -0.007282367907464504,\n          -0.0052862209267914295\n        ],\n        [\n          -0.35364261269569397,\n          0.0049067274667322636,\n          0.0010333984391763806\n        ],\n        [\n          0.0023804877419024706,\n          -0.005421861540526152,\n          -0.19129839539527893\n        ],\n        [\n          0.0008946731686592102,\n          0.0049979668110609055,\n          -0.0008540445705875754\n        ],\n        [\n          -0.00037546976818703115,\n          -0.09826900064945221,\n          0.0006841858848929405\n        ],\n        [\n          0.004415650386363268,\n          0.0112489964812994,\n          0.025344429537653923\n        ],\n        [\n          0.5182019472122192,\n          0.002875699894502759,\n          0.002064053900539875\n        ],\n        [\n          -0.7899102568626404,\n          -0.11301380395889282,\n          0.261331170797348\n        ],\n        [\n          -0.004763631150126457,\n          0.003188431030139327,\n          0.191846564412117\n        ],\n        [\n          -0.0006821855786256492,\n          -0.24938665330410004,\n          0.0013275814708322287\n        ],\n        [\n          1.1367335319519043,\n          0.0038948820438236,\n          0.0009569167159497738\n        ],\n        [\n          0.006261332891881466,\n          0.020894864574074745,\n          -1.050469160079956\n        ],\n        [\n          0.06118401885032654,\n          0.0005131644429638982,\n          0.00042430072790011764\n        ],\n        [\n          0.0017778673209249973,\n          0.08777552843093872,\n          -0.044312309473752975\n        ],\n        [\n          -0.0006084830965846777,\n          0.0022449076641350985,\n          -0.001873409142717719\n        ],\n        [\n          0.33878403902053833,\n          -0.04740850627422333,\n          -0.2796333432197571\n        ],\n        [\n          0.02221747301518917,\n          0.013649695552885532,\n          -0.11847231537103653\n        ],\n        [\n          0.007714178413152695,\n          0.6182990074157715,\n          0.009067214094102383\n        ],\n        [\n          0.8923805952072144,\n          -0.00016622581460978836,\n          0.0021162345074117184\n        ],\n        [\n          0.0038995807990431786,\n          -0.006832453887909651,\n          0.3025287687778473\n        ],\n        [\n          0.03307999297976494,\n          0.0005516205565072596,\n          0.0009820020059123635\n        ],\n        [\n          0.0015379488468170166,\n          -0.08221427351236343,\n          -0.014401843771338463\n        ],\n        [\n          -0.00022057670867070556,\n          0.002010792726650834,\n          0.0012923656031489372\n        ]\n      ],\n      [\n        [\n          -0.08197958767414093,\n          0.10326994955539703,\n          -0.1510602980852127\n        ],\n        [\n          0.28157129883766174,\n          0.0011461800895631313,\n          0.000703590689226985\n        ],\n        [\n          -0.182321235537529,\n          0.05269569158554077,\n          0.2730983793735504\n        ],\n        [\n          -0.0003947282093577087,\n          0.09641454368829727,\n          0.0040251282043755054\n        ],\n        [\n          1.089223861694336,\n          -0.00700604822486639,\n          -0.002539312234148383\n        ],\n        [\n          -0.09248486906290054,\n          0.003849609522148967,\n          0.0016473153373226523\n        ],\n        [\n          -0.010541710071265697,\n          0.004344945307821035,\n          0.07663393765687943\n        ],\n        [\n          -0.00044715296826325357,\n          -0.004340745974332094,\n          0.007171581499278545\n        ],\n        [\n          -0.3379390239715576,\n          0.0015806800220161676,\n          -0.0003471111413091421\n        ],\n        [\n          -0.1781967729330063,\n          0.016616491600871086,\n          0.1652776598930359\n        ],\n        [\n          -0.002019439358264208,\n          -0.11581386625766754,\n          0.0009603232610970736\n        ],\n        [\n          0.6794841289520264,\n          -5.403390241554007e-05,\n          -0.0012657493352890015\n        ],\n        [\n          -0.09013757854700089,\n          0.0018549489323049784,\n          -0.000238976048422046\n        ],\n        [\n          -0.0009166855015791953,\n          -0.0007138565997593105,\n          -0.0742788091301918\n        ],\n        [\n          -0.0009655999601818621,\n          0.0029521933756768703,\n          -0.00039851426845416427\n        ],\n        [\n          -0.0006129079265519977,\n          -0.19495022296905518,\n          -0.0019512351136654615\n        ],\n        [\n          0.0019297772087156773,\n          -0.0025066917296499014,\n          0.1518552601337433\n        ],\n        [\n          0.18073193728923798,\n          -0.0008597049745731056,\n          0.00023304206843022257\n        ],\n        [\n          -0.19048453867435455,\n          -0.02173178642988205,\n          0.2785468101501465\n        ],\n        [\n          0.0032724339980632067,\n          0.001481848070397973,\n          0.00837984960526228\n        ],\n        [\n          0.0037242062389850616,\n          -0.19455766677856445,\n          0.009616612456738949\n        ],\n        [\n          -0.19767794013023376,\n          0.004192049615085125,\n          0.004219892434775829\n        ],\n        [\n          -0.018522148951888084,\n          0.01758752018213272,\n          -1.4997444152832031\n        ],\n        [\n          -0.07066819816827774,\n          -0.0006776255904696882,\n          0.00122307357378304\n        ],\n        [\n          0.007704276591539383,\n          0.14503517746925354,\n          0.0951184555888176\n        ],\n        [\n          0.004533262457698584,\n          -0.0066575342789292336,\n          -0.010643035173416138\n        ],\n        [\n          0.3773331642150879,\n          -0.05414784327149391,\n          -0.2780730128288269\n        ],\n        [\n          0.003753547091037035,\n          0.002539943205192685,\n          0.12321871519088745\n        ],\n        [\n          -0.004724413156509399,\n          0.46992960572242737,\n          0.001832474721595645\n        ],\n        [\n          1.2976007461547852,\n          0.0007234009681269526,\n          -0.001626322278752923\n        ],\n        [\n          -0.0016050372505560517,\n          -0.00880438182502985,\n          0.17947044968605042\n        ],\n        [\n          0.05334911122918129,\n          -0.00018671243742574006,\n          0.0010833276901394129\n        ],\n        [\n          -0.0015367609448730946,\n          -0.05425700917840004,\n          0.01668459363281727\n        ],\n        [\n          -0.00021225935779511929,\n          0.001713683595880866,\n          0.0009809889597818255\n        ]\n      ],\n      [\n        [\n          -0.21817633509635925,\n          -0.012708673253655434,\n          -0.029821090400218964\n        ],\n        [\n          0.3743710219860077,\n          0.0007941523799672723,\n          0.00032366320374421775\n        ],\n        [\n          -0.16750676929950714,\n          0.003018906805664301,\n          0.07928019016981125\n        ],\n        [\n          -0.0003895726113114506,\n          0.030501781031489372,\n          0.0013912678696215153\n        ],\n        [\n          0.2578306794166565,\n          -0.0026517061050981283,\n          -0.0001139347514254041\n        ],\n        [\n          -0.227533221244812,\n          0.0004564583650790155,\n          -0.0004175934591330588\n        ],\n        [\n          -0.0015815469669178128,\n          0.0026496825739741325,\n          -0.017801448702812195\n        ],\n        [\n          0.00024288007989525795,\n          0.000392801477573812,\n          -2.9845070457668044e-05\n        ],\n        [\n          0.31938642263412476,\n          -0.0006790655897930264,\n          -0.0004276619874872267\n        ],\n        [\n          -0.17199693620204926,\n          0.007707139942795038,\n          0.014987054280936718\n        ],\n        [\n          0.0012992072151973844,\n          0.003620905103161931,\n          -0.001210421440191567\n        ],\n        [\n          0.22853288054466248,\n          -0.0018797506345435977,\n          -0.0002660619793459773\n        ],\n        [\n          -0.1335543692111969,\n          0.0010313205420970917,\n          0.0001083972238120623\n        ],\n        [\n          0.003520265920087695,\n          0.0035283963661640882,\n          0.016698163002729416\n        ],\n        [\n          0.0001443400833522901,\n          -0.001745356246829033,\n          -2.3336755475611426e-05\n        ],\n        [\n          0.0003554633294697851,\n          -0.05629483610391617,\n          -0.0006463310564868152\n        ],\n        [\n          -0.00298635708168149,\n          0.0020182463340461254,\n          -0.03614736720919609\n        ],\n        [\n          0.21955031156539917,\n          0.0005465149879455566,\n          0.00011243963672313839\n        ],\n        [\n          -0.0715053528547287,\n          -0.010282701812684536,\n          0.28057143092155457\n        ],\n        [\n          0.0007245761225931346,\n          0.0019379559671506286,\n          -0.018530432134866714\n        ],\n        [\n          -0.0020012110471725464,\n          -0.5585712194442749,\n          0.0002525273594073951\n        ],\n        [\n          1.1451164484024048,\n          0.000756395107600838,\n          -0.00042264885269105434\n        ],\n        [\n          -0.004087591078132391,\n          -0.0022635578643530607,\n          -0.1811828911304474\n        ],\n        [\n          0.15393203496932983,\n          -0.00010327681229682639,\n          0.000951180059928447\n        ],\n        [\n          -0.0005707733216695487,\n          0.07005079090595245,\n          -0.0018504050094634295\n        ],\n        [\n          -0.0013123765820637345,\n          -0.0004375300486572087,\n          0.0002970081695821136\n        ],\n        [\n          -0.09115279465913773,\n          0.013008617796003819,\n          -0.2808595299720764\n        ],\n        [\n          0.0015214721206575632,\n          -0.007811791729182005,\n          0.031220799311995506\n        ],\n        [\n          -0.00048553026863373816,\n          0.5777612328529358,\n          0.0003351669874973595\n        ],\n        [\n          1.0913182497024536,\n          0.0011191898956894875,\n          -0.0027903772424906492\n        ],\n        [\n          0.000775794149376452,\n          0.00010774911061162129,\n          0.10287072509527206\n        ],\n        [\n          0.0997936949133873,\n          0.0003015398688148707,\n          -0.0006937433499842882\n        ],\n        [\n          0.0003619014751166105,\n          -0.18787385523319244,\n          -0.0010270585771650076\n        ],\n        [\n          -0.001584835583344102,\n          0.0037561857607215643,\n          -0.002414965769276023\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        -0.17535515129566193,\n        0.5689253807067871,\n        0.9417929649353027\n      ],\n      [\n        -0.16934014856815338,\n        0.7382326722145081,\n        2.169330596923828\n      ],\n      [\n        -0.1823902279138565,\n        0.7819305658340454,\n        2.954490900039673\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        -0.17535515129566193,\n        0.9417929649353027\n      ],\n      [\n        -0.16934014856815338,\n        2.169330596923828\n      ],\n      [\n        -0.1823902279138565,\n        2.954490900039673\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/03_full_body_keyframes/meta.json",
    "content": "{\n  \"text\": \"A person walking forward picks up something off the ground\",\n  \"duration\": 5.0,\n  \"num_samples\": 1,\n  \"seed\": 51,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 1.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/04_ee_constraint/constraints.json",
    "content": "[\n  {\n    \"type\": \"right-hand\",\n    \"frame_indices\": [\n      129,\n      93,\n      0\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          -0.11223886162042618,\n          0.20531758666038513,\n          0.13568778336048126\n        ],\n        [\n          0.1075688898563385,\n          0.0032202948350459337,\n          0.0006892754463478923\n        ],\n        [\n          -0.17058254778385162,\n          -0.011657492257654667,\n          -0.23103317618370056\n        ],\n        [\n          -0.02866872400045395,\n          0.4262913167476654,\n          -0.010209682397544384\n        ],\n        [\n          0.2924644649028778,\n          0.007188746705651283,\n          0.0005000674282200634\n        ],\n        [\n          -0.13080132007598877,\n          -0.0029640060383826494,\n          -0.0007075564353726804\n        ],\n        [\n          -0.005761375650763512,\n          -0.002191383158788085,\n          0.15397773683071136\n        ],\n        [\n          -0.00023041102394927293,\n          -0.0010889451950788498,\n          0.0007837787852622569\n        ],\n        [\n          -0.3537895977497101,\n          -0.0006238390924409032,\n          -0.0010272490326315165\n        ],\n        [\n          -0.16032733023166656,\n          -0.02506295032799244,\n          -0.22620464861392975\n        ],\n        [\n          0.0138308797031641,\n          0.21655774116516113,\n          0.0317748561501503\n        ],\n        [\n          1.5745534896850586,\n          0.003732866607606411,\n          0.0021063678432255983\n        ],\n        [\n          -0.17066748440265656,\n          -0.002285068854689598,\n          -0.0029538189992308617\n        ],\n        [\n          0.02313886024057865,\n          -0.07020875811576843,\n          -0.05658446252346039\n        ],\n        [\n          2.5580025976523757e-05,\n          0.004435115493834019,\n          -0.006514436099678278\n        ],\n        [\n          0.0015886364271864295,\n          -0.292732834815979,\n          -0.0014166575856506824\n        ],\n        [\n          -0.008558829315006733,\n          0.0066470191814005375,\n          -0.010221566073596478\n        ],\n        [\n          0.10141321271657944,\n          -0.0028386565390974283,\n          -0.0006978976307436824\n        ],\n        [\n          0.002506372518837452,\n          0.001101600006222725,\n          0.2779805362224579\n        ],\n        [\n          0.020367039367556572,\n          -0.028616085648536682,\n          0.0971180647611618\n        ],\n        [\n          -0.011572631075978279,\n          -0.5930124521255493,\n          -0.026975814253091812\n        ],\n        [\n          0.9286840558052063,\n          -0.00046807233593426645,\n          -0.00013331411173567176\n        ],\n        [\n          -0.031172338873147964,\n          -0.04484722763299942,\n          0.03643424063920975\n        ],\n        [\n          0.03150894120335579,\n          -0.00101278827060014,\n          0.0015338404336944222\n        ],\n        [\n          0.0005915925721637905,\n          0.0930531769990921,\n          -0.028835415840148926\n        ],\n        [\n          -0.001440802589058876,\n          0.0010614224011078477,\n          0.0006542576011270285\n        ],\n        [\n          -0.4149414598941803,\n          0.06656259298324585,\n          -0.2730332314968109\n        ],\n        [\n          -0.006371266208589077,\n          -0.02150307223200798,\n          -1.3590242862701416\n        ],\n        [\n          0.00956002902239561,\n          -0.17155548930168152,\n          0.026624836027622223\n        ],\n        [\n          0.8084958791732788,\n          -0.003991501871496439,\n          0.0007233448559418321\n        ],\n        [\n          -0.020737944170832634,\n          -0.011397535912692547,\n          0.14019189774990082\n        ],\n        [\n          -0.18261606991291046,\n          0.005134414881467819,\n          -0.001045998651534319\n        ],\n        [\n          -0.028427572920918465,\n          -0.6557883620262146,\n          0.038063470274209976\n        ],\n        [\n          0.005555277690291405,\n          0.012246276251971722,\n          0.004770371131598949\n        ]\n      ],\n      [\n        [\n          -0.06392758339643478,\n          0.3478183448314667,\n          0.1171446293592453\n        ],\n        [\n          0.12243298441171646,\n          0.003146131755784154,\n          0.00017438907525502145\n        ],\n        [\n          -0.17841783165931702,\n          -0.0256511103361845,\n          -0.2805330455303192\n        ],\n        [\n          -0.022625330835580826,\n          0.348234087228775,\n          -0.009928824380040169\n        ],\n        [\n          0.28284141421318054,\n          0.009495020844042301,\n          0.0010556986089795828\n        ],\n        [\n          -0.17478667199611664,\n          -0.004891794174909592,\n          -0.0013969563879072666\n        ],\n        [\n          -0.002641322324052453,\n          -0.005833400413393974,\n          0.20226475596427917\n        ],\n        [\n          -0.0009078677394427359,\n          -0.002073301700875163,\n          0.0012749496381729841\n        ],\n        [\n          -0.48070675134658813,\n          0.0005347213009372354,\n          -0.0004243548901285976\n        ],\n        [\n          -0.16694584488868713,\n          -0.03390314802527428,\n          -0.09055406600236893\n        ],\n        [\n          0.009182179346680641,\n          0.1743844896554947,\n          0.01932411640882492\n        ],\n        [\n          1.6481772661209106,\n          0.0002097517135553062,\n          0.0010239556431770325\n        ],\n        [\n          -0.17133140563964844,\n          0.0028362423181533813,\n          -0.004689408931881189\n        ],\n        [\n          0.025385459885001183,\n          -0.06771048158407211,\n          -0.011561849154531956\n        ],\n        [\n          -0.00012663791130762547,\n          0.001872184220701456,\n          -0.002834505634382367\n        ],\n        [\n          0.001523697399534285,\n          -0.48211750388145447,\n          -0.0005278618773445487\n        ],\n        [\n          -0.00822246354073286,\n          -0.00923906546086073,\n          -0.01643195189535618\n        ],\n        [\n          0.04035002365708351,\n          -0.004922393709421158,\n          -0.0005214703269302845\n        ],\n        [\n          -0.02120170183479786,\n          -0.000465662480564788,\n          0.27964550256729126\n        ],\n        [\n          0.042349521070718765,\n          -0.043123405426740646,\n          0.21025802195072174\n        ],\n        [\n          -0.01620035618543625,\n          -0.5838293433189392,\n          -0.03403719887137413\n        ],\n        [\n          1.1832103729248047,\n          0.0004754749243147671,\n          -0.0014872060855850577\n        ],\n        [\n          -0.040768858045339584,\n          -0.04618615657091141,\n          0.04847611486911774\n        ],\n        [\n          0.04482508823275566,\n          -0.0005392982857301831,\n          0.00035259113064967096\n        ],\n        [\n          0.00015537742001470178,\n          -0.024237608537077904,\n          -0.003044326091185212\n        ],\n        [\n          -0.0012453795643523335,\n          0.004743263591080904,\n          0.004625802394002676\n        ],\n        [\n          -0.14595142006874084,\n          0.0308919008821249,\n          -0.2779163420200348\n        ],\n        [\n          -0.03314027562737465,\n          -0.07205720245838165,\n          -1.3401029109954834\n        ],\n        [\n          0.02448190003633499,\n          -0.468079537153244,\n          0.018310735002160072\n        ],\n        [\n          0.9222347140312195,\n          -0.00624655419960618,\n          -0.0003706512216012925\n        ],\n        [\n          0.0311859343200922,\n          -0.01980999857187271,\n          -0.4311404228210449\n        ],\n        [\n          -0.05887744575738907,\n          0.009565972723066807,\n          0.0008855919586494565\n        ],\n        [\n          -0.0638674795627594,\n          -1.1882448196411133,\n          -0.07744041085243225\n        ],\n        [\n          0.002320833969861269,\n          0.014880148693919182,\n          0.00827236007899046\n        ]\n      ],\n      [\n        [\n          0.028708748519420624,\n          0.023731501772999763,\n          -0.05906220152974129\n        ],\n        [\n          0.36697518825531006,\n          0.0020822372753173113,\n          9.442192094866186e-06\n        ],\n        [\n          -0.17328320443630219,\n          -0.029694421216845512,\n          -0.2592017650604248\n        ],\n        [\n          -0.027558816596865654,\n          0.44522055983543396,\n          0.00263651879504323\n        ],\n        [\n          0.45747342705726624,\n          0.006375299766659737,\n          0.000838644162286073\n        ],\n        [\n          -0.29932498931884766,\n          -0.0034287264570593834,\n          -0.005712746176868677\n        ],\n        [\n          0.010242770425975323,\n          0.0686849057674408,\n          0.12300582230091095\n        ],\n        [\n          0.0019906593952327967,\n          -0.006487288512289524,\n          0.004740884527564049\n        ],\n        [\n          -0.417245090007782,\n          0.002172173699364066,\n          -0.000527464144397527\n        ],\n        [\n          -0.16229933500289917,\n          -0.015825729817152023,\n          0.26093363761901855\n        ],\n        [\n          -0.01547759398818016,\n          -0.4560239017009735,\n          -0.001296655391342938\n        ],\n        [\n          0.520811140537262,\n          -0.016100304201245308,\n          -0.0033653294667601585\n        ],\n        [\n          -0.061035193502902985,\n          0.013747301883995533,\n          0.0011975782690569758\n        ],\n        [\n          0.002211581217125058,\n          0.013100380077958107,\n          -0.41168421506881714\n        ],\n        [\n          0.000723487522918731,\n          -0.0009448538185097277,\n          -0.0023157261312007904\n        ],\n        [\n          -0.0008414603653363883,\n          -0.22904154658317566,\n          0.0037871438544243574\n        ],\n        [\n          -0.004434449132531881,\n          -0.0019493037834763527,\n          0.04349867254495621\n        ],\n        [\n          0.11113234609365463,\n          -0.001496539101935923,\n          -6.745033260813216e-06\n        ],\n        [\n          0.03568394109606743,\n          0.00850191805511713,\n          0.2815527021884918\n        ],\n        [\n          0.007574420887976885,\n          -0.005988026969134808,\n          -0.04585442319512367\n        ],\n        [\n          -0.014899174682796001,\n          -0.6360949277877808,\n          0.014495083130896091\n        ],\n        [\n          1.1318601369857788,\n          -0.0009174949955195189,\n          -0.008180576376616955\n        ],\n        [\n          -0.038145799189805984,\n          -0.05923198536038399,\n          -0.04122990742325783\n        ],\n        [\n          0.07719366252422333,\n          -0.0010574767366051674,\n          0.0009220906649716198\n        ],\n        [\n          -0.0010063578374683857,\n          0.12876589596271515,\n          -0.021289559081196785\n        ],\n        [\n          -0.0008511252817697823,\n          -0.0003541657351888716,\n          -0.0006832815706729889\n        ],\n        [\n          0.10374817997217178,\n          -0.014772959984838963,\n          -0.28235113620758057\n        ],\n        [\n          0.029763568192720413,\n          0.00017807059339247644,\n          0.007343007251620293\n        ],\n        [\n          -0.0064206854440271854,\n          0.3665950298309326,\n          -0.0003897137939929962\n        ],\n        [\n          1.0820642709732056,\n          -0.0005379249923862517,\n          -0.0039028781466186047\n        ],\n        [\n          -0.004170380067080259,\n          0.06480656564235687,\n          -0.10721305757761002\n        ],\n        [\n          -0.06350508332252502,\n          0.0011865347623825073,\n          -0.0005369586870074272\n        ],\n        [\n          -0.0021817537490278482,\n          -0.08756759762763977,\n          -0.008148521184921265\n        ],\n        [\n          0.00243115471675992,\n          -0.003949992824345827,\n          0.005949904676526785\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        2.639763593673706,\n        0.767427384853363,\n        2.341259479522705\n      ],\n      [\n        1.9461809396743774,\n        0.7706995010375977,\n        1.7243560552597046\n      ],\n      [\n        0.003315839683637023,\n        0.7415399551391602,\n        -0.0012030001962557435\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        2.639763593673706,\n        2.341259479522705\n      ],\n      [\n        1.9461809396743774,\n        1.7243560552597046\n      ],\n      [\n        0.003315839683637023,\n        -0.0012030001962557435\n      ]\n    ]\n  },\n  {\n    \"type\": \"left-foot\",\n    \"frame_indices\": [\n      93,\n      0\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          -0.06392758339643478,\n          0.3478183448314667,\n          0.1171446293592453\n        ],\n        [\n          0.12243298441171646,\n          0.003146131755784154,\n          0.00017438907525502145\n        ],\n        [\n          -0.17841783165931702,\n          -0.0256511103361845,\n          -0.2805330455303192\n        ],\n        [\n          -0.022625330835580826,\n          0.348234087228775,\n          -0.009928824380040169\n        ],\n        [\n          0.28284141421318054,\n          0.009495020844042301,\n          0.0010556986089795828\n        ],\n        [\n          -0.17478667199611664,\n          -0.004891794174909592,\n          -0.0013969563879072666\n        ],\n        [\n          -0.002641322324052453,\n          -0.005833400413393974,\n          0.20226475596427917\n        ],\n        [\n          -0.0009078677394427359,\n          -0.002073301700875163,\n          0.0012749496381729841\n        ],\n        [\n          -0.48070675134658813,\n          0.0005347213009372354,\n          -0.0004243548901285976\n        ],\n        [\n          -0.16694584488868713,\n          -0.03390314802527428,\n          -0.09055406600236893\n        ],\n        [\n          0.009182179346680641,\n          0.1743844896554947,\n          0.01932411640882492\n        ],\n        [\n          1.6481772661209106,\n          0.0002097517135553062,\n          0.0010239556431770325\n        ],\n        [\n          -0.17133140563964844,\n          0.0028362423181533813,\n          -0.004689408931881189\n        ],\n        [\n          0.025385459885001183,\n          -0.06771048158407211,\n          -0.011561849154531956\n        ],\n        [\n          -0.00012663791130762547,\n          0.001872184220701456,\n          -0.002834505634382367\n        ],\n        [\n          0.001523697399534285,\n          -0.48211750388145447,\n          -0.0005278618773445487\n        ],\n        [\n          -0.00822246354073286,\n          -0.00923906546086073,\n          -0.01643195189535618\n        ],\n        [\n          0.04035002365708351,\n          -0.004922393709421158,\n          -0.0005214703269302845\n        ],\n        [\n          -0.02120170183479786,\n          -0.000465662480564788,\n          0.27964550256729126\n        ],\n        [\n          0.042349521070718765,\n          -0.043123405426740646,\n          0.21025802195072174\n        ],\n        [\n          -0.01620035618543625,\n          -0.5838293433189392,\n          -0.03403719887137413\n        ],\n        [\n          1.1832103729248047,\n          0.0004754749243147671,\n          -0.0014872060855850577\n        ],\n        [\n          -0.040768858045339584,\n          -0.04618615657091141,\n          0.04847611486911774\n        ],\n        [\n          0.04482508823275566,\n          -0.0005392982857301831,\n          0.00035259113064967096\n        ],\n        [\n          0.00015537742001470178,\n          -0.024237608537077904,\n          -0.003044326091185212\n        ],\n        [\n          -0.0012453795643523335,\n          0.004743263591080904,\n          0.004625802394002676\n        ],\n        [\n          -0.14595142006874084,\n          0.0308919008821249,\n          -0.2779163420200348\n        ],\n        [\n          -0.03314027562737465,\n          -0.07205720245838165,\n          -1.3401029109954834\n        ],\n        [\n          0.02448190003633499,\n          -0.468079537153244,\n          0.018310735002160072\n        ],\n        [\n          0.9222347140312195,\n          -0.00624655419960618,\n          -0.0003706512216012925\n        ],\n        [\n          0.0311859343200922,\n          -0.01980999857187271,\n          -0.4311404228210449\n        ],\n        [\n          -0.05887744575738907,\n          0.009565972723066807,\n          0.0008855919586494565\n        ],\n        [\n          -0.0638674795627594,\n          -1.1882448196411133,\n          -0.07744041085243225\n        ],\n        [\n          0.002320833969861269,\n          0.014880148693919182,\n          0.00827236007899046\n        ]\n      ],\n      [\n        [\n          0.028708748519420624,\n          0.023731501772999763,\n          -0.05906220152974129\n        ],\n        [\n          0.36697518825531006,\n          0.0020822372753173113,\n          9.442192094866186e-06\n        ],\n        [\n          -0.17328320443630219,\n          -0.029694421216845512,\n          -0.2592017650604248\n        ],\n        [\n          -0.027558816596865654,\n          0.44522055983543396,\n          0.00263651879504323\n        ],\n        [\n          0.45747342705726624,\n          0.006375299766659737,\n          0.000838644162286073\n        ],\n        [\n          -0.29932498931884766,\n          -0.0034287264570593834,\n          -0.005712746176868677\n        ],\n        [\n          0.010242770425975323,\n          0.0686849057674408,\n          0.12300582230091095\n        ],\n        [\n          0.0019906593952327967,\n          -0.006487288512289524,\n          0.004740884527564049\n        ],\n        [\n          -0.417245090007782,\n          0.002172173699364066,\n          -0.000527464144397527\n        ],\n        [\n          -0.16229933500289917,\n          -0.015825729817152023,\n          0.26093363761901855\n        ],\n        [\n          -0.01547759398818016,\n          -0.4560239017009735,\n          -0.001296655391342938\n        ],\n        [\n          0.520811140537262,\n          -0.016100304201245308,\n          -0.0033653294667601585\n        ],\n        [\n          -0.061035193502902985,\n          0.013747301883995533,\n          0.0011975782690569758\n        ],\n        [\n          0.002211581217125058,\n          0.013100380077958107,\n          -0.41168421506881714\n        ],\n        [\n          0.000723487522918731,\n          -0.0009448538185097277,\n          -0.0023157261312007904\n        ],\n        [\n          -0.0008414603653363883,\n          -0.22904154658317566,\n          0.0037871438544243574\n        ],\n        [\n          -0.004434449132531881,\n          -0.0019493037834763527,\n          0.04349867254495621\n        ],\n        [\n          0.11113234609365463,\n          -0.001496539101935923,\n          -6.745033260813216e-06\n        ],\n        [\n          0.03568394109606743,\n          0.00850191805511713,\n          0.2815527021884918\n        ],\n        [\n          0.007574420887976885,\n          -0.005988026969134808,\n          -0.04585442319512367\n        ],\n        [\n          -0.014899174682796001,\n          -0.6360949277877808,\n          0.014495083130896091\n        ],\n        [\n          1.1318601369857788,\n          -0.0009174949955195189,\n          -0.008180576376616955\n        ],\n        [\n          -0.038145799189805984,\n          -0.05923198536038399,\n          -0.04122990742325783\n        ],\n        [\n          0.07719366252422333,\n          -0.0010574767366051674,\n          0.0009220906649716198\n        ],\n        [\n          -0.0010063578374683857,\n          0.12876589596271515,\n          -0.021289559081196785\n        ],\n        [\n          -0.0008511252817697823,\n          -0.0003541657351888716,\n          -0.0006832815706729889\n        ],\n        [\n          0.10374817997217178,\n          -0.014772959984838963,\n          -0.28235113620758057\n        ],\n        [\n          0.029763568192720413,\n          0.00017807059339247644,\n          0.007343007251620293\n        ],\n        [\n          -0.0064206854440271854,\n          0.3665950298309326,\n          -0.0003897137939929962\n        ],\n        [\n          1.0820642709732056,\n          -0.0005379249923862517,\n          -0.0039028781466186047\n        ],\n        [\n          -0.004170380067080259,\n          0.06480656564235687,\n          -0.10721305757761002\n        ],\n        [\n          -0.06350508332252502,\n          0.0011865347623825073,\n          -0.0005369586870074272\n        ],\n        [\n          -0.0021817537490278482,\n          -0.08756759762763977,\n          -0.008148521184921265\n        ],\n        [\n          0.00243115471675992,\n          -0.003949992824345827,\n          0.005949904676526785\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        1.9461809396743774,\n        0.7706995010375977,\n        1.7243560552597046\n      ],\n      [\n        0.003315839683637023,\n        0.7415399551391602,\n        -0.0012030001962557435\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        1.9461809396743774,\n        1.7243560552597046\n      ],\n      [\n        0.003315839683637023,\n        -0.0012030001962557435\n      ]\n    ]\n  },\n  {\n    \"type\": \"right-foot\",\n    \"frame_indices\": [\n      0\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.028708748519420624,\n          0.023731501772999763,\n          -0.05906220152974129\n        ],\n        [\n          0.36697518825531006,\n          0.0020822372753173113,\n          9.442192094866186e-06\n        ],\n        [\n          -0.17328320443630219,\n          -0.029694421216845512,\n          -0.2592017650604248\n        ],\n        [\n          -0.027558816596865654,\n          0.44522055983543396,\n          0.00263651879504323\n        ],\n        [\n          0.45747342705726624,\n          0.006375299766659737,\n          0.000838644162286073\n        ],\n        [\n          -0.29932498931884766,\n          -0.0034287264570593834,\n          -0.005712746176868677\n        ],\n        [\n          0.010242770425975323,\n          0.0686849057674408,\n          0.12300582230091095\n        ],\n        [\n          0.0019906593952327967,\n          -0.006487288512289524,\n          0.004740884527564049\n        ],\n        [\n          -0.417245090007782,\n          0.002172173699364066,\n          -0.000527464144397527\n        ],\n        [\n          -0.16229933500289917,\n          -0.015825729817152023,\n          0.26093363761901855\n        ],\n        [\n          -0.01547759398818016,\n          -0.4560239017009735,\n          -0.001296655391342938\n        ],\n        [\n          0.520811140537262,\n          -0.016100304201245308,\n          -0.0033653294667601585\n        ],\n        [\n          -0.061035193502902985,\n          0.013747301883995533,\n          0.0011975782690569758\n        ],\n        [\n          0.002211581217125058,\n          0.013100380077958107,\n          -0.41168421506881714\n        ],\n        [\n          0.000723487522918731,\n          -0.0009448538185097277,\n          -0.0023157261312007904\n        ],\n        [\n          -0.0008414603653363883,\n          -0.22904154658317566,\n          0.0037871438544243574\n        ],\n        [\n          -0.004434449132531881,\n          -0.0019493037834763527,\n          0.04349867254495621\n        ],\n        [\n          0.11113234609365463,\n          -0.001496539101935923,\n          -6.745033260813216e-06\n        ],\n        [\n          0.03568394109606743,\n          0.00850191805511713,\n          0.2815527021884918\n        ],\n        [\n          0.007574420887976885,\n          -0.005988026969134808,\n          -0.04585442319512367\n        ],\n        [\n          -0.014899174682796001,\n          -0.6360949277877808,\n          0.014495083130896091\n        ],\n        [\n          1.1318601369857788,\n          -0.0009174949955195189,\n          -0.008180576376616955\n        ],\n        [\n          -0.038145799189805984,\n          -0.05923198536038399,\n          -0.04122990742325783\n        ],\n        [\n          0.07719366252422333,\n          -0.0010574767366051674,\n          0.0009220906649716198\n        ],\n        [\n          -0.0010063578374683857,\n          0.12876589596271515,\n          -0.021289559081196785\n        ],\n        [\n          -0.0008511252817697823,\n          -0.0003541657351888716,\n          -0.0006832815706729889\n        ],\n        [\n          0.10374817997217178,\n          -0.014772959984838963,\n          -0.28235113620758057\n        ],\n        [\n          0.029763568192720413,\n          0.00017807059339247644,\n          0.007343007251620293\n        ],\n        [\n          -0.0064206854440271854,\n          0.3665950298309326,\n          -0.0003897137939929962\n        ],\n        [\n          1.0820642709732056,\n          -0.0005379249923862517,\n          -0.0039028781466186047\n        ],\n        [\n          -0.004170380067080259,\n          0.06480656564235687,\n          -0.10721305757761002\n        ],\n        [\n          -0.06350508332252502,\n          0.0011865347623825073,\n          -0.0005369586870074272\n        ],\n        [\n          -0.0021817537490278482,\n          -0.08756759762763977,\n          -0.008148521184921265\n        ],\n        [\n          0.00243115471675992,\n          -0.003949992824345827,\n          0.005949904676526785\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.003315839683637023,\n        0.7415399551391602,\n        -0.0012030001962557435\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.003315839683637023,\n        -0.0012030001962557435\n      ]\n    ]\n  },\n  {\n    \"type\": \"left-hand\",\n    \"frame_indices\": [\n      0\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.028708748519420624,\n          0.023731501772999763,\n          -0.05906220152974129\n        ],\n        [\n          0.36697518825531006,\n          0.0020822372753173113,\n          9.442192094866186e-06\n        ],\n        [\n          -0.17328320443630219,\n          -0.029694421216845512,\n          -0.2592017650604248\n        ],\n        [\n          -0.027558816596865654,\n          0.44522055983543396,\n          0.00263651879504323\n        ],\n        [\n          0.45747342705726624,\n          0.006375299766659737,\n          0.000838644162286073\n        ],\n        [\n          -0.29932498931884766,\n          -0.0034287264570593834,\n          -0.005712746176868677\n        ],\n        [\n          0.010242770425975323,\n          0.0686849057674408,\n          0.12300582230091095\n        ],\n        [\n          0.0019906593952327967,\n          -0.006487288512289524,\n          0.004740884527564049\n        ],\n        [\n          -0.417245090007782,\n          0.002172173699364066,\n          -0.000527464144397527\n        ],\n        [\n          -0.16229933500289917,\n          -0.015825729817152023,\n          0.26093363761901855\n        ],\n        [\n          -0.01547759398818016,\n          -0.4560239017009735,\n          -0.001296655391342938\n        ],\n        [\n          0.520811140537262,\n          -0.016100304201245308,\n          -0.0033653294667601585\n        ],\n        [\n          -0.061035193502902985,\n          0.013747301883995533,\n          0.0011975782690569758\n        ],\n        [\n          0.002211581217125058,\n          0.013100380077958107,\n          -0.41168421506881714\n        ],\n        [\n          0.000723487522918731,\n          -0.0009448538185097277,\n          -0.0023157261312007904\n        ],\n        [\n          -0.0008414603653363883,\n          -0.22904154658317566,\n          0.0037871438544243574\n        ],\n        [\n          -0.004434449132531881,\n          -0.0019493037834763527,\n          0.04349867254495621\n        ],\n        [\n          0.11113234609365463,\n          -0.001496539101935923,\n          -6.745033260813216e-06\n        ],\n        [\n          0.03568394109606743,\n          0.00850191805511713,\n          0.2815527021884918\n        ],\n        [\n          0.007574420887976885,\n          -0.005988026969134808,\n          -0.04585442319512367\n        ],\n        [\n          -0.014899174682796001,\n          -0.6360949277877808,\n          0.014495083130896091\n        ],\n        [\n          1.1318601369857788,\n          -0.0009174949955195189,\n          -0.008180576376616955\n        ],\n        [\n          -0.038145799189805984,\n          -0.05923198536038399,\n          -0.04122990742325783\n        ],\n        [\n          0.07719366252422333,\n          -0.0010574767366051674,\n          0.0009220906649716198\n        ],\n        [\n          -0.0010063578374683857,\n          0.12876589596271515,\n          -0.021289559081196785\n        ],\n        [\n          -0.0008511252817697823,\n          -0.0003541657351888716,\n          -0.0006832815706729889\n        ],\n        [\n          0.10374817997217178,\n          -0.014772959984838963,\n          -0.28235113620758057\n        ],\n        [\n          0.029763568192720413,\n          0.00017807059339247644,\n          0.007343007251620293\n        ],\n        [\n          -0.0064206854440271854,\n          0.3665950298309326,\n          -0.0003897137939929962\n        ],\n        [\n          1.0820642709732056,\n          -0.0005379249923862517,\n          -0.0039028781466186047\n        ],\n        [\n          -0.004170380067080259,\n          0.06480656564235687,\n          -0.10721305757761002\n        ],\n        [\n          -0.06350508332252502,\n          0.0011865347623825073,\n          -0.0005369586870074272\n        ],\n        [\n          -0.0021817537490278482,\n          -0.08756759762763977,\n          -0.008148521184921265\n        ],\n        [\n          0.00243115471675992,\n          -0.003949992824345827,\n          0.005949904676526785\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.003315839683637023,\n        0.7415399551391602,\n        -0.0012030001962557435\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.003315839683637023,\n        -0.0012030001962557435\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/04_ee_constraint/meta.json",
    "content": "{\n  \"text\": \"A person walks diagonally to the left and waves at someone on their right\",\n  \"duration\": 4.966666666666667,\n  \"num_samples\": 1,\n  \"seed\": 44,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/05_root_path/constraints.json",
    "content": "[\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [\n      0,\n      1,\n      2,\n      3,\n      4,\n      5,\n      6,\n      7,\n      8,\n      9,\n      10,\n      11,\n      12,\n      13,\n      14,\n      15,\n      16,\n      17,\n      18,\n      19,\n      20,\n      21,\n      22,\n      23,\n      24,\n      25,\n      26,\n      27,\n      28,\n      29,\n      30,\n      31,\n      32,\n      33,\n      34,\n      35,\n      36,\n      37,\n      38,\n      39,\n      40,\n      41,\n      42,\n      43,\n      44,\n      45,\n      46,\n      47,\n      48,\n      49,\n      50,\n      51,\n      52,\n      53,\n      54,\n      55,\n      56,\n      57,\n      58,\n      59,\n      60,\n      61,\n      62,\n      63,\n      64,\n      65,\n      66,\n      67,\n      68,\n      69,\n      70,\n      71,\n      72,\n      73,\n      74,\n      75,\n      76,\n      77,\n      78,\n      79,\n      80,\n      81,\n      82,\n      83,\n      84,\n      85,\n      86,\n      87,\n      88,\n      89,\n      90,\n      91,\n      92,\n      93,\n      94,\n      95,\n      96,\n      97,\n      98,\n      99,\n      100,\n      101,\n      102,\n      103,\n      104,\n      105,\n      106,\n      107,\n      108,\n      109,\n      110,\n      111,\n      112,\n      113,\n      114,\n      115,\n      116,\n      117,\n      118,\n      119,\n      120,\n      121,\n      122,\n      123,\n      124,\n      125,\n      126,\n      127,\n      128,\n      129,\n      130,\n      131,\n      132,\n      133,\n      134,\n      135,\n      136,\n      137,\n      138,\n      139,\n      140,\n      141,\n      142,\n      143,\n      144,\n      145,\n      146,\n      147,\n      148,\n      149,\n      150,\n      151,\n      152,\n      153,\n      154,\n      155,\n      156,\n      157,\n      158,\n      159,\n      160,\n      161,\n      162,\n      163,\n      164,\n      165,\n      166,\n      167,\n      168,\n      169,\n      170,\n      171,\n      172,\n      173,\n      174,\n      175,\n      176,\n      177,\n      178,\n      179,\n      180\n    ],\n    \"smooth_root_2d\": [\n      [\n        -0.024789854884147644,\n        0.01764228567481041\n      ],\n      [\n        -0.019911596551537514,\n        0.03666473180055618\n      ],\n      [\n        -0.015032900497317314,\n        0.05568705126643181\n      ],\n      [\n        -0.010153300128877163,\n        0.07470902800559998\n      ],\n      [\n        -0.005272198934108019,\n        0.09373034536838531\n      ],\n      [\n        -0.00038888092967681587,\n        0.11275061219930649\n      ],\n      [\n        0.004497467540204525,\n        0.1317693293094635\n      ],\n      [\n        0.009387745521962643,\n        0.15078598260879517\n      ],\n      [\n        0.014282921329140663,\n        0.16979998350143433\n      ],\n      [\n        0.019184017553925514,\n        0.18881070613861084\n      ],\n      [\n        0.024092093110084534,\n        0.20781749486923218\n      ],\n      [\n        0.029008235782384872,\n        0.226819708943367\n      ],\n      [\n        0.033933546394109726,\n        0.24581670761108398\n      ],\n      [\n        0.038869116455316544,\n        0.2648078203201294\n      ],\n      [\n        0.04381602630019188,\n        0.2837924659252167\n      ],\n      [\n        0.048775337636470795,\n        0.30277004837989807\n      ],\n      [\n        0.05374806746840477,\n        0.321740061044693\n      ],\n      [\n        0.058735184371471405,\n        0.3407020568847656\n      ],\n      [\n        0.06373759359121323,\n        0.35965561866760254\n      ],\n      [\n        0.06875615566968918,\n        0.37860047817230225\n      ],\n      [\n        0.07379162311553955,\n        0.3975364565849304\n      ],\n      [\n        0.07884468138217926,\n        0.4164634943008423\n      ],\n      [\n        0.08391592651605606,\n        0.43538162112236023\n      ],\n      [\n        0.08900584280490875,\n        0.45429113507270813\n      ],\n      [\n        0.09411482512950897,\n        0.47319236397743225\n      ],\n      [\n        0.0992431491613388,\n        0.49208587408065796\n      ],\n      [\n        0.10439097136259079,\n        0.5109724998474121\n      ],\n      [\n        0.1095583438873291,\n        0.5298531651496887\n      ],\n      [\n        0.11474518477916718,\n        0.5487290620803833\n      ],\n      [\n        0.11995130032300949,\n        0.5676016807556152\n      ],\n      [\n        0.12517637014389038,\n        0.5864726901054382\n      ],\n      [\n        0.13041996955871582,\n        0.6053440570831299\n      ],\n      [\n        0.13568153977394104,\n        0.6242179274559021\n      ],\n      [\n        0.1409604400396347,\n        0.6430967450141907\n      ],\n      [\n        0.14625589549541473,\n        0.6619831919670105\n      ],\n      [\n        0.15156707167625427,\n        0.6808802485466003\n      ],\n      [\n        0.15689301490783691,\n        0.6997910141944885\n      ],\n      [\n        0.16223272681236267,\n        0.7187188267707825\n      ],\n      [\n        0.16759774088859558,\n        0.7376715540885925\n      ],\n      [\n        0.17303690314292908,\n        0.7566697001457214\n      ],\n      [\n        0.17862369120121002,\n        0.7757418751716614\n      ],\n      [\n        0.1844315379858017,\n        0.7949170470237732\n      ],\n      [\n        0.19053390622138977,\n        0.8142240643501282\n      ],\n      [\n        0.19700415432453156,\n        0.8336920142173767\n      ],\n      [\n        0.20391567051410675,\n        0.8533498644828796\n      ],\n      [\n        0.21134179830551147,\n        0.8732268214225769\n      ],\n      [\n        0.21935580670833588,\n        0.8933521509170532\n      ],\n      [\n        0.22803090512752533,\n        0.9137551784515381\n      ],\n      [\n        0.23744019865989685,\n        0.9344654083251953\n      ],\n      [\n        0.24765664339065552,\n        0.9555124640464783\n      ],\n      [\n        0.2587530016899109,\n        0.9769262671470642\n      ],\n      [\n        0.2708017826080322,\n        0.9987370371818542\n      ],\n      [\n        0.2838752567768097,\n        1.0209753513336182\n      ],\n      [\n        0.29804527759552,\n        1.0436722040176392\n      ],\n      [\n        0.3133833110332489,\n        1.0668591260910034\n      ],\n      [\n        0.32996034622192383,\n        1.0905684232711792\n      ],\n      [\n        0.3478468656539917,\n        1.1148326396942139\n      ],\n      [\n        0.36711281538009644,\n        1.1396855115890503\n      ],\n      [\n        0.3878275454044342,\n        1.1651611328125\n      ],\n      [\n        0.41000601649284363,\n        1.1912426948547363\n      ],\n      [\n        0.4336090087890625,\n        1.2178623676300049\n      ],\n      [\n        0.45859649777412415,\n        1.24495267868042\n      ],\n      [\n        0.4849279224872589,\n        1.272446632385254\n      ],\n      [\n        0.5125620365142822,\n        1.300277590751648\n      ],\n      [\n        0.5414570569992065,\n        1.3283785581588745\n      ],\n      [\n        0.571570634841919,\n        1.3566826581954956\n      ],\n      [\n        0.6028600931167603,\n        1.3851218223571777\n      ],\n      [\n        0.6352822780609131,\n        1.4136276245117188\n      ],\n      [\n        0.6687941551208496,\n        1.4421300888061523\n      ],\n      [\n        0.7033523917198181,\n        1.4705579280853271\n      ],\n      [\n        0.7389140725135803,\n        1.4988375902175903\n      ],\n      [\n        0.7754364013671875,\n        1.5268937349319458\n      ],\n      [\n        0.8128772974014282,\n        1.554648518562317\n      ],\n      [\n        0.8511953353881836,\n        1.5820214748382568\n      ],\n      [\n        0.8903500437736511,\n        1.6089295148849487\n      ],\n      [\n        0.930302083492279,\n        1.6352869272232056\n      ],\n      [\n        0.9710133075714111,\n        1.6610050201416016\n      ],\n      [\n        1.0124471187591553,\n        1.685992956161499\n      ],\n      [\n        1.0545682907104492,\n        1.7101572751998901\n      ],\n      [\n        1.0973432064056396,\n        1.7334026098251343\n      ],\n      [\n        1.1407400369644165,\n        1.755631923675537\n      ],\n      [\n        1.1847283840179443,\n        1.7767466306686401\n      ],\n      [\n        1.229279637336731,\n        1.7966474294662476\n      ],\n      [\n        1.2743664979934692,\n        1.8152343034744263\n      ],\n      [\n        1.3199630975723267,\n        1.8324071168899536\n      ],\n      [\n        1.3660447597503662,\n        1.848065733909607\n      ],\n      [\n        1.4125876426696777,\n        1.8621103763580322\n      ],\n      [\n        1.4595685005187988,\n        1.8744415044784546\n      ],\n      [\n        1.5069485902786255,\n        1.8850340843200684\n      ],\n      [\n        1.5546728372573853,\n        1.8939374685287476\n      ],\n      [\n        1.6026861667633057,\n        1.9012004137039185\n      ],\n      [\n        1.650932788848877,\n        1.9068700075149536\n      ],\n      [\n        1.6993565559387207,\n        1.9109913110733032\n      ],\n      [\n        1.7479000091552734,\n        1.9136062860488892\n      ],\n      [\n        1.7965046167373657,\n        1.9147534370422363\n      ],\n      [\n        1.8451100587844849,\n        1.9144660234451294\n      ],\n      [\n        1.893654465675354,\n        1.9127724170684814\n      ],\n      [\n        1.942073941230774,\n        1.9096946716308594\n      ],\n      [\n        1.990302324295044,\n        1.9052486419677734\n      ],\n      [\n        2.03827166557312,\n        1.8994430303573608\n      ],\n      [\n        2.0859110355377197,\n        1.8922799825668335\n      ],\n      [\n        2.133148193359375,\n        1.8837546110153198\n      ],\n      [\n        2.179908037185669,\n        1.8738549947738647\n      ],\n      [\n        2.2261133193969727,\n        1.862563133239746\n      ],\n      [\n        2.27168607711792,\n        1.8498553037643433\n      ],\n      [\n        2.316545248031616,\n        1.8357020616531372\n      ],\n      [\n        2.360609769821167,\n        1.8200697898864746\n      ],\n      [\n        2.403796911239624,\n        1.8029208183288574\n      ],\n      [\n        2.44602370262146,\n        1.7842146158218384\n      ],\n      [\n        2.4872069358825684,\n        1.7639081478118896\n      ],\n      [\n        2.5272626876831055,\n        1.7419570684432983\n      ],\n      [\n        2.566108465194702,\n        1.7183157205581665\n      ],\n      [\n        2.603734254837036,\n        1.693010687828064\n      ],\n      [\n        2.640204906463623,\n        1.6661417484283447\n      ],\n      [\n        2.6755847930908203,\n        1.6378077268600464\n      ],\n      [\n        2.7099392414093018,\n        1.6081076860427856\n      ],\n      [\n        2.743333101272583,\n        1.5771397352218628\n      ],\n      [\n        2.7758309841156006,\n        1.5450016260147095\n      ],\n      [\n        2.80749773979187,\n        1.5117899179458618\n      ],\n      [\n        2.8383967876434326,\n        1.477600336074829\n      ],\n      [\n        2.868591785430908,\n        1.4425268173217773\n      ],\n      [\n        2.8981447219848633,\n        1.4066622257232666\n      ],\n      [\n        2.9271178245544434,\n        1.3700973987579346\n      ],\n      [\n        2.9555718898773193,\n        1.3329222202301025\n      ],\n      [\n        2.983566999435425,\n        1.2952247858047485\n      ],\n      [\n        3.011162757873535,\n        1.2570923566818237\n      ],\n      [\n        3.038418769836426,\n        1.2186110019683838\n      ],\n      [\n        3.0653929710388184,\n        1.1798664331436157\n      ],\n      [\n        3.092144250869751,\n        1.1409443616867065\n      ],\n      [\n        3.118730306625366,\n        1.1019304990768433\n      ],\n      [\n        3.1451311111450195,\n        1.062860131263733\n      ],\n      [\n        3.171248197555542,\n        1.0237183570861816\n      ],\n      [\n        3.1969823837280273,\n        0.9844915866851807\n      ],\n      [\n        3.222233295440674,\n        0.945167064666748\n      ],\n      [\n        3.246898889541626,\n        0.905733585357666\n      ],\n      [\n        3.270875930786133,\n        0.8661811947822571\n      ],\n      [\n        3.294057846069336,\n        0.826501190662384\n      ],\n      [\n        3.3163373470306396,\n        0.7866860032081604\n      ],\n      [\n        3.3376033306121826,\n        0.7467291951179504\n      ],\n      [\n        3.357743263244629,\n        0.7066251039505005\n      ],\n      [\n        3.3766419887542725,\n        0.6663689613342285\n      ],\n      [\n        3.394181966781616,\n        0.6259563565254211\n      ],\n      [\n        3.4102442264556885,\n        0.5853835344314575\n      ],\n      [\n        3.424708127975464,\n        0.5446467995643616\n      ],\n      [\n        3.4374516010284424,\n        0.5037427544593811\n      ],\n      [\n        3.448352098464966,\n        0.46266797184944153\n      ],\n      [\n        3.457287073135376,\n        0.42141908407211304\n      ],\n      [\n        3.4641330242156982,\n        0.3799927234649658\n      ],\n      [\n        3.468876838684082,\n        0.33839157223701477\n      ],\n      [\n        3.471616506576538,\n        0.2966245114803314\n      ],\n      [\n        3.4724483489990234,\n        0.2547004222869873\n      ],\n      [\n        3.4714694023132324,\n        0.21262840926647186\n      ],\n      [\n        3.4687745571136475,\n        0.17041781544685364\n      ],\n      [\n        3.4644577503204346,\n        0.1280783712863922\n      ],\n      [\n        3.4586100578308105,\n        0.0856202244758606\n      ],\n      [\n        3.4513206481933594,\n        0.043054141104221344\n      ],\n      [\n        3.442674398422241,\n        0.0003915314737241715\n      ],\n      [\n        3.432753562927246,\n        -0.04235544055700302\n      ],\n      [\n        3.421635389328003,\n        -0.08517380803823471\n      ],\n      [\n        3.409393072128296,\n        -0.12804976105690002\n      ],\n      [\n        3.3960955142974854,\n        -0.17096871137619019\n      ],\n      [\n        3.3818066120147705,\n        -0.21391519904136658\n      ],\n      [\n        3.366586685180664,\n        -0.25687310099601746\n      ],\n      [\n        3.3504908084869385,\n        -0.29982560873031616\n      ],\n      [\n        3.333570718765259,\n        -0.34275543689727783\n      ],\n      [\n        3.315875291824341,\n        -0.3856448531150818\n      ],\n      [\n        3.297449827194214,\n        -0.42847591638565063\n      ],\n      [\n        3.278337240219116,\n        -0.47123050689697266\n      ],\n      [\n        3.2585792541503906,\n        -0.5138905048370361\n      ],\n      [\n        3.238215923309326,\n        -0.5564379692077637\n      ],\n      [\n        3.217292308807373,\n        -0.5988707542419434\n      ],\n      [\n        3.1958582401275635,\n        -0.6412028074264526\n      ],\n      [\n        3.1739635467529297,\n        -0.6834480166435242\n      ],\n      [\n        3.1516590118408203,\n        -0.7256200909614563\n      ],\n      [\n        3.1289961338043213,\n        -0.7677323818206787\n      ],\n      [\n        3.1060280799865723,\n        -0.8097975850105286\n      ],\n      [\n        3.082807779312134,\n        -0.8518276214599609\n      ],\n      [\n        3.0593905448913574,\n        -0.8938331604003906\n      ],\n      [\n        3.0358314514160156,\n        -0.9358235001564026\n      ],\n      [\n        3.0062689781188965,\n        -0.9883013367652893\n      ],\n      [\n        2.9885144233703613,\n        -1.0197867155075073\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/05_root_path/meta.json",
    "content": "{\n  \"text\": \"Initially standing still and calm, the person then starts jogging in a counterclockwise arc.\",\n  \"duration\": 6.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 62,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/06_root_waypoints/constraints.json",
    "content": "[\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [\n      0,\n      87,\n      169,\n      240\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.037946805357933044,\n        -0.036908961832523346\n      ],\n      [\n        2.2506563663482666,\n        0.06945009529590607\n      ],\n      [\n        2.23332142829895,\n        -2.0749685764312744\n      ],\n      [\n        4.0815324783325195,\n        -2.273184061050415\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/06_root_waypoints/meta.json",
    "content": "{\n  \"text\": \"A person is walking while carrying a small object in their left hand\",\n  \"duration\": 8.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 61,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/07_text_terrain/meta.json",
    "content": "{\n  \"text\": \"A person begins walking up the stairs\",\n  \"duration\": 3.5,\n  \"num_samples\": 1,\n  \"seed\": 44,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-g1-rp/08_text_object/meta.json",
    "content": "{\n  \"text\": \"A person picks up an object from low on their left side and places it up high\",\n  \"duration\": 5.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 47,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/01_single_text_prompt/meta.json",
    "content": "{\n  \"text\": \"A person runs forward and then leaps over an obstacle in front of them.\",\n  \"duration\": 5.0,\n  \"num_samples\": 1,\n  \"seed\": 42,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/02_multi_text_prompt/meta.json",
    "content": "{\n  \"texts\": [\n    \"A person is walking forward casually.\",\n    \"A person turns to the right and starts sneakily moving forward\"\n  ],\n  \"durations\": [\n    3.533333333333333,\n    4.033333333333333\n  ],\n  \"num_samples\": 1,\n  \"seed\": 42,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/03_full_body_keyframes/constraints.json",
    "content": "[\n  {\n    \"type\": \"fullbody\",\n    \"frame_indices\": [\n      79,\n      134\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.2765098512172699,\n          0.3728594183921814,\n          -0.3292054831981659\n        ],\n        [\n          0.35604047775268555,\n          0.018222831189632416,\n          -0.054862238466739655\n        ],\n        [\n          0.12065527588129044,\n          -0.027457308024168015,\n          -0.06907646358013153\n        ],\n        [\n          0.6048485636711121,\n          -0.11472737789154053,\n          -0.19573566317558289\n        ],\n        [\n          -0.12398597598075867,\n          0.03840772435069084,\n          0.18822282552719116\n        ],\n        [\n          -0.06553511321544647,\n          0.13032270967960358,\n          0.04257704317569733\n        ],\n        [\n          -0.24969959259033203,\n          0.06990747153759003,\n          0.13426002860069275\n        ],\n        [\n          -0.002762501360848546,\n          0.0010064352536574006,\n          -0.0012083332985639572\n        ],\n        [\n          -0.18770116567611694,\n          -0.06528781354427338,\n          0.006136383395642042\n        ],\n        [\n          -0.18933561444282532,\n          0.06753389537334442,\n          -0.00862747710198164\n        ],\n        [\n          0.1765439361333847,\n          -0.5079103708267212,\n          0.11742556095123291\n        ],\n        [\n          -0.6833809614181519,\n          -0.36341744661331177,\n          -0.09875624626874924\n        ],\n        [\n          -0.004083660896867514,\n          -0.2955799102783203,\n          0.007416445296257734\n        ],\n        [\n          -0.46948903799057007,\n          0.0019703502766788006,\n          0.2218078076839447\n        ],\n        [\n          0.15589098632335663,\n          0.29247695207595825,\n          -0.2839103043079376\n        ],\n        [\n          -0.006183772347867489,\n          0.039787642657756805,\n          -1.0509610176086426\n        ],\n        [\n          0.28110796213150024,\n          -0.01673225313425064,\n          0.05465283617377281\n        ],\n        [\n          0.4582408368587494,\n          0.6058111786842346,\n          1.040449619293213\n        ],\n        [\n          -0.016165010631084442,\n          0.7843144536018372,\n          0.007565980777144432\n        ],\n        [\n          -0.21160456538200378,\n          0.009858175180852413,\n          0.022257711738348007\n        ],\n        [\n          0.08559019863605499,\n          -0.26941442489624023,\n          0.28404051065444946\n        ],\n        [\n          -0.0722564086318016,\n          -0.055347055196762085,\n          0.8767912983894348\n        ],\n        [\n          -0.9036330580711365,\n          -0.19308030605316162,\n          0.6912829875946045\n        ],\n        [\n          1.7018375396728516,\n          -0.052370231598615646,\n          0.0016176343197003007\n        ],\n        [\n          -0.6713079810142517,\n          -0.22423480451107025,\n          -0.17199599742889404\n        ],\n        [\n          -0.2397085577249527,\n          -0.04111046716570854,\n          0.02976534143090248\n        ],\n        [\n          -1.4084941148757935,\n          -0.42399686574935913,\n          0.23780424892902374\n        ],\n        [\n          1.488803744316101,\n          -0.006882219575345516,\n          0.005796314682811499\n        ],\n        [\n          -0.34890878200531006,\n          0.25402817130088806,\n          -0.10165958851575851\n        ],\n        [\n          -0.017090337350964546,\n          0.013983047567307949,\n          -0.02469288557767868\n        ]\n      ],\n      [\n        [\n          -0.10219376534223557,\n          0.15241079032421112,\n          -0.1140606626868248\n        ],\n        [\n          -0.07097288966178894,\n          -0.023205779492855072,\n          0.014893154613673687\n        ],\n        [\n          -0.11436910182237625,\n          -0.07182353734970093,\n          -0.024793410673737526\n        ],\n        [\n          0.32571300864219666,\n          -0.11312247067689896,\n          -0.017911700531840324\n        ],\n        [\n          0.036515623331069946,\n          -0.0007576555362902582,\n          0.14029929041862488\n        ],\n        [\n          -0.06553909182548523,\n          0.07225329428911209,\n          0.0065536051988601685\n        ],\n        [\n          -0.09946814924478531,\n          0.02283940091729164,\n          0.060293473303318024\n        ],\n        [\n          -0.0007363191107288003,\n          0.0019088855478912592,\n          0.00034123589284718037\n        ],\n        [\n          -0.18651022017002106,\n          -0.06423485279083252,\n          0.0069741918705403805\n        ],\n        [\n          -0.18586836755275726,\n          0.06800899654626846,\n          -0.0060585117898881435\n        ],\n        [\n          0.23363706469535828,\n          -0.20687633752822876,\n          -0.07240967452526093\n        ],\n        [\n          -0.3135974407196045,\n          -0.2623864710330963,\n          -1.0657873153686523\n        ],\n        [\n          -0.012310811318457127,\n          -1.6650079488754272,\n          -0.010509567335247993\n        ],\n        [\n          -0.8171713352203369,\n          -0.2551392912864685,\n          0.08705981075763702\n        ],\n        [\n          0.13723036646842957,\n          0.2864063084125519,\n          -0.2900709807872772\n        ],\n        [\n          -0.005930017679929733,\n          0.05293968319892883,\n          -1.0459250211715698\n        ],\n        [\n          0.24218180775642395,\n          0.02018338069319725,\n          0.1226770281791687\n        ],\n        [\n          0.3315959572792053,\n          0.3782292902469635,\n          1.2296319007873535\n        ],\n        [\n          -0.0014527677558362484,\n          0.3045952022075653,\n          -0.0014049106976017356\n        ],\n        [\n          -0.20010970532894135,\n          -0.07485076785087585,\n          0.0041703470051288605\n        ],\n        [\n          0.08470325917005539,\n          -0.3079097270965576,\n          0.29375413060188293\n        ],\n        [\n          -0.09725581854581833,\n          -0.055068179965019226,\n          0.8742175698280334\n        ],\n        [\n          0.4040503203868866,\n          -0.016711091622710228,\n          0.21672509610652924\n        ],\n        [\n          0.5082376599311829,\n          -0.013459251262247562,\n          0.004872385878115892\n        ],\n        [\n          0.1745426058769226,\n          -0.24501416087150574,\n          -0.003703102469444275\n        ],\n        [\n          -0.33402949571609497,\n          -0.035541169345378876,\n          0.032360970973968506\n        ],\n        [\n          -0.37681734561920166,\n          0.02067263424396515,\n          0.10783999413251877\n        ],\n        [\n          0.4257254898548126,\n          0.0016118268249556422,\n          0.0033562832977622747\n        ],\n        [\n          0.04139057174324989,\n          0.032555095851421356,\n          0.04008425772190094\n        ],\n        [\n          -0.03090120106935501,\n          0.01570875011384487,\n          -0.024774780496954918\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        -0.18697306513786316,\n        0.7126776576042175,\n        1.1559109687805176\n      ],\n      [\n        -0.014062155969440937,\n        0.9611971974372864,\n        2.898127555847168\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        -0.18697306513786316,\n        1.1559109687805176\n      ],\n      [\n        -0.014062155969440937,\n        2.898127555847168\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/03_full_body_keyframes/meta.json",
    "content": "{\n  \"text\": \"A person walks forward and picks something up from the ground\",\n  \"duration\": 5.0,\n  \"num_samples\": 1,\n  \"seed\": 43,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/04_ee_constraint/constraints.json",
    "content": "[\n  {\n    \"type\": \"right-foot\",\n    \"frame_indices\": [\n      28,\n      94\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.14788010716438293,\n          -0.010833931155502796,\n          -0.01388303842395544\n        ],\n        [\n          -0.03901153802871704,\n          0.0003969503741245717,\n          -0.00016447225061710924\n        ],\n        [\n          -0.09507032483816147,\n          0.008639314211905003,\n          -0.0073561337776482105\n        ],\n        [\n          0.21237806975841522,\n          -0.02139095962047577,\n          -0.01700877584517002\n        ],\n        [\n          -0.20991119742393494,\n          0.06551700085401535,\n          -0.05272415280342102\n        ],\n        [\n          -0.06337061524391174,\n          0.05204080045223236,\n          0.014292852953076363\n        ],\n        [\n          0.07047945261001587,\n          0.08330309391021729,\n          -0.002013514516875148\n        ],\n        [\n          -0.0019600456580519676,\n          -0.0013381227618083358,\n          -2.7628393581835553e-06\n        ],\n        [\n          -0.18709787726402283,\n          -0.06659803539514542,\n          0.0078862514346838\n        ],\n        [\n          -0.18698126077651978,\n          0.06395528465509415,\n          -0.008215037174522877\n        ],\n        [\n          0.08230585604906082,\n          -0.38376951217651367,\n          0.05542140454053879\n        ],\n        [\n          -0.7260366082191467,\n          -0.24878422915935516,\n          -0.35609468817710876\n        ],\n        [\n          0.004249485209584236,\n          -0.4476320147514343,\n          -0.018469776958227158\n        ],\n        [\n          -0.9212101697921753,\n          -0.1470143049955368,\n          0.5044775605201721\n        ],\n        [\n          0.14870156347751617,\n          0.2985619604587555,\n          -0.29298385977745056\n        ],\n        [\n          0.001955621177330613,\n          0.055549487471580505,\n          -1.0630463361740112\n        ],\n        [\n          0.11859050393104553,\n          0.46535199880599976,\n          -0.030845582485198975\n        ],\n        [\n          -0.7298654317855835,\n          0.5346517562866211,\n          0.2791443467140198\n        ],\n        [\n          0.008972911164164543,\n          0.48752307891845703,\n          0.01847967691719532\n        ],\n        [\n          -0.5805565118789673,\n          -0.08708631247282028,\n          -0.15088550746440887\n        ],\n        [\n          0.08582834899425507,\n          -0.2886488735675812,\n          0.2854447066783905\n        ],\n        [\n          -0.0898093581199646,\n          -0.05874425172805786,\n          0.8657776117324829\n        ],\n        [\n          -0.3135877549648285,\n          0.07464626431465149,\n          0.0517989918589592\n        ],\n        [\n          0.29447537660598755,\n          -0.003720453940331936,\n          0.0011728419922292233\n        ],\n        [\n          -0.12890003621578217,\n          0.0839272066950798,\n          -0.090343177318573\n        ],\n        [\n          0.008360159583389759,\n          -0.03457032889127731,\n          0.02827553078532219\n        ],\n        [\n          -0.3120643198490143,\n          -0.01133657619357109,\n          -0.03218594938516617\n        ],\n        [\n          0.2538771331310272,\n          0.0018040596041828394,\n          0.0009352069464512169\n        ],\n        [\n          -0.0887608677148819,\n          -0.03465384244918823,\n          0.07154331356287003\n        ],\n        [\n          0.01681467890739441,\n          0.01778421923518181,\n          -0.025033073499798775\n        ]\n      ],\n      [\n        [\n          0.21243979036808014,\n          1.0922467708587646,\n          -0.05739659443497658\n        ],\n        [\n          -0.04288899898529053,\n          0.019888481125235558,\n          -0.014078406617045403\n        ],\n        [\n          -0.09594971686601639,\n          0.10335114598274231,\n          -0.007776615675538778\n        ],\n        [\n          0.2422163188457489,\n          0.08445896953344345,\n          -0.05605608597397804\n        ],\n        [\n          -0.14986605942249298,\n          0.10279522091150284,\n          -0.19410337507724762\n        ],\n        [\n          -0.07278254628181458,\n          0.00021229058620519936,\n          -0.0064666238613426685\n        ],\n        [\n          -0.18101167678833008,\n          -0.047196485102176666,\n          0.09371022135019302\n        ],\n        [\n          -0.0013136633206158876,\n          -0.0020103836432099342,\n          -0.0002618256548885256\n        ],\n        [\n          -0.1867513209581375,\n          -0.0681525468826294,\n          0.0023792991414666176\n        ],\n        [\n          -0.18714284896850586,\n          0.06443598866462708,\n          -0.003183535533025861\n        ],\n        [\n          0.1040755957365036,\n          -0.1164601668715477,\n          -0.08953910320997238\n        ],\n        [\n          -0.7818892598152161,\n          -0.40082883834838867,\n          -0.40901198983192444\n        ],\n        [\n          0.0014971806667745113,\n          -0.7006690502166748,\n          -0.003588718129321933\n        ],\n        [\n          -0.7653300762176514,\n          -0.030549153685569763,\n          0.5779297947883606\n        ],\n        [\n          0.1444747895002365,\n          0.30648332834243774,\n          -0.2944350242614746\n        ],\n        [\n          0.00627485616132617,\n          0.05844533443450928,\n          -1.0504485368728638\n        ],\n        [\n          0.16790169477462769,\n          0.6803913116455078,\n          -0.0802350640296936\n        ],\n        [\n          -0.7650246620178223,\n          0.2571314871311188,\n          0.044474273920059204\n        ],\n        [\n          0.00177879654802382,\n          0.32478848099708557,\n          0.024663111194968224\n        ],\n        [\n          -1.1130585670471191,\n          0.06198093295097351,\n          -0.1499929279088974\n        ],\n        [\n          0.09419120848178864,\n          -0.28672322630882263,\n          0.2861841320991516\n        ],\n        [\n          -0.08110660314559937,\n          -0.06315471976995468,\n          0.8641197085380554\n        ],\n        [\n          -0.4702282249927521,\n          -0.2976788580417633,\n          -0.08966172486543655\n        ],\n        [\n          0.2188275307416916,\n          -0.010813144035637379,\n          -0.0024994502309709787\n        ],\n        [\n          0.12644176185131073,\n          -0.4933742582798004,\n          -0.23269610106945038\n        ],\n        [\n          -0.05216464772820473,\n          -0.03182952478528023,\n          0.026469329372048378\n        ],\n        [\n          -0.21055173873901367,\n          -0.5854666233062744,\n          -0.08316371589899063\n        ],\n        [\n          0.2703852653503418,\n          -0.0070351893082261086,\n          0.00034556735772639513\n        ],\n        [\n          -0.20080512762069702,\n          -0.5529999136924744,\n          0.08794122189283371\n        ],\n        [\n          -0.020619722083210945,\n          0.01961597241461277,\n          -0.02498687617480755\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.006224155426025391,\n        1.0099574327468872,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        1.0039517879486084,\n        0.0002174415858462453\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.006224155426025391,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        0.0002174415858462453\n      ]\n    ]\n  },\n  {\n    \"type\": \"left-foot\",\n    \"frame_indices\": [\n      28,\n      94\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.14788010716438293,\n          -0.010833931155502796,\n          -0.01388303842395544\n        ],\n        [\n          -0.03901153802871704,\n          0.0003969503741245717,\n          -0.00016447225061710924\n        ],\n        [\n          -0.09507032483816147,\n          0.008639314211905003,\n          -0.0073561337776482105\n        ],\n        [\n          0.21237806975841522,\n          -0.02139095962047577,\n          -0.01700877584517002\n        ],\n        [\n          -0.20991119742393494,\n          0.06551700085401535,\n          -0.05272415280342102\n        ],\n        [\n          -0.06337061524391174,\n          0.05204080045223236,\n          0.014292852953076363\n        ],\n        [\n          0.07047945261001587,\n          0.08330309391021729,\n          -0.002013514516875148\n        ],\n        [\n          -0.0019600456580519676,\n          -0.0013381227618083358,\n          -2.7628393581835553e-06\n        ],\n        [\n          -0.18709787726402283,\n          -0.06659803539514542,\n          0.0078862514346838\n        ],\n        [\n          -0.18698126077651978,\n          0.06395528465509415,\n          -0.008215037174522877\n        ],\n        [\n          0.08230585604906082,\n          -0.38376951217651367,\n          0.05542140454053879\n        ],\n        [\n          -0.7260366082191467,\n          -0.24878422915935516,\n          -0.35609468817710876\n        ],\n        [\n          0.004249485209584236,\n          -0.4476320147514343,\n          -0.018469776958227158\n        ],\n        [\n          -0.9212101697921753,\n          -0.1470143049955368,\n          0.5044775605201721\n        ],\n        [\n          0.14870156347751617,\n          0.2985619604587555,\n          -0.29298385977745056\n        ],\n        [\n          0.001955621177330613,\n          0.055549487471580505,\n          -1.0630463361740112\n        ],\n        [\n          0.11859050393104553,\n          0.46535199880599976,\n          -0.030845582485198975\n        ],\n        [\n          -0.7298654317855835,\n          0.5346517562866211,\n          0.2791443467140198\n        ],\n        [\n          0.008972911164164543,\n          0.48752307891845703,\n          0.01847967691719532\n        ],\n        [\n          -0.5805565118789673,\n          -0.08708631247282028,\n          -0.15088550746440887\n        ],\n        [\n          0.08582834899425507,\n          -0.2886488735675812,\n          0.2854447066783905\n        ],\n        [\n          -0.0898093581199646,\n          -0.05874425172805786,\n          0.8657776117324829\n        ],\n        [\n          -0.3135877549648285,\n          0.07464626431465149,\n          0.0517989918589592\n        ],\n        [\n          0.29447537660598755,\n          -0.003720453940331936,\n          0.0011728419922292233\n        ],\n        [\n          -0.12890003621578217,\n          0.0839272066950798,\n          -0.090343177318573\n        ],\n        [\n          0.008360159583389759,\n          -0.03457032889127731,\n          0.02827553078532219\n        ],\n        [\n          -0.3120643198490143,\n          -0.01133657619357109,\n          -0.03218594938516617\n        ],\n        [\n          0.2538771331310272,\n          0.0018040596041828394,\n          0.0009352069464512169\n        ],\n        [\n          -0.0887608677148819,\n          -0.03465384244918823,\n          0.07154331356287003\n        ],\n        [\n          0.01681467890739441,\n          0.01778421923518181,\n          -0.025033073499798775\n        ]\n      ],\n      [\n        [\n          0.21243979036808014,\n          1.0922467708587646,\n          -0.05739659443497658\n        ],\n        [\n          -0.04288899898529053,\n          0.019888481125235558,\n          -0.014078406617045403\n        ],\n        [\n          -0.09594971686601639,\n          0.10335114598274231,\n          -0.007776615675538778\n        ],\n        [\n          0.2422163188457489,\n          0.08445896953344345,\n          -0.05605608597397804\n        ],\n        [\n          -0.14986605942249298,\n          0.10279522091150284,\n          -0.19410337507724762\n        ],\n        [\n          -0.07278254628181458,\n          0.00021229058620519936,\n          -0.0064666238613426685\n        ],\n        [\n          -0.18101167678833008,\n          -0.047196485102176666,\n          0.09371022135019302\n        ],\n        [\n          -0.0013136633206158876,\n          -0.0020103836432099342,\n          -0.0002618256548885256\n        ],\n        [\n          -0.1867513209581375,\n          -0.0681525468826294,\n          0.0023792991414666176\n        ],\n        [\n          -0.18714284896850586,\n          0.06443598866462708,\n          -0.003183535533025861\n        ],\n        [\n          0.1040755957365036,\n          -0.1164601668715477,\n          -0.08953910320997238\n        ],\n        [\n          -0.7818892598152161,\n          -0.40082883834838867,\n          -0.40901198983192444\n        ],\n        [\n          0.0014971806667745113,\n          -0.7006690502166748,\n          -0.003588718129321933\n        ],\n        [\n          -0.7653300762176514,\n          -0.030549153685569763,\n          0.5779297947883606\n        ],\n        [\n          0.1444747895002365,\n          0.30648332834243774,\n          -0.2944350242614746\n        ],\n        [\n          0.00627485616132617,\n          0.05844533443450928,\n          -1.0504485368728638\n        ],\n        [\n          0.16790169477462769,\n          0.6803913116455078,\n          -0.0802350640296936\n        ],\n        [\n          -0.7650246620178223,\n          0.2571314871311188,\n          0.044474273920059204\n        ],\n        [\n          0.00177879654802382,\n          0.32478848099708557,\n          0.024663111194968224\n        ],\n        [\n          -1.1130585670471191,\n          0.06198093295097351,\n          -0.1499929279088974\n        ],\n        [\n          0.09419120848178864,\n          -0.28672322630882263,\n          0.2861841320991516\n        ],\n        [\n          -0.08110660314559937,\n          -0.06315471976995468,\n          0.8641197085380554\n        ],\n        [\n          -0.4702282249927521,\n          -0.2976788580417633,\n          -0.08966172486543655\n        ],\n        [\n          0.2188275307416916,\n          -0.010813144035637379,\n          -0.0024994502309709787\n        ],\n        [\n          0.12644176185131073,\n          -0.4933742582798004,\n          -0.23269610106945038\n        ],\n        [\n          -0.05216464772820473,\n          -0.03182952478528023,\n          0.026469329372048378\n        ],\n        [\n          -0.21055173873901367,\n          -0.5854666233062744,\n          -0.08316371589899063\n        ],\n        [\n          0.2703852653503418,\n          -0.0070351893082261086,\n          0.00034556735772639513\n        ],\n        [\n          -0.20080512762069702,\n          -0.5529999136924744,\n          0.08794122189283371\n        ],\n        [\n          -0.020619722083210945,\n          0.01961597241461277,\n          -0.02498687617480755\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.006224155426025391,\n        1.0099574327468872,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        1.0039517879486084,\n        0.0002174415858462453\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.006224155426025391,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        0.0002174415858462453\n      ]\n    ]\n  },\n  {\n    \"type\": \"left-hand\",\n    \"frame_indices\": [\n      28,\n      94\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.14788010716438293,\n          -0.010833931155502796,\n          -0.01388303842395544\n        ],\n        [\n          -0.03901153802871704,\n          0.0003969503741245717,\n          -0.00016447225061710924\n        ],\n        [\n          -0.09507032483816147,\n          0.008639314211905003,\n          -0.0073561337776482105\n        ],\n        [\n          0.21237806975841522,\n          -0.02139095962047577,\n          -0.01700877584517002\n        ],\n        [\n          -0.20991119742393494,\n          0.06551700085401535,\n          -0.05272415280342102\n        ],\n        [\n          -0.06337061524391174,\n          0.05204080045223236,\n          0.014292852953076363\n        ],\n        [\n          0.07047945261001587,\n          0.08330309391021729,\n          -0.002013514516875148\n        ],\n        [\n          -0.0019600456580519676,\n          -0.0013381227618083358,\n          -2.7628393581835553e-06\n        ],\n        [\n          -0.18709787726402283,\n          -0.06659803539514542,\n          0.0078862514346838\n        ],\n        [\n          -0.18698126077651978,\n          0.06395528465509415,\n          -0.008215037174522877\n        ],\n        [\n          0.08230585604906082,\n          -0.38376951217651367,\n          0.05542140454053879\n        ],\n        [\n          -0.7260366082191467,\n          -0.24878422915935516,\n          -0.35609468817710876\n        ],\n        [\n          0.004249485209584236,\n          -0.4476320147514343,\n          -0.018469776958227158\n        ],\n        [\n          -0.9212101697921753,\n          -0.1470143049955368,\n          0.5044775605201721\n        ],\n        [\n          0.14870156347751617,\n          0.2985619604587555,\n          -0.29298385977745056\n        ],\n        [\n          0.001955621177330613,\n          0.055549487471580505,\n          -1.0630463361740112\n        ],\n        [\n          0.11859050393104553,\n          0.46535199880599976,\n          -0.030845582485198975\n        ],\n        [\n          -0.7298654317855835,\n          0.5346517562866211,\n          0.2791443467140198\n        ],\n        [\n          0.008972911164164543,\n          0.48752307891845703,\n          0.01847967691719532\n        ],\n        [\n          -0.5805565118789673,\n          -0.08708631247282028,\n          -0.15088550746440887\n        ],\n        [\n          0.08582834899425507,\n          -0.2886488735675812,\n          0.2854447066783905\n        ],\n        [\n          -0.0898093581199646,\n          -0.05874425172805786,\n          0.8657776117324829\n        ],\n        [\n          -0.3135877549648285,\n          0.07464626431465149,\n          0.0517989918589592\n        ],\n        [\n          0.29447537660598755,\n          -0.003720453940331936,\n          0.0011728419922292233\n        ],\n        [\n          -0.12890003621578217,\n          0.0839272066950798,\n          -0.090343177318573\n        ],\n        [\n          0.008360159583389759,\n          -0.03457032889127731,\n          0.02827553078532219\n        ],\n        [\n          -0.3120643198490143,\n          -0.01133657619357109,\n          -0.03218594938516617\n        ],\n        [\n          0.2538771331310272,\n          0.0018040596041828394,\n          0.0009352069464512169\n        ],\n        [\n          -0.0887608677148819,\n          -0.03465384244918823,\n          0.07154331356287003\n        ],\n        [\n          0.01681467890739441,\n          0.01778421923518181,\n          -0.025033073499798775\n        ]\n      ],\n      [\n        [\n          0.21243979036808014,\n          1.0922467708587646,\n          -0.05739659443497658\n        ],\n        [\n          -0.04288899898529053,\n          0.019888481125235558,\n          -0.014078406617045403\n        ],\n        [\n          -0.09594971686601639,\n          0.10335114598274231,\n          -0.007776615675538778\n        ],\n        [\n          0.2422163188457489,\n          0.08445896953344345,\n          -0.05605608597397804\n        ],\n        [\n          -0.14986605942249298,\n          0.10279522091150284,\n          -0.19410337507724762\n        ],\n        [\n          -0.07278254628181458,\n          0.00021229058620519936,\n          -0.0064666238613426685\n        ],\n        [\n          -0.18101167678833008,\n          -0.047196485102176666,\n          0.09371022135019302\n        ],\n        [\n          -0.0013136633206158876,\n          -0.0020103836432099342,\n          -0.0002618256548885256\n        ],\n        [\n          -0.1867513209581375,\n          -0.0681525468826294,\n          0.0023792991414666176\n        ],\n        [\n          -0.18714284896850586,\n          0.06443598866462708,\n          -0.003183535533025861\n        ],\n        [\n          0.1040755957365036,\n          -0.1164601668715477,\n          -0.08953910320997238\n        ],\n        [\n          -0.7818892598152161,\n          -0.40082883834838867,\n          -0.40901198983192444\n        ],\n        [\n          0.0014971806667745113,\n          -0.7006690502166748,\n          -0.003588718129321933\n        ],\n        [\n          -0.7653300762176514,\n          -0.030549153685569763,\n          0.5779297947883606\n        ],\n        [\n          0.1444747895002365,\n          0.30648332834243774,\n          -0.2944350242614746\n        ],\n        [\n          0.00627485616132617,\n          0.05844533443450928,\n          -1.0504485368728638\n        ],\n        [\n          0.16790169477462769,\n          0.6803913116455078,\n          -0.0802350640296936\n        ],\n        [\n          -0.7650246620178223,\n          0.2571314871311188,\n          0.044474273920059204\n        ],\n        [\n          0.00177879654802382,\n          0.32478848099708557,\n          0.024663111194968224\n        ],\n        [\n          -1.1130585670471191,\n          0.06198093295097351,\n          -0.1499929279088974\n        ],\n        [\n          0.09419120848178864,\n          -0.28672322630882263,\n          0.2861841320991516\n        ],\n        [\n          -0.08110660314559937,\n          -0.06315471976995468,\n          0.8641197085380554\n        ],\n        [\n          -0.4702282249927521,\n          -0.2976788580417633,\n          -0.08966172486543655\n        ],\n        [\n          0.2188275307416916,\n          -0.010813144035637379,\n          -0.0024994502309709787\n        ],\n        [\n          0.12644176185131073,\n          -0.4933742582798004,\n          -0.23269610106945038\n        ],\n        [\n          -0.05216464772820473,\n          -0.03182952478528023,\n          0.026469329372048378\n        ],\n        [\n          -0.21055173873901367,\n          -0.5854666233062744,\n          -0.08316371589899063\n        ],\n        [\n          0.2703852653503418,\n          -0.0070351893082261086,\n          0.00034556735772639513\n        ],\n        [\n          -0.20080512762069702,\n          -0.5529999136924744,\n          0.08794122189283371\n        ],\n        [\n          -0.020619722083210945,\n          0.01961597241461277,\n          -0.02498687617480755\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.006224155426025391,\n        1.0099574327468872,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        1.0039517879486084,\n        0.0002174415858462453\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.006224155426025391,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        0.0002174415858462453\n      ]\n    ]\n  },\n  {\n    \"type\": \"right-hand\",\n    \"frame_indices\": [\n      28,\n      94\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          0.14788010716438293,\n          -0.010833931155502796,\n          -0.01388303842395544\n        ],\n        [\n          -0.03901153802871704,\n          0.0003969503741245717,\n          -0.00016447225061710924\n        ],\n        [\n          -0.09507032483816147,\n          0.008639314211905003,\n          -0.0073561337776482105\n        ],\n        [\n          0.21237806975841522,\n          -0.02139095962047577,\n          -0.01700877584517002\n        ],\n        [\n          -0.20991119742393494,\n          0.06551700085401535,\n          -0.05272415280342102\n        ],\n        [\n          -0.06337061524391174,\n          0.05204080045223236,\n          0.014292852953076363\n        ],\n        [\n          0.07047945261001587,\n          0.08330309391021729,\n          -0.002013514516875148\n        ],\n        [\n          -0.0019600456580519676,\n          -0.0013381227618083358,\n          -2.7628393581835553e-06\n        ],\n        [\n          -0.18709787726402283,\n          -0.06659803539514542,\n          0.0078862514346838\n        ],\n        [\n          -0.18698126077651978,\n          0.06395528465509415,\n          -0.008215037174522877\n        ],\n        [\n          0.08230585604906082,\n          -0.38376951217651367,\n          0.05542140454053879\n        ],\n        [\n          -0.7260366082191467,\n          -0.24878422915935516,\n          -0.35609468817710876\n        ],\n        [\n          0.004249485209584236,\n          -0.4476320147514343,\n          -0.018469776958227158\n        ],\n        [\n          -0.9212101697921753,\n          -0.1470143049955368,\n          0.5044775605201721\n        ],\n        [\n          0.14870156347751617,\n          0.2985619604587555,\n          -0.29298385977745056\n        ],\n        [\n          0.001955621177330613,\n          0.055549487471580505,\n          -1.0630463361740112\n        ],\n        [\n          0.11859050393104553,\n          0.46535199880599976,\n          -0.030845582485198975\n        ],\n        [\n          -0.7298654317855835,\n          0.5346517562866211,\n          0.2791443467140198\n        ],\n        [\n          0.008972911164164543,\n          0.48752307891845703,\n          0.01847967691719532\n        ],\n        [\n          -0.5805565118789673,\n          -0.08708631247282028,\n          -0.15088550746440887\n        ],\n        [\n          0.08582834899425507,\n          -0.2886488735675812,\n          0.2854447066783905\n        ],\n        [\n          -0.0898093581199646,\n          -0.05874425172805786,\n          0.8657776117324829\n        ],\n        [\n          -0.3135877549648285,\n          0.07464626431465149,\n          0.0517989918589592\n        ],\n        [\n          0.29447537660598755,\n          -0.003720453940331936,\n          0.0011728419922292233\n        ],\n        [\n          -0.12890003621578217,\n          0.0839272066950798,\n          -0.090343177318573\n        ],\n        [\n          0.008360159583389759,\n          -0.03457032889127731,\n          0.02827553078532219\n        ],\n        [\n          -0.3120643198490143,\n          -0.01133657619357109,\n          -0.03218594938516617\n        ],\n        [\n          0.2538771331310272,\n          0.0018040596041828394,\n          0.0009352069464512169\n        ],\n        [\n          -0.0887608677148819,\n          -0.03465384244918823,\n          0.07154331356287003\n        ],\n        [\n          0.01681467890739441,\n          0.01778421923518181,\n          -0.025033073499798775\n        ]\n      ],\n      [\n        [\n          0.21243979036808014,\n          1.0922467708587646,\n          -0.05739659443497658\n        ],\n        [\n          -0.04288899898529053,\n          0.019888481125235558,\n          -0.014078406617045403\n        ],\n        [\n          -0.09594971686601639,\n          0.10335114598274231,\n          -0.007776615675538778\n        ],\n        [\n          0.2422163188457489,\n          0.08445896953344345,\n          -0.05605608597397804\n        ],\n        [\n          -0.14986605942249298,\n          0.10279522091150284,\n          -0.19410337507724762\n        ],\n        [\n          -0.07278254628181458,\n          0.00021229058620519936,\n          -0.0064666238613426685\n        ],\n        [\n          -0.18101167678833008,\n          -0.047196485102176666,\n          0.09371022135019302\n        ],\n        [\n          -0.0013136633206158876,\n          -0.0020103836432099342,\n          -0.0002618256548885256\n        ],\n        [\n          -0.1867513209581375,\n          -0.0681525468826294,\n          0.0023792991414666176\n        ],\n        [\n          -0.18714284896850586,\n          0.06443598866462708,\n          -0.003183535533025861\n        ],\n        [\n          0.1040755957365036,\n          -0.1164601668715477,\n          -0.08953910320997238\n        ],\n        [\n          -0.7818892598152161,\n          -0.40082883834838867,\n          -0.40901198983192444\n        ],\n        [\n          0.0014971806667745113,\n          -0.7006690502166748,\n          -0.003588718129321933\n        ],\n        [\n          -0.7653300762176514,\n          -0.030549153685569763,\n          0.5779297947883606\n        ],\n        [\n          0.1444747895002365,\n          0.30648332834243774,\n          -0.2944350242614746\n        ],\n        [\n          0.00627485616132617,\n          0.05844533443450928,\n          -1.0504485368728638\n        ],\n        [\n          0.16790169477462769,\n          0.6803913116455078,\n          -0.0802350640296936\n        ],\n        [\n          -0.7650246620178223,\n          0.2571314871311188,\n          0.044474273920059204\n        ],\n        [\n          0.00177879654802382,\n          0.32478848099708557,\n          0.024663111194968224\n        ],\n        [\n          -1.1130585670471191,\n          0.06198093295097351,\n          -0.1499929279088974\n        ],\n        [\n          0.09419120848178864,\n          -0.28672322630882263,\n          0.2861841320991516\n        ],\n        [\n          -0.08110660314559937,\n          -0.06315471976995468,\n          0.8641197085380554\n        ],\n        [\n          -0.4702282249927521,\n          -0.2976788580417633,\n          -0.08966172486543655\n        ],\n        [\n          0.2188275307416916,\n          -0.010813144035637379,\n          -0.0024994502309709787\n        ],\n        [\n          0.12644176185131073,\n          -0.4933742582798004,\n          -0.23269610106945038\n        ],\n        [\n          -0.05216464772820473,\n          -0.03182952478528023,\n          0.026469329372048378\n        ],\n        [\n          -0.21055173873901367,\n          -0.5854666233062744,\n          -0.08316371589899063\n        ],\n        [\n          0.2703852653503418,\n          -0.0070351893082261086,\n          0.00034556735772639513\n        ],\n        [\n          -0.20080512762069702,\n          -0.5529999136924744,\n          0.08794122189283371\n        ],\n        [\n          -0.020619722083210945,\n          0.01961597241461277,\n          -0.02498687617480755\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        0.006224155426025391,\n        1.0099574327468872,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        1.0039517879486084,\n        0.0002174415858462453\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.006224155426025391,\n        0.0004121592501178384\n      ],\n      [\n        0.025673866271972656,\n        0.0002174415858462453\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/04_ee_constraint/meta.json",
    "content": "{\n  \"text\": \"A person picks up an object in front of them with two hands and places it to the left side\",\n  \"duration\": 5.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 48,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/05_root_path/constraints.json",
    "content": "[\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [\n      0,\n      1,\n      2,\n      3,\n      4,\n      5,\n      6,\n      7,\n      8,\n      9,\n      10,\n      11,\n      12,\n      13,\n      14,\n      15,\n      16,\n      17,\n      18,\n      19,\n      20,\n      21,\n      22,\n      23,\n      24,\n      25,\n      26,\n      27,\n      28,\n      29,\n      30,\n      31,\n      32,\n      33,\n      34,\n      35,\n      36,\n      37,\n      38,\n      39,\n      40,\n      41,\n      42,\n      43,\n      44,\n      45,\n      46,\n      47,\n      48,\n      49,\n      50,\n      51,\n      52,\n      53,\n      54,\n      55,\n      56,\n      57,\n      58,\n      59,\n      60,\n      61,\n      62,\n      63,\n      64,\n      65,\n      66,\n      67,\n      68,\n      69,\n      70,\n      71,\n      72,\n      73,\n      74,\n      75,\n      76,\n      77,\n      78,\n      79,\n      80,\n      81,\n      82,\n      83,\n      84,\n      85,\n      86,\n      87,\n      88,\n      89,\n      90,\n      91,\n      92,\n      93,\n      94,\n      95,\n      96,\n      97,\n      98,\n      99,\n      100,\n      101,\n      102,\n      103,\n      104,\n      105,\n      106,\n      107,\n      108,\n      109,\n      110,\n      111,\n      112,\n      113,\n      114,\n      115,\n      116,\n      117,\n      118,\n      119,\n      120,\n      121,\n      122,\n      123,\n      124,\n      125,\n      126,\n      127,\n      128,\n      129,\n      130,\n      131,\n      132,\n      133,\n      134,\n      135,\n      136,\n      137,\n      138,\n      139,\n      140,\n      141,\n      142,\n      143,\n      144,\n      145,\n      146,\n      147,\n      148,\n      149,\n      150,\n      151,\n      152,\n      153,\n      154,\n      155,\n      156,\n      157,\n      158,\n      159,\n      160,\n      161,\n      162,\n      163,\n      164,\n      165,\n      166,\n      167,\n      168,\n      169,\n      170,\n      171,\n      172,\n      173,\n      174,\n      175,\n      176,\n      177,\n      178,\n      179,\n      180,\n      181,\n      182,\n      183,\n      184,\n      185,\n      186,\n      187,\n      188,\n      189,\n      190,\n      191,\n      192,\n      193,\n      194,\n      195,\n      196,\n      197,\n      198,\n      199,\n      200,\n      201,\n      202,\n      203,\n      204,\n      205,\n      206,\n      207,\n      208,\n      209,\n      210,\n      211,\n      212,\n      213,\n      214,\n      215,\n      216,\n      217,\n      218,\n      219,\n      220,\n      221,\n      222,\n      223,\n      224,\n      225,\n      226,\n      227,\n      228,\n      229,\n      230,\n      231,\n      232,\n      233,\n      234,\n      235,\n      236,\n      237,\n      238,\n      239,\n      240,\n      241,\n      242,\n      243,\n      244,\n      245,\n      246,\n      247,\n      248,\n      249,\n      250,\n      251,\n      252,\n      253,\n      254,\n      255,\n      256,\n      257,\n      258,\n      259,\n      260,\n      261,\n      262,\n      263,\n      264,\n      265,\n      266,\n      267,\n      268,\n      269,\n      270,\n      271,\n      272,\n      273,\n      274,\n      275,\n      276,\n      277,\n      278,\n      279,\n      280,\n      281,\n      282,\n      283,\n      284,\n      285,\n      286,\n      287,\n      288,\n      289,\n      290,\n      291,\n      292,\n      293,\n      294,\n      295,\n      296,\n      297,\n      298,\n      299\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.0720488652586937,\n        0.005473949480801821\n      ],\n      [\n        0.08682604879140854,\n        0.03799768537282944\n      ],\n      [\n        0.10160323977470398,\n        0.07052142173051834\n      ],\n      [\n        0.11638043075799942,\n        0.10304517298936844\n      ],\n      [\n        0.13115762174129486,\n        0.13556894659996033\n      ],\n      [\n        0.1459348350763321,\n        0.1680927276611328\n      ],\n      [\n        0.16071203351020813,\n        0.20061656832695007\n      ],\n      [\n        0.17548926174640656,\n        0.23314043879508972\n      ],\n      [\n        0.19026650488376617,\n        0.26566436886787415\n      ],\n      [\n        0.2050437480211258,\n        0.29818838834762573\n      ],\n      [\n        0.2198210209608078,\n        0.3307124972343445\n      ],\n      [\n        0.234598308801651,\n        0.3632366955280304\n      ],\n      [\n        0.2493756115436554,\n        0.39576101303100586\n      ],\n      [\n        0.2641529440879822,\n        0.42828547954559326\n      ],\n      [\n        0.27893027663230896,\n        0.4608100950717926\n      ],\n      [\n        0.29370763897895813,\n        0.4933348596096039\n      ],\n      [\n        0.3084850013256073,\n        0.5258598327636719\n      ],\n      [\n        0.32326239347457886,\n        0.5583849549293518\n      ],\n      [\n        0.3380397856235504,\n        0.5909103155136108\n      ],\n      [\n        0.352817177772522,\n        0.623435914516449\n      ],\n      [\n        0.36759456992149353,\n        0.6559617519378662\n      ],\n      [\n        0.3823719322681427,\n        0.6884878277778625\n      ],\n      [\n        0.39714932441711426,\n        0.721014142036438\n      ],\n      [\n        0.41192665696144104,\n        0.7535408139228821\n      ],\n      [\n        0.4267039895057678,\n        0.7860677242279053\n      ],\n      [\n        0.4414812922477722,\n        0.8185949325561523\n      ],\n      [\n        0.4562585949897766,\n        0.8511224389076233\n      ],\n      [\n        0.47103583812713623,\n        0.8836503028869629\n      ],\n      [\n        0.48581308126449585,\n        0.9161785244941711\n      ],\n      [\n        0.5005902647972107,\n        0.948707103729248\n      ],\n      [\n        0.5153675079345703,\n        0.9812359809875488\n      ],\n      [\n        0.5301446914672852,\n        1.0137652158737183\n      ],\n      [\n        0.5449219346046448,\n        1.046294927597046\n      ],\n      [\n        0.5596991777420044,\n        1.0788248777389526\n      ],\n      [\n        0.5744765400886536,\n        1.1113553047180176\n      ],\n      [\n        0.5892539024353027,\n        1.1438862085342407\n      ],\n      [\n        0.6040313243865967,\n        1.1764174699783325\n      ],\n      [\n        0.6188088655471802,\n        1.208949089050293\n      ],\n      [\n        0.6335865259170532,\n        1.2414813041687012\n      ],\n      [\n        0.648364245891571,\n        1.274013876914978\n      ],\n      [\n        0.6631421446800232,\n        1.3065470457077026\n      ],\n      [\n        0.6779201030731201,\n        1.3390806913375854\n      ],\n      [\n        0.6926981806755066,\n        1.371614933013916\n      ],\n      [\n        0.7074640989303589,\n        1.4041519165039062\n      ],\n      [\n        0.7221670746803284,\n        1.4367012977600098\n      ],\n      [\n        0.7367299199104309,\n        1.4692773818969727\n      ],\n      [\n        0.7510751485824585,\n        1.5018945932388306\n      ],\n      [\n        0.7651242613792419,\n        1.5345673561096191\n      ],\n      [\n        0.7787973880767822,\n        1.5673108100891113\n      ],\n      [\n        0.7920125126838684,\n        1.6001399755477905\n      ],\n      [\n        0.8046852350234985,\n        1.6330705881118774\n      ],\n      [\n        0.8167278170585632,\n        1.66611909866333\n      ],\n      [\n        0.8280492424964905,\n        1.6993021965026855\n      ],\n      [\n        0.8385547399520874,\n        1.7326377630233765\n      ],\n      [\n        0.8481456637382507,\n        1.766144037246704\n      ],\n      [\n        0.856719434261322,\n        1.7998400926589966\n      ],\n      [\n        0.8641700744628906,\n        1.8337457180023193\n      ],\n      [\n        0.8703880906105042,\n        1.8678812980651855\n      ],\n      [\n        0.875261127948761,\n        1.9022676944732666\n      ],\n      [\n        0.8786745071411133,\n        1.9369266033172607\n      ],\n      [\n        0.8805115222930908,\n        1.971879482269287\n      ],\n      [\n        0.8806543946266174,\n        2.0071487426757812\n      ],\n      [\n        0.8789843320846558,\n        2.0427565574645996\n      ],\n      [\n        0.8753821849822998,\n        2.0787250995635986\n      ],\n      [\n        0.869838297367096,\n        2.1150567531585693\n      ],\n      [\n        0.8624524474143982,\n        2.1517333984375\n      ],\n      [\n        0.8533244729042053,\n        2.1887366771698\n      ],\n      [\n        0.8425538539886475,\n        2.226048469543457\n      ],\n      [\n        0.8302397131919861,\n        2.263650894165039\n      ],\n      [\n        0.816480278968811,\n        2.301525831222534\n      ],\n      [\n        0.8013728260993958,\n        2.3396553993225098\n      ],\n      [\n        0.7850133180618286,\n        2.3780221939086914\n      ],\n      [\n        0.7674961686134338,\n        2.4166083335876465\n      ],\n      [\n        0.7489144802093506,\n        2.4553961753845215\n      ],\n      [\n        0.7293595671653748,\n        2.494368553161621\n      ],\n      [\n        0.7089214324951172,\n        2.533508062362671\n      ],\n      [\n        0.6876888871192932,\n        2.5727970600128174\n      ],\n      [\n        0.665749728679657,\n        2.6122183799743652\n      ],\n      [\n        0.6431912779808044,\n        2.651754379272461\n      ],\n      [\n        0.6200692653656006,\n        2.691394805908203\n      ],\n      [\n        0.5964087247848511,\n        2.731137275695801\n      ],\n      [\n        0.5722349882125854,\n        2.770979166030884\n      ],\n      [\n        0.5475742816925049,\n        2.810917615890503\n      ],\n      [\n        0.5224538445472717,\n        2.8509483337402344\n      ],\n      [\n        0.49690231680870056,\n        2.8910679817199707\n      ],\n      [\n        0.47094982862472534,\n        2.93127179145813\n      ],\n      [\n        0.44462811946868896,\n        2.971554756164551\n      ],\n      [\n        0.4179706573486328,\n        3.011911630630493\n      ],\n      [\n        0.3910125195980072,\n        3.0523364543914795\n      ],\n      [\n        0.3637904226779938,\n        3.0928235054016113\n      ],\n      [\n        0.336342453956604,\n        3.133366107940674\n      ],\n      [\n        0.3087080717086792,\n        3.173957586288452\n      ],\n      [\n        0.2809275984764099,\n        3.2145910263061523\n      ],\n      [\n        0.25304216146469116,\n        3.2552595138549805\n      ],\n      [\n        0.2250932902097702,\n        3.2959556579589844\n      ],\n      [\n        0.19712261855602264,\n        3.336672067642212\n      ],\n      [\n        0.16917157173156738,\n        3.3774020671844482\n      ],\n      [\n        0.14128103852272034,\n        3.418138027191162\n      ],\n      [\n        0.11349108070135117,\n        3.4588732719421387\n      ],\n      [\n        0.08584070205688477,\n        3.499600648880005\n      ],\n      [\n        0.05836760997772217,\n        3.540313243865967\n      ],\n      [\n        0.031108075752854347,\n        3.5810046195983887\n      ],\n      [\n        0.004096813499927521,\n        3.6216683387756348\n      ],\n      [\n        -0.022633060812950134,\n        3.6622982025146484\n      ],\n      [\n        -0.049050018191337585,\n        3.702888250350952\n      ],\n      [\n        -0.07512406259775162,\n        3.7434325218200684\n      ],\n      [\n        -0.10082659870386124,\n        3.7839250564575195\n      ],\n      [\n        -0.12613031268119812,\n        3.8243606090545654\n      ],\n      [\n        -0.1510089486837387,\n        3.8647332191467285\n      ],\n      [\n        -0.17543718218803406,\n        3.9050378799438477\n      ],\n      [\n        -0.19939035177230835,\n        3.9452688694000244\n      ],\n      [\n        -0.22284428775310516,\n        3.9854207038879395\n      ],\n      [\n        -0.24577516317367554,\n        4.025487899780273\n      ],\n      [\n        -0.26815930008888245,\n        4.065464496612549\n      ],\n      [\n        -0.28985288739204407,\n        4.1053338050842285\n      ],\n      [\n        -0.3105919361114502,\n        4.145066261291504\n      ],\n      [\n        -0.33011239767074585,\n        4.184632301330566\n      ],\n      [\n        -0.34815022349357605,\n        4.224003314971924\n      ],\n      [\n        -0.3644413650035858,\n        4.263148784637451\n      ],\n      [\n        -0.3787217438220978,\n        4.302039623260498\n      ],\n      [\n        -0.3907274007797241,\n        4.340645790100098\n      ],\n      [\n        -0.4001944959163666,\n        4.378937721252441\n      ],\n      [\n        -0.40685927867889404,\n        4.416884899139404\n      ],\n      [\n        -0.41045811772346497,\n        4.4544572830200195\n      ],\n      [\n        -0.41072750091552734,\n        4.491624355316162\n      ],\n      [\n        -0.40740400552749634,\n        4.528356552124023\n      ],\n      [\n        -0.4004855453968048,\n        4.564655303955078\n      ],\n      [\n        -0.3902314007282257,\n        4.600553512573242\n      ],\n      [\n        -0.37690070271492004,\n        4.636085033416748\n      ],\n      [\n        -0.3607523441314697,\n        4.67128324508667\n      ],\n      [\n        -0.3420449197292328,\n        4.706181049346924\n      ],\n      [\n        -0.32103657722473145,\n        4.740812301635742\n      ],\n      [\n        -0.2979850471019745,\n        4.775211334228516\n      ],\n      [\n        -0.2731475234031677,\n        4.809412002563477\n      ],\n      [\n        -0.24678070843219757,\n        4.843447685241699\n      ],\n      [\n        -0.21914079785346985,\n        4.877353668212891\n      ],\n      [\n        -0.19048355519771576,\n        4.911164283752441\n      ],\n      [\n        -0.16106447577476501,\n        4.944913864135742\n      ],\n      [\n        -0.13102509081363678,\n        4.978619575500488\n      ],\n      [\n        -0.10039319843053818,\n        5.0122785568237305\n      ],\n      [\n        -0.06919693201780319,\n        5.0458903312683105\n      ],\n      [\n        -0.03746507689356804,\n        5.079452991485596\n      ],\n      [\n        -0.005227350629866123,\n        5.1129655838012695\n      ],\n      [\n        0.027485284954309464,\n        5.146428108215332\n      ],\n      [\n        0.06064034625887871,\n        5.179840087890625\n      ],\n      [\n        0.09420355409383774,\n        5.213201522827148\n      ],\n      [\n        0.12813864648342133,\n        5.246513843536377\n      ],\n      [\n        0.16240715980529785,\n        5.279778003692627\n      ],\n      [\n        0.19696833193302155,\n        5.312995910644531\n      ],\n      [\n        0.2317790538072586,\n        5.3461689949035645\n      ],\n      [\n        0.266793817281723,\n        5.379299640655518\n      ],\n      [\n        0.30196475982666016,\n        5.412391662597656\n      ],\n      [\n        0.3372417688369751,\n        5.4454474449157715\n      ],\n      [\n        0.37257257103919983,\n        5.478470325469971\n      ],\n      [\n        0.40790289640426636,\n        5.511464595794678\n      ],\n      [\n        0.4431767165660858,\n        5.544434547424316\n      ],\n      [\n        0.478336364030838,\n        5.577383518218994\n      ],\n      [\n        0.5133227705955505,\n        5.610316753387451\n      ],\n      [\n        0.5480756759643555,\n        5.643238544464111\n      ],\n      [\n        0.5825338363647461,\n        5.676154136657715\n      ],\n      [\n        0.6166353225708008,\n        5.709067344665527\n      ],\n      [\n        0.6503174901008606,\n        5.741983413696289\n      ],\n      [\n        0.6835171580314636,\n        5.774907112121582\n      ],\n      [\n        0.7161709666252136,\n        5.8078436851501465\n      ],\n      [\n        0.7482153177261353,\n        5.840796947479248\n      ],\n      [\n        0.7795863747596741,\n        5.873773097991943\n      ],\n      [\n        0.8102203011512756,\n        5.906775951385498\n      ],\n      [\n        0.8400532603263855,\n        5.939810276031494\n      ],\n      [\n        0.8690049648284912,\n        5.9728803634643555\n      ],\n      [\n        0.8969439268112183,\n        6.005988121032715\n      ],\n      [\n        0.9237036108970642,\n        6.039134979248047\n      ],\n      [\n        0.9491175413131714,\n        6.072321891784668\n      ],\n      [\n        0.9730191230773926,\n        6.105550289154053\n      ],\n      [\n        0.9952419996261597,\n        6.138820171356201\n      ],\n      [\n        1.0156195163726807,\n        6.172133445739746\n      ],\n      [\n        1.0339852571487427,\n        6.205490589141846\n      ],\n      [\n        1.0501729249954224,\n        6.238892555236816\n      ],\n      [\n        1.0640157461166382,\n        6.272340774536133\n      ],\n      [\n        1.075347661972046,\n        6.305835723876953\n      ],\n      [\n        1.084001898765564,\n        6.339378356933594\n      ],\n      [\n        1.0898123979568481,\n        6.372969627380371\n      ],\n      [\n        1.0927863121032715,\n        6.406609058380127\n      ],\n      [\n        1.093105435371399,\n        6.440292835235596\n      ],\n      [\n        1.090950846672058,\n        6.474018096923828\n      ],\n      [\n        1.0865041017532349,\n        6.507782459259033\n      ],\n      [\n        1.079946517944336,\n        6.541581630706787\n      ],\n      [\n        1.0714592933654785,\n        6.575413227081299\n      ],\n      [\n        1.0612238645553589,\n        6.609274387359619\n      ],\n      [\n        1.0494211912155151,\n        6.643161773681641\n      ],\n      [\n        1.036232590675354,\n        6.677072525024414\n      ],\n      [\n        1.0218391418457031,\n        6.71100378036499\n      ],\n      [\n        1.006421685218811,\n        6.7449517250061035\n      ],\n      [\n        0.9901613593101501,\n        6.778914451599121\n      ],\n      [\n        0.9732388854026794,\n        6.812887668609619\n      ],\n      [\n        0.9558353424072266,\n        6.846869468688965\n      ],\n      [\n        0.9380521178245544,\n        6.880856990814209\n      ],\n      [\n        0.9199115633964539,\n        6.91485071182251\n      ],\n      [\n        0.9014359712600708,\n        6.948850154876709\n      ],\n      [\n        0.8826476335525513,\n        6.98285436630249\n      ],\n      [\n        0.8635689616203308,\n        7.016862869262695\n      ],\n      [\n        0.8442226052284241,\n        7.050876140594482\n      ],\n      [\n        0.8246312141418457,\n        7.084892749786377\n      ],\n      [\n        0.8048177361488342,\n        7.118912696838379\n      ],\n      [\n        0.7848052978515625,\n        7.15293550491333\n      ],\n      [\n        0.7646171450614929,\n        7.186960697174072\n      ],\n      [\n        0.7442769408226013,\n        7.220987796783447\n      ],\n      [\n        0.7238084673881531,\n        7.255016326904297\n      ],\n      [\n        0.703235924243927,\n        7.289045810699463\n      ],\n      [\n        0.682583749294281,\n        7.323075771331787\n      ],\n      [\n        0.6618766784667969,\n        7.357105731964111\n      ],\n      [\n        0.6411397457122803,\n        7.391135215759277\n      ],\n      [\n        0.6203982830047607,\n        7.425163269042969\n      ],\n      [\n        0.5996780395507812,\n        7.4591898918151855\n      ],\n      [\n        0.5790049433708191,\n        7.4932146072387695\n      ],\n      [\n        0.5584054589271545,\n        7.5272369384765625\n      ],\n      [\n        0.5379061102867126,\n        7.56125545501709\n      ],\n      [\n        0.5175339579582214,\n        7.595271110534668\n      ],\n      [\n        0.4973162114620209,\n        7.629281997680664\n      ],\n      [\n        0.4772806167602539,\n        7.663288116455078\n      ],\n      [\n        0.457455039024353,\n        7.697288990020752\n      ],\n      [\n        0.43786779046058655,\n        7.731284141540527\n      ],\n      [\n        0.41854748129844666,\n        7.765272617340088\n      ],\n      [\n        0.3995230197906494,\n        7.799253940582275\n      ],\n      [\n        0.38082367181777954,\n        7.833227634429932\n      ],\n      [\n        0.3624790608882904,\n        7.867193222045898\n      ],\n      [\n        0.34451907873153687,\n        7.901149749755859\n      ],\n      [\n        0.32697397470474243,\n        7.935096263885498\n      ],\n      [\n        0.3098742961883545,\n        7.969033241271973\n      ],\n      [\n        0.2932509779930115,\n        8.002959251403809\n      ],\n      [\n        0.2771351933479309,\n        8.036873817443848\n      ],\n      [\n        0.2615584135055542,\n        8.070775985717773\n      ],\n      [\n        0.24655242264270782,\n        8.10466480255127\n      ],\n      [\n        0.23214924335479736,\n        8.138541221618652\n      ],\n      [\n        0.21838118135929108,\n        8.172403335571289\n      ],\n      [\n        0.20528072118759155,\n        8.206250190734863\n      ],\n      [\n        0.19288058578968048,\n        8.240081787109375\n      ],\n      [\n        0.18121366202831268,\n        8.273897171020508\n      ],\n      [\n        0.17031297087669373,\n        8.307695388793945\n      ],\n      [\n        0.1602116823196411,\n        8.341476440429688\n      ],\n      [\n        0.15094305574893951,\n        8.375238418579102\n      ],\n      [\n        0.14254039525985718,\n        8.408982276916504\n      ],\n      [\n        0.13503706455230713,\n        8.442705154418945\n      ],\n      [\n        0.12846647202968597,\n        8.476408958435059\n      ],\n      [\n        0.12282804399728775,\n        8.510091781616211\n      ],\n      [\n        0.11808725446462631,\n        8.543754577636719\n      ],\n      [\n        0.11420957744121552,\n        8.577399253845215\n      ],\n      [\n        0.11116043478250504,\n        8.6110258102417\n      ],\n      [\n        0.10890527069568634,\n        8.644634246826172\n      ],\n      [\n        0.10740949213504791,\n        8.678226470947266\n      ],\n      [\n        0.10663850605487823,\n        8.711803436279297\n      ],\n      [\n        0.1065577045083046,\n        8.74536418914795\n      ],\n      [\n        0.10713250190019608,\n        8.778911590576172\n      ],\n      [\n        0.10832829773426056,\n        8.812445640563965\n      ],\n      [\n        0.11011053621768951,\n        8.845966339111328\n      ],\n      [\n        0.112444669008255,\n        8.879474639892578\n      ],\n      [\n        0.11529617011547089,\n        8.912972450256348\n      ],\n      [\n        0.11863056570291519,\n        8.946459770202637\n      ],\n      [\n        0.12241341173648834,\n        8.979937553405762\n      ],\n      [\n        0.12661030888557434,\n        9.013405799865723\n      ],\n      [\n        0.1311868578195572,\n        9.046866416931152\n      ],\n      [\n        0.13610877096652985,\n        9.080318450927734\n      ],\n      [\n        0.14134173095226288,\n        9.113764762878418\n      ],\n      [\n        0.14685149490833282,\n        9.147205352783203\n      ],\n      [\n        0.15260380506515503,\n        9.18064022064209\n      ],\n      [\n        0.158564493060112,\n        9.214071273803711\n      ],\n      [\n        0.16469934582710266,\n        9.24749755859375\n      ],\n      [\n        0.17097420990467072,\n        9.280921936035156\n      ],\n      [\n        0.17735493183135986,\n        9.314343452453613\n      ],\n      [\n        0.1838073432445526,\n        9.347764015197754\n      ],\n      [\n        0.19029729068279266,\n        9.381183624267578\n      ],\n      [\n        0.19679751992225647,\n        9.414603233337402\n      ],\n      [\n        0.20329780876636505,\n        9.448022842407227\n      ],\n      [\n        0.2097981721162796,\n        9.481443405151367\n      ],\n      [\n        0.21629860997200012,\n        9.514863014221191\n      ],\n      [\n        0.22279909253120422,\n        9.548283576965332\n      ],\n      [\n        0.2292996346950531,\n        9.581703186035156\n      ],\n      [\n        0.23580022156238556,\n        9.615123748779297\n      ],\n      [\n        0.2423008531332016,\n        9.648544311523438\n      ],\n      [\n        0.24880154430866241,\n        9.681964874267578\n      ],\n      [\n        0.2553022503852844,\n        9.715385437011719\n      ],\n      [\n        0.2618030309677124,\n        9.74880599975586\n      ],\n      [\n        0.2683038115501404,\n        9.7822265625\n      ],\n      [\n        0.27480462193489075,\n        9.815648078918457\n      ],\n      [\n        0.2813054919242859,\n        9.849068641662598\n      ],\n      [\n        0.28780636191368103,\n        9.882490158081055\n      ],\n      [\n        0.29430726170539856,\n        9.915910720825195\n      ],\n      [\n        0.3008081614971161,\n        9.949331283569336\n      ],\n      [\n        0.307309091091156,\n        9.982752799987793\n      ],\n      [\n        0.3138100206851959,\n        10.01617431640625\n      ],\n      [\n        0.3203109800815582,\n        10.04959487915039\n      ],\n      [\n        0.32681193947792053,\n        10.083016395568848\n      ],\n      [\n        0.33331289887428284,\n        10.116436958312988\n      ],\n      [\n        0.33981388807296753,\n        10.149858474731445\n      ],\n      [\n        0.34631484746932983,\n        10.183279991149902\n      ],\n      [\n        0.3528158366680145,\n        10.216700553894043\n      ],\n      [\n        0.3593168258666992,\n        10.2501220703125\n      ],\n      [\n        0.3658177852630615,\n        10.283543586730957\n      ],\n      [\n        0.3723187744617462,\n        10.316965103149414\n      ],\n      [\n        0.3804450035095215,\n        10.35874080657959\n      ],\n      [\n        0.3853207528591156,\n        10.383807182312012\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/05_root_path/meta.json",
    "content": "{\n  \"text\": \"A person is casually walking forward slowly\",\n  \"duration\": 10.0,\n  \"num_samples\": 1,\n  \"seed\": 42,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/06_root_waypoints/constraints.json",
    "content": "[\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [\n      0,\n      90,\n      180\n    ],\n    \"smooth_root_2d\": [\n      [\n        0.0,\n        -0.013232914730906487\n      ],\n      [\n        -1.1690130233764648,\n        1.5332785844802856\n      ],\n      [\n        0.738669753074646,\n        1.4469488859176636\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/06_root_waypoints/meta.json",
    "content": "{\n  \"text\": \"A person is doing a hip hop dance while moving around\",\n  \"duration\": 6.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 42,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/07_mixed_constraints/constraints.json",
    "content": "[\n  {\n    \"type\": \"fullbody\",\n    \"frame_indices\": [\n      108\n    ],\n    \"local_joints_rot\": [\n      [\n        [\n          -0.035887543112039566,\n          -0.02776639349758625,\n          -0.005372282117605209\n        ],\n        [\n          0.06515975296497345,\n          -0.010784560814499855,\n          0.006556123960763216\n        ],\n        [\n          -0.06292378902435303,\n          -0.05156821012496948,\n          -0.009085050784051418\n        ],\n        [\n          0.11570766568183899,\n          -0.0793282613158226,\n          -0.03867234289646149\n        ],\n        [\n          0.09106606245040894,\n          0.06571822613477707,\n          0.002558206906542182\n        ],\n        [\n          -0.06086159870028496,\n          0.10295507311820984,\n          0.02592187374830246\n        ],\n        [\n          -0.15437740087509155,\n          0.16596992313861847,\n          0.009326435625553131\n        ],\n        [\n          -0.0005251984694041312,\n          0.0018051519291475415,\n          -9.946066711563617e-05\n        ],\n        [\n          -0.184775248169899,\n          -0.064349465072155,\n          0.00573313795030117\n        ],\n        [\n          -0.18454650044441223,\n          0.068090058863163,\n          -0.005659883841872215\n        ],\n        [\n          0.20501427352428436,\n          -0.14578332006931305,\n          -0.04773213341832161\n        ],\n        [\n          0.26504039764404297,\n          -0.16855353116989136,\n          -1.0829373598098755\n        ],\n        [\n          0.006512798834592104,\n          -0.6961542367935181,\n          -0.011537229642271996\n        ],\n        [\n          0.07062757760286331,\n          0.03925099968910217,\n          -0.027518808841705322\n        ],\n        [\n          0.14896969497203827,\n          0.29287680983543396,\n          -0.2919791340827942\n        ],\n        [\n          0.009383739903569221,\n          0.0508926659822464,\n          -1.056564450263977\n        ],\n        [\n          0.11172245442867279,\n          0.12029653787612915,\n          -0.12930497527122498\n        ],\n        [\n          -0.41130027174949646,\n          -0.5924108028411865,\n          -0.0006285393028520048\n        ],\n        [\n          0.006594705395400524,\n          0.4732210040092468,\n          -0.002528452081605792\n        ],\n        [\n          -0.32021215558052063,\n          -0.25638389587402344,\n          -0.3734903335571289\n        ],\n        [\n          0.09024477005004883,\n          -0.2926441431045532,\n          0.2660353481769562\n        ],\n        [\n          -0.09575983881950378,\n          -0.055268142372369766,\n          0.8844737410545349\n        ],\n        [\n          -0.0118059441447258,\n          0.07546520978212357,\n          0.0746397078037262\n        ],\n        [\n          0.8310757875442505,\n          -0.012923321686685085,\n          0.004925338551402092\n        ],\n        [\n          0.03474503755569458,\n          -0.23956389725208282,\n          -0.16712959110736847\n        ],\n        [\n          -0.09206951409578323,\n          -0.03187529370188713,\n          0.027407124638557434\n        ],\n        [\n          -0.2677958309650421,\n          0.11606352031230927,\n          0.036957308650016785\n        ],\n        [\n          0.394832044839859,\n          -0.0007178321247920394,\n          0.0004849981633014977\n        ],\n        [\n          -0.09032224863767624,\n          -0.14483025670051575,\n          -0.015989331528544426\n        ],\n        [\n          -0.0217722300440073,\n          0.01900928094983101,\n          -0.025495363399386406\n        ]\n      ]\n    ],\n    \"root_positions\": [\n      [\n        -0.09470777958631516,\n        0.9947724342346191,\n        -3.980208396911621\n      ]\n    ],\n    \"smooth_root_2d\": [\n      [\n        -0.09470777958631516,\n        -3.980208396911621\n      ]\n    ]\n  },\n  {\n    \"type\": \"root2d\",\n    \"frame_indices\": [\n      0,\n      1,\n      2,\n      3,\n      4,\n      5,\n      6,\n      7,\n      8,\n      9,\n      10,\n      11,\n      12,\n      13,\n      14,\n      15,\n      16,\n      17,\n      18,\n      19,\n      20,\n      21,\n      22,\n      23,\n      24,\n      25,\n      26,\n      27,\n      28,\n      29,\n      30,\n      31,\n      32,\n      33,\n      34,\n      35,\n      36,\n      37,\n      38,\n      39,\n      40,\n      41,\n      42,\n      43,\n      44,\n      45,\n      46,\n      47,\n      48,\n      49,\n      50,\n      51,\n      52,\n      53,\n      54,\n      55,\n      56,\n      57,\n      58,\n      59,\n      60,\n      61,\n      62,\n      63,\n      64,\n      65,\n      66,\n      67,\n      68,\n      69,\n      70,\n      71,\n      72,\n      73,\n      74,\n      75,\n      76,\n      77,\n      78,\n      79,\n      80,\n      81,\n      82,\n      83,\n      84,\n      85,\n      86,\n      87,\n      88,\n      89,\n      90,\n      91,\n      92,\n      93,\n      94,\n      95,\n      96,\n      97,\n      98,\n      99,\n      100,\n      101,\n      102,\n      103,\n      104,\n      105,\n      106,\n      107,\n      108,\n      109,\n      110,\n      111,\n      112,\n      113,\n      114,\n      115,\n      116,\n      117,\n      118,\n      119,\n      120,\n      121,\n      122,\n      123,\n      124,\n      125,\n      126,\n      127,\n      128,\n      129,\n      130,\n      131,\n      132,\n      133,\n      134,\n      135,\n      136,\n      137,\n      138,\n      139,\n      140,\n      141,\n      142,\n      143,\n      144,\n      145,\n      146,\n      147,\n      148,\n      149,\n      150,\n      151\n    ],\n    \"smooth_root_2d\": [\n      [\n        -0.022358937188982964,\n        0.03532936051487923\n      ],\n      [\n        -0.024468135088682175,\n        -0.0013195642968639731\n      ],\n      [\n        -0.02657654881477356,\n        -0.037969205528497696\n      ],\n      [\n        -0.028683679178357124,\n        -0.07462010532617569\n      ],\n      [\n        -0.030789025127887726,\n        -0.11127285659313202\n      ],\n      [\n        -0.032892078161239624,\n        -0.14792808890342712\n      ],\n      [\n        -0.03499194607138634,\n        -0.184586301445961\n      ],\n      [\n        -0.03708736225962639,\n        -0.2212478667497635\n      ],\n      [\n        -0.03917701542377472,\n        -0.25791314244270325\n      ],\n      [\n        -0.04125956818461418,\n        -0.2945826053619385\n      ],\n      [\n        -0.04333365708589554,\n        -0.3312567174434662\n      ],\n      [\n        -0.045397885143756866,\n        -0.3679359555244446\n      ],\n      [\n        -0.04745082929730415,\n        -0.4046209156513214\n      ],\n      [\n        -0.04949106276035309,\n        -0.44131216406822205\n      ],\n      [\n        -0.05151714012026787,\n        -0.4780103266239166\n      ],\n      [\n        -0.05352761223912239,\n        -0.5147159695625305\n      ],\n      [\n        -0.05552104488015175,\n        -0.551429808139801\n      ],\n      [\n        -0.05749599635601044,\n        -0.5881525278091431\n      ],\n      [\n        -0.059451062232255936,\n        -0.6248847842216492\n      ],\n      [\n        -0.061384834349155426,\n        -0.6616273522377014\n      ],\n      [\n        -0.06329593807458878,\n        -0.6983808875083923\n      ],\n      [\n        -0.06518300622701645,\n        -0.7351461052894592\n      ],\n      [\n        -0.06704472005367279,\n        -0.7719237804412842\n      ],\n      [\n        -0.06887973845005035,\n        -0.8087146878242493\n      ],\n      [\n        -0.07068677246570587,\n        -0.8455195426940918\n      ],\n      [\n        -0.07246451079845428,\n        -0.8823391795158386\n      ],\n      [\n        -0.07421167194843292,\n        -0.9191742539405823\n      ],\n      [\n        -0.07592695951461792,\n        -0.9560256004333496\n      ],\n      [\n        -0.07760907709598541,\n        -0.9928940534591675\n      ],\n      [\n        -0.07925672084093094,\n        -1.029780387878418\n      ],\n      [\n        -0.08086856454610825,\n        -1.0666853189468384\n      ],\n      [\n        -0.0824432522058487,\n        -1.1036096811294556\n      ],\n      [\n        -0.08397942036390305,\n        -1.1405543088912964\n      ],\n      [\n        -0.08547566086053848,\n        -1.1775201559066772\n      ],\n      [\n        -0.08693055063486099,\n        -1.2145079374313354\n      ],\n      [\n        -0.08834262937307358,\n        -1.2515183687210083\n      ],\n      [\n        -0.08971039950847626,\n        -1.2885526418685913\n      ],\n      [\n        -0.09103234112262726,\n        -1.3256113529205322\n      ],\n      [\n        -0.09230689704418182,\n        -1.362695574760437\n      ],\n      [\n        -0.0935325101017952,\n        -1.3998061418533325\n      ],\n      [\n        -0.09470757842063904,\n        -1.4369438886642456\n      ],\n      [\n        -0.09583047777414322,\n        -1.4741098880767822\n      ],\n      [\n        -0.0968996062874794,\n        -1.5113049745559692\n      ],\n      [\n        -0.09791331738233566,\n        -1.548530101776123\n      ],\n      [\n        -0.09886999428272247,\n        -1.58578622341156\n      ],\n      [\n        -0.0997680053114891,\n        -1.6230742931365967\n      ],\n      [\n        -0.10060573369264603,\n        -1.6603953838348389\n      ],\n      [\n        -0.1013815775513649,\n        -1.6977503299713135\n      ],\n      [\n        -0.10209395736455917,\n        -1.7351402044296265\n      ],\n      [\n        -0.1027413085103035,\n        -1.7725658416748047\n      ],\n      [\n        -0.10332208126783371,\n        -1.8100284337997437\n      ],\n      [\n        -0.10383477061986923,\n        -1.8475286960601807\n      ],\n      [\n        -0.10427788645029068,\n        -1.8850678205490112\n      ],\n      [\n        -0.10464996099472046,\n        -1.9226467609405518\n      ],\n      [\n        -0.10494954138994217,\n        -1.9602664709091187\n      ],\n      [\n        -0.10517755895853043,\n        -1.997925877571106\n      ],\n      [\n        -0.10533731430768967,\n        -2.0356218814849854\n      ],\n      [\n        -0.10543208569288254,\n        -2.0733516216278076\n      ],\n      [\n        -0.10546516627073288,\n        -2.111111879348755\n      ],\n      [\n        -0.10543984919786453,\n        -2.148899555206299\n      ],\n      [\n        -0.10535937547683716,\n        -2.1867120265960693\n      ],\n      [\n        -0.10522699356079102,\n        -2.224546194076538\n      ],\n      [\n        -0.10504589974880219,\n        -2.262399435043335\n      ],\n      [\n        -0.10481927543878555,\n        -2.3002686500549316\n      ],\n      [\n        -0.10455025732517242,\n        -2.338151216506958\n      ],\n      [\n        -0.10424194484949112,\n        -2.376044511795044\n      ],\n      [\n        -0.10389743000268936,\n        -2.4139459133148193\n      ],\n      [\n        -0.10351976752281189,\n        -2.451852560043335\n      ],\n      [\n        -0.10311200469732285,\n        -2.4897620677948\n      ],\n      [\n        -0.10267717391252518,\n        -2.5276718139648438\n      ],\n      [\n        -0.10221832990646362,\n        -2.5655791759490967\n      ],\n      [\n        -0.10173854231834412,\n        -2.6034812927246094\n      ],\n      [\n        -0.10124091058969498,\n        -2.64137601852417\n      ],\n      [\n        -0.10072856396436691,\n        -2.67926025390625\n      ],\n      [\n        -0.100204698741436,\n        -2.7171311378479004\n      ],\n      [\n        -0.09967257082462311,\n        -2.754986047744751\n      ],\n      [\n        -0.09913549572229385,\n        -2.7928221225738525\n      ],\n      [\n        -0.09859687089920044,\n        -2.8306362628936768\n      ],\n      [\n        -0.09806016832590103,\n        -2.8684253692626953\n      ],\n      [\n        -0.09752892702817917,\n        -2.906186103820801\n      ],\n      [\n        -0.09700676798820496,\n        -2.943915367126465\n      ],\n      [\n        -0.09649737179279327,\n        -2.98160982131958\n      ],\n      [\n        -0.09600447863340378,\n        -3.019265651702881\n      ],\n      [\n        -0.09553186595439911,\n        -3.056879758834839\n      ],\n      [\n        -0.09508336335420609,\n        -3.0944483280181885\n      ],\n      [\n        -0.09466280788183212,\n        -3.131967782974243\n      ],\n      [\n        -0.09427405893802643,\n        -3.1694345474243164\n      ],\n      [\n        -0.09392096847295761,\n        -3.2068448066711426\n      ],\n      [\n        -0.09360739588737488,\n        -3.244194984436035\n      ],\n      [\n        -0.09333716332912445,\n        -3.2814812660217285\n      ],\n      [\n        -0.09311125427484512,\n        -3.3187034130096436\n      ],\n      [\n        -0.09292776882648468,\n        -3.355863571166992\n      ],\n      [\n        -0.0927848145365715,\n        -3.3929643630981445\n      ],\n      [\n        -0.09268050640821457,\n        -3.4300084114074707\n      ],\n      [\n        -0.09261301904916763,\n        -3.4669981002807617\n      ],\n      [\n        -0.09258053451776505,\n        -3.5039358139038086\n      ],\n      [\n        -0.09258133918046951,\n        -3.5408236980438232\n      ],\n      [\n        -0.09261377900838852,\n        -3.5776638984680176\n      ],\n      [\n        -0.0926763191819191,\n        -3.6144583225250244\n      ],\n      [\n        -0.09276753664016724,\n        -3.6512088775634766\n      ],\n      [\n        -0.09288612008094788,\n        -3.687917470932007\n      ],\n      [\n        -0.0930309146642685,\n        -3.7245850563049316\n      ],\n      [\n        -0.09320087730884552,\n        -3.7612133026123047\n      ],\n      [\n        -0.09339512139558792,\n        -3.7978031635284424\n      ],\n      [\n        -0.09361287951469421,\n        -3.8343558311462402\n      ],\n      [\n        -0.09385351091623306,\n        -3.8708720207214355\n      ],\n      [\n        -0.09411647915840149,\n        -3.9073524475097656\n      ],\n      [\n        -0.09440135210752487,\n        -3.9437978267669678\n      ],\n      [\n        -0.09470777958631516,\n        -3.980208396911621\n      ],\n      [\n        -0.09503547102212906,\n        -4.016584873199463\n      ],\n      [\n        -0.09538418799638748,\n        -4.052927494049072\n      ],\n      [\n        -0.09575372189283371,\n        -4.089236736297607\n      ],\n      [\n        -0.09614387899637222,\n        -4.125512599945068\n      ],\n      [\n        -0.0965544655919075,\n        -4.1617560386657715\n      ],\n      [\n        -0.09698529541492462,\n        -4.197966575622559\n      ],\n      [\n        -0.09743614494800568,\n        -4.234145641326904\n      ],\n      [\n        -0.09790677577257156,\n        -4.27029275894165\n      ],\n      [\n        -0.09839694201946259,\n        -4.306408882141113\n      ],\n      [\n        -0.09890634566545486,\n        -4.342494487762451\n      ],\n      [\n        -0.09943470358848572,\n        -4.378549575805664\n      ],\n      [\n        -0.09998169541358948,\n        -4.41457462310791\n      ],\n      [\n        -0.10054702311754227,\n        -4.450570583343506\n      ],\n      [\n        -0.10113038867712021,\n        -4.486537456512451\n      ],\n      [\n        -0.10173150897026062,\n        -4.522475242614746\n      ],\n      [\n        -0.1023501306772232,\n        -4.558384895324707\n      ],\n      [\n        -0.10298605263233185,\n        -4.594265937805176\n      ],\n      [\n        -0.10363911837339401,\n        -4.6301188468933105\n      ],\n      [\n        -0.10430921614170074,\n        -4.665942668914795\n      ],\n      [\n        -0.10499630123376846,\n        -4.701738357543945\n      ],\n      [\n        -0.10570038110017776,\n        -4.737504482269287\n      ],\n      [\n        -0.10642150044441223,\n        -4.7732415199279785\n      ],\n      [\n        -0.10715975612401962,\n        -4.808948040008545\n      ],\n      [\n        -0.10791526734828949,\n        -4.844624042510986\n      ],\n      [\n        -0.10868816822767258,\n        -4.880269527435303\n      ],\n      [\n        -0.10947857797145844,\n        -4.915882587432861\n      ],\n      [\n        -0.11028657108545303,\n        -4.9514641761779785\n      ],\n      [\n        -0.11111218482255936,\n        -4.98701286315918\n      ],\n      [\n        -0.11195536702871323,\n        -5.022529602050781\n      ],\n      [\n        -0.1128159612417221,\n        -5.058013439178467\n      ],\n      [\n        -0.11369368433952332,\n        -5.093465328216553\n      ],\n      [\n        -0.11458808928728104,\n        -5.128885746002197\n      ],\n      [\n        -0.11549859493970871,\n        -5.164275646209717\n      ],\n      [\n        -0.11642441153526306,\n        -5.199635982513428\n      ],\n      [\n        -0.11736457794904709,\n        -5.234969615936279\n      ],\n      [\n        -0.11831795424222946,\n        -5.270277500152588\n      ],\n      [\n        -0.11928320676088333,\n        -5.305562496185303\n      ],\n      [\n        -0.12025882303714752,\n        -5.340827941894531\n      ],\n      [\n        -0.12124315649271011,\n        -5.3760762214660645\n      ],\n      [\n        -0.12223441153764725,\n        -5.41131067276001\n      ],\n      [\n        -0.12323068082332611,\n        -5.446536064147949\n      ],\n      [\n        -0.12448007613420486,\n        -5.4905595779418945\n      ],\n      [\n        -0.1252303272485733,\n        -5.516972541809082\n      ]\n    ]\n  }\n]\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/07_mixed_constraints/meta.json",
    "content": "{\n  \"text\": \"A person walking backward points to the right side with their right hand\",\n  \"duration\": 5.066666666666666,\n  \"num_samples\": 1,\n  \"seed\": 49,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/demo/examples/kimodo-soma-rp/08_stylized_text/meta.json",
    "content": "{\n  \"text\": \"A zombie with their left arm extended forward walks with an uneven gait at a slow pace.\",\n  \"duration\": 4.033333333333333,\n  \"num_samples\": 1,\n  \"seed\": 42,\n  \"diffusion_steps\": 100,\n  \"cfg\": {\n    \"enabled\": true,\n    \"text_weight\": 2.0,\n    \"constraint_weight\": 2.0\n  }\n}\n"
  },
  {
    "path": "kimodo/assets/skeletons/g1skel34/xml/g1.xml",
    "content": "<mujoco model=\"g1\">\n    <compiler angle=\"radian\" meshdir=\"../meshes/g1\"/>\n\n    <default>\n        <default class=\"g1\">\n            <geom contype=\"0\" conaffinity=\"0\"/>\n\n            <joint frictionloss=\"0.1\" solimplimit=\"0.97 0.995 0.001\"/>\n\n            <default class=\"hip\">\n                <default class=\"hip_pitch\">\n                    <joint axis=\"0 1 0\" range=\"-2.5307 2.8798\" actuatorfrcrange=\"-88 88\" armature=\"0.01017752004\"/>\n                </default>\n                <default class=\"hip_roll\">\n                    <joint axis=\"1 0 0\" actuatorfrcrange=\"-139 139\" armature=\"0.025101925\"/>\n                </default>\n                <default class=\"hip_yaw\">\n                    <joint axis=\"0 0 1\" range=\"-2.7576 2.7576\" actuatorfrcrange=\"-88 88\" armature=\"0.01017752004\"/>\n                </default>\n            </default>\n            <default class=\"knee\">\n                <joint axis=\"0 1 0\" range=\"-0.087267 2.8798\" actuatorfrcrange=\"-139 139\" armature=\"0.025101925\"/>\n            </default>\n            <default class=\"ankle\">\n                <default class=\"ankle_pitch\">\n                    <joint axis=\"0 1 0\" range=\"-0.87267 0.5236\" actuatorfrcrange=\"-50 50\" armature=\"0.00721945\"/>\n                </default>\n                <default class=\"ankle_roll\">\n                    <joint axis=\"1 0 0\" range=\"-0.2618 0.2618\" actuatorfrcrange=\"-50 50\" armature=\"0.00721945\"/>\n                </default>\n            </default>\n            <default class=\"waist_yaw\">\n                <joint axis=\"0 0 1\" range=\"-2.618 2.618\" actuatorfrcrange=\"-88 88\" armature=\"0.01017752004\"/>\n            </default>\n            <default class=\"waist_pitch\">\n                <joint axis=\"0 1 0\" range=\"-0.52 0.52\" actuatorfrcrange=\"-50 50\" armature=\"0.00721945\"/>\n            </default>\n            <default class=\"waist_roll\">\n                <joint axis=\"1 0 0\" range=\"-0.52 0.52\" actuatorfrcrange=\"-50 50\" armature=\"0.00721945\"/>\n            </default>\n            <default class=\"shoulder\">\n                <default class=\"shoulder_pitch\">\n                    <joint axis=\"0 1 0\" range=\"-3.0892 2.6704\" actuatorfrcrange=\"-25 25\" armature=\"0.003609725\"/>\n                </default>\n                <default class=\"shoulder_roll\">\n                    <joint axis=\"1 0 0\" actuatorfrcrange=\"-25 25\" armature=\"0.003609725\"/>\n                </default>\n                <default class=\"shoulder_yaw\">\n                    <joint axis=\"0 0 1\" range=\"-2.618 2.618\" actuatorfrcrange=\"-25 25\" armature=\"0.003609725\"/>\n                </default>\n            </default>\n            <default class=\"elbow\">\n                <joint axis=\"0 1 0\" range=\"-1.0472 2.0944\" actuatorfrcrange=\"-25 25\" armature=\"0.003609725\"/>\n            </default>\n            <default class=\"wrist\">\n                <default class=\"wrist_roll\">\n                    <joint axis=\"1 0 0\" range=\"-1.97222 1.97222\" actuatorfrcrange=\"-25 25\" armature=\"0.003609725\"/>\n                </default>\n                <default class=\"wrist_pitch\">\n                    <joint axis=\"0 1 0\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" armature=\"0.00425\"/>\n                </default>\n                <default class=\"wrist_yaw\">\n                    <joint axis=\"0 0 1\" range=\"-1.61443 1.61443\" actuatorfrcrange=\"-5 5\" armature=\"0.00425\"/>\n                </default>\n            </default>\n\n            <default class=\"visual\">\n                <geom group=\"2\" type=\"mesh\" density=\"0\" material=\"silver\"/>\n            </default>\n            <default class=\"collision\">\n                <geom group=\"3\" rgba=\".2 .6 .2 .3\" contype=\"1\" conaffinity=\"1\"/>\n                <default class=\"foot\">\n                    <geom size=\"0.085 0.03 0.005\"/>\n                </default>\n            </default>\n            <site group=\"5\" rgba=\"1 0 0 1\"/>\n        </default>\n    </default>\n\n    <asset>\n        <material name=\"silver\" rgba=\"0.7 0.7 0.7 1\"/>\n        <material name=\"black\" rgba=\"0.2 0.2 0.2 1\"/>\n\n        <mesh name=\"pelvis\" file=\"pelvis.STL\"/>\n        <mesh name=\"pelvis_contour_link\" file=\"pelvis_contour_link.STL\"/>\n        <mesh name=\"left_hip_pitch_link\" file=\"left_hip_pitch_link.STL\"/>\n        <mesh name=\"left_hip_roll_link\" file=\"left_hip_roll_link.STL\"/>\n        <mesh name=\"left_hip_yaw_link\" file=\"left_hip_yaw_link.STL\"/>\n        <mesh name=\"left_knee_link\" file=\"left_knee_link.STL\"/>\n        <mesh name=\"left_ankle_pitch_link\"\n              file=\"left_ankle_pitch_link.STL\"/>\n        <mesh name=\"left_ankle_roll_link\" file=\"left_ankle_roll_link.STL\"/>\n        <mesh name=\"right_hip_pitch_link\" file=\"right_hip_pitch_link.STL\"/>\n        <mesh name=\"right_hip_roll_link\" file=\"right_hip_roll_link.STL\"/>\n        <mesh name=\"right_hip_yaw_link\" file=\"right_hip_yaw_link.STL\"/>\n        <mesh name=\"right_knee_link\" file=\"right_knee_link.STL\"/>\n        <mesh name=\"right_ankle_pitch_link\"\n              file=\"right_ankle_pitch_link.STL\"/>\n        <mesh name=\"right_ankle_roll_link\"\n              file=\"right_ankle_roll_link.STL\"/>\n        <mesh name=\"waist_yaw_link\" file=\"waist_yaw_link_rev_1_0.STL\"/>\n        <mesh name=\"waist_roll_link\" file=\"waist_roll_link_rev_1_0.STL\"/>\n        <mesh name=\"torso_link\" file=\"torso_link_rev_1_0.STL\"/>\n        <mesh name=\"logo_link\" file=\"logo_link.STL\"/>\n        <mesh name=\"head_link\" file=\"head_link.STL\"/>\n        <mesh name=\"left_shoulder_pitch_link\"\n              file=\"left_shoulder_pitch_link.STL\"/>\n        <mesh name=\"left_shoulder_roll_link\"\n              file=\"left_shoulder_roll_link.STL\"/>\n        <mesh name=\"left_shoulder_yaw_link\"\n              file=\"left_shoulder_yaw_link.STL\"/>\n        <mesh name=\"left_elbow_link\" file=\"left_elbow_link.STL\"/>\n        <mesh name=\"left_wrist_roll_link\" file=\"left_wrist_roll_link.STL\"/>\n        <mesh name=\"left_wrist_pitch_link\"\n              file=\"left_wrist_pitch_link.STL\"/>\n        <mesh name=\"left_wrist_yaw_link\" file=\"left_wrist_yaw_link.STL\"/>\n        <mesh name=\"left_rubber_hand\" file=\"left_rubber_hand.STL\"/>\n        <mesh name=\"right_shoulder_pitch_link\"\n              file=\"right_shoulder_pitch_link.STL\"/>\n        <mesh name=\"right_shoulder_roll_link\"\n              file=\"right_shoulder_roll_link.STL\"/>\n        <mesh name=\"right_shoulder_yaw_link\"\n              file=\"right_shoulder_yaw_link.STL\"/>\n        <mesh name=\"right_elbow_link\" file=\"right_elbow_link.STL\"/>\n        <mesh name=\"right_wrist_roll_link\"\n              file=\"right_wrist_roll_link.STL\"/>\n        <mesh name=\"right_wrist_pitch_link\"\n              file=\"right_wrist_pitch_link.STL\"/>\n        <mesh name=\"right_wrist_yaw_link\" file=\"right_wrist_yaw_link.STL\"/>\n        <mesh name=\"right_rubber_hand\" file=\"right_rubber_hand.STL\"/>\n    </asset>\n\n    <worldbody>\n        <body name=\"pelvis\" pos=\"0 0 0.793\" childclass=\"g1\">\n            <inertial pos=\"0 0 -0.07605\" quat=\"1 0 -0.000399148 0\" mass=\"3.813\" diaginertia=\"0.010549 0.0093089 0.0079184\"/>\n            <freejoint name=\"floating_base_joint\"/>\n            <geom class=\"visual\" material=\"black\" mesh=\"pelvis\"/>\n            <geom class=\"visual\" mesh=\"pelvis_contour_link\"/>\n            <geom mesh=\"pelvis_contour_link\" class=\"visual\"/>\n            <geom name=\"pelvis_collision\" class=\"collision\" type=\"sphere\" size=\"0.07\" pos=\"0 0 -0.08\"/>\n            <site name=\"imu_in_pelvis\" size=\"0.01\" pos=\"0.04525 0 -0.08339\"/>\n            <site name=\"pelvis\" size=\"0.01\" pos=\"0 0 0\"/>\n            <body name=\"left_hip_pitch_link\" pos=\"0 0.064452 -0.1027\">\n                <inertial pos=\"0.002741 0.047791 -0.02606\" quat=\"0.954862 0.293964 0.0302556 0.030122\" mass=\"1.35\"\n                          diaginertia=\"0.00181517 0.00153422 0.00116212\"/>\n                <joint name=\"left_hip_pitch_joint\" class=\"hip_pitch\"/>\n                <geom class=\"visual\" material=\"black\" mesh=\"left_hip_pitch_link\"/>\n                <geom material=\"black\" mesh=\"left_hip_pitch_link\" class=\"visual\"/>\n                <body name=\"left_hip_roll_link\" pos=\"0 0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n                    <inertial pos=\"0.029812 -0.001045 -0.087934\" quat=\"0.977808 -1.97119e-05 0.205576 -0.0403793\" mass=\"1.52\"\n                              diaginertia=\"0.00254986 0.00241169 0.00148755\"/>\n                    <joint name=\"left_hip_roll_joint\" class=\"hip_roll\" range=\"-0.5236 2.9671\"/>\n                    <geom class=\"visual\" mesh=\"left_hip_roll_link\"/>\n                    <geom mesh=\"left_hip_roll_link\" class=\"visual\"/>\n                    <geom name=\"left_thigh\" class=\"collision\" type=\"capsule\" size=\"0.05\" fromto=\"0.02 0 0 0.02 0 -0.2\"/>\n                    <body name=\"left_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n                        <inertial pos=\"-0.057709 -0.010981 -0.15078\" quat=\"0.600598 0.15832 0.223482 0.751181\" mass=\"1.702\"\n                                  diaginertia=\"0.00776166 0.00717575 0.00160139\"/>\n                        <joint name=\"left_hip_yaw_joint\" class=\"hip_yaw\"/>\n                        <geom class=\"visual\" mesh=\"left_hip_yaw_link\"/>\n                        <body name=\"left_knee_link\" pos=\"-0.078273 0.0021489 -0.17734\" quat=\"0.996179 0 0.0873386 0\">\n                            <inertial pos=\"0.005457 0.003964 -0.12074\" quat=\"0.923418 -0.0327699 0.0158246 0.382067\" mass=\"1.932\"\n                                      diaginertia=\"0.0113804 0.0112778 0.00146458\"/>\n                            <joint name=\"left_knee_joint\" class=\"knee\"/>\n                            <geom class=\"visual\" mesh=\"left_knee_link\"/>\n                            <geom name=\"left_shin\" class=\"collision\" type=\"capsule\" size=\"0.04\" fromto=\"0.02 0 0 0.02 0 -0.25\"/>\n                            <body name=\"left_ankle_pitch_link\" pos=\"0 -9.4445e-05 -0.30001\">\n                                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\" mass=\"0.074\"\n                                          diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\"/>\n                                <joint name=\"left_ankle_pitch_joint\" class=\"ankle_pitch\"/>\n                                <geom class=\"visual\" mesh=\"left_ankle_pitch_link\"/>\n                                <body name=\"left_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                                    <site name=\"left_foot\" rgba=\"1 0 0 1\"/>\n                                    <inertial pos=\"0.026505 0 -0.016425\" quat=\"-0.000481092 0.728482 -0.000618967 0.685065\" mass=\"0.608\"\n                                              diaginertia=\"0.00167218 0.0016161 0.000217621\"/>\n                                    <joint name=\"left_ankle_roll_joint\" class=\"ankle_roll\"/>\n                                    <geom class=\"visual\" material=\"black\" mesh=\"left_ankle_roll_link\"/>\n                                    <geom name=\"left_foot1_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"0.1 -0.026 -0.025 0.05 -0.027 -0.025\"/>\n                                    <geom name=\"left_foot2_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.044 -0.018 -0.025 0.123 -0.018 -0.025\"/>\n                                    <geom name=\"left_foot3_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.052 -0.01 -0.025 0.13 -0.01 -0.025\"/>\n                                    <geom name=\"left_foot4_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.054 0 -0.025 0.132 0 -0.025\"/>\n                                    <geom name=\"left_foot5_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.052 0.01 -0.025 0.13 0.01 -0.025\"/>\n                                    <geom name=\"left_foot6_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.044 0.018 -0.025 0.123 0.018 -0.025\"/>\n                                    <geom name=\"left_foot7_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"0.1 0.026 -0.025 0.05 0.026 -0.025\"/>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"right_hip_pitch_link\" pos=\"0 -0.064452 -0.1027\">\n                <inertial pos=\"0.002741 -0.047791 -0.02606\" quat=\"0.954862 -0.293964 0.0302556 -0.030122\" mass=\"1.35\"\n                          diaginertia=\"0.00181517 0.00153422 0.00116212\"/>\n                <joint name=\"right_hip_pitch_joint\" class=\"hip_pitch\"/>\n                <geom class=\"visual\" material=\"black\" mesh=\"right_hip_pitch_link\"/>\n                <body name=\"right_hip_roll_link\" pos=\"0 -0.052 -0.030465\" quat=\"0.996179 0 -0.0873386 0\">\n                    <inertial pos=\"0.029812 0.001045 -0.087934\" quat=\"0.977808 1.97119e-05 0.205576 0.0403793\" mass=\"1.52\"\n                              diaginertia=\"0.00254986 0.00241169 0.00148755\"/>\n                    <joint name=\"right_hip_roll_joint\" class=\"hip_roll\" range=\"-2.9671 0.5236\"/>\n                    <geom class=\"visual\" mesh=\"right_hip_roll_link\"/>\n                    <geom name=\"right_thigh\" class=\"collision\" type=\"capsule\" size=\"0.05\" fromto=\"0.02 0 0 0.02 0 -0.2\"/>\n                    <body name=\"right_hip_yaw_link\" pos=\"0.025001 0 -0.12412\">\n                        <inertial pos=\"-0.057709 0.010981 -0.15078\" quat=\"0.751181 0.223482 0.15832 0.600598\" mass=\"1.702\"\n                                  diaginertia=\"0.00776166 0.00717575 0.00160139\"/>\n                        <joint name=\"right_hip_yaw_joint\" class=\"hip_yaw\"/>\n                        <geom class=\"visual\" mesh=\"right_hip_yaw_link\"/>\n                        <body name=\"right_knee_link\" pos=\"-0.078273 -0.0021489 -0.17734\" quat=\"0.996179 0 0.0873386 0\">\n                            <inertial pos=\"0.005457 -0.003964 -0.12074\" quat=\"0.923439 0.0345276 0.0116333 -0.382012\" mass=\"1.932\"\n                                      diaginertia=\"0.011374 0.0112843 0.00146452\"/>\n                            <joint name=\"right_knee_joint\" class=\"knee\"/>\n                            <geom class=\"visual\" mesh=\"right_knee_link\"/>\n                            <geom name=\"right_shin\" class=\"collision\" type=\"capsule\" size=\"0.04\" fromto=\"0.02 0 0 0.02 0 -0.25\"/>\n                            <body name=\"right_ankle_pitch_link\" pos=\"0 9.4445e-05 -0.30001\">\n                                <inertial pos=\"-0.007269 0 0.011137\" quat=\"0.603053 0.369225 0.369225 0.603053\" mass=\"0.074\"\n                                          diaginertia=\"1.89e-05 1.40805e-05 6.9195e-06\"/>\n                                <joint name=\"right_ankle_pitch_joint\" class=\"ankle_pitch\"/>\n                                <geom class=\"visual\" mesh=\"right_ankle_pitch_link\"/>\n                                <body name=\"right_ankle_roll_link\" pos=\"0 0 -0.017558\">\n                                    <site name=\"right_foot\" pos=\"0 0 0\"/>\n                                    <inertial pos=\"0.026505 0 -0.016425\" quat=\"0.000481092 0.728482 0.000618967 0.685065\" mass=\"0.608\"\n                                              diaginertia=\"0.00167218 0.0016161 0.000217621\"/>\n                                    <joint name=\"right_ankle_roll_joint\" class=\"ankle_roll\"/>\n                                    <geom class=\"visual\" material=\"black\" mesh=\"right_ankle_roll_link\"/>\n                                    <geom name=\"right_foot1_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"0.1 -0.026 -0.025 0.05 -0.026 -0.025\"/>\n                                    <geom name=\"right_foot2_collision\" type=\"capsule\" size=\"0.008\" class=\"collision\" fromto=\"-0.044 -0.018 -0.025 0.123 -0.018 -0.025\"/>\n                                    <geom name=\"right_foot3_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.052 -0.01 -0.025 0.13 -0.01 -0.025\"/>\n                                    <geom name=\"right_foot4_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.054 0 -0.025 0.132 0 -0.025\"/>\n                                    <geom name=\"right_foot5_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"-0.052 0.01 -0.025 0.13 0.01 -0.025\"/>\n                                    <geom name=\"right_foot6_collision\" type=\"capsule\" size=\"0.008\" class=\"collision\" fromto=\"-0.044 0.018 -0.025 0.123 0.018 -0.025\"/>\n                                    <geom name=\"right_foot7_collision\" type=\"capsule\" size=\"0.01\" class=\"collision\" fromto=\"0.1 0.026 -0.025 0.05 0.026 -0.025\"/>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n            <body name=\"waist_yaw_link\">\n                <inertial pos=\"0.003494 0.000233 0.018034\" quat=\"0.289697 0.591001 -0.337795 0.672821\" mass=\"0.214\"\n                          diaginertia=\"0.000163531 0.000107714 0.000102205\"/>\n                <joint name=\"waist_yaw_joint\" class=\"waist_yaw\"/>\n                <geom class=\"visual\" mesh=\"waist_yaw_link\"/>\n                <body name=\"waist_roll_link\" pos=\"-0.0039635 0 0.044\">\n                    <inertial pos=\"0 2.3e-05 0\" quat=\"0.5 0.5 -0.5 0.5\" mass=\"0.086\" diaginertia=\"8.245e-06 7.079e-06 6.339e-06\"/>\n                    <joint name=\"waist_roll_joint\" class=\"waist_roll\"/>\n                    <geom class=\"visual\" mesh=\"waist_roll_link\"/>\n                    <body name=\"torso_link\">\n                        <inertial pos=\"0.00203158 0.000339683 0.184568\" quat=\"0.999803 -6.03319e-05 0.0198256 0.00131986\"\n                                  mass=\"7.818\" diaginertia=\"0.121847 0.109825 0.0273735\"/>\n                        <joint name=\"waist_pitch_joint\" class=\"waist_pitch\"/>\n                        <geom class=\"visual\" mesh=\"torso_link\"/>\n                        <geom pos=\"0.0039635 0 -0.044\" quat=\"1 0 0 0\" class=\"visual\" material=\"black\" mesh=\"logo_link\"/>\n                        <geom pos=\"0.0039635 0 -0.044\" class=\"visual\" material=\"black\" mesh=\"head_link\"/>\n\n                        <geom name=\"torso_collision1\" class=\"collision\" type=\"capsule\" size=\"0.073\" fromto=\"0.005 -0.032 .22 0.005 0.032 .22\"/>\n                        <geom name=\"torso_collision2\" class=\"collision\" type=\"capsule\" size=\"0.07\" fromto=\"0.005 -0.028 .13 0.005 0.028 .13\"/>\n                        <geom name=\"torso_collision3\" class=\"collision\" type=\"capsule\" size=\"0.065\" fromto=\"0.005 -0.02 .06 0.005 0.02 .06\"/>\n                        <geom name=\"head_collision\" class=\"collision\" type=\"capsule\" size=\"0.068\" fromto=\"0.01 0 .41 0.01 0 .42\"/>\n\n                        <site name=\"imu_in_torso\" size=\"0.01\" pos=\"-0.03959 -0.00224 0.14792\"/>\n                        <site name=\"mid360\" size=\"0.01\" pos=\"0.0002835 0.00003 0.41618\" quat=\"0.00000094 -0.99979404 0.00004632 0.02029493\"/>\n                        <body name=\"left_shoulder_pitch_link\" pos=\"0.0039563 0.10022 0.24778\"\n                              quat=\"0.990264 0.139201 1.38722e-05 -9.86868e-05\">\n                            <inertial pos=\"0 0.035892 -0.011628\" quat=\"0.654152 0.0130458 -0.326267 0.68225\" mass=\"0.718\"\n                                      diaginertia=\"0.000465864 0.000432842 0.000406394\"/>\n                            <joint name=\"left_shoulder_pitch_joint\" class=\"shoulder_pitch\"/>\n                            <geom class=\"visual\" mesh=\"left_shoulder_pitch_link\"/>\n                            <geom size=\"0.03 0.025\" pos=\"0 0.04 -0.01\" quat=\"0.707107 0 0.707107 0\" type=\"cylinder\" class=\"visual\"/>\n                            <body name=\"left_shoulder_roll_link\" pos=\"0 0.038 -0.013831\" quat=\"0.990268 -0.139172 0 0\">\n                                <inertial pos=\"-0.000227 0.00727 -0.063243\" quat=\"0.701256 -0.0196223 -0.00710317 0.712604\" mass=\"0.643\"\n                                          diaginertia=\"0.000691311 0.000618011 0.000388977\"/>\n                                <joint name=\"left_shoulder_roll_joint\" range=\"-1.5882 2.2515\" class=\"shoulder_roll\"/>\n                                <geom class=\"visual\" mesh=\"left_shoulder_roll_link\"/>\n                                <geom size=\"0.03 0.015\" pos=\"-0.004 0.006 -0.053\" type=\"cylinder\" class=\"visual\"/>\n                                <body name=\"left_shoulder_yaw_link\" pos=\"0 0.00624 -0.1032\">\n                                    <inertial pos=\"0.010773 -0.002949 -0.072009\" quat=\"0.716879 -0.0964829 -0.0679942 0.687134\"\n                                              mass=\"0.734\" diaginertia=\"0.00106187 0.00103217 0.000400661\"/>\n                                    <joint name=\"left_shoulder_yaw_joint\" class=\"shoulder_yaw\"/>\n                                    <geom class=\"visual\" mesh=\"left_shoulder_yaw_link\"/>\n                                    <geom name=\"left_shoulder_yaw_collision\" class=\"collision\" type=\"capsule\" size=\"0.035\" fromto=\"0 0 -0.08 0 0 0.05\"/>\n                                    <body name=\"left_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                                        <inertial pos=\"0.064956 0.004454 -0.010062\" quat=\"0.541765 0.636132 0.388821 0.388129\" mass=\"0.6\"\n                                                  diaginertia=\"0.000443035 0.000421612 0.000259353\"/>\n                                        <joint name=\"left_elbow_joint\" class=\"elbow\"/>\n                                        <geom class=\"visual\" mesh=\"left_elbow_link\"/>\n                                        <geom name=\"left_elbow_yaw_collision\" class=\"collision\" type=\"capsule\" size=\"0.035\" fromto=\"-0.01 0 -0.01 0.12 0 -0.01\"/>\n                                        <body name=\"left_wrist_roll_link\" pos=\"0.1 0.00188791 -0.01\">\n                                            <inertial pos=\"0.0171394 0.000537591 4.8864e-07\" quat=\"0.575338 0.411667 -0.574906 0.411094\"\n                                                      mass=\"0.085445\" diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\"/>\n                                            <joint name=\"left_wrist_roll_joint\" class=\"wrist_roll\"/>\n                                            <geom class=\"visual\" mesh=\"left_wrist_roll_link\"/>\n                                            <body name=\"left_wrist_pitch_link\" pos=\"0.038 0 0\">\n                                                <inertial pos=\"0.0229999 -0.00111685 -0.00111658\" quat=\"0.249998 0.661363 0.293036 0.643608\"\n                                                          mass=\"0.48405\" diaginertia=\"0.000430353 0.000429873 0.000164648\"/>\n                                                <joint name=\"left_wrist_pitch_joint\" class=\"wrist_pitch\"/>\n                                                <geom class=\"visual\" mesh=\"left_wrist_pitch_link\"/>\n                                                <body name=\"left_wrist_yaw_link\" pos=\"0.046 0 0\">\n                                                    <inertial pos=\"0.0708244 0.000191745 0.00161742\" quat=\"0.510571 0.526295 0.468078 0.493188\"\n                                                              mass=\"0.254576\" diaginertia=\"0.000646113 0.000559993 0.000147566\"/>\n                                                    <joint name=\"left_wrist_yaw_joint\" class=\"wrist_yaw\"/>\n                                                    <geom class=\"visual\" mesh=\"left_wrist_yaw_link\"/>\n                                                    <geom pos=\"0.0415 0.003 0\" quat=\"1 0 0 0\" class=\"visual\" mesh=\"left_rubber_hand\"/>\n                                                    <site name=\"left_palm\" pos=\"0.08 0 0\" size=\"0.01\"/>\n                                                    <geom name=\"left_hand_collision\" class=\"collision\" type=\"capsule\" size=\"0.05\" fromto=\"0.05 0 0 0.1 0 0\"/>\n                                                </body>\n                                            </body>\n                                        </body>\n                                    </body>\n                                </body>\n                            </body>\n                        </body>\n                        <body name=\"right_shoulder_pitch_link\" pos=\"0.0039563 -0.10021 0.24778\"\n                              quat=\"0.990264 -0.139201 1.38722e-05 9.86868e-05\">\n                            <inertial pos=\"0 -0.035892 -0.011628\" quat=\"0.68225 -0.326267 0.0130458 0.654152\" mass=\"0.718\"\n                                      diaginertia=\"0.000465864 0.000432842 0.000406394\"/>\n                            <joint name=\"right_shoulder_pitch_joint\" class=\"shoulder_pitch\"/>\n                            <geom class=\"visual\" mesh=\"right_shoulder_pitch_link\"/>\n                            <geom size=\"0.03 0.025\" pos=\"0 -0.04 -0.01\" quat=\"0.707107 0 0.707107 0\" type=\"cylinder\" class=\"visual\"/>\n                            <body name=\"right_shoulder_roll_link\" pos=\"0 -0.038 -0.013831\" quat=\"0.990268 0.139172 0 0\">\n                                <inertial pos=\"-0.000227 -0.00727 -0.063243\" quat=\"0.712604 -0.00710317 -0.0196223 0.701256\"\n                                          mass=\"0.643\" diaginertia=\"0.000691311 0.000618011 0.000388977\"/>\n                                <joint name=\"right_shoulder_roll_joint\" range=\"-2.2515 1.5882\" class=\"shoulder_roll\"/>\n                                <geom class=\"visual\" mesh=\"right_shoulder_roll_link\"/>\n                                <geom size=\"0.03 0.015\" pos=\"-0.004 -0.006 -0.053\" type=\"cylinder\" class=\"visual\"/>\n                                <body name=\"right_shoulder_yaw_link\" pos=\"0 -0.00624 -0.1032\">\n                                    <inertial pos=\"0.010773 0.002949 -0.072009\" quat=\"0.687134 -0.0679942 -0.0964829 0.716879\"\n                                              mass=\"0.734\" diaginertia=\"0.00106187 0.00103217 0.000400661\"/>\n                                    <joint name=\"right_shoulder_yaw_joint\" class=\"shoulder_yaw\"/>\n                                    <geom class=\"visual\" mesh=\"right_shoulder_yaw_link\"/>\n                                    <geom name=\"right_shoulder_yaw_collision\" class=\"collision\" type=\"capsule\" size=\"0.035\" fromto=\"0 0 -0.08 0 0 0.05\"/>\n                                    <body name=\"right_elbow_link\" pos=\"0.015783 0 -0.080518\">\n                                        <inertial pos=\"0.064956 -0.004454 -0.010062\" quat=\"0.388129 0.388821 0.636132 0.541765\" mass=\"0.6\"\n                                                  diaginertia=\"0.000443035 0.000421612 0.000259353\"/>\n                                        <joint name=\"right_elbow_joint\" class=\"elbow\"/>\n                                        <geom class=\"visual\" mesh=\"right_elbow_link\"/>\n                                        <geom name=\"right_elbow_yaw_collision\" class=\"collision\" type=\"capsule\" size=\"0.035\" fromto=\"-0.01 0 -0.01 0.12 0 -0.01\"/>\n                                        <body name=\"right_wrist_roll_link\" pos=\"0.1 -0.00188791 -0.01\">\n                                            <inertial pos=\"0.0171394 -0.000537591 4.8864e-07\" quat=\"0.411667 0.575338 -0.411094 0.574906\"\n                                                      mass=\"0.085445\" diaginertia=\"5.48211e-05 4.96646e-05 3.57798e-05\"/>\n                                            <joint name=\"right_wrist_roll_joint\" class=\"wrist_roll\"/>\n                                            <geom class=\"visual\" mesh=\"right_wrist_roll_link\"/>\n                                            <body name=\"right_wrist_pitch_link\" pos=\"0.038 0 0\">\n                                                <inertial pos=\"0.0229999 0.00111685 -0.00111658\" quat=\"0.643608 0.293036 0.661363 0.249998\"\n                                                          mass=\"0.48405\" diaginertia=\"0.000430353 0.000429873 0.000164648\"/>\n                                                <joint name=\"right_wrist_pitch_joint\" class=\"wrist_pitch\"/>\n                                                <geom class=\"visual\" mesh=\"right_wrist_pitch_link\"/>\n                                                <body name=\"right_wrist_yaw_link\" pos=\"0.046 0 0\">\n                                                    <inertial pos=\"0.0708244 -0.000191745 0.00161742\" quat=\"0.493188 0.468078 0.526295 0.510571\"\n                                                              mass=\"0.254576\" diaginertia=\"0.000646113 0.000559993 0.000147566\"/>\n                                                    <joint name=\"right_wrist_yaw_joint\" class=\"wrist_yaw\"/>\n                                                    <geom class=\"visual\" mesh=\"right_wrist_yaw_link\"/>\n                                                    <geom pos=\"0.0415 -0.003 0\" quat=\"1 0 0 0\" class=\"visual\" mesh=\"right_rubber_hand\"/>\n                                                    <site name=\"right_palm\" pos=\"0.08 0 0\" size=\"0.01\"/>\n                                                    <geom name=\"right_hand_collision\" class=\"collision\" type=\"capsule\" size=\"0.05\" fromto=\"0.05 0 0 0.1 0 0\"/>\n                                                </body>\n                                            </body>\n                                        </body>\n                                    </body>\n                                </body>\n                            </body>\n                        </body>\n                    </body>\n                </body>\n            </body>\n        </body>\n    </worldbody>\n\n    <contact>>\n        <!-- left foot - floor -->\n        <pair name=\"left_foot1_floor\" geom1=\"left_foot1_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot2_floor\" geom1=\"left_foot2_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot3_floor\" geom1=\"left_foot3_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot4_floor\" geom1=\"left_foot4_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot5_floor\" geom1=\"left_foot5_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot6_floor\" geom1=\"left_foot6_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"left_foot7_floor\" geom1=\"left_foot7_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <!-- right foot - floor -->\n        <pair name=\"right_foot1_floor\" geom1=\"right_foot1_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot2_floor\" geom1=\"right_foot2_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot3_floor\" geom1=\"right_foot3_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot4_floor\" geom1=\"right_foot4_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot5_floor\" geom1=\"right_foot5_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot6_floor\" geom1=\"right_foot6_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n        <pair name=\"right_foot7_floor\" geom1=\"right_foot7_collision\" geom2=\"floor\" solref=\"0.01 1\" friction=\"0.8 0.8\"/>\n    </contact>\n\n    <sensor>\n        <framequat name=\"base_quat\" objtype=\"site\" objname=\"imu_in_pelvis\"/>\n        <gyro name=\"base_gyro\" site=\"imu_in_pelvis\"/>\n        <accelerometer name=\"base_accel\" site=\"imu_in_pelvis\"/>\n\n        <framequat name=\"mid360_quat\" objtype=\"site\" objname=\"mid360\"/>\n        <framepos name=\"mid360_pos\" objtype=\"site\" objname=\"mid360\"/>\n    </sensor>\n\n    <visual>\n        <headlight diffuse=\"0.6 0.6 0.6\" ambient=\"0.3 0.3 0.3\" specular=\"0 0 0\"/>\n        <rgba haze=\"0.15 0.25 0.35 1\"/>\n        <global azimuth=\"120\" elevation=\"-20\"/>\n    </visual>\n\n    <asset>\n        <texture type=\"skybox\" builtin=\"gradient\" rgb1=\"0.3 0.5 0.7\" rgb2=\"0 0 0\" width=\"512\" height=\"3072\"/>\n        <texture type=\"2d\" name=\"groundplane\" builtin=\"checker\" mark=\"edge\" rgb1=\"0.2 0.3 0.4\" rgb2=\"0.1 0.2 0.3\"\n                 markrgb=\"0.8 0.8 0.8\" width=\"300\" height=\"300\"/>\n        <material name=\"groundplane\" texture=\"groundplane\" texuniform=\"true\" texrepeat=\"5 5\" reflectance=\"0.2\"/>\n    </asset>\n\n    <worldbody>\n        <light pos=\"0 0 1.5\" dir=\"0 0 -1\" directional=\"true\"/>\n        <geom name=\"floor\" size=\"0 0 0.01\" type=\"plane\" material=\"groundplane\" conaffinity=\"1\"/>\n    </worldbody>\n\n</mujoco>\n"
  },
  {
    "path": "kimodo/assets/skeletons/somaskel77/somaskel77_standard_tpose.bvh",
    "content": "HIERARCHY\nROOT Root\n{\n  OFFSET 0.0 0.0 0.0\n  CHANNELS 6 Xposition Yposition Zposition Zrotation Yrotation Xrotation\n  JOINT Hips\n  {\n    OFFSET 0.0 100.0 0.0\n    CHANNELS 6 Xposition Yposition Zposition Zrotation Yrotation Xrotation\n    JOINT Spine1\n    {\n      OFFSET -0.013727 5.003763 -0.053727\n      CHANNELS 3 Zrotation Yrotation Xrotation\n      JOINT Spine2\n      {\n        OFFSET -0.0 7.125301 -0.029825\n        CHANNELS 3 Zrotation Yrotation Xrotation\n        JOINT Chest\n        {\n          OFFSET -1e-06 7.550063 -0.815971\n          CHANNELS 3 Zrotation Yrotation Xrotation\n          JOINT Neck1\n          {\n            OFFSET -0.181677 26.311295 -0.553348\n            CHANNELS 3 Zrotation Yrotation Xrotation\n            JOINT Neck2\n            {\n              OFFSET -3e-06 7.709397 2.302585\n              CHANNELS 3 Zrotation Yrotation Xrotation\n              JOINT Head\n              {\n                OFFSET -5e-06 6.128916 1.953709\n                CHANNELS 3 Zrotation Yrotation Xrotation\n                JOINT HeadEnd\n                {\n                  OFFSET 0.003598 16.065403 -1.835379\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                }\n                JOINT Jaw\n                {\n                  OFFSET 0.002637 0.475592 3.094941\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                }\n                JOINT LeftEye\n                {\n                  OFFSET 3.206381 5.380205 7.586883\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                }\n                JOINT RightEye\n                {\n                  OFFSET -3.22244 5.361869 7.558234\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                }\n              }\n            }\n          }\n          JOINT LeftShoulder\n          {\n            OFFSET 1.621652 23.237164 5.113413\n            CHANNELS 3 Zrotation Yrotation Xrotation\n            JOINT LeftArm\n            {\n              OFFSET 14.919846 2e-06 -5.502326\n              CHANNELS 3 Zrotation Yrotation Xrotation\n              JOINT LeftForeArm\n              {\n                OFFSET 28.739307 0.0 -0.002588\n                CHANNELS 3 Zrotation Yrotation Xrotation\n                JOINT LeftHand\n                {\n                  OFFSET 27.093981 -1e-06 0.002609\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                  JOINT LeftHandThumb1\n                  {\n                    OFFSET 2.276482 -1.392045 3.191413\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT LeftHandThumb2\n                    {\n                      OFFSET 4.012836 -1.828127 1.641654\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT LeftHandThumb3\n                      {\n                        OFFSET 2.798515 0.0 -3e-06\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT LeftHandThumbEnd\n                        {\n                          OFFSET 3.180793 -4e-06 4e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                        }\n                      }\n                    }\n                  }\n                  JOINT LeftHandIndex1\n                  {\n                    OFFSET 3.247555 -0.531998 2.296169\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT LeftHandIndex2\n                    {\n                      OFFSET 6.364578 0.01206 0.1786\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT LeftHandIndex3\n                      {\n                        OFFSET 3.662364 0.0 0.0\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT LeftHandIndex4\n                        {\n                          OFFSET 2.329242 4e-06 4e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT LeftHandIndexEnd\n                          {\n                            OFFSET 2.759615 -0.180537 -0.113024\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT LeftHandMiddle1\n                  {\n                    OFFSET 3.163495 0.240981 1.000332\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT LeftHandMiddle2\n                    {\n                      OFFSET 6.19078 -0.259278 -1.002548\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT LeftHandMiddle3\n                      {\n                        OFFSET 4.35652 -4e-06 -1e-06\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT LeftHandMiddle4\n                        {\n                          OFFSET 2.996877 -8e-06 0.0\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT LeftHandMiddleEnd\n                          {\n                            OFFSET 2.304287 -0.294569 -0.031741\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT LeftHandRing1\n                  {\n                    OFFSET 2.882643 -0.053652 -0.322543\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT LeftHandRing2\n                    {\n                      OFFSET 5.854541 -0.486202 -1.373841\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT LeftHandRing3\n                      {\n                        OFFSET 4.350578 0.0 3e-06\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT LeftHandRing4\n                        {\n                          OFFSET 2.651321 7e-06 2e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT LeftHandRingEnd\n                          {\n                            OFFSET 1.936105 0.077687 -7.1e-05\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT LeftHandPinky1\n                  {\n                    OFFSET 2.8655 -0.310005 -1.600378\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT LeftHandPinky2\n                    {\n                      OFFSET 5.087849 -1.331141 -1.77123\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT LeftHandPinky3\n                      {\n                        OFFSET 3.070974 4e-06 0.0\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT LeftHandPinky4\n                        {\n                          OFFSET 1.549672 0.0 1e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT LeftHandPinkyEnd\n                          {\n                            OFFSET 1.944893 -0.157802 0.057219\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                }\n              }\n            }\n          }\n          JOINT RightShoulder\n          {\n            OFFSET -1.380118 23.180309 5.214158\n            CHANNELS 3 Zrotation Yrotation Xrotation\n            JOINT RightArm\n            {\n              OFFSET -15.037196 1.2e-05 -5.545604\n              CHANNELS 3 Zrotation Yrotation Xrotation\n              JOINT RightForeArm\n              {\n                OFFSET -28.736639 2e-06 -0.002597\n                CHANNELS 3 Zrotation Yrotation Xrotation\n                JOINT RightHand\n                {\n                  OFFSET -27.133619 -0.0 0.002613\n                  CHANNELS 3 Zrotation Yrotation Xrotation\n                  JOINT RightHandThumb1\n                  {\n                    OFFSET -2.274032 -1.383988 3.163127\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT RightHandThumb2\n                    {\n                      OFFSET -4.011429 -1.827466 1.640914\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT RightHandThumb3\n                      {\n                        OFFSET -2.794935 -4e-06 -3e-06\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT RightHandThumbEnd\n                        {\n                          OFFSET -3.183852 4e-06 1e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                        }\n                      }\n                    }\n                  }\n                  JOINT RightHandIndex1\n                  {\n                    OFFSET -3.253266 -0.520057 2.282866\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT RightHandIndex2\n                    {\n                      OFFSET -6.341917 0.012471 0.178266\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT RightHandIndex3\n                      {\n                        OFFSET -3.654871 -8e-06 -0.0\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT RightHandIndex4\n                        {\n                          OFFSET -2.327586 0.0 1e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT RightHandIndexEnd\n                          {\n                            OFFSET -2.76179 -0.180656 -0.113078\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT RightHandMiddle1\n                  {\n                    OFFSET -3.168106 0.246593 1.00103\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT RightHandMiddle2\n                    {\n                      OFFSET -6.180828 -0.258836 -1.000895\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT RightHandMiddle3\n                      {\n                        OFFSET -4.348901 0.0 -0.0\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT RightHandMiddle4\n                        {\n                          OFFSET -3.00024 -4e-06 -2e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT RightHandMiddleEnd\n                          {\n                            OFFSET -2.30252 -0.29437 -0.031706\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT RightHandRing1\n                  {\n                    OFFSET -2.88569 -0.067952 -0.308858\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT RightHandRing2\n                    {\n                      OFFSET -5.854198 -0.48613 -1.373731\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT RightHandRing3\n                      {\n                        OFFSET -4.33881 -4e-06 -0.0\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT RightHandRing4\n                        {\n                          OFFSET -2.654903 -4e-06 4e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT RightHandRingEnd\n                          {\n                            OFFSET -1.933568 0.077527 -5.2e-05\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                  JOINT RightHandPinky1\n                  {\n                    OFFSET -2.866425 -0.342796 -1.584145\n                    CHANNELS 3 Zrotation Yrotation Xrotation\n                    JOINT RightHandPinky2\n                    {\n                      OFFSET -5.091371 -1.332055 -1.772385\n                      CHANNELS 3 Zrotation Yrotation Xrotation\n                      JOINT RightHandPinky3\n                      {\n                        OFFSET -3.062664 -4e-06 1e-06\n                        CHANNELS 3 Zrotation Yrotation Xrotation\n                        JOINT RightHandPinky4\n                        {\n                          OFFSET -1.546529 4e-06 -2e-06\n                          CHANNELS 3 Zrotation Yrotation Xrotation\n                          JOINT RightHandPinkyEnd\n                          {\n                            OFFSET -1.945119 -0.157718 0.057211\n                            CHANNELS 3 Zrotation Yrotation Xrotation\n                          }\n                        }\n                      }\n                    }\n                  }\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n    JOINT LeftLeg\n    {\n      OFFSET 10.043214 -8.434526 2.595655\n      CHANNELS 3 Zrotation Yrotation Xrotation\n      JOINT LeftShin\n      {\n        OFFSET -1e-06 -43.221752 -0.802913\n        CHANNELS 3 Zrotation Yrotation Xrotation\n        JOINT LeftFoot\n        {\n          OFFSET 1e-06 -42.155094 -3.481523\n          CHANNELS 3 Zrotation Yrotation Xrotation\n          JOINT LeftToeBase\n          {\n            OFFSET 0.0 -5.059472 13.231529\n            CHANNELS 3 Zrotation Yrotation Xrotation\n            JOINT LeftToeEnd\n            {\n              OFFSET -0.009607 -1.647619 6.513017\n              CHANNELS 3 Zrotation Yrotation Xrotation\n            }\n          }\n        }\n      }\n    }\n    JOINT RightLeg\n    {\n      OFFSET -10.047278 -8.29526 2.620317\n      CHANNELS 3 Zrotation Yrotation Xrotation\n      JOINT RightShin\n      {\n        OFFSET 1e-06 -43.362206 -0.805556\n        CHANNELS 3 Zrotation Yrotation Xrotation\n        JOINT RightFoot\n        {\n          OFFSET 2e-06 -42.117393 -3.478398\n          CHANNELS 3 Zrotation Yrotation Xrotation\n          JOINT RightToeBase\n          {\n            OFFSET -0.0 -5.079609 13.284196\n            CHANNELS 3 Zrotation Yrotation Xrotation\n            JOINT RightToeEnd\n            {\n              OFFSET 0.009532 -1.634378 6.460591\n              CHANNELS 3 Zrotation Yrotation Xrotation\n            }\n          }\n        }\n      }\n    }\n  }\n}\nMOTION\nFrames: 1\nFrame Time: 0.03333333333333333\n0.0 0.0 0.0 0.0 0.0 0.0 0.0 100.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n"
  },
  {
    "path": "kimodo/assets.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom pathlib import Path\n\nPACKAGE_ROOT = Path(__file__).resolve().parent\nASSETS_ROOT = PACKAGE_ROOT / \"assets\"\nDEMO_ASSETS_ROOT = ASSETS_ROOT / \"demo\"\nDEMO_EXAMPLES_ROOT = DEMO_ASSETS_ROOT / \"examples\"\nSKELETONS_ROOT = ASSETS_ROOT / \"skeletons\"\nSOMA_ASSETS_ROOT = ASSETS_ROOT / \"SOMA\"\n\n\ndef skeleton_asset_path(*parts: str) -> Path:\n    return SKELETONS_ROOT.joinpath(*parts)\n\n\ndef demo_asset_path(*parts: str) -> Path:\n    return DEMO_ASSETS_ROOT.joinpath(*parts)\n"
  },
  {
    "path": "kimodo/constraints.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Constraint sets for conditioning motion generation (root 2D, full body, end-effectors).\"\"\"\n\nfrom typing import Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom kimodo.motion_rep.feature_utils import compute_heading_angle\nfrom kimodo.skeleton import SkeletonBase, SOMASkeleton30, SOMASkeleton77\nfrom kimodo.tools import ensure_batched, load_json, save_json\n\nfrom .geometry import axis_angle_to_matrix, matrix_to_axis_angle\n\n\ndef _convert_constraint_local_rots_to_skeleton(local_rot_mats: Tensor, skeleton: SkeletonBase) -> Tensor:\n    \"\"\"Convert loaded local rotation matrices to match the skeleton's joint count.\n\n    Handles SOMA 30↔77: constraint files may have been saved with 30 or 77 joints while the session\n    skeleton (e.g. from the SOMA30 model) uses SOMASkeleton77.\n    \"\"\"\n    n_joints = local_rot_mats.shape[-3]\n    skeleton_joints = skeleton.nbjoints\n    if n_joints == skeleton_joints:\n        return local_rot_mats\n    if n_joints == 77 and skeleton_joints == 30 and isinstance(skeleton, SOMASkeleton30):\n        return skeleton.from_SOMASkeleton77(local_rot_mats)\n    if n_joints == 30 and skeleton_joints == 77 and isinstance(skeleton, SOMASkeleton77):\n        skel30 = SOMASkeleton30()\n        return skel30.to_SOMASkeleton77(local_rot_mats)\n    raise ValueError(\n        f\"Constraint joint count ({n_joints}) does not match skeleton joint count \"\n        f\"({skeleton_joints}). Only SOMA 30↔77 conversion is supported.\"\n    )\n\n\ndef create_pairs(tensor_A: Tensor, tensor_B: Tensor) -> Tensor:\n    \"\"\"Form all (a, b) pairs from two 1D tensors; output shape (len(A)*len(B), 2).\"\"\"\n    pairs = torch.stack(\n        (\n            tensor_A[:, None].expand(-1, len(tensor_B)),\n            tensor_B.expand(len(tensor_A), -1),\n        ),\n        dim=-1,\n    ).reshape(-1, 2)\n    return pairs\n\n\ndef compute_global_heading(global_joints_positions: Tensor, skeleton: SkeletonBase) -> Tensor:\n    \"\"\"Compute global root heading (cos, sin) from global joint positions using skeleton.\"\"\"\n    root_heading_angle = compute_heading_angle(global_joints_positions, skeleton)\n    global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)\n    return global_root_heading\n\n\ndef _tensor_to(\n    t: Tensor,\n    device: Optional[Union[str, torch.device]] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> Tensor:\n    \"\"\"Move tensor to device and/or dtype.\n\n    Returns same tensor if no args.\n    \"\"\"\n    if device is not None and dtype is not None:\n        return t.to(device=device, dtype=dtype)\n    if device is not None:\n        return t.to(device=device)\n    if dtype is not None:\n        return t.to(dtype=dtype)\n    return t\n\n\nclass Root2DConstraintSet:\n    \"\"\"Constraint set fixing root (x, z) trajectory and optionally global heading on given\n    frames.\"\"\"\n\n    name = \"root2d\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        frame_indices: Tensor,\n        smooth_root_2d: Tensor,\n        to_crop: bool = False,\n        global_root_heading: Optional[Tensor] = None,\n    ) -> None:\n        self.skeleton = skeleton\n\n        # if we pass the full smooth root 3D as input\n        if smooth_root_2d.shape[-1] == 3:\n            smooth_root_2d = smooth_root_2d[..., [0, 1]]\n\n        if to_crop:\n            smooth_root_2d = smooth_root_2d[frame_indices]\n            if global_root_heading is not None:\n                global_root_heading = global_root_heading[frame_indices]\n        else:\n            assert len(smooth_root_2d) == len(\n                frame_indices\n            ), \"The number of smooth root 2d should be match the number of frames\"\n            if global_root_heading is not None:\n                assert len(global_root_heading) == len(\n                    frame_indices\n                ), \"The number of global root heading should be match the number of frames\"\n\n        self.smooth_root_2d = smooth_root_2d\n        self.global_root_heading = global_root_heading\n        self.frame_indices = frame_indices\n\n    def update_constraints(self, data_dict: dict, index_dict: dict) -> None:\n        \"\"\"Append this constraint's smooth_root_2d (and optional global_root_heading) to data/index\n        dicts.\"\"\"\n        data_dict[\"smooth_root_2d\"].append(self.smooth_root_2d)\n        index_dict[\"smooth_root_2d\"].append(self.frame_indices)\n\n        if self.global_root_heading is not None:\n            # constraint the global heading\n            data_dict[\"global_root_heading\"].append(self.global_root_heading)\n            index_dict[\"global_root_heading\"].append(self.frame_indices)\n\n    def crop_move(self, start: int, end: int) -> \"Root2DConstraintSet\":\n        \"\"\"Return a new constraint set for the cropped frame range [start, end).\"\"\"\n        mask = (self.frame_indices >= start) & (self.frame_indices < end)\n\n        if self.global_root_heading is not None:\n            masked_global_root_heading = self.global_root_heading[mask]\n        else:\n            masked_global_root_heading = None\n\n        return Root2DConstraintSet(\n            self.skeleton,\n            self.frame_indices[mask] - start,\n            self.smooth_root_2d[mask],\n            global_root_heading=masked_global_root_heading,\n        )\n\n    def get_save_info(self) -> dict:\n        \"\"\"Return a dict suitable for JSON serialization (frame_indices, smooth_root_2d, optional\n        global_root_heading).\"\"\"\n        out = {\n            \"type\": self.name,\n            \"frame_indices\": self.frame_indices,\n            \"smooth_root_2d\": self.smooth_root_2d,\n        }\n        if self.global_root_heading is not None:\n            out[\"global_root_heading\"] = self.global_root_heading\n        return out\n\n    def to(\n        self,\n        device: Optional[Union[str, torch.device]] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> \"Root2DConstraintSet\":\n        self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)\n        self.frame_indices = _tensor_to(self.frame_indices, device, dtype)\n        if self.global_root_heading is not None:\n            self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)\n        if device is not None and hasattr(self.skeleton, \"to\"):\n            self.skeleton = self.skeleton.to(device)\n        return self\n\n    @classmethod\n    def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> \"Root2DConstraintSet\":\n        \"\"\"Build a Root2DConstraintSet from a dict (e.g. loaded from JSON).\"\"\"\n        device = skeleton.device if hasattr(skeleton, \"device\") else \"cpu\"\n\n        if \"global_root_heading\" in dico:\n            global_root_heading = torch.tensor(dico[\"global_root_heading\"], device=device)\n        else:\n            global_root_heading = None\n\n        return cls(\n            skeleton,\n            frame_indices=torch.tensor(dico[\"frame_indices\"]),\n            smooth_root_2d=torch.tensor(dico[\"smooth_root_2d\"], device=device),\n            global_root_heading=global_root_heading,\n        )\n\n\nclass FullBodyConstraintSet:\n    \"\"\"Constraint set fixing full-body global positions and rotations on given keyframes.\"\"\"\n\n    name = \"fullbody\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        frame_indices: Tensor,\n        global_joints_positions: Tensor,\n        global_joints_rots: Tensor,\n        smooth_root_2d: Optional[Tensor] = None,\n        to_crop: bool = False,\n    ):\n        self.skeleton = skeleton\n        self.frame_indices = frame_indices\n\n        # if we pass the full smooth root 3D as input\n        if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:\n            smooth_root_2d = smooth_root_2d[..., [0, 1]]\n\n        if to_crop:\n            global_joints_positions = global_joints_positions[frame_indices]\n            global_joints_rots = global_joints_rots[frame_indices]\n            if smooth_root_2d is not None:\n                smooth_root_2d = smooth_root_2d[frame_indices]\n        else:\n            assert len(global_joints_positions) == len(\n                frame_indices\n            ), \"The number of global positions should be match the number of frames\"\n            assert len(global_joints_rots) == len(\n                frame_indices\n            ), \"The number of global joint rotations should be match the number of frames\"\n\n            if smooth_root_2d is not None:\n                assert len(smooth_root_2d) == len(\n                    frame_indices\n                ), \"The number of smooth root 2d (if specified) should be match the number of frames\"\n\n        if smooth_root_2d is None:\n            # substitute the smooth root 2d with the real root\n            smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]\n\n        # root y: from smooth or pelvis is the same\n        self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]\n\n        self.global_joints_positions = global_joints_positions\n        self.global_joints_rots = global_joints_rots\n        self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)\n        self.smooth_root_2d = smooth_root_2d\n\n    def update_constraints(self, data_dict: dict, index_dict: dict) -> None:\n        \"\"\"Append global positions, smooth root 2D, root y, and global heading to data/index\n        dicts.\"\"\"\n        nbjoints = self.skeleton.nbjoints\n        indices_lst = create_pairs(\n            self.frame_indices,\n            torch.arange(nbjoints, device=self.frame_indices.device),\n        )\n        data_dict[\"global_joints_positions\"].append(\n            self.global_joints_positions.reshape(-1, 3)\n        )  # flatten the global positions\n        index_dict[\"global_joints_positions\"].append(indices_lst)\n\n        # global rotations are not used here\n\n        # as we use smooth root, also constraint the smooth root to get the same full body\n        # maybe keep storing the hips offset, if we smooth it ourselves\n        data_dict[\"smooth_root_2d\"].append(self.smooth_root_2d)\n        index_dict[\"smooth_root_2d\"].append(self.frame_indices)\n\n        # constraint the y pos of the root\n        data_dict[\"root_y_pos\"].append(self.root_y_pos)\n        index_dict[\"root_y_pos\"].append(self.frame_indices)\n\n        # constraint the global heading\n        data_dict[\"global_root_heading\"].append(self.global_root_heading)\n        index_dict[\"global_root_heading\"].append(self.frame_indices)\n\n    def crop_move(self, start: int, end: int) -> \"FullBodyConstraintSet\":\n        \"\"\"Return a new FullBodyConstraintSet for the cropped frame range [start, end).\"\"\"\n        mask = (self.frame_indices >= start) & (self.frame_indices < end)\n        return FullBodyConstraintSet(\n            self.skeleton,\n            self.frame_indices[mask] - start,\n            self.global_joints_positions[mask],\n            self.global_joints_rots[mask],\n            self.smooth_root_2d[mask],\n        )\n\n    def get_save_info(self) -> dict:\n        \"\"\"Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d.\"\"\"\n        local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)\n        if isinstance(self.skeleton, SOMASkeleton30):\n            local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)\n        local_joints_rot = matrix_to_axis_angle(local_joints_rot)\n\n        root_positions = self.global_joints_positions[:, self.skeleton.root_idx]\n        return {\n            \"type\": self.name,\n            \"frame_indices\": self.frame_indices,\n            \"local_joints_rot\": local_joints_rot,\n            \"root_positions\": root_positions,\n            \"smooth_root_2d\": self.smooth_root_2d,\n        }\n\n    def to(\n        self,\n        device: Optional[Union[str, torch.device]] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> \"FullBodyConstraintSet\":\n        self.frame_indices = _tensor_to(self.frame_indices, device, dtype)\n        self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)\n        self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)\n        self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)\n        self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)\n        self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)\n        if device is not None and hasattr(self.skeleton, \"to\"):\n            self.skeleton = self.skeleton.to(device)\n        return self\n\n    @classmethod\n    def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> \"FullBodyConstraintSet\":\n        \"\"\"Build a FullBodyConstraintSet from a dict (e.g. loaded from JSON).\"\"\"\n        frame_indices = torch.tensor(dico[\"frame_indices\"])\n        device = skeleton.device if hasattr(skeleton, \"device\") else \"cpu\"\n        local_rot = torch.tensor(dico[\"local_joints_rot\"], device=device)\n        local_rot_mats = axis_angle_to_matrix(local_rot)\n        local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)\n        global_joints_rots, global_joints_positions, _ = skeleton.fk(\n            local_rot_mats,\n            torch.tensor(dico[\"root_positions\"], device=device),\n        )\n        smooth_root_2d = None\n        if \"smooth_root_2d\" in dico:\n            smooth_root_2d = torch.tensor(dico[\"smooth_root_2d\"], device=device)\n\n        return cls(\n            skeleton,\n            frame_indices=frame_indices,\n            global_joints_positions=global_joints_positions,\n            global_joints_rots=global_joints_rots,\n            smooth_root_2d=smooth_root_2d,\n        )\n\n\nclass EndEffectorConstraintSet:\n    \"\"\"Constraint set fixing selected end-effector positions and rotations on given frames.\"\"\"\n\n    name = \"end-effector\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        frame_indices: Tensor,\n        global_joints_positions: Tensor,\n        global_joints_rots: Tensor,\n        smooth_root_2d: Optional[Tensor],\n        *,\n        joint_names: list[str],\n        to_crop: bool = False,\n    ) -> None:\n        self.skeleton = skeleton\n        self.frame_indices = frame_indices\n        self.joint_names = joint_names\n\n        # joint_names are constant for all the frames\n        rot_joint_names, pos_joint_names = self.skeleton.expand_joint_names(self.joint_names)\n        # indexing works for motion_rep with smooth root only (contains pelvis index)\n        self.pos_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in pos_joint_names])\n        self.rot_indices = torch.tensor([self.skeleton.bone_index[jname] for jname in rot_joint_names])\n\n        # if we pass the full smooth root 3D as input\n        if smooth_root_2d is not None and smooth_root_2d.shape[-1] == 3:\n            smooth_root_2d = smooth_root_2d[..., [0, 1]]\n\n        if to_crop:\n            global_joints_positions = global_joints_positions[frame_indices]\n            global_joints_rots = global_joints_rots[frame_indices]\n            if smooth_root_2d is not None:\n                smooth_root_2d = smooth_root_2d[frame_indices]\n        else:\n            assert len(global_joints_positions) == len(\n                frame_indices\n            ), \"The number of global positions should be match the number of frames\"\n            assert len(global_joints_rots) == len(\n                frame_indices\n            ), \"The number of global joint rotations should be match the number of frames\"\n            if smooth_root_2d is not None:\n                assert len(smooth_root_2d) == len(\n                    frame_indices\n                ), \"The number of smooth root 2d (if specified) should be match the number of frames\"\n\n        if smooth_root_2d is None:\n            # substitute the smooth root 2d with the real root\n            smooth_root_2d = global_joints_positions[:, skeleton.root_idx, [0, 2]]\n\n        # root y: from smooth or pelvis is the same\n        self.root_y_pos = global_joints_positions[:, skeleton.root_idx, 1]\n\n        self.global_joints_positions = global_joints_positions\n        self.global_root_heading = compute_global_heading(global_joints_positions, skeleton)\n        self.global_joints_rots = global_joints_rots\n        self.smooth_root_2d = smooth_root_2d\n\n    def update_constraints(self, data_dict: dict, index_dict: dict) -> None:\n        \"\"\"Append constrained joint positions/rots, smooth root 2D, root y, and heading to\n        data/index dicts.\"\"\"\n        crop_frames_indexing = torch.arange(len(self.frame_indices), device=self.frame_indices.device)\n\n        # constraint positions\n        pos_indices_real = create_pairs(\n            self.frame_indices,\n            self.pos_indices,\n        )\n        pos_indices_crop = create_pairs(\n            crop_frames_indexing,\n            self.pos_indices,\n        )\n        data_dict[\"global_joints_positions\"].append(self.global_joints_positions[tuple(pos_indices_crop.T)])\n        index_dict[\"global_joints_positions\"].append(pos_indices_real)\n\n        # constraint rotations\n        rot_indices_real = create_pairs(\n            self.frame_indices,\n            self.rot_indices,\n        )\n        rot_indices_crop = create_pairs(\n            crop_frames_indexing,\n            self.rot_indices,\n        )\n        data_dict[\"global_joints_rots\"].append(self.global_joints_rots[tuple(rot_indices_crop.T)])\n        index_dict[\"global_joints_rots\"].append(rot_indices_real)\n\n        # as we use smooth root, also constraint the smooth root to get the same full body\n        # maybe keep storing the hips offset, if we smooth it ourselves\n        data_dict[\"smooth_root_2d\"].append(self.smooth_root_2d)\n        index_dict[\"smooth_root_2d\"].append(self.frame_indices)\n\n        # constraint the y pos of the root\n        data_dict[\"root_y_pos\"].append(self.root_y_pos)\n        index_dict[\"root_y_pos\"].append(self.frame_indices)\n\n        # constraint the global heading\n        data_dict[\"global_root_heading\"].append(self.global_root_heading)\n        index_dict[\"global_root_heading\"].append(self.frame_indices)\n\n    def crop_move(self, start: int, end: int) -> \"EndEffectorConstraintSet\":\n        \"\"\"Return a new EndEffectorConstraintSet for the cropped frame range [start, end).\"\"\"\n        mask = (self.frame_indices >= start) & (self.frame_indices < end)\n\n        cls = type(self)\n        kwargs = {}\n        if not hasattr(cls, \"joint_names\"):\n            kwargs[\"joint_names\"] = self.joint_names\n\n        return cls(\n            self.skeleton,\n            self.frame_indices[mask] - start,\n            self.global_joints_positions[mask],\n            self.global_joints_rots[mask],\n            self.smooth_root_2d[mask],\n            **kwargs,\n        )\n\n    def get_save_info(self) -> dict:\n        \"\"\"Return a dict for JSON save: type, frame_indices, local_joints_rot, root_positions, smooth_root_2d, joint_names.\"\"\"\n        local_joints_rot = self.skeleton.global_rots_to_local_rots(self.global_joints_rots)\n        if isinstance(self.skeleton, SOMASkeleton30):\n            local_joints_rot = self.skeleton.to_SOMASkeleton77(local_joints_rot)\n        local_joints_rot = matrix_to_axis_angle(local_joints_rot)\n\n        root_positions = self.global_joints_positions[:, self.skeleton.root_idx]\n        output = {\n            \"type\": self.name,\n            \"frame_indices\": self.frame_indices,\n            \"local_joints_rot\": local_joints_rot,\n            \"root_positions\": root_positions,\n            \"smooth_root_2d\": self.smooth_root_2d,\n        }\n        if not hasattr(self.__class__, \"joint_names\"):\n            # save the joint_names for this base class\n            # but not for children\n            output[\"joint_names\"] = self.joint_names\n        return output\n\n    def to(\n        self,\n        device: Optional[Union[str, torch.device]] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> \"EndEffectorConstraintSet\":\n        self.frame_indices = _tensor_to(self.frame_indices, device, dtype)\n        self.pos_indices = _tensor_to(self.pos_indices, device, dtype)\n        self.rot_indices = _tensor_to(self.rot_indices, device, dtype)\n        self.root_y_pos = _tensor_to(self.root_y_pos, device, dtype)\n        self.global_joints_positions = _tensor_to(self.global_joints_positions, device, dtype)\n        self.global_root_heading = _tensor_to(self.global_root_heading, device, dtype)\n        self.global_joints_rots = _tensor_to(self.global_joints_rots, device, dtype)\n        self.smooth_root_2d = _tensor_to(self.smooth_root_2d, device, dtype)\n        if device is not None and hasattr(self.skeleton, \"to\"):\n            self.skeleton = self.skeleton.to(device)\n        return self\n\n    @classmethod\n    def from_dict(cls, skeleton: SkeletonBase, dico: dict) -> \"EndEffectorConstraintSet\":\n        \"\"\"Build an EndEffectorConstraintSet from a dict (e.g. loaded from JSON).\"\"\"\n        frame_indices = torch.tensor(dico[\"frame_indices\"])\n        device = skeleton.device if hasattr(skeleton, \"device\") else \"cpu\"\n        local_rot = torch.tensor(dico[\"local_joints_rot\"], device=device)\n        local_rot_mats = axis_angle_to_matrix(local_rot)\n        local_rot_mats = _convert_constraint_local_rots_to_skeleton(local_rot_mats, skeleton)\n        global_joints_rots, global_joints_positions, _ = skeleton.fk(\n            local_rot_mats,\n            torch.tensor(dico[\"root_positions\"], device=device),\n        )\n        smooth_root_2d = None\n        if \"smooth_root_2d\" in dico:\n            smooth_root_2d = torch.tensor(dico[\"smooth_root_2d\"], device=device)\n\n        kwargs = {}\n        if not hasattr(cls, \"joint_names\"):\n            kwargs[\"joint_names\"] = dico[\"joint_names\"]\n\n        return cls(\n            skeleton,\n            frame_indices=frame_indices,\n            global_joints_positions=global_joints_positions,\n            global_joints_rots=global_joints_rots,\n            smooth_root_2d=smooth_root_2d,\n            **kwargs,\n        )\n\n\nclass LeftHandConstraintSet(EndEffectorConstraintSet):\n    \"\"\"End-effector constraint for the left hand only.\"\"\"\n\n    name = \"left-hand\"\n    joint_names: list[str] = [\"LeftHand\"]\n\n    def __init__(self, *args, **kwargs: dict):\n        super().__init__(*args, joint_names=self.joint_names, **kwargs)\n\n\nclass RightHandConstraintSet(EndEffectorConstraintSet):\n    \"\"\"End-effector constraint for the right hand only.\"\"\"\n\n    name = \"right-hand\"\n    joint_names: list[str] = [\"RightHand\"]\n\n    def __init__(self, *args, **kwargs: dict):\n        super().__init__(*args, joint_names=self.joint_names, **kwargs)\n\n\nclass LeftFootConstraintSet(EndEffectorConstraintSet):\n    \"\"\"End-effector constraint for the left foot only.\"\"\"\n\n    name = \"left-foot\"\n    joint_names: list[str] = [\"LeftFoot\"]\n\n    def __init__(self, *args, **kwargs: dict):\n        super().__init__(*args, joint_names=self.joint_names, **kwargs)\n\n\nclass RightFootConstraintSet(EndEffectorConstraintSet):\n    \"\"\"End-effector constraint for the right foot only.\"\"\"\n\n    name = \"right-foot\"\n    joint_names: list[str] = [\"RightFoot\"]\n\n    def __init__(self, *args, **kwargs: dict):\n        super().__init__(*args, joint_names=self.joint_names, **kwargs)\n\n\nTYPE_TO_CLASS = {\n    \"root2d\": Root2DConstraintSet,\n    \"fullbody\": FullBodyConstraintSet,\n    \"left-hand\": LeftHandConstraintSet,\n    \"right-hand\": RightHandConstraintSet,\n    \"left-foot\": LeftFootConstraintSet,\n    \"right-foot\": RightFootConstraintSet,\n    \"end-effector\": EndEffectorConstraintSet,\n}\n\n\ndef load_constraints_lst(\n    path_or_data: str | list,\n    skeleton: SkeletonBase,\n    device: Optional[Union[str, torch.device]] = None,\n    dtype: Optional[torch.dtype] = None,\n):\n    \"\"\"Load a list of constraints from JSON path or list of dicts.\n\n    Args:\n        path_or_data: Path to constraints.json or list of constraint dicts.\n        skeleton: Skeleton instance (used for from_dict).\n        device: If set, move all constraint tensors and skeleton to this device.\n        dtype: If set, cast constraint tensors to this dtype.\n    \"\"\"\n    if isinstance(path_or_data, str):\n        saved = load_json(path_or_data)\n    else:\n        saved = path_or_data\n\n    constraints_lst = []\n    for el in saved:\n        cls = TYPE_TO_CLASS[el[\"type\"]]\n        c = cls.from_dict(skeleton, el)\n        if device is not None or dtype is not None:\n            c.to(device=device, dtype=dtype)\n        constraints_lst.append(c)\n    return constraints_lst\n\n\ndef save_constraints_lst(path: str, constraints_lst: list) -> list | None:\n    \"\"\"Save a list of constraint sets to a JSON file.\n\n    Returns None if list is empty.\n    \"\"\"\n    if not constraints_lst:\n        print(\"The constraints lst is empty. Skip saving\")\n        return\n\n    to_save = []\n\n    def tensor_to_list(obj):\n        \"\"\"Recursively convert tensors to lists for JSON serialization.\"\"\"\n        if isinstance(obj, Tensor):\n            return obj.cpu().tolist()\n        elif isinstance(obj, dict):\n            return {k: tensor_to_list(v) for k, v in obj.items()}\n        elif isinstance(obj, list):\n            return [tensor_to_list(v) for v in obj]\n        else:\n            return obj\n\n    for constraint in constraints_lst:\n        constraint_info = constraint.get_save_info()\n        # Convert all tensors to lists for JSON serialization\n        constraint_info = tensor_to_list(constraint_info)\n        to_save.append(constraint_info)\n\n    save_json(path, to_save)\n    print(f\"Saved constraints to {path}\")\n    return to_save\n"
  },
  {
    "path": "kimodo/demo/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# ruff: noqa: I001\nimport argparse\n\nfrom kimodo.model import DEFAULT_MODEL\nfrom kimodo.model.registry import resolve_model_name\n\nfrom .app import Demo\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser(description=\"Run the kimodo demo UI.\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=DEFAULT_MODEL,\n        help=\"Default model to load (e.g. Kimodo-SOMA-RP-v1, kimodo-soma-rp, or SOMA).\",\n    )\n    args = parser.parse_args()\n\n    resolved = resolve_model_name(args.model, \"Kimodo\")\n    demo = Demo(default_model_name=resolved)\n    demo.run()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kimodo/demo/__main__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Entry point for `python -m kimodo.demo`.\"\"\"\n\nfrom kimodo.demo import main\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kimodo/demo/app.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport base64\nimport os\nimport shutil\nimport threading\nimport time\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nimport viser\nfrom kimodo.assets import DEMO_ASSETS_ROOT\nfrom kimodo.model.load_model import load_model\nfrom kimodo.model.registry import resolve_model_name\nfrom kimodo.skeleton import SkeletonBase, SOMASkeleton30\nfrom kimodo.tools import load_json\nfrom kimodo.viz import viser_utils\nfrom kimodo.viz.viser_utils import (\n    Character,\n    CharacterMotion,\n    EEJointsKeyframeSet,\n    FullbodyKeyframeSet,\n    RootKeyframe2DSet,\n)\nfrom viser.theme import TitlebarButton, TitlebarConfig, TitlebarImage\n\nfrom . import generation, ui\nfrom .config import (\n    DARK_THEME,\n    DEFAULT_CUR_DURATION,\n    DEFAULT_MODEL,\n    DEFAULT_PLAYBACK_SPEED,\n    DEFAULT_PROMPT,\n    DEMO_UI_QUICK_START_MODAL_MD,\n    EXAMPLES_ROOT_DIR,\n    HF_MODE,\n    LIGHT_THEME,\n    MAX_ACTIVE_USERS,\n    MAX_DURATION,\n    MAX_SESSION_MINUTES,\n    MIN_DURATION,\n    MODEL_EXAMPLES_DIRS,\n    MODEL_NAMES,\n    SERVER_NAME,\n    SERVER_PORT,\n)\nfrom .embedding_cache import CachedTextEncoder\nfrom .queue_manager import QueueManager, UserQueue\nfrom .state import ClientSession, ModelBundle\n\n\nclass Demo:\n    def __init__(self, default_model_name: str = DEFAULT_MODEL):\n        self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n        print(f\"Using device: {self.device}\")\n        self.models: dict[str, ModelBundle] = {}\n        self._text_encoder = None\n        resolved = resolve_model_name(default_model_name, \"Kimodo\")\n        if resolved not in MODEL_NAMES:\n            raise ValueError(f\"Unknown model '{default_model_name}'. Expected one of: {MODEL_NAMES}\")\n        self.default_model_name = resolved\n        self.ensure_examples_layout()\n        self.load_model(self.default_model_name)\n\n        # Serialize GPU-bound generation across all clients\n        self._generation_lock = threading.Lock()\n        self._cuda_healthy = True\n\n        # Per-client sessions\n        self.client_sessions: dict[int, ClientSession] = {}\n        self.start_direction_markers: dict[int, viser_utils.WaypointMesh] = {}\n        self.grid_handles: dict[int, viser.GridHandle] = {}\n\n        self.server = viser.ViserServer(\n            host=SERVER_NAME,\n            port=SERVER_PORT,\n            label=\"Kimodo\",\n            enable_camera_keyboard_controls=False,  # don't move the camera with the arrow keys\n        )\n        self.server.scene.world_axes.visible = False  # used for debugging\n        self.server.scene.set_up_direction(\"+y\")\n\n        # Register callbacks for session handling\n        self.server.on_client_connect(self.on_client_connect)\n        self.server.on_client_disconnect(self.on_client_disconnect)\n\n        # HF mode: queue and session limit\n        if HF_MODE:\n            self.user_queue = UserQueue(MAX_ACTIVE_USERS, MAX_SESSION_MINUTES)\n            self.queue_manager = QueueManager(\n                queue=self.user_queue,\n                server=self.server,\n                setup_demo_for_client=self._setup_demo_for_client,\n                cleanup_session=self._cleanup_session_for_client,\n            )\n        else:\n            self.user_queue = None\n            self.queue_manager = None\n\n        # create grid and floor\n        self.floor_len = 20.0  # meters\n\n    def ensure_examples_layout(self) -> None:\n        os.makedirs(EXAMPLES_ROOT_DIR, exist_ok=True)\n        for model_dir in MODEL_EXAMPLES_DIRS.values():\n            os.makedirs(model_dir, exist_ok=True)\n\n        for entry in os.listdir(EXAMPLES_ROOT_DIR):\n            if entry in MODEL_EXAMPLES_DIRS:\n                continue\n            src = os.path.join(EXAMPLES_ROOT_DIR, entry)\n            if not os.path.isdir(src):\n                continue\n            dst = os.path.join(\n                MODEL_EXAMPLES_DIRS.get(DEFAULT_MODEL, next(iter(MODEL_EXAMPLES_DIRS.values()))),\n                entry,\n            )\n            if not os.path.exists(dst):\n                shutil.move(src, dst)\n\n    def get_examples_base_dir(self, model_name: str, absolute: bool = True) -> str:\n        return MODEL_EXAMPLES_DIRS[model_name]\n\n    def load_model(self, model_name: str) -> ModelBundle:\n        if model_name in self.models:\n            return self.models[model_name]\n\n        print(f\"Loading model {model_name}...\")\n        try:\n            model = load_model(\n                modelname=model_name,\n                device=self.device,\n                text_encoder=self._text_encoder,\n            )\n        except Exception as e:\n            print(f\"Error loading model: {e}\\nMake sure text encoder server is running!\")\n            raise e\n\n        if hasattr(model, \"text_encoder\"):\n            if self._text_encoder is None:\n                self._text_encoder = model.text_encoder\n            model.text_encoder = CachedTextEncoder(model.text_encoder, model_name=model_name)\n\n        skeleton = model.motion_rep.skeleton\n        if isinstance(skeleton, SOMASkeleton30):\n            skeleton = skeleton.somaskel77.to(model.device)\n        bundle = ModelBundle(\n            model=model,\n            motion_rep=model.motion_rep,\n            skeleton=skeleton,\n            model_fps=model.motion_rep.fps,\n        )\n        self.models[model_name] = bundle\n        print(f\"Model {model_name} loaded successfully\")\n        self.prewarm_embedding_cache(model_name, bundle.model)\n        return bundle\n\n    def prewarm_embedding_cache(self, model_name: str, model: object) -> None:\n        encoder = getattr(model, \"text_encoder\", None)\n        if not isinstance(encoder, CachedTextEncoder):\n            return\n\n        prompt_set = set()\n        prompt_set.add(DEFAULT_PROMPT)\n\n        examples_dir = MODEL_EXAMPLES_DIRS.get(model_name)\n        if examples_dir and os.path.isdir(examples_dir):\n            for entry in os.listdir(examples_dir):\n                example_dir = os.path.join(examples_dir, entry)\n                if not os.path.isdir(example_dir):\n                    continue\n                meta_path = os.path.join(example_dir, \"meta.json\")\n                if not os.path.exists(meta_path):\n                    continue\n                try:\n                    meta = load_json(meta_path)\n                except Exception:\n                    continue\n                for prompt in meta.get(\"prompts_text\", []):\n                    if isinstance(prompt, str):\n                        prompt_set.add(prompt)\n\n        if prompt_set:\n            encoder.prewarm(list(prompt_set))\n\n    def build_constraint_tracks(\n        self, client: viser.ClientHandle, skeleton: SkeletonBase\n    ) -> dict[str, viser_utils.ConstraintSet]:\n        return {\n            \"Full-Body\": FullbodyKeyframeSet(\n                name=\"Full-Body\",\n                server=client,\n                skeleton=skeleton,\n            ),\n            \"End-Effectors\": EEJointsKeyframeSet(\n                name=\"End-Effectors\",\n                server=client,\n                skeleton=skeleton,\n            ),\n            \"2D Root\": RootKeyframe2DSet(\n                name=\"2D Root\",\n                server=client,\n                skeleton=skeleton,\n            ),\n        }\n\n    def set_timeline_defaults(self, timeline, model_fps: float) -> None:\n        timeline.set_defaults(\n            default_text=DEFAULT_PROMPT,\n            default_duration=int(DEFAULT_CUR_DURATION * model_fps - 1),\n            min_duration=int(MIN_DURATION * model_fps - 1),  # 2 seconds minimum,\n            max_duration=int(\n                MAX_DURATION * model_fps - 1  # - NB_TRANSITION_FRAMES\n            ),  # 10 seconds maximum, minus the transition frames, if needed\n            default_num_frames_zoom=int(1.10 * 10 * model_fps),  # a bit more than the max\n            max_frames_zoom=1000,\n            fps=model_fps,\n        )\n\n    def _apply_constraint_overlay_visibility(self, session: ClientSession) -> None:\n        \"\"\"Apply show-all vs show-only-current-frame to constraint overlays.\"\"\"\n        only_frame = session.frame_idx if session.show_only_current_constraint else None\n        for constraint in session.constraints.values():\n            constraint.set_overlay_visibility(only_frame)\n\n    def set_constraint_tracks_visible(self, session: ClientSession, visible: bool) -> None:\n        timeline = session.client.timeline\n        timeline_data = session.timeline_data\n        if timeline_data.get(\"constraint_tracks_visible\", True) == visible:\n            return\n\n        with timeline_data[\"keyframe_update_lock\"]:\n            if visible:\n                for track_id, track_info in timeline_data[\"tracks\"].items():\n                    timeline.add_track(\n                        track_info[\"name\"],\n                        track_type=track_info.get(\"track_type\", \"keyframe\"),\n                        color=track_info.get(\"color\"),\n                        height_scale=track_info.get(\"height_scale\", 1.0),\n                        uuid=track_id,\n                    )\n\n                for keyframe_id, keyframe_data in timeline_data[\"keyframes\"].items():\n                    timeline.add_keyframe(\n                        track_id=keyframe_data[\"track_id\"],\n                        frame=keyframe_data[\"frame\"],\n                        value=keyframe_data.get(\"value\"),\n                        opacity=keyframe_data.get(\"opacity\", 1.0),\n                        locked=keyframe_data.get(\"locked\", False),\n                        uuid=keyframe_id,\n                    )\n\n                for interval_id, interval_data in timeline_data[\"intervals\"].items():\n                    timeline.add_interval(\n                        track_id=interval_data[\"track_id\"],\n                        start_frame=interval_data[\"start_frame_idx\"],\n                        end_frame=interval_data[\"end_frame_idx\"],\n                        value=interval_data.get(\"value\"),\n                        opacity=interval_data.get(\"opacity\", 1.0),\n                        locked=interval_data.get(\"locked\", False),\n                        uuid=interval_id,\n                    )\n            else:\n                for track_id in list(timeline_data[\"tracks\"].keys()):\n                    timeline.remove_track(track_id)\n\n        timeline_data[\"constraint_tracks_visible\"] = visible\n\n    def _cleanup_session_for_client(self, client_id: int) -> None:\n        \"\"\"Remove session and scene state for a client (e.g. on session expiry).\"\"\"\n        if client_id in self.client_sessions:\n            del self.client_sessions[client_id]\n        self.start_direction_markers.pop(client_id, None)\n        self.grid_handles.pop(client_id, None)\n\n    def _setup_demo_for_client(self, client: viser.ClientHandle) -> None:\n        \"\"\"Initialize scene, GUI, and session state for a client (no modals).\"\"\"\n        self.setup_scene(client)\n\n        model_bundle = self.load_model(self.default_model_name)\n\n        # Initialize each empty constraint track\n        constraint_tracks = self.build_constraint_tracks(client, model_bundle.skeleton)\n\n        # Create GUI elements for this client\n        (\n            gui_elements,\n            timeline_tracks,\n            example_dict,\n            gui_examples_dropdown,\n            gui_save_example_path_text,\n            gui_model_selector,\n        ) = ui.create_gui(\n            demo=self,\n            client=client,\n            model_name=self.default_model_name,\n            model_fps=model_bundle.model_fps,\n        )\n        timeline_data = {\n            \"tracks\": timeline_tracks,\n            \"tracks_ids\": {val[\"name\"]: key for key, val in timeline_tracks.items()},\n            \"keyframes\": {},\n            \"intervals\": {},\n            \"keyframe_update_lock\": threading.Lock(),\n            \"keyframe_move_timers\": {},\n            \"pending_keyframe_moves\": {},  # keyframe_id -> new_frame\n            \"constraint_tracks_visible\": True,\n            \"dense_path_after_release_timer\": None,\n        }\n\n        # Initialize session state\n        cur_duration = DEFAULT_CUR_DURATION\n        max_frame_idx = int(cur_duration * model_bundle.model_fps - 1)\n\n        session = ClientSession(\n            client=client,\n            gui_elements=gui_elements,\n            motions={},\n            constraints=constraint_tracks,\n            timeline_data=timeline_data,\n            frame_idx=0,\n            playing=False,\n            playback_speed=DEFAULT_PLAYBACK_SPEED,\n            cur_duration=cur_duration,\n            max_frame_idx=max_frame_idx,\n            updating_motions=False,\n            edit_mode=False,\n            model_name=self.default_model_name,\n            model_fps=model_bundle.model_fps,\n            skeleton=model_bundle.skeleton,\n            motion_rep=model_bundle.motion_rep,\n            examples_base_dir=self.get_examples_base_dir(self.default_model_name, absolute=True),\n            example_dict=example_dict,\n            gui_examples_dropdown=gui_examples_dropdown,\n            gui_save_example_path_text=gui_save_example_path_text,\n            gui_model_selector=gui_model_selector,\n        )\n\n        self.client_sessions[client.client_id] = session\n\n        # Initialize default character for this client\n        self.add_character_motion(client, session.skeleton)\n\n    def on_client_connect(self, client: viser.ClientHandle) -> None:\n        \"\"\"Initialize GUI and state for each new client.\"\"\"\n        print(f\"Client {client.client_id} connected\")\n\n        if HF_MODE and self.queue_manager is not None:\n            self.queue_manager.on_client_connect(client)\n        else:\n            # Show quick start popup when a browser client connects (non-HF mode).\n            with client.gui.add_modal(\n                \"Welcome — Quick Start\",\n                size=\"xl\",\n                show_close_button=True,\n                save_choice=\"kimodo.demo.quick_start_ack\",\n            ) as modal:\n                client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)\n                client.gui.add_button(\"Got it (don't remind me again)\").on_click(lambda _event: modal.close())\n            self._setup_demo_for_client(client)\n\n    def setup_scene(self, client: viser.ClientHandle) -> None:\n        self.configure_theme(client)\n        client.camera.position = np.array(\n            [2.7417358737841426, 1.8790455698853281, 7.675741569777456],\n            dtype=np.float64,\n        )\n        client.camera.look_at = np.array([0.0, 0.0, 0.0], dtype=np.float64)\n        client.camera.up_direction = np.array(\n            [-1.1102230246251568e-16, 1.0, 1.3596310734468913e-32],\n            dtype=np.float64,\n        )\n        client.camera.fov = np.deg2rad(45.0)\n        grid_handle = client.scene.add_grid(\n            \"/grid\",\n            width=self.floor_len,\n            height=self.floor_len,\n            wxyz=viser.transforms.SO3.from_x_radians(-np.pi / 2.0).wxyz,\n            position=(0.0, 0.0001, 0.0),\n            fade_distance=3 * self.floor_len,\n            section_color=LIGHT_THEME[\"grid\"],\n            infinite_grid=True,\n        )\n        self.grid_handles[client.client_id] = grid_handle\n        # marker for origin\n        origin_waypoint = viser_utils.WaypointMesh(\n            \"/origin_waypoint\",\n            client,\n            position=np.array([0.0, 0.0, 0.0]),\n            heading=np.array([0.0, 1.0]),\n            color=(0, 0, 255),\n        )\n        self.start_direction_markers[client.client_id] = origin_waypoint\n\n    def on_client_disconnect(self, client: viser.ClientHandle) -> None:\n        \"\"\"Clean up when client disconnects.\"\"\"\n        print(f\"Client {client.client_id} disconnected\")\n        client_id = client.client_id\n\n        if HF_MODE and self.queue_manager is not None:\n            self.queue_manager.on_client_disconnect(client_id)\n\n        self._cleanup_session_for_client(client_id)\n\n    def set_start_direction_visible(self, client_id: int, visible: bool) -> None:\n        marker = self.start_direction_markers.get(client_id)\n        if marker is None:\n            return\n        marker.set_visible(visible)\n\n    def client_active(self, client_id: int) -> bool:\n        return client_id in self.client_sessions\n\n    def add_character_motion(\n        self,\n        client: viser.ClientHandle,\n        skeleton: SkeletonBase,\n        joints_pos: Optional[torch.Tensor] = None,\n        joints_rot: Optional[torch.Tensor] = None,\n        foot_contacts: Optional[torch.Tensor] = None,\n    ) -> None:\n        client_id = client.client_id\n        if not self.client_active(client_id):\n            return\n        session = self.client_sessions[client_id]\n\n        ci = len(session.motions)\n        character_name = f\"character{ci}\"\n        # build character skeleton and skinning mesh\n        if \"g1\" in session.model_name:\n            mesh_mode = \"g1_stl\"\n        elif \"smplx\" in session.model_name:\n            mesh_mode = \"smplx_skin\"\n        elif \"soma\" in session.model_name:\n            if session.gui_elements.gui_use_soma_layer_checkbox.value:\n                mesh_mode = \"soma_layer_skin\"\n            else:\n                mesh_mode = \"soma_skin\"\n        else:\n            raise ValueError(\"The model name is not recognized for skinning.\")\n\n        new_character = Character(\n            character_name,\n            client,\n            skeleton,\n            create_skeleton_mesh=True,\n            create_skinned_mesh=True,\n            visible_skeleton=False,  # don't show immediately\n            visible_skinned_mesh=False,  # don't show immediately\n            skinned_mesh_opacity=session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value,\n            show_foot_contacts=session.gui_elements.gui_viz_foot_contacts_checkbox.value,\n            dark_mode=session.gui_elements.gui_dark_mode_checkbox.value,\n            mesh_mode=mesh_mode,\n            gui_use_soma_layer_checkbox=session.gui_elements.gui_use_soma_layer_checkbox,\n        )\n\n        # if no motion given, initialize to character default (rest) pose for one frame\n        init_joints_pos, init_joints_rot = new_character.get_pose()\n        if joints_pos is None:\n            joints_pos = init_joints_pos[None].repeat(session.max_frame_idx + 1, 1, 1)\n        if joints_rot is None:\n            joints_rot = init_joints_rot[None].repeat(session.max_frame_idx + 1, 1, 1, 1)\n\n        new_motion = CharacterMotion(new_character, joints_pos, joints_rot, foot_contacts)\n        # save the motion in our dict\n        session.motions[character_name] = new_motion\n\n        # put the character at the right frame\n        new_motion.set_frame(session.frame_idx)\n\n        # put them visible with a small delay\n        # so that the set_frame function has time to finish\n        def _set_visibility():\n            new_motion.character.set_skinned_mesh_visibility(session.gui_elements.gui_viz_skinned_mesh_checkbox.value)\n            new_motion.character.set_skeleton_visibility(session.gui_elements.gui_viz_skeleton_checkbox.value)\n\n        timer = threading.Timer(\n            0.2,  # 0.2s delay\n            _set_visibility,\n        )\n        timer.start()\n\n    def clear_motions(self, client_id: int) -> None:\n        if not self.client_active(client_id):\n            return\n        session = self.client_sessions[client_id]\n        for motion in list(session.motions.values()):\n            motion.clear()\n        session.motions.clear()\n\n    def compute_model_constraints_lst(\n        self,\n        session: ClientSession,\n        model_bundle: ModelBundle,\n        num_frames: int,\n    ):\n        return generation.compute_model_constraints_lst(session, model_bundle, num_frames, self.device)\n\n    def check_cuda_health(self) -> bool:\n        \"\"\"Check if CUDA is still functional.\n\n        Trigger auto-restart if corrupted.\n        \"\"\"\n        if self.device == \"cpu\":\n            return True\n        try:\n            torch.tensor([1.0], device=self.device) + torch.tensor([1.0], device=self.device)\n            return True\n        except RuntimeError as e:\n            if \"device-side assert\" in str(e) or \"CUDA error\" in str(e):\n                if self._cuda_healthy:\n                    self._cuda_healthy = False\n                    print(\"FATAL: CUDA context is corrupted (device-side assert). \" \"The process must be restarted.\")\n                    self._trigger_restart()\n                return False\n            raise\n\n    def _trigger_restart(self) -> None:\n        \"\"\"Exit the process so the HF Space (or systemd/Docker) can restart it.\"\"\"\n        import sys\n\n        print(\"Initiating automatic restart due to unrecoverable CUDA error...\")\n        sys.stdout.flush()\n        sys.stderr.flush()\n        os._exit(1)\n\n    def generate(\n        self,\n        client: viser.ClientHandle,\n        prompts: list[str],\n        num_frames: list[int],\n        num_samples: int,\n        seed: int,\n        diffusion_steps: int,\n        cfg_weight: Optional[list[float]] = None,\n        cfg_type: Optional[str] = None,\n        postprocess_parameters: Optional[dict] = None,\n        transitions_parameters: Optional[dict] = None,\n        real_robot_rotations: bool = False,\n    ) -> None:\n        if not self._cuda_healthy:\n            raise RuntimeError(\"CUDA is in a corrupted state. The space is restarting...\")\n\n        locked = self._generation_lock.acquire(blocking=False)\n        if not locked:\n            waiting_notif = client.add_notification(\n                title=\"Waiting for GPU...\",\n                body=\"Another generation is in progress. Yours will start automatically.\",\n                loading=True,\n                with_close_button=False,\n            )\n            self._generation_lock.acquire()\n            waiting_notif.remove()\n\n        try:\n            session = self.client_sessions[client.client_id]\n            model_bundle = self.load_model(session.model_name)\n            generation.generate(\n                client=client,\n                session=session,\n                model_bundle=model_bundle,\n                prompts=prompts,\n                num_frames=num_frames,\n                num_samples=num_samples,\n                seed=seed,\n                diffusion_steps=diffusion_steps,\n                cfg_weight=cfg_weight,\n                cfg_type=cfg_type,\n                postprocess_parameters=postprocess_parameters,\n                transitions_parameters=transitions_parameters,\n                real_robot_rotations=real_robot_rotations,\n                device=self.device,\n                clear_motions=self.clear_motions,\n                add_character_motion=self.add_character_motion,\n            )\n        finally:\n            self._generation_lock.release()\n\n    def set_frame(self, client_id: int, frame_idx: int, update_timeline: bool = True):\n        if not self.client_active(client_id):\n            return\n\n        session = self.client_sessions[client_id]\n\n        session.frame_idx = frame_idx\n        if update_timeline:\n            session.client.timeline.set_current_frame(frame_idx)\n        for motion in list(session.motions.values()):\n            motion.set_frame(frame_idx)\n        self._apply_constraint_overlay_visibility(session)\n\n    def run(self) -> None:\n        update_counter = 0\n        cuda_check_interval = 300\n        while True:\n            last_update_time = time.time()\n            if self.models:\n                # the max playback speed is 2x the model fps (from gui_playback_speed_buttons)\n                playback_fps = max(bundle.model_fps for bundle in self.models.values()) * 2.0\n            else:\n                playback_fps = 60.0\n\n            # update each client session independently\n            #   copy to a list first to avoid changing size if client disconnects\n            for client_id, session in list(self.client_sessions.items()):\n                update_interval = int(playback_fps / (session.playback_speed * session.model_fps))\n                new_frame_idx = session.frame_idx\n                if session.playing and update_counter % update_interval == 0:\n                    if session.frame_idx >= session.max_frame_idx:\n                        new_frame_idx = 0\n                    else:\n                        new_frame_idx = session.frame_idx + 1\n\n                    # make sure the client is still active before updating the frame\n                    if self.client_active(client_id):\n                        self.set_frame(client_id, new_frame_idx)\n\n            if update_counter % cuda_check_interval == 0:\n                self.check_cuda_health()\n\n            time_remaining = max(0, 1.0 / playback_fps - (time.time() - last_update_time))\n            time.sleep(time_remaining)\n            update_counter += 1\n            update_counter %= playback_fps  # wrap around to 0 every second\n\n    def configure_theme(\n        self,\n        client: viser.ClientHandle,\n        dark_mode: bool = False,\n        titlebar_dark_mode_checkbox_uuid: str | None = None,\n    ):\n        # Sync grid color with theme (light vs dark)\n        theme = DARK_THEME if dark_mode else LIGHT_THEME\n        grid_handle = self.grid_handles.get(client.client_id)\n        if grid_handle is not None:\n            grid_handle.section_color = theme[\"grid\"]\n\n        #\n        # setup theme\n        #\n        buttons = (\n            TitlebarButton(\n                text=\"Documentation\",\n                icon=\"Description\",\n                href=\"https://research.nvidia.com/labs/sil/projects/kimodo/docs/interactive_demo/index.html\",\n            ),\n            TitlebarButton(\n                text=\"Project Page\",\n                icon=None,\n                href=\"https://research.nvidia.com/labs/sil/projects/kimodo/\",\n            ),\n            TitlebarButton(\n                text=\"Github\",\n                icon=\"GitHub\",\n                href=\"https://github.com/nv-tlabs/kimodo\",\n            ),\n        )\n        assets_dir = DEMO_ASSETS_ROOT\n        logo_light_path = assets_dir / \"nvidia_logo.png\"\n        logo_dark_path = assets_dir / \"nvidia_logo_dark.png\"\n        if logo_light_path.exists():\n            light_b64 = base64.standard_b64encode(logo_light_path.read_bytes()).decode(\"ascii\")\n            dark_b64 = (\n                base64.standard_b64encode(logo_dark_path.read_bytes()).decode(\"ascii\")\n                if logo_dark_path.exists()\n                else None\n            )\n            image = TitlebarImage(\n                image_url_light=f\"data:image/png;base64,{light_b64}\",\n                image_url_dark=(f\"data:image/png;base64,{dark_b64}\" if dark_b64 else None),\n                image_alt=\"NVIDIA\",\n                href=\"https://www.nvidia.com/\",\n            )\n        else:\n            image = None\n        titlebar_theme = TitlebarConfig(buttons=buttons, image=image, title_text=\"Kimodo\")\n        client.gui.set_panel_label(\"Kimodo\")\n        client.gui.configure_theme(\n            titlebar_content=titlebar_theme,\n            control_layout=\"floating\",  # \"floating\",  # ['floating', 'collapsible', 'fixed']\n            control_width=\"large\",  # ['small', 'medium', 'large']\n            dark_mode=dark_mode,\n            show_logo=False,  # hide viser logo on bottom left corner\n            show_share_button=False,\n            titlebar_dark_mode_checkbox_uuid=titlebar_dark_mode_checkbox_uuid,\n            brand_color=(152, 189, 255),  # (60, 131, 0),  # (R, G, B) tuple\n        )\n"
  },
  {
    "path": "kimodo/demo/config.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\n\nfrom kimodo.assets import DEMO_EXAMPLES_ROOT\nfrom kimodo.model.registry import (\n    AVAILABLE_MODELS,\n    DEFAULT_MODEL,\n    FRIENDLY_NAMES,\n    get_datasets,\n    get_model_info,\n    get_models_for_dataset_skeleton,\n    get_short_key_from_display_name,\n    get_skeleton_display_name,\n    get_skeleton_display_names_for_dataset,\n    get_skeleton_key_from_display_name,\n    get_skeletons_for_dataset,\n    get_versions_for_dataset_skeleton,\n    resolve_to_short_key,\n)\n\nSERVER_NAME = os.environ.get(\"SERVER_NAME\", \"0.0.0.0\")\nSERVER_PORT = int(os.environ.get(\"SERVER_PORT\", \"7860\"))\nHF_MODE = os.environ.get(\"HF_MODE\", False)\n\n# HF mode: user queue and session limit (override via env in Spaces)\nMAX_ACTIVE_USERS = int(os.environ.get(\"MAX_ACTIVE_USERS\", \"5\"))\nMAX_SESSION_MINUTES = float(os.environ.get(\"MAX_SESSION_MINUTES\", \"5.0\"))\n\nDEFAULT_PLAYBACK_SPEED = 1.0\n# default start duration is 6.0 sec, but model can handle up to 10 sec\nDEFAULT_CUR_DURATION = 6.0\nDEFAULT_PROMPT = \"A person walks forward.\"\nMIN_DURATION = 2.0\nMAX_DURATION = 10.0\n\nSHOW_TRANSITION_PARAMS = True\nINIT_POSTPROCESSING = True\nNB_TRANSITION_FRAMES = 5\n\nLIGHT_THEME = dict(\n    floor=(220, 220, 220),\n    grid=(180, 180, 180),\n)\n\n# Dark theme: slightly lighter grid and floor for better visibility and less flat black\nDARK_THEME = dict(\n    floor=(48, 48, 52),\n    grid=(105, 105, 110),\n)\n\nEXAMPLES_ROOT_DIR = str(DEMO_EXAMPLES_ROOT)\n\n# Model list and paths from kimodo registry (all models: Kimodo + TMR)\nMODEL_NAMES = tuple(AVAILABLE_MODELS)\nMODEL_EXAMPLES_DIRS = {name: os.path.join(EXAMPLES_ROOT_DIR, name) for name in MODEL_NAMES}\n# Display labels for backward compatibility (short_key -> display name)\nMODEL_LABELS = {name: FRIENDLY_NAMES.get(name, f\"Model ({name})\") for name in MODEL_NAMES}\nMODEL_LABEL_TO_NAME = {label: name for name, label in MODEL_LABELS.items()}\n\n# -----------------------------------------------------------------------------\n# Demo UI copy\n# -----------------------------------------------------------------------------\n\nDEMO_UI_QUICK_START_CORE_MD = \"\"\"\n### Camera\n- **Left-drag**: rotate\n- **Right-drag**: pan\n- **Scroll**: zoom\n\n### Playback\n- **Space** to play/pause\n- **←/→** to step frames, or click the frame number.\n- **Scroll up/down** in the timeline: move left/right\n- **Shift + scroll** in the timeline: zoom in/out\n\n### Prompts\n- **Double-click** a text prompt to edit it.\n- **Click and drag** the right edge of a prompt box to extend/shorten it.\n- **Click empty space** to add a prompt.\n- **Right-click** a prompt to delete it.\n\n### Generate\n- Go to the **Generate** tab to modify options\n- It is also possible to **load** examples\n- Click **Generate** to generate a motion\n\n### Constraints\n- This is **optional**: should be use after a first generation\n- **Click** in the timeline tracks (Full-Body / 2D root etc) to add a constraint.\n- **Right-click** on a constraint to delete it.\n- To **edit** a constraint:\n    - Move playback to the target frame\n    - Click **Enter Editing Mode** in the Constraints tab.\n\"\"\"\n\nDEMO_UI_QUICK_START_MODAL_MD = (\n    DEMO_UI_QUICK_START_CORE_MD\n    + \"\"\"\n\nSee the **Instructions** tab for the full user manual.\n\"\"\"\n)\n\nDEMO_UI_INSTRUCTIONS_TAB_MD = (\n    \"\"\"\n## How to Use This Demo\n\n\"\"\"\n    + DEMO_UI_QUICK_START_CORE_MD\n    + \"\"\"\n\n---\n\n### Generating Motion (step-by-step)\n\n1. **Edit the text prompts** in the timeline (e.g., \"A person walks forward.\")\n2. **Modify the duration** by moving the right edge of each prompts (2–10 seconds)\n3. **Add constraints** (optional) to control the motion:\n   - Click **Enter Editing Mode** to adjust the character pose\n   - Use the timeline to place keyframes or intervals in constraint tracks (see below)\n4. **Click Generate** to create the motion\n5. If generating multiple samples, **click on a mesh** to select which one to keep\n\n### Timeline Editing\n\n**Adding Constraints:**\n1. Click anywhere on the timeline to add a keyframe at that frame. The keyframe is created based on the current character motion.\n2. Ctrl/Cmd+click+drag to add an interval constraint, or expand a keyframe into an interval\n3. Enter editing mode with the **Enter Editing Mode** button to adjust character pose before/after adding constraints.\n\n**Constraint Types:**\n- **Full-Body**: constrains the entire character pose\n- **2D Root**: constrains the character's path on the ground plane\n  - Enable **Densify** to create a continuous path\n- **End-Effectors**: constrains hands and feet positions\n  - Use separate tracks for Left/Right Hand/Foot\n\n\n**Moving & Deleting:**\n- **Drag keyframes/intervals** to move them to different frames\n- **Right-click** a keyframe or interval to delete it\n- Use **Clear All Constraints** to remove everything\n\n**Tips:**\n- The posing skeleton becomes visible in editing mode for precise positioning\n- Use **Snap to constraint** to align the current frame to a constraint\n\n### Saving & Loading\n\nYou can save the current constraints or current motion to load in later from the Load/Save menu.\nSaving an **Example** will save the full constraints, motion, and generation metadata.\n\n### Visualization Options\n\nSwitch to the **Visualize** tab to:\n- Toggle mesh and skeleton visibility\n- Adjust mesh opacity\n- Show/hide foot contact indicators\n- Switch between light and dark modes\n\"\"\"\n)\n"
  },
  {
    "path": "kimodo/demo/embedding_cache.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport contextlib\nimport contextvars\nimport hashlib\nimport json\nimport os\nimport threading\nimport time\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\nfrom typing import Iterable, Optional\n\nimport numpy as np\nimport torch\n\nfrom kimodo.sanitize import sanitize_texts\n\n_ACTIVE_SESSION = contextvars.ContextVar(\"kimodo_demo_active_session\", default=None)\n\n\n@dataclass\nclass CacheStats:\n    hits: int = 0\n    misses: int = 0\n    disk_hits: int = 0\n\n\nclass EmbeddingCache:\n    \"\"\"Disk-backed text embedding cache with a small in-memory LRU.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        model_name: str,\n        encoder_id: str,\n        base_dir: Optional[str] = None,\n        max_mem_entries: int = 128,\n    ) -> None:\n        cache_root = base_dir or os.environ.get(\n            \"kimodo_EMBED_CACHE_DIR\",\n            os.path.join(\"~\", \".cache\", \"kimodo_demo\", \"embeddings\"),\n        )\n        self.base_dir = os.path.expanduser(cache_root)\n        self.model_name = model_name\n        self.encoder_id = encoder_id\n        self.max_mem_entries = max_mem_entries\n        self.stats = CacheStats()\n\n        self._lock = threading.Lock()\n        self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict()\n        self._index = {}\n        self._index_loaded = False\n\n    def _model_dir(self) -> str:\n        return os.path.join(self.base_dir, self.model_name)\n\n    def _index_path(self) -> str:\n        return os.path.join(self._model_dir(), \"index.json\")\n\n    def _prewarm_marker_path(self, key: str) -> str:\n        return os.path.join(self._model_dir(), f\"prewarm_{key}.json\")\n\n    def has_prewarm_marker(self, key: str) -> bool:\n        return os.path.exists(self._prewarm_marker_path(key))\n\n    def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None:\n        os.makedirs(self._model_dir(), exist_ok=True)\n        payload = {\"prompt_count\": prompt_count, \"updated_at\": time.time()}\n        tmp_path = f\"{self._prewarm_marker_path(key)}.tmp\"\n        with open(tmp_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(payload, f)\n        os.replace(tmp_path, self._prewarm_marker_path(key))\n\n    def _load_index(self) -> None:\n        if self._index_loaded:\n            return\n        index_path = self._index_path()\n        if os.path.exists(index_path):\n            try:\n                with open(index_path, \"r\", encoding=\"utf-8\") as f:\n                    self._index = json.load(f)\n            except json.JSONDecodeError:\n                self._index = {}\n        self._index_loaded = True\n\n    def _save_index(self) -> None:\n        os.makedirs(self._model_dir(), exist_ok=True)\n        tmp_path = f\"{self._index_path()}.tmp\"\n        with open(tmp_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(self._index, f)\n        os.replace(tmp_path, self._index_path())\n\n    def _make_key(self, text: str) -> str:\n        key_src = f\"{self.model_name}|{self.encoder_id}|{text}\"\n        return hashlib.sha256(key_src.encode(\"utf-8\")).hexdigest()\n\n    def _entry_path(self, key: str) -> str:\n        return os.path.join(self._model_dir(), f\"{key}.npy\")\n\n    def _mem_get(self, key: str) -> Optional[np.ndarray]:\n        if key in self._mem_cache:\n            self._mem_cache.move_to_end(key)\n            return self._mem_cache[key]\n        return None\n\n    def _mem_put(self, key: str, value: np.ndarray) -> None:\n        self._mem_cache[key] = value\n        self._mem_cache.move_to_end(key)\n        while len(self._mem_cache) > self.max_mem_entries:\n            self._mem_cache.popitem(last=False)\n\n    def _disk_load(self, key: str) -> Optional[np.ndarray]:\n        path = self._entry_path(key)\n        if not os.path.exists(path):\n            return None\n        try:\n            return np.load(path)\n        except Exception:\n            return None\n\n    def _disk_save(self, key: str, value: np.ndarray) -> None:\n        os.makedirs(self._model_dir(), exist_ok=True)\n        np.save(self._entry_path(key), value)\n        self._index[key] = {\n            \"length\": int(value.shape[0]),\n            \"dtype\": str(value.dtype),\n            \"updated_at\": time.time(),\n        }\n\n    def _maybe_use_session_cache(self, texts: list[str]):\n        session = _ACTIVE_SESSION.get()\n        if session is None:\n            return None\n        if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None:\n            return session.last_prompt_embeddings, session.last_prompt_lengths\n        return None\n\n    def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None:\n        session = _ACTIVE_SESSION.get()\n        if session is None:\n            return\n        session.last_prompt_texts = texts\n        session.last_prompt_embeddings = tensor\n        session.last_prompt_lengths = lengths\n\n    def get_or_encode(self, texts: Iterable[str], encoder):\n        if isinstance(texts, str):\n            texts = [texts]\n        texts = sanitize_texts(list(texts))\n        if len(texts) == 0:\n            empty = torch.empty()\n            return empty, []\n\n        session_cache = self._maybe_use_session_cache(texts)\n        if session_cache is not None:\n            return session_cache\n\n        arrays: list[Optional[np.ndarray]] = [None] * len(texts)\n        lengths: list[int] = [0] * len(texts)\n        misses: list[tuple[int, str, str]] = []\n\n        with self._lock:\n            self._load_index()\n            for idx, text in enumerate(texts):\n                key = self._make_key(text)\n                cached = self._mem_get(key)\n                if cached is not None:\n                    arrays[idx] = cached\n                    lengths[idx] = cached.shape[0]\n                    self.stats.hits += 1\n                    continue\n\n                cached = self._disk_load(key)\n                if cached is not None:\n                    arrays[idx] = cached\n                    lengths[idx] = cached.shape[0]\n                    self._mem_put(key, cached)\n                    self.stats.disk_hits += 1\n                    continue\n\n                misses.append((idx, text, key))\n                self.stats.misses += 1\n\n        if misses:\n            miss_texts = [text for _, text, _ in misses]\n            miss_tensor, miss_lengths = encoder(miss_texts)\n            miss_tensor = miss_tensor.detach().cpu()\n            miss_tensor_np = miss_tensor.numpy()\n\n            with self._lock:\n                self._load_index()\n                for miss_idx, length in enumerate(miss_lengths):\n                    idx, _text, key = misses[miss_idx]\n                    arr = miss_tensor_np[miss_idx, :length].copy()\n                    arrays[idx] = arr\n                    lengths[idx] = int(length)\n                    self._mem_put(key, arr)\n                    self._disk_save(key, arr)\n                self._save_index()\n\n        max_len = max(lengths) if lengths else 0\n        feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0\n        dtype = arrays[0].dtype if arrays[0] is not None else np.float32\n        padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype)\n        for idx, arr in enumerate(arrays):\n            if arr is None:\n                continue\n            padded[idx, : arr.shape[0]] = arr\n\n        result = torch.from_numpy(padded)\n        self._update_session_cache(texts, result, lengths)\n        return result, lengths\n\n\nclass CachedTextEncoder:\n    \"\"\"Wrapper around a text encoder to add disk-backed caching.\"\"\"\n\n    def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None):\n        self.encoder = encoder\n        self.model_name = model_name\n        encoder_id = f\"{type(encoder).__name__}\"\n        self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir)\n\n    def __call__(self, texts):\n        return self.cache.get_or_encode(texts, self.encoder)\n\n    def prewarm(self, texts) -> None:\n        if isinstance(texts, str):\n            texts = [texts]\n        texts = sanitize_texts(list(texts))\n        prewarm_key = hashlib.sha256(\"|\".join(texts).encode(\"utf-8\")).hexdigest()\n        if self.cache.has_prewarm_marker(prewarm_key):\n            return\n        self.cache.get_or_encode(texts, self.encoder)\n        self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts))\n\n    def to(self, device=None, dtype=None):\n        if hasattr(self.encoder, \"to\"):\n            self.encoder.to(device=device, dtype=dtype)\n        return self\n\n    @contextlib.contextmanager\n    def session_context(self, session):\n        token = _ACTIVE_SESSION.set(session)\n        try:\n            yield\n        finally:\n            _ACTIVE_SESSION.reset(token)\n\n    def __getattr__(self, name):\n        return getattr(self.encoder, name)\n"
  },
  {
    "path": "kimodo/demo/generation.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom collections import defaultdict\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nimport viser\nfrom kimodo.constraints import (\n    TYPE_TO_CLASS,\n    FullBodyConstraintSet,\n    Root2DConstraintSet,\n)\nfrom kimodo.exports.mujoco import apply_g1_real_robot_projection\nfrom kimodo.skeleton import G1Skeleton34, SOMASkeleton30\nfrom kimodo.tools import seed_everything\n\nfrom .embedding_cache import CachedTextEncoder\nfrom .state import ClientSession, ModelBundle\n\n\ndef compute_model_constraints_lst(\n    session: ClientSession,\n    model_bundle: ModelBundle,\n    num_frames: int,\n    device: str,\n):\n    \"\"\"Compute the lst of constraints for the model based on the constraints in viser.\"\"\"\n    assert len(session.motions) == 1, \"Only one motion allowed for constrained generation\"\n    if not session.constraints:\n        return []\n\n    model_skeleton = model_bundle.model.skeleton\n    # For SOMA, UI uses somaskel77; extract 30-joint subset for the model\n    use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints\n    skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None\n\n    dense_smooth_root_pos_2d = None\n    if session.constraints[\"2D Root\"].dense_path:\n        # get the full 2d root\n        dense_smooth_root_pos_2d = session.constraints[\"2D Root\"].get_constraint_info(device=device)[\"root_pos\"][\n            :, [0, 2]\n        ]\n\n    model_constraints = []\n    for track_name, constraint in session.constraints.items():\n        constraint_info = constraint.get_constraint_info(device=device)\n        frame_idx = constraint_info[\"frame_idx\"]\n        # drop any constraints outside the generation range\n        valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames]\n        valid_idx = [i for i, _ in valid_info]\n        valid_frame_idx = [fi for _, fi in valid_info]\n\n        if len(valid_frame_idx) == 0:\n            continue\n\n        frame_indices = torch.tensor(valid_frame_idx)\n        if track_name == \"2D Root\":\n            smooth_root_pos_2d = constraint_info[\"root_pos\"][valid_idx][:, [0, 2]].to(device)\n            # same as \"smooth_root_2d\"\n            model_constraints.append(\n                Root2DConstraintSet(\n                    model_skeleton,\n                    frame_indices,\n                    smooth_root_pos_2d,\n                )\n            )\n        elif track_name == \"Full-Body\":\n            constraint_joints_pos = constraint_info[\"joints_pos\"][valid_idx].to(device)\n            constraint_joints_rot = constraint_info[\"joints_rot\"][valid_idx].to(device)\n            if skel_slice is not None:\n                constraint_joints_pos = constraint_joints_pos[:, skel_slice]\n                constraint_joints_rot = constraint_joints_rot[:, skel_slice]\n\n            smooth_root_pos_2d = None\n            if dense_smooth_root_pos_2d is not None:\n                smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices]\n\n            model_constraints.append(\n                FullBodyConstraintSet(\n                    model_skeleton,\n                    frame_indices,\n                    constraint_joints_pos,\n                    constraint_joints_rot,\n                    smooth_root_2d=smooth_root_pos_2d,\n                )\n            )\n        elif track_name == \"End-Effectors\":\n            constraint_joints_pos = constraint_info[\"joints_pos\"][valid_idx].to(device)\n            constraint_joints_rot = constraint_info[\"joints_rot\"][valid_idx].to(device)\n            if skel_slice is not None:\n                constraint_joints_pos = constraint_joints_pos[:, skel_slice]\n                constraint_joints_rot = constraint_joints_rot[:, skel_slice]\n\n            end_effector_type_set_lst = [\n                end_effector_type_set\n                for i, end_effector_type_set in enumerate(constraint_info[\"end_effector_type\"])\n                if i in valid_idx\n            ]\n\n            # regroup the end effector data by type\n            cls_idx = defaultdict(list)\n            for idx, end_effector_type_set in enumerate(end_effector_type_set_lst):\n                for end_effector_type in end_effector_type_set:\n                    cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx)\n\n            for cls, lst_idx in cls_idx.items():\n                frame_indices_cls = frame_indices[lst_idx]\n                smooth_root_pos_2d = None\n                if dense_smooth_root_pos_2d is not None:\n                    smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls]\n\n                constraint_joints_pos_el = constraint_joints_pos[lst_idx]\n                constraint_joints_rot_el = constraint_joints_rot[lst_idx]\n\n                model_constraints.append(\n                    cls(\n                        model_skeleton,\n                        frame_indices_cls,\n                        constraint_joints_pos_el,\n                        constraint_joints_rot_el,\n                        smooth_root_2d=smooth_root_pos_2d,\n                    )\n                )\n        else:\n            raise ValueError(f\"Unsupported constraint type: {constraint.display_name}\")\n    return model_constraints\n\n\ndef generate(\n    *,\n    client: viser.ClientHandle,\n    session: ClientSession,\n    model_bundle: ModelBundle,\n    prompts: list[str],\n    num_frames: list[int],\n    num_samples: int,\n    seed: int,\n    diffusion_steps: int,\n    cfg_weight: Optional[list[float]] = None,\n    cfg_type: Optional[str] = None,\n    postprocess_parameters: Optional[dict] = None,\n    transitions_parameters: Optional[dict] = None,\n    real_robot_rotations: bool = False,\n    device: str,\n    clear_motions,\n    add_character_motion,\n) -> None:\n    client_id = client.client_id\n    print(\n        f\"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})\"\n    )\n\n    seed_everything(seed)\n\n    model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device)\n    cfg_weight = cfg_weight or [2.0, 2.0]\n    postprocess_parameters = postprocess_parameters or {}\n    transitions_parameters = transitions_parameters or {}\n\n    encoder = getattr(model_bundle.model, \"text_encoder\", None)\n    if isinstance(encoder, CachedTextEncoder):\n        with encoder.session_context(session):\n            pred_joints_output = model_bundle.model(\n                prompts,\n                num_frames,\n                diffusion_steps,\n                multi_prompt=True,\n                constraint_lst=model_constraints,\n                cfg_weight=cfg_weight,\n                num_samples=num_samples,\n                cfg_type=cfg_type,\n                **(postprocess_parameters | transitions_parameters),\n            )  # [B, T, motion_rep_dim]\n    else:\n        pred_joints_output = model_bundle.model(\n            prompts,\n            num_frames,\n            diffusion_steps,\n            multi_prompt=True,\n            constraint_lst=model_constraints,\n            cfg_weight=cfg_weight,\n            num_samples=num_samples,\n            cfg_type=cfg_type,\n            **(postprocess_parameters | transitions_parameters),\n        )  # [B, T, motion_rep_dim]\n\n    joints_pos = pred_joints_output[\"posed_joints\"]  # [B, T, J, 3]\n    joints_rot = pred_joints_output[\"global_rot_mats\"]\n    foot_contacts = pred_joints_output.get(\"foot_contacts\")\n\n    # Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display.\n    if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34):\n        joints_pos, joints_rot = apply_g1_real_robot_projection(\n            session.skeleton,\n            pred_joints_output[\"posed_joints\"],\n            pred_joints_output[\"global_rot_mats\"],\n            clamp_to_limits=True,\n        )\n\n    # Display on characters (callbacks keep this module UI-agnostic).\n    clear_motions(client_id)\n    # Keep one sample centered at the origin so constraints align.\n    spread_factor = 1.0  # meters\n    center_idx = num_samples // 2\n    x_trans = (np.arange(num_samples) - center_idx) * spread_factor\n    for i in range(num_samples):\n        cur_joints_pos = joints_pos[i]\n        cur_joints_pos[..., 0] += x_trans[i]\n        add_character_motion(\n            client,\n            session.skeleton,\n            cur_joints_pos,\n            joints_rot[i],\n            foot_contacts[i],\n        )\n"
  },
  {
    "path": "kimodo/demo/queue_manager.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"HF mode user queue and session time limit.\"\"\"\n\nimport math\nimport threading\nimport time\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport viser\n\nfrom .config import DEMO_UI_QUICK_START_MODAL_MD, MAX_SESSION_MINUTES\n\n# Link for \"Duplicate this Space\" on Hugging Face (used in queue and expiry modals).\nDUPLICATE_SPACE_URL = \"https://huggingface.co/spaces/nvidia/Kimodo?duplicate=true\"\nGITHUB_REPO_URL = \"https://github.com/nv-tlabs/kimodo\"\n\n# How often to refresh queue modal content (position, total, estimated wait).\nQUEUE_MODAL_REFRESH_INTERVAL_SEC = 15\n\n\nclass UserQueue:\n    \"\"\"Thread-safe queue: active users (with activation timestamp) and waiting queue.\"\"\"\n\n    def __init__(self, max_active: int, max_minutes: float) -> None:\n        self._max_active = max_active\n        self._max_minutes = max_minutes\n        self._max_seconds = max_minutes * 60.0\n        self._active: dict[int, float] = {}  # client_id -> activation timestamp\n        self._queued: list[int] = []\n        self._lock = threading.Lock()\n\n    def try_activate(self, client_id: int) -> bool:\n        \"\"\"If a slot is free, add client as active and return True.\n\n        Else return False.\n        \"\"\"\n        with self._lock:\n            if len(self._active) < self._max_active:\n                self._active[client_id] = time.time()\n                return True\n            return False\n\n    def enqueue(self, client_id: int) -> None:\n        with self._lock:\n            if client_id not in self._queued:\n                self._queued.append(client_id)\n\n    def remove(self, client_id: int) -> bool:\n        \"\"\"Remove from active or queue.\n\n        Returns True if was active.\n        \"\"\"\n        with self._lock:\n            was_active = client_id in self._active\n            self._active.pop(client_id, None)\n            if client_id in self._queued:\n                self._queued.remove(client_id)\n            return was_active\n\n    def promote_next(self) -> int | None:\n        \"\"\"If queue non-empty, pop first, activate them, return their client_id.\n\n        Else None.\n        \"\"\"\n        with self._lock:\n            if not self._queued:\n                return None\n            client_id = self._queued.pop(0)\n            self._active[client_id] = time.time()\n            return client_id\n\n    def get_queue_position(self, client_id: int) -> tuple[int, int] | None:\n        \"\"\"(1-based position, total_in_queue) or None if not queued.\"\"\"\n        with self._lock:\n            if client_id not in self._queued:\n                return None\n            pos = self._queued.index(client_id)\n            return (pos + 1, len(self._queued))\n\n    def get_estimated_wait_seconds(self, client_id: int) -> float:\n        \"\"\"Estimated seconds until this queued client gets a slot.\"\"\"\n        with self._lock:\n            if client_id not in self._queued:\n                return 0.0\n            pos = self._queued.index(client_id) + 1  # 1-based\n            # Expiry times of active users (when they free a slot)\n            now = time.time()\n            expiries = sorted(now + self._max_seconds - (now - t) for t in self._active.values())\n            if not expiries:\n                return 0.0\n            # Nth slot to free (1-indexed) wraps over expiries\n            idx = (pos - 1) % len(expiries)\n            cycles = (pos - 1) // len(expiries)\n            slot_free_time = expiries[idx] + cycles * self._max_seconds\n            return max(0.0, slot_free_time - now)\n\n    def is_active(self, client_id: int) -> bool:\n        with self._lock:\n            return client_id in self._active\n\n    def was_active(self, client_id: int) -> bool:\n        \"\"\"True if client is currently active (for use when already holding lock).\"\"\"\n        return client_id in self._active\n\n\ndef _format_wait(seconds: float) -> str:\n    if seconds < 60:\n        return \"less than a minute\"\n    mins = int(math.ceil(seconds / 60))\n    return f\"~{mins} minute{'s' if mins != 1 else ''}\"\n\n\ndef _queue_modal_markdown(position: int, total: int, estimated_wait_sec: float) -> str:\n    wait_str = _format_wait(estimated_wait_sec)\n    mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES\n    return f\"\"\"## Kimodo Demo — Please Wait\n\nThis demo runs with limited capacity.\nEach user gets **{mins} minute{\"s\" if mins != 1 else \"\"}** of interactive time.\n\n**Your position in queue:** {position} / {total}\n\n**Estimated wait:** {wait_str}\n\nPlease keep this tab open — the demo will start automatically when it's your turn.\n\n---\n*Want unlimited access? [Duplicate this Space]({DUPLICATE_SPACE_URL}) or clone the [GitHub repo]({GITHUB_REPO_URL}) to run locally!*\n\"\"\"\n\n\ndef _welcome_modal_markdown() -> str:\n    mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES\n    return f\"\"\"## Welcome to Kimodo Demo\n\nYou have been granted a **{mins}-minute** demo session.\nYour session timer has started.\n\nClick the button below to begin!\n\"\"\"\n\n\ndef _expiry_modal_markdown() -> str:\n    mins = int(MAX_SESSION_MINUTES) if MAX_SESSION_MINUTES == int(MAX_SESSION_MINUTES) else MAX_SESSION_MINUTES\n    return f\"\"\"## Session Expired\n\nYour {mins}-minute demo session has ended.\nThank you for trying Kimodo!\n\nRefresh this page to rejoin the queue, or [duplicate this Space]({DUPLICATE_SPACE_URL}) for unlimited access.\n\"\"\"\n\n\nclass QueueManager:\n    \"\"\"Orchestrates HF mode: queue modals, welcome modal, session timer, promotion.\"\"\"\n\n    def __init__(\n        self,\n        queue: UserQueue,\n        server: viser.ViserServer,\n        setup_demo_for_client: Callable[[viser.ClientHandle], None],\n        cleanup_session: Callable[[int], None],\n    ) -> None:\n        self._queue = queue\n        self._server = server\n        self._setup_demo_for_client = setup_demo_for_client\n        self._cleanup_session = cleanup_session\n        self._max_seconds = queue._max_seconds\n\n        self._queue_modal_handles: dict[int, tuple[Any, Any]] = {}\n        self._welcome_modal_handles: dict[int, Any] = {}\n        self._expiry_timers: dict[int, threading.Timer] = {}\n        self._lock = threading.Lock()\n        self._refresh_stop = threading.Event()\n        self._refresh_thread = threading.Thread(\n            target=self._queue_modal_refresh_loop,\n            name=\"queue-modal-refresh\",\n            daemon=True,\n        )\n        self._refresh_thread.start()\n\n    def _queue_modal_refresh_loop(self) -> None:\n        \"\"\"Periodically refresh queue modals so position, total, and estimated wait stay current.\"\"\"\n        while not self._refresh_stop.wait(timeout=QUEUE_MODAL_REFRESH_INTERVAL_SEC):\n            self._update_all_queue_modals()\n\n    def on_client_connect(self, client: viser.ClientHandle) -> None:\n        \"\"\"Handle new connection: activate if slot free, else enqueue and show queue modal.\"\"\"\n        client_id = client.client_id\n        if self._queue.try_activate(client_id):\n            try:\n                self._setup_demo_for_client(client)\n            except RuntimeError as e:\n                if \"CUDA error\" in str(e):\n                    print(f\"CUDA error while setting up client {client_id}: {e}\")\n                    return\n                raise\n            self._start_session_timer(client_id)\n            self._show_welcome_modal(client)\n        else:\n            self._queue.enqueue(client_id)\n            self._show_queue_modal(client)\n            self._update_all_queue_modals()\n\n    def on_client_disconnect(self, client_id: int) -> None:\n        \"\"\"Remove from queue/active, cancel timer, promote next if was active.\n\n        Session/scene cleanup is done by the demo's on_client_disconnect.\n        \"\"\"\n        with self._lock:\n            self._expiry_timers.pop(client_id, None)\n            self._queue_modal_handles.pop(client_id, None)\n            self._welcome_modal_handles.pop(client_id, None)\n        was_active = self._queue.remove(client_id)\n        if was_active:\n            self._promote_next_user()\n        else:\n            self._update_all_queue_modals()\n\n    def _show_queue_modal(self, client: viser.ClientHandle) -> None:\n        client_id = client.client_id\n        pos, total = self._queue.get_queue_position(client_id) or (0, 0)\n        wait_sec = self._queue.get_estimated_wait_seconds(client_id)\n        md_content = _queue_modal_markdown(pos, total, wait_sec)\n\n        modal = client.gui.add_modal(\n            \"Kimodo Demo — Please Wait\",\n            size=\"xl\",\n            show_close_button=False,\n        )\n        with modal:\n            md_handle = client.gui.add_markdown(md_content)\n        with self._lock:\n            self._queue_modal_handles[client_id] = (modal, md_handle)\n\n    def _show_quick_start_modal(self, client: viser.ClientHandle) -> None:\n        \"\"\"Show the quick start instructions modal (same as non-HF mode).\"\"\"\n        with client.gui.add_modal(\n            \"Welcome — Quick Start\",\n            size=\"xl\",\n            show_close_button=True,\n            save_choice=\"kimodo.demo.quick_start_ack\",\n        ) as quick_start_modal:\n            client.gui.add_markdown(DEMO_UI_QUICK_START_MODAL_MD)\n            client.gui.add_button(\"Got it (don't remind me again)\").on_click(lambda _: quick_start_modal.close())\n\n    def _show_welcome_modal(self, client: viser.ClientHandle) -> None:\n        client_id = client.client_id\n\n        def _on_start_demo(_: Any) -> None:\n            modal.close()\n            self._show_quick_start_modal(client)\n\n        modal = client.gui.add_modal(\n            \"Welcome to Kimodo Demo\",\n            size=\"xl\",\n            show_close_button=True,\n        )\n        with modal:\n            client.gui.add_markdown(_welcome_modal_markdown())\n            client.gui.add_button(\"Start Demo\").on_click(_on_start_demo)\n        with self._lock:\n            self._welcome_modal_handles[client_id] = modal\n\n    def _update_all_queue_modals(self) -> None:\n        with self._lock:\n            handles = list(self._queue_modal_handles.items())\n        for client_id, (modal, md_handle) in handles:\n            pos_total = self._queue.get_queue_position(client_id)\n            if pos_total is None:\n                continue\n            pos, total = pos_total\n            wait_sec = self._queue.get_estimated_wait_seconds(client_id)\n            try:\n                md_handle.content = _queue_modal_markdown(pos, total, wait_sec)\n            except Exception:\n                pass\n\n    def _promote_next_user(self) -> None:\n        promoted_id = self._queue.promote_next()\n        if promoted_id is None:\n            return\n        clients = self._server.get_clients()\n        client = clients.get(promoted_id)\n        if client is None:\n            return\n        with self._lock:\n            old = self._queue_modal_handles.pop(promoted_id, None)\n        if old is not None:\n            try:\n                old[0].close()\n            except Exception:\n                pass\n        try:\n            self._setup_demo_for_client(client)\n        except RuntimeError as e:\n            if \"CUDA error\" in str(e):\n                print(f\"CUDA error while setting up client {promoted_id}: {e}\")\n                return\n            raise\n        self._start_session_timer(promoted_id)\n        self._show_welcome_modal(client)\n        self._update_all_queue_modals()\n\n    def _start_session_timer(self, client_id: int) -> None:\n        def on_expiry() -> None:\n            self._on_session_expired(client_id)\n\n        t = threading.Timer(self._max_seconds, on_expiry)\n        t.daemon = True\n        with self._lock:\n            self._expiry_timers[client_id] = t\n        t.start()\n\n    def _on_session_expired(self, client_id: int) -> None:\n        with self._lock:\n            self._expiry_timers.pop(client_id, None)\n        if not self._queue.is_active(client_id):\n            return\n        self._queue.remove(client_id)\n        clients = self._server.get_clients()\n        client = clients.get(client_id)\n        if client is not None:\n            try:\n                with client.gui.add_modal(\n                    \"Session Expired\",\n                    size=\"lg\",\n                    show_close_button=False,\n                ) as modal_ctx:\n                    client.gui.add_markdown(_expiry_modal_markdown())\n            except Exception:\n                pass\n        self._cleanup_session(client_id)\n        self._promote_next_user()\n"
  },
  {
    "path": "kimodo/demo/state.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom dataclasses import dataclass, field\nfrom typing import Optional\n\nimport torch\n\nimport kimodo.viz.viser_utils as viser_utils\nimport viser\nfrom kimodo.skeleton import SkeletonBase\nfrom kimodo.viz.viser_utils import GuiElements\n\nfrom .config import (\n    DEFAULT_CUR_DURATION,\n    DEFAULT_MODEL,\n    DEFAULT_PLAYBACK_SPEED,\n)\n\n\n@dataclass(frozen=True)\nclass ModelBundle:\n    model: object\n    motion_rep: object\n    skeleton: SkeletonBase\n    model_fps: float\n\n\n@dataclass\nclass ClientSession:\n    \"\"\"Per-client session data.\"\"\"\n\n    client: viser.ClientHandle\n    gui_elements: GuiElements\n    motions: dict  # character_name -> CharacterMotion\n    constraints: dict[str, viser_utils.ConstraintSet] = field(default_factory=dict)\n    timeline_data: object = None\n    frame_idx: int = 0\n    playing: bool = False\n    playback_speed: float = DEFAULT_PLAYBACK_SPEED\n    cur_duration: float = DEFAULT_CUR_DURATION\n    max_frame_idx: int = 100  # will be updated based on model_fps\n    updating_motions: bool = False\n    edit_mode: bool = False\n    model_name: str = DEFAULT_MODEL\n    model_fps: float = 0.0\n    skeleton: SkeletonBase | None = None\n    motion_rep: object | None = None\n    examples_base_dir: str = \"\"\n    example_dict: dict[str, str] = field(default_factory=dict)\n    gui_examples_dropdown: Optional[viser.GuiInputHandle] = None\n    gui_save_example_path_text: Optional[viser.GuiInputHandle] = None\n    gui_model_selector: Optional[viser.GuiInputHandle] = None\n    last_prompt_texts: Optional[list[str]] = None\n    last_prompt_embeddings: Optional[torch.Tensor] = None\n    last_prompt_lengths: Optional[list[int]] = None\n    edit_mode_snapshot: Optional[dict[int, dict[str, object]]] = None\n    undo_drag_snapshot: Optional[dict[str, object]] = None\n    show_only_current_constraint: bool = False  # False = Show All, True = Show only Current\n"
  },
  {
    "path": "kimodo/demo/ui.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\n# ruff: noqa: I001\nimport math\nimport os\nimport threading\nfrom typing import Optional\n\nfrom kimodo.constraints import load_constraints_lst, save_constraints_lst\nfrom kimodo.exports.bvh import motion_to_bvh_bytes, save_motion_bvh\nfrom kimodo.exports.motion_io import (\n    amass_npz_to_bytes,\n    g1_csv_to_bytes,\n    kimodo_npz_to_bytes,\n    load_motion_file,\n    save_kimodo_npz,\n)\nfrom kimodo.model.registry import kimodo_short_key_for_skeleton_dataset, registry_skeleton_for_joint_count\nfrom kimodo.tools import to_torch\nfrom kimodo.viz import viser_utils\nfrom kimodo.viz.viser_utils import GuiElements\nimport numpy as np\nimport torch\nimport viser\nfrom viser._timeline_api import PROMPT_COLORS\n\nfrom . import generation\nfrom .config import (\n    DEFAULT_CUR_DURATION,\n    DEMO_UI_INSTRUCTIONS_TAB_MD,\n    get_datasets,\n    get_model_info,\n    get_models_for_dataset_skeleton,\n    get_skeleton_display_name,\n    get_skeleton_display_names_for_dataset,\n    get_skeleton_key_from_display_name,\n    get_short_key_from_display_name,\n    HF_MODE,\n    INIT_POSTPROCESSING,\n    MODEL_NAMES,\n    NB_TRANSITION_FRAMES,\n    SHOW_TRANSITION_PARAMS,\n)\nfrom .state import ClientSession\nfrom kimodo.skeleton import G1Skeleton34, SOMASkeleton30, SOMASkeleton77\n\n\ndef extract_intervals_and_singles(t: torch.Tensor):\n    intervals = []\n    intervals_indices = []\n    single_frames = []\n    single_frames_indices = []\n\n    start_idx = 0\n\n    for i in range(1, len(t) + 1):\n        # End of run if:\n        #  - end of tensor\n        #  - non-consecutive value\n        if i == len(t) or t[i] != t[i - 1] + 1:\n            run_length = i - start_idx\n\n            if run_length >= 2:\n                intervals.append((int(t[start_idx]), int(t[i - 1])))\n                intervals_indices.append((start_idx, i - 1))\n            else:\n                single_frames.append(int(t[start_idx]))\n                single_frames_indices.append(start_idx)\n\n            start_idx = i\n\n    return intervals, intervals_indices, single_frames, single_frames_indices\n\n\ndef create_gui(\n    demo,\n    client: viser.ClientHandle,\n    model_name: str,\n    model_fps: float,\n):\n    \"\"\"Create GUI elements for a specific client.\"\"\"\n    client_id = client.client_id\n\n    def get_active_session(event_client: viser.ClientHandle | None):\n        if event_client is None:\n            return None\n        if not demo.client_active(event_client.client_id):\n            return None\n        return demo.client_sessions[event_client.client_id]\n\n    def build_timeline_tracks():\n        timeline = client.timeline\n        demo.set_timeline_defaults(timeline, model_fps)\n        timeline.set_visible(True)\n        timeline.set_current_frame(0)\n\n        timeline_tracks = {}\n        fullbody_id = timeline.add_track(\n            \"Full-Body\",\n            track_type=\"keyframe\",\n            color=(219, 148, 86),\n            height_scale=0.5,\n        )\n        timeline_tracks[fullbody_id] = {\n            \"name\": \"Full-Body\",\n            \"track_type\": \"keyframe\",\n            \"color\": (219, 148, 86),\n            \"height_scale\": 0.5,\n        }\n\n        root2d_id = timeline.add_track(\n            \"2D Root\",\n            track_type=\"keyframe\",\n            color=(150, 100, 200),\n            height_scale=0.5,\n        )\n        timeline_tracks[root2d_id] = {\n            \"name\": \"2D Root\",\n            \"track_type\": \"keyframe\",\n            \"color\": (150, 100, 200),\n            \"height_scale\": 0.5,\n        }\n        lefthand_id = timeline.add_track(\n            \"Left Hand\",\n            track_type=\"keyframe\",\n            color=(100, 200, 150),\n            height_scale=0.5,\n        )\n        timeline_tracks[lefthand_id] = {\n            \"name\": \"Left Hand\",\n            \"track_type\": \"keyframe\",\n            \"color\": (100, 200, 150),\n            \"height_scale\": 0.5,\n        }\n        righthand_id = timeline.add_track(\n            \"Right Hand\",\n            track_type=\"keyframe\",\n            color=(200, 100, 150),\n            height_scale=0.5,\n        )\n        timeline_tracks[righthand_id] = {\n            \"name\": \"Right Hand\",\n            \"track_type\": \"keyframe\",\n            \"color\": (200, 100, 150),\n            \"height_scale\": 0.5,\n        }\n        leftfoot_id = timeline.add_track(\n            \"Left Foot\",\n            track_type=\"keyframe\",\n            color=(219, 148, 86),\n            height_scale=0.5,\n        )\n        timeline_tracks[leftfoot_id] = {\n            \"name\": \"Left Foot\",\n            \"track_type\": \"keyframe\",\n            \"color\": (219, 148, 86),\n            \"height_scale\": 0.5,\n        }\n        rightfoot_id = timeline.add_track(\n            \"Right Foot\",\n            track_type=\"keyframe\",\n            color=(150, 100, 200),\n            height_scale=0.5,\n        )\n        timeline_tracks[rightfoot_id] = {\n            \"name\": \"Right Foot\",\n            \"track_type\": \"keyframe\",\n            \"color\": (150, 100, 200),\n            \"height_scale\": 0.5,\n        }\n        return timeline, timeline_tracks\n\n    timeline, timeline_tracks = build_timeline_tracks()\n    # These handles are part of GuiElements, but the demo currently uses timeline + buttons\n    # embedded in the Viser UI instead of custom controls.\n    gui_play_pause_button = None\n    gui_next_frame_button = None\n    gui_prev_frame_button = None\n    gui_timeline = None\n    gui_duration_slider = None\n\n    # now other gui elements\n    tab_group = client.gui.add_tab_group()\n\n    #\n    # Playback and Motion generation controls\n    #\n    with tab_group.add_tab(\"Generate\", viser.Icon.WALK):\n        with client.gui.add_folder(\"Model Selection\", expand_by_default=True):\n            info = get_model_info(model_name)\n            if info is None:\n                info = get_model_info(next(iter(MODEL_NAMES)))\n\n            def get_allowed_skeleton_labels(dataset_ui_label: str) -> list[str]:\n                labels = get_skeleton_display_names_for_dataset(dataset_ui_label, family=\"Kimodo\")\n                if HF_MODE:\n                    labels = [label for label in labels if get_skeleton_key_from_display_name(label) != \"SMPLX\"]\n                return labels\n\n            dataset_ui_label = \"Rigplay\" if HF_MODE else info.dataset_ui_label\n            datasets = [\"Rigplay\"] if HF_MODE else get_datasets(family=\"Kimodo\")\n            skeleton_labels = get_allowed_skeleton_labels(dataset_ui_label)\n            initial_skeleton_label = get_skeleton_display_name(info.skeleton)\n            if initial_skeleton_label not in skeleton_labels and skeleton_labels:\n                initial_skeleton_label = skeleton_labels[0]\n            initial_skeleton_key = (\n                get_skeleton_key_from_display_name(initial_skeleton_label) if skeleton_labels else None\n            )\n            models_for_pair = (\n                get_models_for_dataset_skeleton(dataset_ui_label, initial_skeleton_key, family=\"Kimodo\")\n                if initial_skeleton_key is not None\n                else []\n            )\n            version_options = [m.display_name for m in models_for_pair]\n            initial_version = (\n                info.display_name\n                if info.display_name in version_options\n                else (version_options[0] if version_options else \"\")\n            )\n            gui_dataset_selector = client.gui.add_dropdown(\n                \"Training dataset\",\n                options=datasets,\n                initial_value=dataset_ui_label,\n                visible=not HF_MODE,\n            )\n            gui_skeleton_selector = client.gui.add_dropdown(\n                \"Model\" if HF_MODE else \"Skeleton\",\n                options=skeleton_labels,\n                initial_value=initial_skeleton_label,\n            )\n            gui_version_selector = client.gui.add_dropdown(\n                \"Version\",\n                options=version_options,\n                initial_value=initial_version,\n            )\n            gui_version_selector.visible = len(models_for_pair) > 1\n            gui_model_display = client.gui.add_markdown(\n                content=f\"**Model:** {initial_version}\",\n            )\n            gui_load_model_button = client.gui.add_button(\n                \"Load model\",\n                hint=\"Load the selected model (dataset, skeleton, version).\",\n            )\n\n            class ModelSelectorHandle:\n                \"\"\"Wrapper so session and callbacks can treat three dropdowns as one.\"\"\"\n\n                def __init__(self):\n                    self._dataset = gui_dataset_selector\n                    self._skeleton = gui_skeleton_selector\n                    self._version = gui_version_selector\n                    self._display = gui_model_display\n\n                @property\n                def value(self) -> str:\n                    return get_short_key_from_display_name(self._version.value) or \"\"\n\n                def set_from_short_key(self, short_key: str) -> None:\n                    info = get_model_info(short_key)\n                    if info is None:\n                        return\n                    dataset_ui_label = \"Rigplay\" if HF_MODE else info.dataset_ui_label\n                    self._dataset.value = dataset_ui_label\n                    self._skeleton.options = get_allowed_skeleton_labels(dataset_ui_label)\n                    skeleton_label = get_skeleton_display_name(info.skeleton)\n                    if skeleton_label not in self._skeleton.options and self._skeleton.options:\n                        skeleton_label = self._skeleton.options[0]\n                    self._skeleton.value = skeleton_label\n                    skeleton_key = get_skeleton_key_from_display_name(skeleton_label)\n                    if skeleton_key is None:\n                        return\n                    models = get_models_for_dataset_skeleton(dataset_ui_label, skeleton_key, family=\"Kimodo\")\n                    self._version.options = [m.display_name for m in models]\n                    self._version.value = (\n                        info.display_name if info.display_name in self._version.options else self._version.options[0]\n                    )\n                    self._version.visible = len(models) > 1\n                    self._display.content = f\"**Model:** {self._version.value}\"\n\n            gui_model_selector = ModelSelectorHandle()\n\n        with client.gui.add_folder(\"Examples\", expand_by_default=True):\n            examples_base_dir = demo.get_examples_base_dir(model_name, absolute=True)\n            example_dict = viser_utils.load_example_cases(examples_base_dir)\n            example_names = list(example_dict.keys())\n            if not example_names:\n                example_names = [\"<no examples>\"]\n            gui_examples_dropdown = client.gui.add_dropdown(\n                \"Example\",\n                options=example_names,\n                initial_value=example_names[0],\n            )\n            gui_load_example_button = client.gui.add_button(\n                \"Load Example\",\n                hint=\"Load the selected example.\",\n                disabled=not example_dict,\n            )\n\n            def update_examples_dropdown(\n                new_example_dict: dict[str, str],\n                keep_selection: bool = True,\n            ) -> None:\n                if not new_example_dict:\n                    gui_examples_dropdown.options = [\"<no examples>\"]\n                    gui_examples_dropdown.value = \"<no examples>\"\n                    gui_load_example_button.disabled = True\n                    return\n                gui_load_example_button.disabled = False\n                example_names_local = list(new_example_dict.keys())\n                gui_examples_dropdown.options = example_names_local\n                if keep_selection and gui_examples_dropdown.value in example_names_local:\n                    return\n                gui_examples_dropdown.value = example_names_local[0]\n\n        with client.gui.add_folder(\"Generate\", expand_by_default=True):\n            gui_duration = client.gui.add_markdown(content=f\"Total duration: {DEFAULT_CUR_DURATION:.1f} (sec)\")\n\n            def update_duration_gui(duration):\n                gui_duration.content = f\"Total duration: {duration:.1f} (sec)\"\n\n            def compute_prompt_num_frames(prompt_values):\n                \"\"\"Convert timeline prompt bounds to per-prompt frame counts.\n\n                Convention in this demo:\n                - All prompts except the last are treated as [start_frame, end_frame)\n                  (end is exclusive).\n                - The last prompt is treated as [start_frame, end_frame] (end is inclusive).\n                - This assumes the prompts values are sorted by start_frame.\n                \"\"\"\n                if len(prompt_values) == 0:\n                    return []\n                num_frames = []\n                for i, x in enumerate(prompt_values):\n                    cur = x.end_frame - x.start_frame\n                    if i == len(prompt_values) - 1:\n                        cur += 1\n                    num_frames.append(cur)\n                return num_frames\n\n            def update_duration_auto():\n                session = demo.client_sessions[client_id]\n                prompt_values = sorted(\n                    [x for x in timeline._prompts.values()],\n                    key=lambda x: x.start_frame,\n                )\n                num_frames = compute_prompt_num_frames(prompt_values)\n                total_nb_frames = sum(num_frames)\n                cur_duration = total_nb_frames / session.model_fps\n                set_new_duration(client_id, cur_duration)\n                update_duration_gui(cur_duration)\n\n            gui_num_samples_slider = client.gui.add_slider(\n                \"Num Samples\",\n                min=1,\n                max=10,\n                step=1,\n                initial_value=1,\n                visible=not HF_MODE,\n            )\n\n            gui_use_soma_layer_checkbox = client.gui.add_checkbox(\n                \"SOMA layer\",\n                initial_value=False,\n                visible=\"soma\" in (model_name or \"\"),\n            )\n\n            with client.gui.add_folder(\"Model Parameters\", expand_by_default=False):\n                gui_seed = client.gui.add_number(\"Seed\", initial_value=42)\n\n                with client.gui.add_folder(\"Diffusion\", expand_by_default=False):\n                    gui_diffusion_steps_slider = client.gui.add_slider(\n                        \"Denoising Steps\",\n                        min=2,\n                        max=1000,\n                        step=10,\n                        initial_value=100,\n                    )\n                with client.gui.add_folder(\"Classifier-Free Guidance\", expand_by_default=False):\n                    gui_cfg_checkbox = client.gui.add_checkbox(\n                        \"Enable\",\n                        initial_value=True,\n                        visible=True,\n                    )\n\n                    gui_cfg_text_weight_slider = client.gui.add_slider(\n                        \"Text Weight\",\n                        min=0.0,\n                        max=5.0,\n                        step=0.1,\n                        initial_value=2.0,\n                        visible=True,\n                    )\n                    gui_cfg_constraint_weight_slider = client.gui.add_slider(\n                        \"Constraint Weight\",\n                        min=0.0,\n                        max=5.0,\n                        step=0.1,\n                        initial_value=2.0,\n                        visible=True,\n                    )\n                with client.gui.add_folder(\n                    \"Transitions\",\n                    expand_by_default=False,\n                    visible=SHOW_TRANSITION_PARAMS,\n                ):\n                    gui_num_transition_frames_slider = client.gui.add_slider(\n                        \"Transition frames\",\n                        min=1,\n                        max=10,\n                        step=1,\n                        initial_value=NB_TRANSITION_FRAMES,\n                        visible=True,\n                    )\n\n            with client.gui.add_folder(\"Post Processing\", expand_by_default=False):\n                _model_name = model_name or \"\"\n                _postprocess_visible = \"g1\" not in _model_name\n                gui_postprocess_checkbox = client.gui.add_checkbox(\n                    \"Enable\",\n                    initial_value=INIT_POSTPROCESSING,\n                    hint=\"Apply motion post-processing (not available for G1)\",\n                    visible=_postprocess_visible,\n                )\n                gui_root_margin = client.gui.add_number(\n                    \"Root Margin\",\n                    min=0.0,\n                    # max=0.5,\n                    step=0.01,\n                    initial_value=0.04,\n                    hint=\"Margin for root position (meters). Lower values pin root closer to target.\",\n                    visible=INIT_POSTPROCESSING and _postprocess_visible,\n                )\n\n                @gui_postprocess_checkbox.on_update\n                def _(event: viser.GuiEvent) -> None:\n                    if get_active_session(event.client) is None:\n                        return\n                    # disable the slider if sharing transition is False\n                    gui_root_margin.visible = gui_postprocess_checkbox.value\n\n                gui_real_robot_rotations_checkbox = client.gui.add_checkbox(\n                    \"Real robot rotations\",\n                    initial_value=False,\n                    hint=\"Project joint rotations to G1 real robot DoF (1-DoF per joint) and clamp to axis limits from the MuJoCo XML.\",\n                    visible=\"g1\" in _model_name,\n                )\n\n            gui_generate_button = client.gui.add_button(\"Generate\", color=\"green\")\n        with client.gui.add_folder(\"Constraints\", expand_by_default=False):\n            gui_gizmo_space_dropdown = client.gui.add_dropdown(\n                \"Gizmo space\",\n                (\"Local\", \"World\"),\n                initial_value=\"Local\",\n                visible=\"g1\" not in _model_name,\n            )\n            gui_edit_constraint_button = client.gui.add_button(\"Enter Editing Mode\")\n            gui_snap_to_constraint_button = client.gui.add_button(\n                \"Snap to Constraint\",\n                disabled=True,\n            )\n            gui_reset_constraint_button = client.gui.add_button(\n                \"Reset Constraint\",\n                disabled=True,\n            )\n            gui_undo_drag_button = client.gui.add_button(\n                \"Undo Move\",\n                disabled=True,\n            )\n\n            with client.gui.add_folder(\"Root 2D Options\", expand_by_default=True):\n                gui_dense_path_checkbox = client.gui.add_checkbox(\n                    \"Make Smooth Path\",\n                    initial_value=False,\n                    visible=True,\n                )\n\n            gui_show_only_current_constraint_checkbox = client.gui.add_checkbox(\n                \"Show only Current\",\n                initial_value=False,\n                hint=\"Show only constraint overlays at the current frame; uncheck to show all.\",\n            )\n\n            def apply_constraint_overlay_visibility(session: ClientSession) -> None:\n                demo._apply_constraint_overlay_visibility(session)\n\n            @gui_show_only_current_constraint_checkbox.on_update\n            def _(event: viser.GuiEvent) -> None:\n                session = get_active_session(event.client)\n                if session is None:\n                    return\n                session.show_only_current_constraint = gui_show_only_current_constraint_checkbox.value\n                apply_constraint_overlay_visibility(session)\n\n            gui_clear_all_constraints_button = client.gui.add_button(\n                \"Clear All Constraints\",\n                color=\"red\",\n            )\n\n            def has_constraint_at_frame(session: ClientSession, frame_idx: int) -> bool:\n                for constraint_name in [\"Full-Body\", \"End-Effectors\", \"2D Root\"]:\n                    constraint = session.constraints.get(constraint_name)\n                    if constraint is None:\n                        continue\n                    if frame_idx in constraint.keyframes:\n                        return True\n                return False\n\n            def update_snap_to_constraint_button(session: ClientSession) -> None:\n                gui_snap_to_constraint_button.disabled = not has_constraint_at_frame(session, session.frame_idx)\n\n            def ensure_edit_snapshot(session: ClientSession, motion, frame_idx: int) -> None:\n                if session.edit_mode_snapshot is None:\n                    session.edit_mode_snapshot = {}\n                if frame_idx in session.edit_mode_snapshot:\n                    return\n                session.edit_mode_snapshot[frame_idx] = {\n                    \"joints_pos\": motion.get_joints_pos(frame_idx),\n                    \"joints_rot\": motion.get_joints_rot(frame_idx),\n                }\n\n            def _update_dense_path(motion, session):\n                constraint_info = session.constraints[\"2D Root\"].get_constraint_info()\n\n                if len(constraint_info[\"frame_idx\"]) > 0:\n                    min_root_frame = min(constraint_info[\"frame_idx\"])\n                    max_root_frame = max(constraint_info[\"frame_idx\"])\n                    motion.set_projected_root_pos_path(\n                        constraint_info[\"root_pos\"][:, [0, 2]],\n                        min_frame_idx=min_root_frame,\n                        max_frame_idx=max_root_frame,\n                    )\n\n            # Delay (ms) after last keyframe/interval move before updating path = \"on release\".\n            DENSE_PATH_AFTER_RELEASE_MS = 300\n\n            def _schedule_dense_path_after_release(session):\n                \"\"\"Schedule a single path update to run after user stops dragging.\"\"\"\n                if \"2D Root\" not in session.constraints or not session.constraints[\"2D Root\"].dense_path:\n                    return\n                tdata = session.timeline_data\n                if tdata.get(\"dense_path_after_release_timer\"):\n                    tdata[\"dense_path_after_release_timer\"].cancel()\n                delay = DENSE_PATH_AFTER_RELEASE_MS / 1000.0\n\n                def run():\n                    if not demo.client_active(client_id):\n                        return\n                    sess = demo.client_sessions[client_id]\n                    tdata[\"dense_path_after_release_timer\"] = None\n                    if \"2D Root\" not in sess.constraints or not sess.constraints[\"2D Root\"].dense_path:\n                        return\n                    mot = list(sess.motions.values())[0]\n                    _update_dense_path(mot, sess)\n\n                t = threading.Timer(delay, run)\n                tdata[\"dense_path_after_release_timer\"] = t\n                t.start()\n\n            @gui_dense_path_checkbox.on_update\n            def _(event: viser.GuiEvent) -> None:\n                session = get_active_session(event.client)\n                if session is None:\n                    return\n\n                if gui_dense_path_checkbox.value:\n                    # Make sure 0 and max_frame_idx keyframes are added to the constraint\n                    # since dense path should cover full duration for best model performance\n                    root_2d_track = session.timeline_data[\"tracks_ids\"][\"2D Root\"]\n\n                    # add a locked keyframe at 0\n                    start_keyframe_id = client.timeline.add_locked_keyframe(  # noqa\n                        root_2d_track,\n                        0,\n                        opacity=0.0,\n                    )\n                    session.timeline_data[\"keyframes\"][start_keyframe_id] = {\n                        \"frame\": 0,\n                        \"track_id\": root_2d_track,\n                        \"locked\": True,\n                        \"opacity\": 0.0,\n                        \"value\": None,\n                    }\n                    add_constraint_callback(\n                        start_keyframe_id,\n                        \"2D Root\",\n                        (0, 0),\n                        verbose=False,\n                    )\n\n                    # add a locked keyframe at max_frame_idx\n                    end_keyframe_id = client.timeline.add_locked_keyframe(\n                        root_2d_track,\n                        session.max_frame_idx,\n                        opacity=0.0,\n                    )\n                    session.timeline_data[\"keyframes\"][end_keyframe_id] = {\n                        \"frame\": session.max_frame_idx,\n                        \"track_id\": root_2d_track,\n                        \"locked\": True,\n                        \"opacity\": 0.0,\n                        \"value\": None,\n                    }\n                    add_constraint_callback(\n                        end_keyframe_id,\n                        \"2D Root\",\n                        (session.max_frame_idx, session.max_frame_idx),\n                        verbose=False,\n                    )\n\n                    # add a locked interval only for visual purposes\n                    locked_interval = client.timeline.add_locked_interval(  # noqa\n                        root_2d_track,\n                        start_frame=0,\n                        end_frame=session.max_frame_idx,\n                    )\n                    session.timeline_data[\"intervals\"][locked_interval] = {\n                        \"track_id\": root_2d_track,\n                        \"start_frame_idx\": 0,\n                        \"end_frame_idx\": session.max_frame_idx,\n                        \"locked\": True,\n                        \"opacity\": 0.3,\n                        \"value\": None,\n                    }\n\n                session.constraints[\"2D Root\"].set_dense_path(gui_dense_path_checkbox.value)\n                if session.constraints[\"2D Root\"].dense_path:\n                    # update the character motion to reflect the full path\n                    # will be full length by construction, no need to specify min/max frame idx\n                    motion = list(session.motions.values())[0]\n                    _update_dense_path(motion, session)\n\n                # remove locked interval and locked keyframes\n                if not gui_dense_path_checkbox.value:\n                    # Get all locked keyframes\n                    keyframes_to_remove = []\n                    for uuid, keyframe in client.timeline._keyframes.items():\n                        if keyframe.locked:\n                            keyframes_to_remove.append(uuid)\n                            _data = session.timeline_data[\"keyframes\"][uuid]\n                            remove_constraint_callback(\n                                uuid,\n                                constraint_type=session.timeline_data[\"tracks\"][_data[\"track_id\"]][\"name\"],\n                                frame_range=(_data[\"frame\"], _data[\"frame\"]),\n                                verbose=False,\n                            )\n\n                    intervals_to_remove = []\n                    # remove all locked intervals\n                    for uuid, interval in client.timeline._intervals.items():\n                        if interval.locked:\n                            intervals_to_remove.append(uuid)\n\n                    # removing keyframes and intervals\n                    for uuid in keyframes_to_remove:\n                        client.timeline.remove_keyframe(uuid)\n\n                    for uuid in intervals_to_remove:\n                        client.timeline.remove_interval(uuid)\n\n                apply_constraint_overlay_visibility(session)\n\n        with client.gui.add_folder(\n            \"Load/Save\",\n            expand_by_default=False,\n            visible=not HF_MODE,\n        ):\n            with client.gui.add_folder(\"Motion\", expand_by_default=False):\n                gui_save_motion_path_text = client.gui.add_text(\"Save Path\", initial_value=\"output\")\n                gui_save_motion_format_dropdown = client.gui.add_dropdown(\n                    \"Save Format\",\n                    options=(\n                        [\"NPZ\", \"CSV\"]\n                        if \"g1\" in model_name.lower()\n                        else [\"NPZ\", \"AMASS NPZ\"]\n                        if \"smplx\" in model_name.lower()\n                        else [\"NPZ\", \"BVH\"]\n                    ),\n                    initial_value=\"NPZ\",\n                )\n                gui_save_bvh_standard_tpose_checkbox = client.gui.add_checkbox(\n                    \"Standard T-pose\",\n                    initial_value=False,\n                    hint=\"For BVH export, use the standard T-pose rest skeleton.\",\n                    visible=False,\n                )\n                gui_save_motion_button = client.gui.add_button(\n                    \"Save Motion\",\n                    hint=\"Save the current motion (format + path above)\",\n                )\n                gui_load_motion_path_text = client.gui.add_text(\n                    \"Load Path\",\n                    initial_value=\"output.npz\",\n                    hint=\"SOMA .bvh, Kimodo or AMASS .npz, or G1 MuJoCo .csv\",\n                )\n                gui_load_motion_button = client.gui.add_button(\n                    \"Load Motion\",\n                    hint=\"Load the selected motion\",\n                )\n            with client.gui.add_folder(\"Constraints\", expand_by_default=False):\n                gui_save_constraints_path_text = client.gui.add_text(\n                    \"Save Path\", initial_value=\"output_constraints.json\"\n                )\n                gui_save_constraints_button = client.gui.add_button(\"Save Constraints\")\n                gui_load_constraints_path_text = client.gui.add_text(\n                    \"Load Path\", initial_value=\"output_constraints.json\"\n                )\n                gui_load_constraints_button = client.gui.add_button(\"Load Constraints\")\n            with client.gui.add_folder(\"Example\", expand_by_default=False):\n                gui_save_example_path_text = client.gui.add_text(\n                    \"Save Dir\",\n                    initial_value=os.path.join(\n                        demo.get_examples_base_dir(model_name, absolute=True),\n                        \"custom_example_1\",\n                    ),\n                )\n                gui_save_example_button = client.gui.add_button(\"Save Example\")\n                gui_load_example_path_text = client.gui.add_text(\n                    \"Load Dir\",\n                    initial_value=os.path.join(\n                        demo.get_examples_base_dir(model_name, absolute=True),\n                        \"custom_example_1\",\n                    ),\n                )\n                gui_load_gt_checkbox = client.gui.add_checkbox(\n                    \"Load GT instead\",\n                    initial_value=False,\n                )\n                gui_load_example_from_path_button = client.gui.add_button(\"Load Example\")\n\n            def _get_primary_motion(session: ClientSession):\n                return list(session.motions.values())[0]\n\n            def _motion_to_numpy_dict(motion) -> dict[str, np.ndarray]:\n                joints_pos = motion.joints_pos.detach().cpu().numpy()\n                joints_rot = motion.joints_rot.detach().cpu().numpy()\n                joints_local_rot = motion.joints_local_rot.detach().cpu().numpy()\n\n                if joints_pos.ndim != 3:\n                    raise ValueError(f\"Expected unbatched joints_pos with shape [T, J, 3], got {joints_pos.shape}\")\n                if joints_rot.ndim != 4:\n                    raise ValueError(f\"Expected unbatched joints_rot with shape [T, J, 3, 3], got {joints_rot.shape}\")\n                if joints_local_rot.ndim != 4:\n                    raise ValueError(\n                        \"Expected unbatched joints_local_rot with shape \" f\"[T, J, 3, 3], got {joints_local_rot.shape}\"\n                    )\n\n                motion_data = {\n                    \"posed_joints\": joints_pos,\n                    \"global_rot_mats\": joints_rot,\n                    \"local_rot_mats\": joints_local_rot,\n                    \"root_positions\": joints_pos[:, motion.skeleton.root_idx, :],\n                }\n                if motion.foot_contacts is not None:\n                    foot_contacts = motion.foot_contacts.detach().cpu().numpy()\n                    if foot_contacts.ndim != 2:\n                        raise ValueError(\n                            f\"Expected unbatched foot_contacts with shape [T, C], got {foot_contacts.shape}\"\n                        )\n                    motion_data[\"foot_contacts\"] = foot_contacts\n                return motion_data\n\n            def _coerce_save_path(raw_path: str, *, ext: str) -> str:\n                \"\"\"Ensure the save path ends with the correct extension for the chosen format.\"\"\"\n                name = (raw_path or \"\").strip()\n                if name == \"\":\n                    return f\"output{ext}\"\n                known_exts = (\".npz\", \".bvh\", \".csv\")\n                if name.lower().endswith(known_exts):\n                    return os.path.splitext(name)[0] + ext\n                if os.path.splitext(name)[1] == \"\":\n                    return name + ext\n                return name\n\n            def save_motion(client, save_path, fmt):\n                session = demo.client_sessions[client.client_id]\n                motion = _get_primary_motion(session)\n                motion_data = _motion_to_numpy_dict(motion)\n\n                if fmt == \"BVH\":\n                    save_path = _coerce_save_path(save_path, ext=\".bvh\")\n                    save_motion_bvh(\n                        save_path,\n                        motion.joints_local_rot,\n                        motion.joints_pos[:, session.skeleton.root_idx, :],\n                        skeleton=session.skeleton,\n                        fps=float(session.model_fps),\n                        standard_tpose=bool(gui_save_bvh_standard_tpose_checkbox.value),\n                    )\n                elif fmt == \"CSV\":\n                    save_path = _coerce_save_path(save_path, ext=\".csv\")\n                    data = g1_csv_to_bytes(motion_data, session.skeleton, demo.device)\n                    with open(save_path, \"wb\") as f:\n                        f.write(data)\n                elif fmt == \"AMASS NPZ\":\n                    save_path = _coerce_save_path(save_path, ext=\".npz\")\n                    data = amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps)\n                    with open(save_path, \"wb\") as f:\n                        f.write(data)\n                else:\n                    save_path = _coerce_save_path(save_path, ext=\".npz\")\n                    save_kimodo_npz(save_path, motion_data)\n                return save_path\n\n            @gui_save_motion_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                if get_active_session(event_client) is None:\n                    return\n\n                raw_path = gui_save_motion_path_text.value\n                fmt = str(gui_save_motion_format_dropdown.value).upper()\n                try:\n                    saved_path = save_motion(event_client, raw_path, fmt)\n                    event_client.add_notification(\n                        title=\"Motion saved!\",\n                        body=f\"Saved motion to {saved_path}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to save motion!\",\n                        body=str(e),\n                        auto_close_seconds=5.0,\n                        color=\"red\",\n                    )\n\n            def load_motion(client, load_path):\n                session = demo.client_sessions[client.client_id]\n\n                fps_arg = session.model_fps if session.model_fps and session.model_fps > 0 else None\n                motion_dict, num_joints_motion = load_motion_file(load_path, target_fps=fps_arg)\n\n                target_skel = registry_skeleton_for_joint_count(num_joints_motion)\n                current_info = get_model_info(session.model_name)\n                current_skel = current_info.skeleton if current_info is not None else None\n\n                if current_skel != target_skel:\n                    dataset = current_info.dataset if current_info is not None else \"RP\"\n                    new_key = kimodo_short_key_for_skeleton_dataset(target_skel, dataset)\n                    if new_key is None:\n                        new_key = kimodo_short_key_for_skeleton_dataset(target_skel, \"RP\")\n                    if new_key is None:\n                        raise ValueError(\n                            f\"No Kimodo model found for skeleton {target_skel} (motion has J={num_joints_motion}).\"\n                        )\n                    if new_key != session.model_name:\n                        gui_model_selector.set_from_short_key(new_key)\n                        apply_model_selection(new_key)\n                        _update_visibility_for_loaded_model(new_key)\n                        client.add_notification(\n                            title=\"Model switched\",\n                            body=f\"Switched to {new_key} to match loaded motion (J={num_joints_motion}).\",\n                            auto_close_seconds=5.0,\n                            color=\"blue\",\n                        )\n                    session = demo.client_sessions[client.client_id]\n\n                joints_pos = motion_dict[\"posed_joints\"].to(device=demo.device, dtype=torch.float32)\n                joints_rot = motion_dict[\"global_rot_mats\"].to(device=demo.device, dtype=torch.float32)\n                foot_contacts = motion_dict.get(\"foot_contacts\")\n                if foot_contacts is not None:\n                    foot_contacts = foot_contacts.to(device=demo.device, dtype=torch.float32)\n\n                # Support both batched [B, T, J, 3] and unbatched [T, J, 3]; take first sample if batched\n                if joints_pos.ndim == 4:\n                    joints_pos = joints_pos[0]\n                if joints_rot.ndim == 5:\n                    joints_rot = joints_rot[0]\n                if foot_contacts is not None and foot_contacts.ndim == 3:\n                    foot_contacts = foot_contacts[0]\n\n                # Motion must match the current model's skeleton after auto-switch\n                num_joints_loaded = joints_pos.shape[1]\n                num_joints_skeleton = session.skeleton.nbjoints\n                if num_joints_loaded != num_joints_skeleton:\n                    # Backward compat: expand 30-joint SOMA motion to 77\n                    if (\n                        num_joints_loaded == 30\n                        and num_joints_skeleton == 77\n                        and isinstance(session.skeleton, SOMASkeleton77)\n                    ):\n                        from kimodo.skeleton import global_rots_to_local_rots\n\n                        skel30 = SOMASkeleton30().to(demo.device)\n                        if \"local_rot_mats\" in motion_dict:\n                            local_rot_30 = motion_dict[\"local_rot_mats\"].to(device=demo.device, dtype=torch.float32)\n                            if local_rot_30.ndim == 4:\n                                local_rot_30 = local_rot_30[0]\n                        else:\n                            local_rot_30 = global_rots_to_local_rots(joints_rot, skel30)\n                        local_rot_77 = skel30.to_SOMASkeleton77(local_rot_30)\n                        root_positions = joints_pos[:, skel30.root_idx, :]\n                        joints_rot, joints_pos, _ = session.skeleton.fk(local_rot_77, root_positions)\n\n                        if foot_contacts is not None and foot_contacts.shape[-1] == 4:\n                            foot_contacts = torch.cat(\n                                [\n                                    foot_contacts[..., :2],\n                                    foot_contacts[..., 1:2],\n                                    foot_contacts[..., 2:4],\n                                    foot_contacts[..., 3:4],\n                                ],\n                                dim=-1,\n                            )\n                    else:\n                        raise ValueError(\n                            f\"The loaded motion has {num_joints_loaded} joints but the current model \"\n                            f\"({session.model_name}) has {num_joints_skeleton} joints. \"\n                            \"Load a motion generated with the same skeleton, or switch the model to match the motion.\"\n                        )\n                elif joints_rot.shape[1] != num_joints_skeleton:\n                    raise ValueError(\n                        f\"Rotation data has {joints_rot.shape[1]} joints but the current model has \"\n                        f\"{num_joints_skeleton} joints. The NPZ may be corrupted or from a different skeleton.\"\n                    )\n\n                # Apply G1 real robot projection (1-DoF per joint + axis limits) if enabled.\n                if (\n                    \"g1\" in session.model_name\n                    and isinstance(session.skeleton, G1Skeleton34)\n                    and gui_real_robot_rotations_checkbox.value\n                ):\n                    joints_pos, joints_rot = generation.apply_g1_real_robot_projection(\n                        session.skeleton, joints_pos, joints_rot\n                    )\n\n                # Update duration and frame range based on loaded motion\n                num_frames = joints_pos.shape[0]\n                duration = num_frames / session.model_fps\n\n                # Update GUI elements\n                session.cur_duration = duration\n                session.max_frame_idx = num_frames - 1\n\n                # Clear existing motions and add the loaded one\n                demo.clear_motions(client.client_id)\n                demo.add_character_motion(\n                    client,\n                    session.skeleton,\n                    joints_pos,\n                    joints_rot,\n                    foot_contacts,\n                )\n\n                # Reset to frame 0\n                demo.set_frame(client.client_id, 0)\n\n            @gui_load_motion_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                session = get_active_session(event_client)\n                if session is None:\n                    return\n\n                load_path = gui_load_motion_path_text.value\n                loading_notif = event_client.add_notification(\n                    title=\"Loading motion...\",\n                    body=f\"Loading from {load_path}\",\n                    loading=True,\n                    with_close_button=False,\n                    auto_close_seconds=None,\n                )\n                try:\n                    load_motion(event_client, load_path)\n\n                    loading_notif.title = \"Motion loaded!\"\n                    loading_notif.body = f\"Loaded motion from {load_path} ({session.max_frame_idx + 1} frames, {session.cur_duration:.2f}s)\"\n                    loading_notif.loading = False\n                    loading_notif.with_close_button = True\n                    loading_notif.auto_close_seconds = 5.0\n                    loading_notif.color = \"green\"\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    loading_notif.title = \"Failed to load motion!\"\n                    loading_notif.body = str(e)\n                    loading_notif.loading = False\n                    loading_notif.with_close_button = True\n                    loading_notif.auto_close_seconds = 10.0\n                    loading_notif.color = \"red\"\n\n            def save_constraints(client, save_path):\n                session = demo.client_sessions[client.client_id]\n                # Keep save behavior aligned with demo frame convention:\n                # valid frame indices are [0, max_frame_idx], so count is +1.\n                num_frames = session.max_frame_idx + 1\n                model_bundle = demo.load_model(session.model_name)\n                constraints_lst = demo.compute_model_constraints_lst(session, model_bundle, num_frames)\n                save_constraints_lst(save_path, constraints_lst)\n\n            @gui_save_constraints_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                if get_active_session(event_client) is None:\n                    return\n\n                try:\n                    save_path = gui_save_constraints_path_text.value\n                    save_constraints(event_client, save_path)\n                    event_client.add_notification(\n                        title=\"Constraints saved!\",\n                        body=f\"Saved constraints to {save_path}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to save constraints!\",\n                        body=str(e),\n                        auto_close_seconds=10.0,\n                        color=\"red\",\n                    )\n\n            def load_constraints(client, load_path):\n                session = demo.client_sessions[client.client_id]\n                constraints_lst = load_constraints_lst(load_path, skeleton=session.skeleton)\n\n                # Clear existing constraints first\n                with session.timeline_data[\"keyframe_update_lock\"]:\n                    for constraint in list(session.constraints.values()):\n                        constraint.clear()\n                    client.timeline.clear_keyframes()\n                    client.timeline.clear_intervals()\n\n                # Add loaded constraints to the session\n                # We need to directly add constraint data, not read from current motion\n                device = demo.device\n                for constraint_obj in constraints_lst:\n                    constraint_type = constraint_obj.name\n\n                    # decompose the frame indices into intervals or single keyframes\n                    frame_indices = constraint_obj.frame_indices\n                    (\n                        intervals,\n                        intervals_indices,\n                        single_frames,\n                        single_frames_indices,\n                    ) = extract_intervals_and_singles(frame_indices)\n\n                    load_targets: list[dict] = []\n                    root_pos = None\n\n                    if constraint_type == \"root2d\":\n                        # smooth_root_2d is [T, 2] (x, z), convert to [T, 3] (x, 0, z)\n                        num_frames = constraint_obj.smooth_root_2d.shape[0]\n                        root_pos = torch.zeros(num_frames, 3, device=device)\n                        root_pos[:, 0] = constraint_obj.smooth_root_2d[:, 0]\n                        root_pos[:, 2] = constraint_obj.smooth_root_2d[:, 1]\n                        load_targets = [\n                            {\n                                \"track_name\": \"2D Root\",\n                                \"constraint_track\": session.constraints[\"2D Root\"],\n                            }\n                        ]\n                    elif constraint_type == \"fullbody\":\n                        load_targets = [\n                            {\n                                \"track_name\": \"Full-Body\",\n                                \"constraint_track\": session.constraints[\"Full-Body\"],\n                            }\n                        ]\n                    elif constraint_type in {\n                        \"left-hand\",\n                        \"right-hand\",\n                        \"left-foot\",\n                        \"right-foot\",\n                    }:\n                        track_name = {\n                            \"left-hand\": \"Left Hand\",\n                            \"right-hand\": \"Right Hand\",\n                            \"left-foot\": \"Left Foot\",\n                            \"right-foot\": \"Right Foot\",\n                        }[constraint_type]\n                        load_targets = [\n                            {\n                                \"track_name\": track_name,\n                                \"constraint_track\": session.constraints[\"End-Effectors\"],\n                                \"joint_names\": constraint_obj.joint_names,\n                                \"end_effector_type\": constraint_type,\n                            }\n                        ]\n                    elif constraint_type in {\"end-effector\", \"end-effectors\"}:\n                        # Backward-compatible loader:\n                        # split a generic end-effector constraint into per-limb timeline tracks.\n                        joint_names_set = set(constraint_obj.joint_names)\n                        for jname, track_name, eff_type in [\n                            (\"LeftHand\", \"Left Hand\", \"left-hand\"),\n                            (\"RightHand\", \"Right Hand\", \"right-hand\"),\n                            (\"LeftFoot\", \"Left Foot\", \"left-foot\"),\n                            (\"RightFoot\", \"Right Foot\", \"right-foot\"),\n                        ]:\n                            if jname not in joint_names_set:\n                                continue\n                            target_joint_names = [jname]\n                            if \"Hips\" in joint_names_set:\n                                target_joint_names.append(\"Hips\")\n                            load_targets.append(\n                                {\n                                    \"track_name\": track_name,\n                                    \"constraint_track\": session.constraints[\"End-Effectors\"],\n                                    \"joint_names\": target_joint_names,\n                                    \"end_effector_type\": eff_type,\n                                }\n                            )\n                        if not load_targets:\n                            raise KeyError(\n                                \"No recognized end-effector joint in constraint \"\n                                f\"joint_names={constraint_obj.joint_names}\"\n                            )\n                    else:\n                        raise KeyError(f\"Unsupported constraint type in loader: {constraint_type}\")\n\n                    for target in load_targets:\n                        track_id = session.timeline_data[\"tracks_ids\"][target[\"track_name\"]]\n                        constraint_track = target[\"constraint_track\"]\n\n                        # add intervals\n                        for (start_idx, end_idx), (start_idx_t, end_idx_t) in zip(intervals, intervals_indices):\n                            # Add to timeline\n                            interval_id = client.timeline.add_interval(track_id, start_idx, end_idx)\n                            session.timeline_data[\"intervals\"][interval_id] = {\n                                \"track_id\": track_id,\n                                \"start_frame_idx\": start_idx,\n                                \"end_frame_idx\": end_idx,\n                                \"locked\": False,\n                                \"opacity\": 1.0,\n                                \"value\": None,\n                            }\n                            if constraint_type == \"root2d\":\n                                constraint_track.add_interval(\n                                    interval_id,\n                                    start_idx,\n                                    end_idx,\n                                    root_pos[start_idx_t : end_idx_t + 1],\n                                )\n                            elif constraint_type == \"fullbody\":\n                                constraint_track.add_interval(\n                                    interval_id,\n                                    start_idx,\n                                    end_idx,\n                                    constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1],\n                                    constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1],\n                                )\n                            else:\n                                constraint_track.add_interval(\n                                    interval_id,\n                                    start_idx,\n                                    end_idx,\n                                    constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1],\n                                    constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1],\n                                    target[\"joint_names\"],\n                                    target[\"end_effector_type\"],\n                                )\n\n                        # add keyframes\n                        for frame, frame_t in zip(single_frames, single_frames_indices):\n                            # Add to timeline\n                            keyframe_id = client.timeline.add_keyframe(track_id, frame)\n                            session.timeline_data[\"keyframes\"][keyframe_id] = {\n                                \"track_id\": track_id,\n                                \"frame\": frame,\n                                \"locked\": False,\n                                \"opacity\": 1.0,\n                                \"value\": None,\n                            }\n                            if constraint_type == \"root2d\":\n                                constraint_track.add_keyframe(\n                                    keyframe_id,\n                                    frame,\n                                    root_pos[frame_t],\n                                )\n                            elif constraint_type == \"fullbody\":\n                                constraint_track.add_keyframe(\n                                    keyframe_id,\n                                    frame,\n                                    constraint_obj.global_joints_positions[frame_t],\n                                    constraint_obj.global_joints_rots[frame_t],\n                                )\n                            else:\n                                constraint_track.add_keyframe(\n                                    keyframe_id,\n                                    frame,\n                                    constraint_obj.global_joints_positions[frame_t],\n                                    constraint_obj.global_joints_rots[frame_t],\n                                    target[\"joint_names\"],\n                                    target[\"end_effector_type\"],\n                                )\n\n            @gui_load_constraints_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                if get_active_session(event_client) is None:\n                    return\n\n                try:\n                    load_path = gui_load_constraints_path_text.value\n                    load_constraints(event_client, load_path)\n                    session = demo.client_sessions[event_client.client_id]\n                    apply_constraint_overlay_visibility(session)\n\n                    event_client.add_notification(\n                        title=\"Constraints loaded!\",\n                        body=f\"Loaded constraints from {load_path}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to load constraints!\",\n                        body=str(e),\n                        auto_close_seconds=10.0,\n                        color=\"red\",\n                    )\n\n        with client.gui.add_folder(\"Exports\", expand_by_default=False):\n            with client.gui.add_folder(\"Screenshot\", expand_by_default=False, visible=not HF_MODE):\n                gui_screenshot_path_text = client.gui.add_text(\n                    \"Save Path\",\n                    initial_value=\"render.png\",\n                    hint=\"Filename for the screenshot (PNG).\",\n                )\n                gui_screenshot_button = client.gui.add_button(\n                    \"Download Screenshot\",\n                    hint=\"Capture the current canvas and download a PNG.\",\n                )\n            with client.gui.add_folder(\"Video\", expand_by_default=False, visible=not HF_MODE):\n                gui_video_path_text = client.gui.add_text(\n                    \"Save Path\",\n                    initial_value=\"render.mp4\",\n                    hint=\"Filename for the video (MP4).\",\n                )\n                gui_video_button = client.gui.add_button(\n                    \"Download Video\",\n                    hint=\"Render every frame and download as MP4.\",\n                )\n            with client.gui.add_folder(\"Motion\", expand_by_default=True):\n                gui_download_name_text = client.gui.add_text(\n                    \"Name\",\n                    initial_value=\"output\",\n                    hint=\"Base filename to save as (extension will be added based on format if omitted).\",\n                )\n                gui_download_format_dropdown = client.gui.add_dropdown(\n                    \"Format\",\n                    options=(\n                        [\"NPZ\", \"CSV\"]\n                        if \"g1\" in model_name.lower()\n                        else [\"NPZ\", \"AMASS NPZ\"]\n                        if \"smplx\" in model_name.lower()\n                        else [\"NPZ\", \"BVH\"]\n                    ),\n                    initial_value=\"NPZ\",\n                )\n                gui_download_bvh_standard_tpose_checkbox = client.gui.add_checkbox(\n                    \"Standard T-pose\",\n                    initial_value=False,\n                    hint=\"For BVH export, use the standard T-pose rest skeleton.\",\n                    visible=False,\n                )\n                gui_download_button = client.gui.add_button(\n                    \"Download\",\n                    hint=\"Download the current motion (format + name above).\",\n                )\n\n            def _download_bytes_to_browser(\n                event_client: viser.ClientHandle,\n                *,\n                data: bytes,\n                filename: str,\n                mime_type: str = \"application/octet-stream\",\n            ) -> None:\n                \"\"\"Trigger a browser download for an in-memory byte payload.\n\n                Important: this intentionally does NOT use `showSaveFilePicker()` to avoid\n                Chrome/Edge's file-write permission prompt (\"this site can see edits you make\").\n                If you want \"always ask where to save\", configure your browser download settings.\n                \"\"\"\n                import base64\n                import json\n\n                # Base64 is the most robust way to move binary over our websocket JS channel.\n                b64 = base64.b64encode(data).decode(\"ascii\")\n                js = f\"\"\"\n(() => {{\n  const filename = {json.dumps(filename)};\n  const mimeType = {json.dumps(mime_type)};\n  const b64 = {json.dumps(b64)};\n\n  // Decode base64 -> Uint8Array.\n  const binStr = atob(b64);\n  const bytes = new Uint8Array(binStr.length);\n  for (let i = 0; i < binStr.length; i++) bytes[i] = binStr.charCodeAt(i);\n  const blob = new Blob([bytes], {{ type: mimeType }});\n\n  // Standard browser download behavior.\n  const url = URL.createObjectURL(blob);\n  const a = document.createElement(\"a\");\n  a.href = url;\n  a.download = filename;\n  document.body.appendChild(a);\n  a.click();\n  a.remove();\n  URL.revokeObjectURL(url);\n}})();\n\"\"\"\n                # Reuse viser’s JS execution mechanism (used for Plotly setup).\n                from viser import _messages as _viser_messages\n\n                event_client.gui._websock_interface.queue_message(  # type: ignore[attr-defined]\n                    _viser_messages.RunJavascriptMessage(source=js)\n                )\n\n            def _motion_to_npz_bytes(motion) -> bytes:\n                motion_data = _motion_to_numpy_dict(motion)\n                return kimodo_npz_to_bytes(motion_data)\n\n            def _motion_to_csv_bytes(motion, session: ClientSession) -> bytes:\n                motion_data = _motion_to_numpy_dict(motion)\n                return g1_csv_to_bytes(motion_data, session.skeleton, demo.device)\n\n            def _motion_to_amass_npz_bytes(motion, session: ClientSession) -> bytes:\n                motion_data = _motion_to_numpy_dict(motion)\n                return amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps)\n\n            def _get_motion_export_formats(loaded_model_name: str) -> list[str]:\n                model_name_lower = (loaded_model_name or \"\").lower()\n                if \"g1\" in model_name_lower:\n                    return [\"NPZ\", \"CSV\"]\n                if \"smplx\" in model_name_lower:\n                    return [\"NPZ\", \"AMASS NPZ\"]\n                return [\"NPZ\", \"BVH\"]\n\n            def _update_format_dropdown(dropdown, loaded_model_name: str) -> None:\n                new_options = _get_motion_export_formats(loaded_model_name)\n                current_value = str(dropdown.value)\n                dropdown.options = new_options\n                dropdown.value = current_value if current_value in new_options else new_options[0]\n\n            def _update_motion_export_dropdown(loaded_model_name: str) -> None:\n                _update_format_dropdown(gui_download_format_dropdown, loaded_model_name)\n                _update_format_dropdown(gui_save_motion_format_dropdown, loaded_model_name)\n                _update_bvh_standard_tpose_visibility()\n\n            def _update_bvh_standard_tpose_visibility() -> None:\n                gui_save_bvh_standard_tpose_checkbox.visible = (\n                    str(gui_save_motion_format_dropdown.value).upper() == \"BVH\"\n                )\n                gui_download_bvh_standard_tpose_checkbox.visible = (\n                    str(gui_download_format_dropdown.value).upper() == \"BVH\"\n                )\n\n            @gui_save_motion_format_dropdown.on_update\n            def _(_event: viser.GuiEvent) -> None:\n                _update_bvh_standard_tpose_visibility()\n\n            @gui_download_format_dropdown.on_update\n            def _(_event: viser.GuiEvent) -> None:\n                _update_bvh_standard_tpose_visibility()\n\n            def _coerce_download_filename(raw_name: str, *, ext: str) -> str:\n                \"\"\"Coerce a user-entered filename to a safe basename with the desired extension.\n\n                - If empty: uses \"output{ext}\"\n                - If no extension: appends ext\n                - If endswith a known export extension: rewrites extension to ext (prevents mismatches)\n                - Any provided directory components are stripped\n                \"\"\"\n                import os\n\n                name = (raw_name or \"\").strip()\n                name = os.path.basename(name.replace(\"\\\\\", \"/\"))\n                if name == \"\":\n                    return f\"output{ext}\"\n\n                known_exts = (\".npz\", \".bvh\", \".csv\", \".png\", \".mp4\")\n                lower = name.lower()\n                if lower.endswith(known_exts):\n                    return os.path.splitext(name)[0] + ext\n\n                root, cur_ext = os.path.splitext(name)\n                if cur_ext == \"\":\n                    return name + ext\n                return name\n\n            def _get_render_size(event_client: viser.ClientHandle) -> tuple[int, int]:\n                width = int(event_client.camera.image_width)\n                height = int(event_client.camera.image_height)\n                if width <= 0 or height <= 0:\n                    # Fall back to a reasonable default if the camera hasn't synced yet.\n                    return (1280, 720)\n                return (width, height)\n\n            def _round_up_to_multiple(value: int, multiple: int) -> int:\n                if multiple <= 0:\n                    return value\n                return ((value + multiple - 1) // multiple) * multiple\n\n            def _download_canvas_to_browser(event_client: viser.ClientHandle, *, filename: str) -> None:\n                \"\"\"Use the client-side canvas save path to avoid server-side renders.\"\"\"\n                import json\n\n                js = f\"\"\"\n(() => {{\n  const filename = {json.dumps(filename)};\n  const canvases = Array.from(document.querySelectorAll(\"canvas\"));\n  if (!canvases.length) {{\n    console.error(\"No canvases found to save.\");\n    return;\n  }}\n  // Pick the largest canvas by area (usually the main 3D view).\n  const canvas = canvases.reduce((best, cur) => {{\n    const bestArea = (best?.width || 0) * (best?.height || 0);\n    const curArea = (cur?.width || 0) * (cur?.height || 0);\n    return curArea > bestArea ? cur : best;\n  }}, null);\n  if (!canvas) {{\n    console.error(\"No canvas selected to save.\");\n    return;\n  }}\n  canvas.toBlob((blob) => {{\n    if (!blob) {{\n      console.error(\"Export failed\");\n      return;\n    }}\n    const url = URL.createObjectURL(blob);\n    const a = document.createElement(\"a\");\n    a.href = url;\n    a.download = filename;\n    document.body.appendChild(a);\n    a.click();\n    a.remove();\n    URL.revokeObjectURL(url);\n  }}, \"image/png\");\n}})();\n\"\"\"\n                from viser import _messages as _viser_messages\n\n                event_client.gui._websock_interface.queue_message(  # type: ignore[attr-defined]\n                    _viser_messages.RunJavascriptMessage(source=js)\n                )\n\n            @gui_screenshot_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                if get_active_session(event_client) is None:\n                    return\n\n                try:\n                    filename = _coerce_download_filename(\n                        str(gui_screenshot_path_text.value),\n                        ext=\".png\",\n                    )\n                    _download_canvas_to_browser(event_client, filename=filename)\n                    event_client.add_notification(\n                        title=\"Screenshot download started\",\n                        body=f\"Saving {filename}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to download screenshot!\",\n                        body=str(e),\n                        auto_close_seconds=10.0,\n                        color=\"red\",\n                    )\n\n            @gui_video_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                session = get_active_session(event_client)\n                if session is None:\n                    return\n                recording_notification: viser.NotificationHandle | None = None\n                try:\n                    recording_notification = event_client.add_notification(\n                        title=\"Recording video...\",\n                        body=\"Saving frames, please wait.\",\n                        loading=True,\n                        with_close_button=False,\n                        auto_close_seconds=None,\n                        color=\"blue\",\n                    )\n                    event_client.timeline.disable_constraints()\n                    width, height = _get_render_size(event_client)\n                    # Avoid ffmpeg macro block resizing warnings.\n                    width = _round_up_to_multiple(width, 16)\n                    height = _round_up_to_multiple(height, 16)\n                    original_frame = session.frame_idx\n                    frames = []\n                    for frame_idx in range(session.max_frame_idx + 1):\n                        demo.set_frame(\n                            event_client.client_id,\n                            frame_idx,\n                            update_timeline=True,\n                        )\n                        frames.append(\n                            event_client.get_render(\n                                height=height,\n                                width=width,\n                                transport_format=\"jpeg\",\n                            )\n                        )\n\n                    # Restore the original frame (and timeline).\n                    demo.set_frame(event_client.client_id, original_frame)\n\n                    import imageio.v3 as iio\n\n                    filename = _coerce_download_filename(\n                        str(gui_video_path_text.value),\n                        ext=\".mp4\",\n                    )\n                    payload = iio.imwrite(\n                        \"<bytes>\",\n                        frames,\n                        extension=\".mp4\",\n                        fps=float(session.model_fps),\n                        codec=\"h264\",\n                        plugin=\"pyav\",\n                    )\n                    event_client.send_file_download(filename, payload, save_immediately=True)\n                    event_client.add_notification(\n                        title=\"Video download started\",\n                        body=f\"Saving {filename}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to download video!\",\n                        body=str(e),\n                        auto_close_seconds=10.0,\n                        color=\"red\",\n                    )\n                finally:\n                    event_client.timeline.enable_constraints()\n                    if recording_notification is not None:\n                        recording_notification.remove()\n\n            @gui_download_button.on_click\n            def _(event: viser.GuiEvent) -> None:\n                event_client = event.client\n                session = get_active_session(event_client)\n                if session is None:\n                    return\n                motion = _get_primary_motion(session)\n                try:\n                    fmt = str(gui_download_format_dropdown.value).upper()\n                    raw_name = str(gui_download_name_text.value)\n\n                    if fmt == \"BVH\":\n                        filename = _coerce_download_filename(raw_name, ext=\".bvh\")\n                        payload = motion_to_bvh_bytes(\n                            motion.joints_local_rot,\n                            motion.joints_pos[:, session.skeleton.root_idx, :],  # root positions\n                            skeleton=session.skeleton,\n                            fps=float(session.model_fps),\n                            standard_tpose=bool(gui_download_bvh_standard_tpose_checkbox.value),\n                        )\n                        mime = \"text/plain\"\n                    elif fmt == \"CSV\":\n                        filename = _coerce_download_filename(raw_name, ext=\".csv\")\n                        payload = _motion_to_csv_bytes(motion, session)\n                        mime = \"text/csv\"\n                    elif fmt == \"AMASS NPZ\":\n                        filename = _coerce_download_filename(raw_name, ext=\".npz\")\n                        payload = _motion_to_amass_npz_bytes(motion, session)\n                        mime = \"application/octet-stream\"\n                    else:\n                        # Default to NPZ (most common and matches existing save/load).\n                        filename = _coerce_download_filename(raw_name, ext=\".npz\")\n                        payload = _motion_to_npz_bytes(motion)\n                        mime = \"application/octet-stream\"\n\n                    _download_bytes_to_browser(\n                        event_client,\n                        data=payload,\n                        filename=filename,\n                        mime_type=mime,\n                    )\n\n                    event_client.add_notification(\n                        title=\"Download started\",\n                        body=f\"Saving {filename}\",\n                        auto_close_seconds=5.0,\n                        color=\"green\",\n                    )\n                except Exception as e:\n                    import traceback\n\n                    traceback.print_exc()\n                    event_client.add_notification(\n                        title=\"Failed to download motion!\",\n                        body=str(e),\n                        auto_close_seconds=10.0,\n                        color=\"red\",\n                    )\n\n        @gui_save_example_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            from kimodo.tools import save_json\n\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            save_dir = gui_save_example_path_text.value\n            if os.path.exists(save_dir):\n                event_client.add_notification(\n                    title=\"Failed to save example!\",\n                    body=\"Example directory already exists\",\n                    auto_close_seconds=10.0,\n                    color=\"red\",\n                )\n                return\n\n            try:\n                os.makedirs(save_dir)\n                # save the constraints\n                constraint_path = os.path.join(save_dir, \"constraints.json\")\n                save_constraints(event_client, constraint_path)\n                # save the motion\n                motion_path = os.path.join(save_dir, \"motion.npz\")\n                save_motion(event_client, motion_path, \"NPZ\")\n                # save the gui metadata\n                meta_path = os.path.join(save_dir, \"meta.json\")\n                prompt_texts = []\n                prompt_durations_sec = []\n                prompt_values = sorted(\n                    [x for x in client.timeline._prompts.values()],\n                    key=lambda x: x.start_frame,\n                )\n                for i, prompt in enumerate(prompt_values):\n                    prompt_texts.append(prompt.text)\n                    # Match demo/generation convention:\n                    # non-last prompts: [start, end) ; last prompt: [start, end].\n                    n_frames = prompt.end_frame - prompt.start_frame\n                    if i == len(prompt_values) - 1:\n                        n_frames += 1\n                    prompt_durations_sec.append(n_frames / session.model_fps)\n                if len(prompt_texts) == 1:\n                    meta_info = {\n                        \"text\": prompt_texts[0],\n                        \"duration\": prompt_durations_sec[0],\n                    }\n                else:\n                    meta_info = {\n                        \"texts\": prompt_texts,\n                        \"durations\": prompt_durations_sec,\n                    }\n                meta_info[\"num_samples\"] = gui_num_samples_slider.value\n                meta_info[\"seed\"] = gui_seed.value\n                meta_info[\"diffusion_steps\"] = gui_diffusion_steps_slider.value\n                meta_info[\"cfg\"] = {\n                    \"enabled\": gui_cfg_checkbox.value,\n                    \"text_weight\": gui_cfg_text_weight_slider.value,\n                    \"constraint_weight\": gui_cfg_constraint_weight_slider.value,\n                }\n                save_json(meta_path, meta_info)\n\n                # update the example dropdown\n                session.example_dict = viser_utils.load_example_cases(session.examples_base_dir)\n                update_examples_dropdown(session.example_dict, keep_selection=True)\n\n                event_client.add_notification(\n                    title=\"Example saved!\",\n                    body=f\"Saved example to {save_dir}\",\n                    auto_close_seconds=5.0,\n                    color=\"green\",\n                )\n            except Exception as e:\n                import traceback\n\n                traceback.print_exc()\n                event_client.add_notification(\n                    title=\"Failed to save example!\",\n                    body=str(e),\n                    auto_close_seconds=10.0,\n                    color=\"red\",\n                )\n\n        def set_new_duration(client_id, new_duration):\n            session = demo.client_sessions[client_id]\n            session.cur_duration = new_duration\n            update_duration_gui(new_duration)\n            session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)\n            if session.frame_idx > session.max_frame_idx:\n                demo.set_frame(client_id, session.max_frame_idx)\n\n        def apply_model_selection(new_model_name: str) -> None:\n            session = demo.client_sessions[client_id]\n            if new_model_name == session.model_name:\n                return\n\n            session.playing = False  # Pause playback when switching models.\n\n            old_model_fps = session.model_fps\n            old_duration = session.cur_duration\n            old_prompts = [\n                (prompt.text, prompt.start_frame, prompt.end_frame) for prompt in client.timeline._prompts.values()\n            ]\n            old_default_zoom_frames = client.timeline._default_num_frames_zoom\n            old_max_zoom_frames = client.timeline._max_frames_zoom\n\n            model_bundle = demo.load_model(new_model_name)\n\n            # Clear motions and constraints when switching models.\n            if session.edit_mode and session.motions:\n                exit_editing_mode(session)\n            session.edit_mode = False\n            demo.clear_motions(client_id)\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                for constraint in list(session.constraints.values()):\n                    constraint.clear()\n                session.constraints = demo.build_constraint_tracks(client, model_bundle.skeleton)\n                session.timeline_data[\"keyframes\"] = {}\n                session.timeline_data[\"intervals\"] = {}\n                client.timeline.clear_keyframes()\n                client.timeline.clear_intervals()\n\n            session.model_name = new_model_name\n            session.model_fps = model_bundle.model_fps\n            session.skeleton = model_bundle.skeleton\n            session.motion_rep = model_bundle.motion_rep\n            session.cur_duration = old_duration\n            session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)\n            session.frame_idx = 0\n            session.edit_mode = False\n\n            demo.set_timeline_defaults(client.timeline, session.model_fps)\n            client.timeline.set_current_frame(0)\n            gui_model_fps.value = session.model_fps\n            update_duration_gui(session.cur_duration)\n\n            if old_model_fps > 0:\n                default_zoom_seconds = old_default_zoom_frames / old_model_fps\n                max_zoom_seconds = old_max_zoom_frames / old_model_fps\n                new_default_zoom = int(round(default_zoom_seconds * session.model_fps))\n                new_max_zoom = int(round(max_zoom_seconds * session.model_fps))\n                new_default_zoom = max(1, new_default_zoom)\n                new_max_zoom = max(new_default_zoom, new_max_zoom)\n                client.timeline.set_zoom_settings(\n                    default_num_frames_zoom=new_default_zoom,\n                    max_frames_zoom=new_max_zoom,\n                )\n\n            client.timeline.clear_prompts()\n            if old_prompts and old_model_fps > 0:\n                for i, (prompt_text, start_frame, end_frame) in enumerate(old_prompts):\n                    start_sec = start_frame / old_model_fps\n                    end_sec = end_frame / old_model_fps\n                    new_start = int(round(start_sec * session.model_fps))\n                    new_end = int(round(end_sec * session.model_fps))\n                    new_start = max(0, min(new_start, session.max_frame_idx))\n                    new_end = max(new_start, min(new_end, session.max_frame_idx))\n                    color = PROMPT_COLORS[i % len(PROMPT_COLORS)]\n                    client.timeline.add_prompt(prompt_text, new_start, new_end, color=color)\n\n            session.examples_base_dir = demo.get_examples_base_dir(new_model_name, absolute=True)\n            session.example_dict = viser_utils.load_example_cases(session.examples_base_dir)\n            update_examples_dropdown(session.example_dict, keep_selection=False)\n            gui_save_example_path_text.value = os.path.join(\n                demo.get_examples_base_dir(new_model_name, absolute=True),\n                \"custom_example_1\",\n            )\n            gui_load_example_path_text.value = os.path.join(\n                demo.get_examples_base_dir(new_model_name, absolute=True),\n                \"custom_example_1\",\n            )\n\n            demo.add_character_motion(client, session.skeleton)\n            apply_constraint_overlay_visibility(session)\n\n        def _update_version_and_display_from_dataset_skeleton() -> None:\n            dataset_ui = gui_dataset_selector.value\n            skeleton_display = gui_skeleton_selector.value\n            skeleton_val = get_skeleton_key_from_display_name(skeleton_display)\n            if skeleton_val is None:\n                return\n            models = get_models_for_dataset_skeleton(dataset_ui, skeleton_val, family=\"Kimodo\")\n            if not models:\n                return\n            gui_version_selector.options = [m.display_name for m in models]\n            gui_version_selector.value = models[0].display_name\n            gui_version_selector.visible = len(models) > 1\n            gui_model_display.content = f\"**Model:** {models[0].display_name}\"\n\n        def _update_visibility_for_loaded_model(loaded_model_name: str) -> None:\n            \"\"\"Update model-specific controls from the currently loaded model only.\"\"\"\n            if not loaded_model_name:\n                return\n            _update_motion_export_dropdown(loaded_model_name)\n            gui_use_soma_layer_checkbox.visible = \"soma\" in loaded_model_name\n            _is_g1 = \"g1\" in loaded_model_name\n            gui_real_robot_rotations_checkbox.visible = _is_g1\n            gui_postprocess_checkbox.visible = not _is_g1\n            gui_root_margin.visible = not _is_g1 and gui_postprocess_checkbox.value\n            if _is_g1:\n                gui_gizmo_space_dropdown.value = \"Local\"\n            gui_gizmo_space_dropdown.visible = not _is_g1\n            gui_gizmo_space_dropdown.disabled = _is_g1\n\n        def _on_load_model_click(event: viser.GuiEvent) -> None:\n            \"\"\"Load the currently selected model (called from Load model button).\"\"\"\n            if get_active_session(event.client) is None:\n                return\n            new_model_name = gui_model_selector.value\n            if not new_model_name:\n                return\n            info = get_model_info(new_model_name)\n            if info is None:\n                return\n            session = demo.client_sessions[event.client.client_id]\n            if new_model_name == session.model_name:\n                return\n            loading_notif = event.client.add_notification(\n                title=\"Loading model...\",\n                body=f\"Loading {info.display_name}\",\n                loading=True,\n                with_close_button=False,\n            )\n            try:\n                apply_model_selection(new_model_name)\n                _update_visibility_for_loaded_model(new_model_name)\n                loading_notif.title = \"Model loaded\"\n                loading_notif.body = f\"{info.display_name} is ready.\"\n                loading_notif.loading = False\n                loading_notif.with_close_button = True\n                loading_notif.auto_close_seconds = 5.0\n                loading_notif.color = \"green\"\n            except Exception as e:\n                loading_notif.loading = False\n                loading_notif.with_close_button = True\n                event.client.add_notification(\n                    title=\"Model failed to load\",\n                    body=str(e),\n                    color=\"red\",\n                    auto_close_seconds=10.0,\n                )\n                gui_model_selector.set_from_short_key(session.model_name)\n\n        @gui_load_model_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            _on_load_model_click(event)\n\n        @gui_dataset_selector.on_update\n        def _(event: viser.GuiEvent) -> None:\n            if get_active_session(event.client) is None:\n                return\n            skeleton_labels = get_allowed_skeleton_labels(gui_dataset_selector.value)\n            gui_skeleton_selector.options = skeleton_labels\n            gui_skeleton_selector.value = skeleton_labels[0] if skeleton_labels else \"\"\n            _update_version_and_display_from_dataset_skeleton()\n\n        @gui_skeleton_selector.on_update\n        def _(event: viser.GuiEvent) -> None:\n            if get_active_session(event.client) is None:\n                return\n            _update_version_and_display_from_dataset_skeleton()\n\n        @gui_version_selector.on_update\n        def _(event: viser.GuiEvent) -> None:\n            if get_active_session(event.client) is None:\n                return\n            info = get_model_info(gui_model_selector.value)\n            if info is not None:\n                gui_model_display.content = f\"**Model:** {info.display_name}\"\n\n        @gui_use_soma_layer_checkbox.on_update\n        def _(event: viser.GuiEvent) -> None:\n            session = get_active_session(event.client)\n            if session is None or \"soma\" not in (session.model_name or \"\"):\n                return\n\n            loading_notif = event.client.add_notification(\n                title=\"Applying SOMA layer...\",\n                body=\"Updating mesh.\",\n                loading=True,\n                with_close_button=False,\n            )\n            try:\n                current_motion = list(session.motions.values())[0] if session.motions else None\n                current_frame_idx = session.frame_idx\n\n                # Recreate the character to apply the new SOMA mesh mode selection.\n                demo.clear_motions(event.client.client_id)\n                if current_motion is None:\n                    demo.add_character_motion(event.client, session.skeleton)\n                else:\n                    demo.add_character_motion(\n                        event.client,\n                        session.skeleton,\n                        current_motion.joints_pos,\n                        current_motion.joints_rot,\n                        current_motion.foot_contacts,\n                    )\n\n                demo.set_frame(event.client.client_id, current_frame_idx)\n            except Exception as e:\n                print(e)\n                event.client.add_notification(\n                    title=\"SOMA layer failed\",\n                    body=str(e),\n                    color=\"red\",\n                    auto_close_seconds=10.0,\n                )\n                gui_use_soma_layer_checkbox.value = not gui_use_soma_layer_checkbox.value\n            finally:\n                loading_notif.loading = False\n                loading_notif.with_close_button = True\n                loading_notif.auto_close_seconds = 2.0\n\n        @gui_real_robot_rotations_checkbox.on_update\n        def _(event: viser.GuiEvent) -> None:\n            session = get_active_session(event.client)\n            if session is None or \"g1\" not in session.model_name:\n                return\n            if not isinstance(session.skeleton, G1Skeleton34) or not session.motions:\n                return\n            if not gui_real_robot_rotations_checkbox.value:\n                return\n            # Reproject all displayed G1 motions to real robot DoF (1-DoF per joint + axis limits).\n            from kimodo.skeleton import global_rots_to_local_rots\n\n            current_frame_idx = session.frame_idx\n            for motion in session.motions.values():\n                if motion.length <= 1:\n                    continue\n                rest_pos = motion.joints_pos[0:1]\n                rest_rot = motion.joints_rot[0:1]\n                same_as_rest = (motion.joints_pos - rest_pos).abs().max().item() < 1e-6 and (\n                    motion.joints_rot - rest_rot\n                ).abs().max().item() < 1e-6\n                if same_as_rest:\n                    continue\n                new_pos, new_rot = generation.apply_g1_real_robot_projection(\n                    session.skeleton,\n                    motion.joints_pos,\n                    motion.joints_rot,\n                )\n                motion.joints_pos = new_pos\n                motion.joints_rot = new_rot\n                motion.joints_local_rot = global_rots_to_local_rots(new_rot, session.skeleton)\n                # Refresh skeleton and skinned mesh caches so the viz uses new positions.\n                motion.precompute_mesh_info()\n            demo.set_frame(event.client.client_id, current_frame_idx)\n            event.client.add_notification(\n                title=\"Real robot projection applied\",\n                body=\"The motion is projected to G1 real robot DoF (1-DoF per joint, clamped to axis limits).\",\n                auto_close_seconds=4.0,\n                color=\"green\",\n            )\n\n        def load_example_from_path(\n            event_client: viser.ClientHandle,\n            example_path: str,\n            load_gt: bool = False,\n        ) -> None:\n            from kimodo.meta import parse_prompts_from_meta\n            from kimodo.tools import load_json\n\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            # Pause playback when loading an example.\n            session.playing = False\n\n            if not os.path.isdir(example_path):\n                event_client.add_notification(\n                    title=\"Example path not found\",\n                    body=f\"Directory does not exist: {example_path}\",\n                    auto_close_seconds=5.0,\n                    color=\"red\",\n                )\n                return\n\n            # Long motions trigger a skinning precompute that can take several\n            # seconds; show a persistent \"loading\" notification so the user\n            # knows the app isn't frozen. Cleared in the finally block below.\n            loading_notif = event_client.add_notification(\n                title=\"Loading example...\",\n                body=f\"Loading {os.path.basename(example_path.rstrip(os.sep))}. This may take a moment for long motions.\",\n                loading=True,\n                with_close_button=False,\n            )\n\n            try:\n                # constraints\n                constraints_path = os.path.join(example_path, \"constraints.json\")\n                if os.path.exists(constraints_path):\n                    load_constraints(event_client, constraints_path)\n                else:\n                    # clear all existing constraints\n                    with session.timeline_data[\"keyframe_update_lock\"]:\n                        for constraint in list(session.constraints.values()):\n                            constraint.clear()\n                        event_client.timeline.clear_keyframes()\n                        event_client.timeline.clear_intervals()\n                # motion\n                motion_filename = \"gt_motion.npz\" if load_gt else \"motion.npz\"\n                motion_path = os.path.join(example_path, motion_filename)\n                if os.path.exists(motion_path):\n                    load_motion(event_client, motion_path)\n                # metadata\n                meta_path = os.path.join(example_path, \"meta.json\")\n                if os.path.exists(meta_path):\n                    meta_info = load_json(meta_path)\n                    event_client.timeline.clear_prompts()\n\n                    texts, durations_sec = parse_prompts_from_meta(meta_info)\n                    fps = session.model_fps\n                    # Convert durations (seconds) to consecutive frame bounds\n                    num_frames = 0\n                    frame_bounds = []\n                    for i, d in enumerate(durations_sec):\n                        n_frames = max(1, int(round(d * fps)))\n                        start_frame = num_frames\n                        # Inverse of compute_prompt_num_frames():\n                        # non-last prompts end at next prompt start (exclusive),\n                        # last prompt includes its end frame.\n                        if i == len(durations_sec) - 1:\n                            end_frame = num_frames + n_frames - 1\n                        else:\n                            end_frame = num_frames + n_frames\n                        frame_bounds.append((start_frame, end_frame))\n                        num_frames += n_frames\n\n                    # Adapt timeline zoom to the loaded motion.\n                    target_visible_frames = int(math.ceil(1.10 * num_frames))\n                    event_client.timeline.set_zoom_settings(\n                        default_num_frames_zoom=target_visible_frames,\n                    )\n\n                    for i, (prompt_text, (start_frame, end_frame)) in enumerate(zip(texts, frame_bounds)):\n                        color = PROMPT_COLORS[i % len(PROMPT_COLORS)]\n                        event_client.timeline.add_prompt(prompt_text, start_frame, end_frame, color=color)\n\n                    update_duration_auto()\n\n                    # Only load optional fields if present\n                    if \"num_samples\" in meta_info:\n                        gui_num_samples_slider.value = meta_info[\"num_samples\"]\n                    if \"seed\" in meta_info:\n                        gui_seed.value = meta_info[\"seed\"]\n                    if \"diffusion_steps\" in meta_info:\n                        gui_diffusion_steps_slider.value = meta_info[\"diffusion_steps\"]\n                    if \"cfg\" in meta_info:\n                        cfg = meta_info[\"cfg\"]\n                        if \"enabled\" in cfg:\n                            gui_cfg_checkbox.value = cfg[\"enabled\"]\n                        if \"text_weight\" in cfg:\n                            gui_cfg_text_weight_slider.value = cfg[\"text_weight\"]\n                        if \"constraint_weight\" in cfg:\n                            gui_cfg_constraint_weight_slider.value = cfg[\"constraint_weight\"]\n\n                # Set frame to 0 when example is loaded.\n                session.frame_idx = 0\n                event_client.timeline.set_current_frame(0)\n                demo.set_frame(event_client.client_id, 0)\n\n                event_client.add_notification(\n                    title=\"Example loaded!\",\n                    body=f\"Loaded example from {example_path}\",\n                    auto_close_seconds=5.0,\n                    color=\"green\",\n                )\n            except Exception as e:\n                import traceback\n\n                traceback.print_exc()\n                event_client.add_notification(\n                    title=\"Failed to load example!\",\n                    body=str(e),\n                    auto_close_seconds=10.0,\n                    color=\"red\",\n                )\n            finally:\n                loading_notif.remove()\n\n        @gui_load_example_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            if not session.example_dict or (gui_examples_dropdown.value not in session.example_dict):\n                event_client.add_notification(\n                    title=\"No examples available\",\n                    body=\"No examples found for the selected model.\",\n                    auto_close_seconds=5.0,\n                    color=\"red\",\n                )\n                return\n\n            example_path = session.example_dict[gui_examples_dropdown.value]\n            load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value)\n\n        @gui_load_example_from_path_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            example_path = gui_load_example_path_text.value\n            if not example_path:\n                event_client.add_notification(\n                    title=\"No example path\",\n                    body=\"Please provide an example directory.\",\n                    auto_close_seconds=5.0,\n                    color=\"red\",\n                )\n                return\n            load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value)\n\n        @gui_cfg_checkbox.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            val = gui_cfg_checkbox.value\n            gui_cfg_text_weight_slider.visible = val\n            gui_cfg_constraint_weight_slider.visible = val\n\n        def exit_editing_mode(session: ClientSession):\n            gui_edit_constraint_button.label = \"Enter Editing Mode\"\n            gui_generate_button.disabled = False\n            gui_generate_button.label = \"Generate\"\n            gui_reset_constraint_button.disabled = True\n            if \"g1\" in session.model_name:\n                gui_gizmo_space_dropdown.value = \"Local\"\n                gui_gizmo_space_dropdown.disabled = True\n                gui_gizmo_space_dropdown.visible = False\n            else:\n                gui_gizmo_space_dropdown.disabled = False\n                gui_gizmo_space_dropdown.visible = True\n            gui_undo_drag_button.disabled = True\n            gui_use_soma_layer_checkbox.disabled = False\n            session.edit_mode_snapshot = None\n            session.undo_drag_snapshot = None\n\n            motion = list(session.motions.values())[0]\n            motion.clear_all_gizmos()\n            motion.character.set_skinned_mesh_wireframe(False)\n            motion.character.set_skeleton_visibility(False)\n            motion.character.set_skinned_mesh_visibility(True)\n            motion.character.set_skinned_mesh_opacity(1.0)\n            session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 1.0\n\n            # If the path is dense, put the motion back on the path\n            if \"2D Root\" in session.constraints and session.constraints[\"2D Root\"].dense_path:\n                _update_dense_path(motion, session)\n\n            gui_viz_skinned_mesh_checkbox.value = True\n            gui_viz_skeleton_checkbox.value = False\n\n        # enter editing mode callback\n        @gui_edit_constraint_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            session.edit_mode = not session.edit_mode\n\n            edit_alert = \"Entered editing mode\"\n            no_edit_alert = \"Exited editing mode\"\n            edit_message = \"You can now modify pose or path constraints.\"\n            no_edit_message = \"Can now generate motions.\"\n            event_client.add_notification(\n                title=edit_alert if session.edit_mode else no_edit_alert,\n                body=edit_message if session.edit_mode else no_edit_message,\n                auto_close_seconds=10.0,\n                color=\"blue\",\n            )\n\n            if session.edit_mode:\n                gui_edit_constraint_button.label = \"Exit Editing Mode\"\n                gui_generate_button.disabled = True\n                gui_generate_button.label = \"Generate Disabled In Editing Mode\"\n                if \"g1\" in session.model_name:\n                    gui_gizmo_space_dropdown.value = \"Local\"\n                gui_gizmo_space_dropdown.disabled = True\n                gui_use_soma_layer_checkbox.disabled = True\n\n                assert len(session.motions) == 1, \"Only one motion allowed in edit mode\"\n                motion = list(session.motions.values())[0]\n                snapshot_frame_idx = min(session.frame_idx, motion.length - 1)\n                session.edit_mode_snapshot = {}\n                ensure_edit_snapshot(session, motion, snapshot_frame_idx)\n                gui_reset_constraint_button.disabled = False\n\n                motion.character.set_skeleton_visibility(True)\n                # motion.character.set_skinned_mesh_wireframe(True)\n                motion.character.set_skinned_mesh_opacity(0.65)\n                session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 0.65\n                motion.character.set_skinned_mesh_visibility(True)\n                gui_viz_skinned_mesh_checkbox.value = True\n                gui_viz_skeleton_checkbox.value = True\n\n                # need gizmos for root translation and individual joints\n                def _on_root2d_gizmo_release():\n                    if \"2D Root\" in session.constraints and session.constraints[\"2D Root\"].dense_path:\n                        mot = list(session.motions.values())[0]\n                        _update_dense_path(mot, session)\n\n                def _on_gizmo_drag_start():\n                    mot = list(session.motions.values())[0]\n                    frame_idx = min(session.frame_idx, mot.length - 1)\n                    session.undo_drag_snapshot = {\n                        \"frame_idx\": frame_idx,\n                        \"joints_pos\": mot.get_joints_pos(frame_idx),\n                        \"joints_rot\": mot.get_joints_rot(frame_idx),\n                    }\n                    gui_undo_drag_button.disabled = False\n\n                motion.add_root_translation_gizmo(\n                    session.constraints,\n                    on_2d_root_drag_end=_on_root2d_gizmo_release,\n                    on_drag_start=_on_gizmo_drag_start,\n                )\n                gizmo_space = \"local\" if \"g1\" in session.model_name else gui_gizmo_space_dropdown.value.lower()\n                motion.add_joint_gizmos(\n                    session.constraints,\n                    space=gizmo_space,\n                    on_drag_start=_on_gizmo_drag_start,\n                )\n            else:\n                exit_editing_mode(session)\n\n        @gui_reset_constraint_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None or not session.edit_mode_snapshot:\n                return\n\n            if not session.motions:\n                return\n            motion = list(session.motions.values())[0]\n            snapshot_frame_idx = min(session.frame_idx, motion.length - 1)\n            if snapshot_frame_idx not in session.edit_mode_snapshot:\n                return\n            motion.update_pose_at_frame(\n                snapshot_frame_idx,\n                joints_pos=session.edit_mode_snapshot[snapshot_frame_idx][\"joints_pos\"],\n                joints_rot=session.edit_mode_snapshot[snapshot_frame_idx][\"joints_rot\"],\n            )\n            demo.set_frame(event_client.client_id, snapshot_frame_idx, update_timeline=False)\n\n        @gui_undo_drag_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None or session.undo_drag_snapshot is None:\n                return\n\n            if not session.motions:\n                return\n            motion = list(session.motions.values())[0]\n            frame_idx = session.undo_drag_snapshot[\"frame_idx\"]\n            motion.update_pose_at_frame(\n                frame_idx,\n                joints_pos=session.undo_drag_snapshot[\"joints_pos\"],\n                joints_rot=session.undo_drag_snapshot[\"joints_rot\"],\n            )\n            demo.set_frame(event_client.client_id, frame_idx, update_timeline=False)\n            session.undo_drag_snapshot = None\n            gui_undo_drag_button.disabled = True\n\n        def validate_interval(start_frame_idx: int, end_frame_idx: int, max_frame_idx: int) -> bool:\n            if start_frame_idx < 0 or start_frame_idx > max_frame_idx:\n                return False\n            if end_frame_idx < 0 or end_frame_idx > max_frame_idx:\n                return False\n            if end_frame_idx < start_frame_idx:\n                return False\n            return True\n\n        def clamp_interval_to_range(\n            start_frame_idx: int, end_frame_idx: int, max_frame_idx: int\n        ) -> Optional[tuple[int, int]]:\n            if end_frame_idx < 0 or start_frame_idx > max_frame_idx:\n                return None\n            start_clamped = max(0, start_frame_idx)\n            end_clamped = min(max_frame_idx, end_frame_idx)\n            if end_clamped < start_clamped:\n                return None\n            return start_clamped, end_clamped\n\n        # add constraint callback\n        def add_constraint_callback(\n            constraint_id: str,\n            constraint_type: str,\n            frame_range: tuple[int, int],\n            joint_names: list[str] = None,\n            verbose: bool = True,\n        ):\n            \"\"\"Add a constraint to the session.\n\n            Args:\n                constraint_type: str, the type of constraint to add\n                frame_range: tuple[int, int], the frame range to add the constraint to\n                joint_names: list[str], the names of the joints to constraint if the constraint type is End-Effectors\n            \"\"\"\n            # Check if session still exists\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n\n            assert len(session.motions) == 1, \"Only one motion allowed for adding constraints\"\n            motion = list(session.motions.values())[0]\n\n            end_effector_type = None\n            if constraint_type in [\n                \"Left Hand\",\n                \"Right Hand\",\n                \"Left Foot\",\n                \"Right Foot\",\n            ]:\n                joint_names = [constraint_type.replace(\" \", \"\"), \"Hips\"]\n                # Hips are required because of smooth root representation\n                end_effector_type = constraint_type.replace(\" \", \"-\").lower()\n                constraint_type = \"End-Effectors\"\n\n            # check to make sure interval is valid\n            is_interval = frame_range[1] != frame_range[0]\n            start_frame_idx = int(frame_range[0])\n            end_frame_idx = int(frame_range[1])\n\n            if is_interval:\n                clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx)\n                if clamped is None:\n                    print(\"Interval outside range! Couldn't add constraint.\")\n                    return\n                start_frame_idx, end_frame_idx = clamped\n            else:\n                if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx):\n                    print(\"Invalid interval! Couldn't add constraint.\")\n                    return\n\n            # collect input args for the constraint based on which track it is\n            if is_interval:\n                constraint_kwargs = {\n                    \"interval_id\": constraint_id,\n                    \"start_frame_idx\": start_frame_idx,\n                    \"end_frame_idx\": end_frame_idx,\n                }\n            else:\n                constraint_kwargs = {\n                    \"keyframe_id\": constraint_id,\n                    \"frame_idx\": start_frame_idx,\n                }\n\n            if constraint_type in [\"Full-Body\", \"End-Effectors\"]:\n                constraint_kwargs[\"joints_pos\"] = motion.get_joints_pos(start_frame_idx, end_frame_idx)\n                constraint_kwargs[\"joints_rot\"] = motion.get_joints_rot(start_frame_idx, end_frame_idx)\n                if constraint_type == \"End-Effectors\":\n                    constraint_kwargs[\"joint_names\"] = joint_names\n                    constraint_kwargs[\"end_effector_type\"] = end_effector_type\n\n            elif constraint_type == \"2D Root\":\n                constraint_kwargs[\"root_pos\"] = motion.get_projected_root_pos(start_frame_idx, end_frame_idx)\n\n            # add the keyframe(s) to the constraint track\n            constraint = session.constraints[constraint_type]\n            if is_interval:\n                constraint.add_interval(**constraint_kwargs)\n            else:\n                constraint.add_keyframe(**constraint_kwargs)\n\n            apply_constraint_overlay_visibility(session)\n\n            if verbose:\n                client.add_notification(\n                    title=\"Constraint added\",\n                    body=\"\",\n                    auto_close_seconds=5.0,\n                    color=\"blue\",\n                )\n\n        # timeline callbacks for keyframes and intervals\n        @client.timeline.on_keyframe_add\n        def _(keyframe_id: str, track_id: str, frame: int):\n            \"\"\"Called when a keyframe is added to a track.\"\"\"\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                constraint_type = session.timeline_data[\"tracks\"][track_id][\"name\"]\n                add_constraint_callback(\n                    keyframe_id,\n                    constraint_type,\n                    (frame, frame),\n                    verbose=False,\n                )\n                keyframe_data = client.timeline._keyframes.get(keyframe_id)\n                session.timeline_data[\"keyframes\"][keyframe_id] = {\n                    \"frame\": frame,\n                    \"track_id\": track_id,\n                    \"locked\": bool(keyframe_data.locked) if keyframe_data is not None else False,\n                    \"opacity\": keyframe_data.opacity if keyframe_data is not None else 1.0,\n                    \"value\": keyframe_data.value if keyframe_data is not None else None,\n                }\n                # Update smooth path when adding a keyframe (single action, not drag).\n                if constraint_type == \"2D Root\" and session.constraints[\"2D Root\"].dense_path:\n                    motion = list(session.motions.values())[0]\n                    _update_dense_path(motion, session)\n\n        @client.timeline.on_interval_add\n        def handle_interval_add(interval_id: str, track_id: str, start_frame: int, end_frame: int):\n            \"\"\"Called when an interval is added to a track.\"\"\"\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                constraint_type = session.timeline_data[\"tracks\"][track_id][\"name\"]\n                add_constraint_callback(\n                    interval_id,\n                    constraint_type,\n                    (start_frame, end_frame),\n                    verbose=False,\n                )\n                interval_data = client.timeline._intervals.get(interval_id)\n                session.timeline_data[\"intervals\"][interval_id] = {\n                    \"track_id\": track_id,\n                    \"start_frame_idx\": start_frame,\n                    \"end_frame_idx\": end_frame,\n                    \"locked\": bool(interval_data.locked) if interval_data is not None else False,\n                    \"opacity\": interval_data.opacity if interval_data is not None else 1.0,\n                    \"value\": interval_data.value if interval_data is not None else None,\n                }\n                if constraint_type == \"2D Root\" and session.constraints[\"2D Root\"].dense_path:\n                    motion = list(session.motions.values())[0]\n                    _update_dense_path(motion, session)\n\n        def remove_constraint_callback(\n            constraint_id: str,\n            constraint_type: str,\n            frame_range: tuple[int, int],\n            verbose: bool = True,\n        ) -> None:\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            session.updating_motions = True\n\n            is_interval = frame_range[1] != frame_range[0]\n            start_frame_idx = int(frame_range[0])\n            end_frame_idx = int(frame_range[1])\n\n            if is_interval:\n                clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx)\n                if clamped is None:\n                    return\n                start_frame_idx, end_frame_idx = clamped\n            else:\n                if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx):\n                    print(\"Invalid interval! Couldn't remove constraint.\")\n                    return\n\n            if constraint_type in [\n                \"Left Hand\",\n                \"Right Hand\",\n                \"Left Foot\",\n                \"Right Foot\",\n            ]:\n                constraint_type = \"End-Effectors\"\n\n            constraint = session.constraints[constraint_type]\n            if is_interval:\n                constraint.remove_interval(constraint_id, start_frame_idx, end_frame_idx)\n            else:\n                constraint.remove_keyframe(constraint_id, start_frame_idx)\n\n            if verbose:\n                client.add_notification(\n                    title=\"Constraint removed\",\n                    body=\"\",\n                    auto_close_seconds=5.0,\n                    color=\"blue\",\n                )\n\n        @client.timeline.on_keyframe_move\n        def handle_keyframe_move(keyframe_id: str, new_frame: int):\n            \"\"\"Called when a keyframe is moved to a new frame.\"\"\"\n            # print(f\"Keyframe moved: {keyframe_id} to frame {new_frame}\")\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n\n            # Cancel any pending timer for this keyframe\n            timeline_data = session.timeline_data\n            with timeline_data[\"keyframe_update_lock\"]:\n                if keyframe_id in timeline_data[\"keyframe_move_timers\"]:\n                    timeline_data[\"keyframe_move_timers\"][keyframe_id].cancel()\n\n                # Store the latest target frame\n                timeline_data[\"pending_keyframe_moves\"][keyframe_id] = new_frame\n                # Create a new timer to execute the actual move after a delay\n                # This debounces rapid movements - only execute when user stops moving\n                timer = threading.Timer(\n                    0.03,  # 10ms delay - adjust as needed\n                    _execute_keyframe_move,\n                    args=(client_id, keyframe_id, new_frame, session),\n                )\n                timeline_data[\"keyframe_move_timers\"][keyframe_id] = timer\n                timer.start()\n\n        def _execute_keyframe_move(\n            client_id: int,\n            keyframe_id: str,\n            new_frame: int,\n            session: ClientSession,\n        ):\n            \"\"\"Actually execute the keyframe move operation (called after debounce delay).\"\"\"\n\n            timeline_data = session.timeline_data\n            with timeline_data[\"keyframe_update_lock\"]:\n                # Check if this move is still the latest one\n                if keyframe_id not in timeline_data[\"pending_keyframe_moves\"]:\n                    return  # Move was cancelled\n\n                if timeline_data[\"pending_keyframe_moves\"][keyframe_id] != new_frame:\n                    return  # A newer move superseded this one\n\n                # Remove from pending\n                del timeline_data[\"pending_keyframe_moves\"][keyframe_id]\n                if keyframe_id in timeline_data[\"keyframe_move_timers\"]:\n                    del timeline_data[\"keyframe_move_timers\"][keyframe_id]\n\n                # Now execute the actual move (keep it in the lock so we don't delete it while moving)\n                if keyframe_id not in timeline_data[\"keyframes\"]:\n                    # double check\n                    return\n                keyframe_data = timeline_data[\"keyframes\"][keyframe_id]\n                if not keyframe_data:\n                    return\n\n                # if the frame did not move, don't do anything\n                if keyframe_data[\"frame\"] == new_frame:\n                    return\n\n                track_id = keyframe_data[\"track_id\"]\n                constraint_type = timeline_data[\"tracks\"][track_id][\"name\"]\n                cur_frame = keyframe_data[\"frame\"]\n\n                # Remove constraint at old frame\n                remove_constraint_callback(\n                    keyframe_id,\n                    constraint_type,\n                    (cur_frame, cur_frame),\n                    verbose=False,\n                )\n                # Add constraint at new frame\n                add_constraint_callback(\n                    keyframe_id,\n                    constraint_type,\n                    (new_frame, new_frame),\n                    verbose=False,\n                )\n\n                # update our data\n                keyframe_data[\"frame\"] = new_frame\n\n                # Schedule path update only after user stops dragging (no move for 300ms).\n                if constraint_type == \"2D Root\":\n                    _schedule_dense_path_after_release(session)\n\n        @client.timeline.on_keyframe_delete\n        def handle_keyframe_delete(keyframe_id: str):\n            \"\"\"Called when a keyframe is deleted.\"\"\"\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                if keyframe_id not in session.timeline_data[\"keyframes\"]:\n                    return\n                keyframe_data = session.timeline_data[\"keyframes\"][keyframe_id]\n                track_id = keyframe_data[\"track_id\"]\n                constraint_type = session.timeline_data[\"tracks\"][track_id][\"name\"]\n                cur_frame = keyframe_data[\"frame\"]\n                remove_constraint_callback(\n                    keyframe_id,\n                    constraint_type,\n                    (cur_frame, cur_frame),\n                    verbose=False,\n                )\n                del session.timeline_data[\"keyframes\"][keyframe_id]\n                if constraint_type == \"2D Root\" and session.constraints[\"2D Root\"].dense_path:\n                    motion = list(session.motions.values())[0]\n                    _update_dense_path(motion, session)\n\n        @client.timeline.on_interval_move\n        def handle_interval_move(interval_id: str, new_start: int, new_end: int):\n            \"\"\"Called when an interval is moved or resized.\"\"\"\n            # print(f\"Interval moved: {interval_id} to {new_start}-{new_end}\")\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n\n            # Cancel any pending timer for this interval\n            # We share the same lock for keyframe and interval moves assuming the user can't move both at the same time\n            timeline_data = session.timeline_data\n            with timeline_data[\"keyframe_update_lock\"]:\n                if interval_id in timeline_data[\"keyframe_move_timers\"]:\n                    timeline_data[\"keyframe_move_timers\"][interval_id].cancel()\n\n                # Store the latest target frame\n                new_interval = (new_start, new_end)\n                timeline_data[\"pending_keyframe_moves\"][interval_id] = new_interval\n                # Create a new timer to execute the actual move after a delay\n                # This debounces rapid movements - only execute when user stops moving\n                timer = threading.Timer(\n                    0.5,  # 100ms delay - adding interval is much slower than moving a keyframe\n                    _execute_interval_move,\n                    args=(client_id, interval_id, new_interval, session),\n                )\n                timeline_data[\"keyframe_move_timers\"][interval_id] = timer\n                timer.start()\n\n        def _execute_interval_move(\n            client_id: int,\n            interval_id: str,\n            new_interval: tuple[int, int],\n            session: ClientSession,\n        ):\n            \"\"\"Actually execute the interval move operation (called after debounce delay).\"\"\"\n\n            timeline_data = session.timeline_data\n            with timeline_data[\"keyframe_update_lock\"]:\n                # Check if this move is still the latest one\n                if interval_id not in timeline_data[\"pending_keyframe_moves\"]:\n                    return  # Move was cancelled\n\n                if timeline_data[\"pending_keyframe_moves\"][interval_id] != new_interval:\n                    return  # A newer move superseded this one\n\n                # Remove from pending\n                del timeline_data[\"pending_keyframe_moves\"][interval_id]\n                if interval_id in timeline_data[\"keyframe_move_timers\"]:\n                    del timeline_data[\"keyframe_move_timers\"][interval_id]\n\n                # Now execute the actual move\n                if interval_id not in timeline_data[\"intervals\"]:\n                    return\n                interval_data = timeline_data[\"intervals\"][interval_id]\n                if not interval_data:\n                    return\n\n                # if the interval did not move, don't do anything\n                if (\n                    interval_data[\"start_frame_idx\"] == new_interval[0]\n                    and interval_data[\"end_frame_idx\"] == new_interval[1]\n                ):\n                    return\n\n                track_id = interval_data[\"track_id\"]\n                constraint_type = timeline_data[\"tracks\"][track_id][\"name\"]\n                cur_range = (\n                    interval_data[\"start_frame_idx\"],\n                    interval_data[\"end_frame_idx\"],\n                )\n\n                # Remove constraint at old frame\n                remove_constraint_callback(\n                    interval_id,\n                    constraint_type,\n                    cur_range,\n                    verbose=False,\n                )\n                # Add constraint at new frame\n                add_constraint_callback(\n                    interval_id,\n                    constraint_type,\n                    new_interval,\n                    verbose=False,\n                )\n\n                # update our data\n                interval_data[\"start_frame_idx\"] = new_interval[0]\n                interval_data[\"end_frame_idx\"] = new_interval[1]\n\n                # Schedule path update only after user stops dragging (no move for 300ms).\n                if constraint_type == \"2D Root\":\n                    _schedule_dense_path_after_release(session)\n\n        @client.timeline.on_interval_delete\n        def handle_interval_delete(interval_id: str):\n            \"\"\"Called when an interval is deleted.\"\"\"\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                if interval_id not in session.timeline_data[\"intervals\"]:\n                    return\n                interval_data = session.timeline_data[\"intervals\"][interval_id]\n                track_id = interval_data[\"track_id\"]\n                constraint_type = session.timeline_data[\"tracks\"][track_id][\"name\"]\n                remove_constraint_callback(\n                    interval_id,\n                    constraint_type,\n                    (\n                        interval_data[\"start_frame_idx\"],\n                        interval_data[\"end_frame_idx\"],\n                    ),\n                    verbose=False,\n                )\n                del session.timeline_data[\"intervals\"][interval_id]\n                if constraint_type == \"2D Root\" and session.constraints[\"2D Root\"].dense_path:\n                    motion = list(session.motions.values())[0]\n                    _update_dense_path(motion, session)\n\n        @gui_snap_to_constraint_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            target_character_motion = list(session.motions.values())[0]\n            frame_idx = session.frame_idx\n\n            if frame_idx >= target_character_motion.length:\n                # frame idx larger than the motion, could not snap\n                return\n\n            for constraint_name in [\"Full-Body\", \"End-Effectors\"]:\n                if (\n                    constraint_name in session.constraints\n                    and frame_idx in session.constraints[constraint_name].keyframes\n                ):\n                    pos = session.constraints[constraint_name].keyframes[frame_idx][\"joints_pos\"]\n                    rot = session.constraints[constraint_name].keyframes[frame_idx][\"joints_rot\"]\n\n                    # update the full joints_pos of the character to match the constraints\n                    target_character_motion.update_pose_at_frame(\n                        frame_idx,\n                        joints_pos=pos,\n                        joints_rot=rot,\n                    )\n                    target_character_motion.set_frame(frame_idx)\n                    return  # motion already fully changed\n\n            if \"2D Root\" in session.constraints and frame_idx in session.constraints[\"2D Root\"].keyframes:\n                # update only the root position\n                new_root_pos = session.constraints[\"2D Root\"].keyframes[frame_idx]\n                old_root_pos = target_character_motion.get_projected_root_pos(frame_idx)\n                root_diff = new_root_pos - old_root_pos\n                root_diff[1] = 0.0  # don't change height\n\n                new_joints_pos = (\n                    target_character_motion.joints_pos[frame_idx]\n                    + to_torch(\n                        root_diff,\n                        device=target_character_motion.joints_pos.device,\n                        dtype=target_character_motion.joints_pos.dtype,\n                    )[None]\n                )\n                rot = target_character_motion.joints_rot[frame_idx]\n\n                target_character_motion.update_pose_at_frame(\n                    frame_idx,\n                    joints_pos=new_joints_pos,\n                    joints_rot=rot,\n                )\n                target_character_motion.set_frame(frame_idx)\n\n        @gui_clear_all_constraints_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n            with session.timeline_data[\"keyframe_update_lock\"]:\n                # use the lock here to wait for any constraint updates to finish\n                for constraint in list(session.constraints.values()):\n                    constraint.clear()\n                client.timeline.clear_keyframes()\n                client.timeline.clear_intervals()\n            if gui_dense_path_checkbox.value:\n                gui_dense_path_checkbox.value = False\n                if \"2D Root\" in session.constraints:\n                    session.constraints[\"2D Root\"].set_dense_path(False)\n\n        # generation callback\n        @gui_generate_button.on_click\n        def _(event: viser.GuiEvent) -> None:\n            event_client = event.client\n            session = get_active_session(event_client)\n            if session is None:\n                return\n\n            generating_notif = event_client.add_notification(\n                title=\"Generating motion...\",\n                body=\"Generating motions for the given prompt!\",\n                loading=True,\n                with_close_button=False,\n            )\n            gui_generate_button.disabled = True\n            client.timeline.disable_constraints()\n\n            num_samples = gui_num_samples_slider.value\n            timeline = session.client.timeline\n\n            # sort them to avoid issues:\n            prompt_values = sorted([x for x in timeline._prompts.values()], key=lambda x: x.start_frame)\n\n            texts = [x.text for x in prompt_values]\n            num_frames = compute_prompt_num_frames(prompt_values)\n\n            # compute the total duration\n            total_nb_frames = sum(num_frames)\n            total_duration = total_nb_frames / session.model_fps\n\n            # update just in case\n            set_new_duration(client_id, total_duration)\n\n            transitions_parameters = {\n                \"num_transition_frames\": gui_num_transition_frames_slider.value,\n            }\n\n            # G1: postprocessing is disabled (does not work well for this model).\n            postprocess_parameters = {\n                \"post_processing\": (False if \"g1\" in session.model_name else gui_postprocess_checkbox.value),\n                \"root_margin\": gui_root_margin.value,\n            }\n            try:\n                demo.generate(\n                    event_client,\n                    texts,\n                    num_frames,\n                    num_samples,\n                    gui_seed.value,\n                    gui_diffusion_steps_slider.value,\n                    cfg_weight=[\n                        gui_cfg_text_weight_slider.value,\n                        gui_cfg_constraint_weight_slider.value,\n                    ],\n                    cfg_type=\"separated\" if gui_cfg_checkbox.value else \"nocfg\",\n                    postprocess_parameters=postprocess_parameters,\n                    transitions_parameters=transitions_parameters,\n                    real_robot_rotations=gui_real_robot_rotations_checkbox.value,\n                )\n                session.max_frame_idx = int(session.cur_duration * session.model_fps - 1)\n                session.max_frame_idx = int(session.cur_duration * session.model_fps) - 1\n                if session.frame_idx > session.max_frame_idx:\n                    session.frame_idx = session.max_frame_idx\n\n                if num_samples > 1:\n                    # add mesh selector to choose character to commit\n                    def commit_motion(event: viser.GuiEvent) -> None:\n                        target = event.target\n                        commit_name = target.name.split(\"/\")[1]  # e.g. /character0/simple_skinned\n                        print(f\"Committing motion for character: {commit_name}\")\n                        # delete non-selected motions\n                        new_motion_kwargs = None\n                        for character_name, motion in session.motions.items():\n                            if character_name == commit_name:\n                                new_motion_kwargs = {\n                                    \"skeleton\": session.skeleton,\n                                    \"joints_rot\": motion.joints_rot,\n                                    \"foot_contacts\": motion.foot_contacts,\n                                }\n                                root_x_offset = motion.joints_pos[0, session.skeleton.root_idx, 0]\n                                new_joints_pos = motion.joints_pos.clone()\n                                new_joints_pos[..., 0] -= root_x_offset\n                                new_motion_kwargs[\"joints_pos\"] = new_joints_pos\n                                break\n                        # clear and re-add the selected motion\n                        demo.clear_motions(event_client.client_id)\n                        demo.add_character_motion(event_client, **new_motion_kwargs)\n                        gui_edit_constraint_button.disabled = False\n                        gui_generate_button.disabled = False\n                        gui_snap_to_constraint_button.disabled = False\n                        client.timeline.enable_constraints()\n                        gui_generate_button.label = \"Generate\"\n                        gui_save_example_button.disabled = False\n                        gui_save_motion_button.disabled = False\n                        gui_download_button.disabled = False\n                        gui_save_constraints_button.disabled = False\n                        gui_load_example_button.disabled = False\n\n                    for motion in session.motions.values():\n                        char = motion.character\n                        character_name = char.name  # e.g. \"character0\"\n                        if char.skinned_mesh is not None:\n                            char.skinned_mesh.on_click(commit_motion)\n                        elif char.g1_mesh_rig is not None:\n                            # Register click on every part so any part can be clicked,\n                            # and use highlight_group so the whole robot highlights together.\n                            for handle in char.g1_mesh_rig.mesh_handles:\n                                handle.on_click(commit_motion, highlight_group=character_name)\n\n                    gui_edit_constraint_button.disabled = True\n                    gui_generate_button.disabled = True\n                    gui_snap_to_constraint_button.disabled = True\n                    gui_generate_button.label = \"Choose Sample Before Generating\"\n                    gui_save_example_button.disabled = True\n                    gui_save_motion_button.disabled = True\n                    gui_download_button.disabled = True\n                    gui_save_constraints_button.disabled = True\n                    gui_load_example_button.disabled = True\n                else:\n                    gui_edit_constraint_button.disabled = False\n                    gui_generate_button.disabled = False\n                    gui_snap_to_constraint_button.disabled = False\n                    client.timeline.enable_constraints()\n\n                generating_notif.title = \"Motion generation finished!\"\n                generating_notif.body = \"Motions have been generated successfully for the given prompt.\"\n                if num_samples > 1:\n                    generating_notif.body += \" Now choose which sample to commit.\"\n                generating_notif.loading = False\n                generating_notif.with_close_button = True\n                generating_notif.auto_close_seconds = 5.0\n                generating_notif.color = \"green\"\n\n                # put the motion at zero\n                demo.set_frame(client_id, 0)\n\n            except Exception as e:\n                import traceback\n\n                traceback.print_exc()\n                print(f\"Error during generation for client {event_client.client_id}: {e}\")\n                # Re-enable buttons and notify the user\n                if event_client.client_id in demo.client_sessions:\n                    session = demo.client_sessions[event_client.client_id]\n                    gui_generate_button.disabled = False\n                    gui_load_example_button.disabled = False\n                    gui_save_example_button.disabled = False\n                    gui_save_motion_button.disabled = False\n                    gui_download_button.disabled = False\n                    try:\n                        event_client.add_notification(\n                            title=\"Generation failed!\",\n                            body=f\"Error: {str(e)}\",\n                            auto_close_seconds=5.0,\n                            color=\"red\",\n                        )\n                    except Exception:\n                        pass\n                demo.check_cuda_health()\n\n    #\n    # Visualization settings\n    #\n    with tab_group.add_tab(\"Visualize\", viser.Icon.EYE):\n        with client.gui.add_folder(\"Playback\", expand_by_default=True):\n            gui_model_fps = client.gui.add_number(\"Model FPS\", initial_value=model_fps, disabled=True)\n            gui_playback_speed_buttons = client.gui.add_button_group(\n                \"Playback Speed\",\n                options=[\n                    \"0.5x\",\n                    \"1x\",\n                    \"2x\",\n                ],\n            )\n            gui_playback_speed_buttons.value = \"1x\"\n\n            @client.timeline.on_frame_change\n            def handle_timeline_frame_change(new_frame_idx: int):\n                \"\"\"Update the frame when the user clicks on the timeline.\"\"\"\n                demo.set_frame(client_id, new_frame_idx, update_timeline=False)\n                session = demo.client_sessions.get(client_id)\n                if session is not None:\n                    if session.edit_mode and session.motions:\n                        motion = list(session.motions.values())[0]\n                        snapshot_frame_idx = min(session.frame_idx, motion.length - 1)\n                        ensure_edit_snapshot(session, motion, snapshot_frame_idx)\n                    update_snap_to_constraint_button(session)\n\n            @client.timeline.on_prompt_add\n            async def _on_add(\n                prompt_id: str,\n                start_frame: int,\n                end_frame: int,\n                text: str,\n                color: tuple[int, int, int] | None,\n            ) -> None:\n                update_duration_auto()\n\n            @client.timeline.on_prompt_update\n            async def _on_update(prompt_id: str, new_text: str) -> None:\n                update_duration_auto()\n\n            @client.timeline.on_prompt_resize\n            async def _on_resize(prompt_id: str, new_start: int, new_end: int) -> None:\n                update_duration_auto()\n\n            @client.timeline.on_prompt_move\n            async def _on_move(prompt_id: str, new_start: int, new_end: int) -> None:\n                update_duration_auto()\n\n            @client.timeline.on_prompt_delete\n            async def _on_delete(prompt_id: str) -> None:\n                update_duration_auto()\n\n            def play_pause_button_callback(session: ClientSession):\n                session.playing = not session.playing\n\n            def next_frame_callback(session: ClientSession):\n                if session.frame_idx < session.max_frame_idx:\n                    session.frame_idx += 1\n                if session.frame_idx == session.max_frame_idx:\n                    pass\n                demo.set_frame(client_id, session.frame_idx)\n\n            def prev_frame_callback(session: ClientSession):\n                if session.frame_idx > 0:\n                    session.frame_idx -= 1\n                if session.frame_idx == 0:\n                    pass\n                demo.set_frame(client_id, session.frame_idx)\n\n            @gui_playback_speed_buttons.on_click\n            def _(_) -> None:\n                if not demo.client_active(client_id):\n                    return\n                speed_map = {\n                    \"0.5x\": 0.5,\n                    \"1x\": 1.0,\n                    \"2x\": 2.0,\n                }\n                session = demo.client_sessions[client_id]\n                session.playback_speed = speed_map[gui_playback_speed_buttons.value]\n\n        with client.gui.add_folder(\"Body options\", expand_by_default=True):\n            gui_viz_skinned_mesh_checkbox = client.gui.add_checkbox(\"Show Mesh\", initial_value=True)\n            gui_viz_skinned_mesh_opacity_slider = client.gui.add_slider(\n                \"Mesh Opacity\", min=0.0, max=1.0, step=0.01, initial_value=1.0\n            )\n            gui_viz_skeleton_checkbox = client.gui.add_checkbox(\"Show Skeleton\", initial_value=False)\n            gui_viz_foot_contacts_checkbox = client.gui.add_checkbox(\"Show Foot Contacts\", initial_value=False)\n            gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value\n        with client.gui.add_folder(\"Camera options\", expand_by_default=True):\n            gui_camera_fov_slider = client.gui.add_slider(\n                \"Camera FOV (deg)\",\n                min=30.0,\n                max=90.0,\n                step=1.0,\n                initial_value=45.0,\n            )\n            client.camera.fov = np.deg2rad(gui_camera_fov_slider.value)\n        with client.gui.add_folder(\"Interface options\", expand_by_default=True):\n            gui_show_timeline_checkbox = client.gui.add_checkbox(\n                \"Show Timeline\",\n                initial_value=True,\n            )\n            gui_show_constraint_tracks_checkbox = client.gui.add_checkbox(\n                \"Show Constraint tracks\",\n                initial_value=True,\n            )\n            gui_show_constraint_labels_checkbox = client.gui.add_checkbox(\n                \"Show Constraint labels\",\n                initial_value=True,\n            )\n            gui_show_starting_direction_checkbox = client.gui.add_checkbox(\n                \"Show Starting Direction\",\n                initial_value=True,\n            )\n            gui_dark_mode_checkbox = client.gui.add_checkbox(\n                \"Dark Mode\",\n                initial_value=False,  # Default to light mode\n            )\n            gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value\n            demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value)\n\n        @gui_dark_mode_checkbox.on_update\n        def _(_):\n            # Apply the theme using configure_theme (pass uuid so titlebar toggle stays)\n            demo.configure_theme(\n                client,\n                gui_dark_mode_checkbox.value,\n                titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid,\n            )\n            session = demo.client_sessions[client.client_id]\n            for motion in session.motions.values():\n                motion.character.change_theme(gui_dark_mode_checkbox.value)\n\n        # Show dark mode toggle in titlebar (right of Github), hide sidebar checkbox\n        demo.configure_theme(\n            client,\n            gui_dark_mode_checkbox.value,\n            titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid,\n        )\n        gui_dark_mode_checkbox.visible = False\n\n        @gui_show_constraint_labels_checkbox.on_update\n        def _(_):\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            for constraint in session.constraints.values():\n                constraint.set_label_visibility(gui_show_constraint_labels_checkbox.value)\n\n        @gui_show_timeline_checkbox.on_update\n        def _(_):\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            session.client.timeline.set_visible(gui_show_timeline_checkbox.value)\n            gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value\n            if gui_show_timeline_checkbox.value:\n                demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value)\n\n        @gui_show_constraint_tracks_checkbox.on_update\n        def _(_):\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value)\n\n        @gui_show_starting_direction_checkbox.on_update\n        def _(_):\n            if not demo.client_active(client_id):\n                return\n            demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value)\n\n        @gui_viz_skeleton_checkbox.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value\n            if not gui_viz_skeleton_checkbox.value:\n                gui_viz_foot_contacts_checkbox.value = False\n            for motion in session.motions.values():\n                motion.character.set_skeleton_visibility(gui_viz_skeleton_checkbox.value)\n\n        @gui_viz_foot_contacts_checkbox.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            for motion in session.motions.values():\n                motion.character.set_show_foot_contacts(\n                    gui_viz_foot_contacts_checkbox.value, frame_idx=motion.cur_frame_idx\n                )\n\n        @gui_viz_skinned_mesh_checkbox.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            for motion in session.motions.values():\n                motion.character.set_skinned_mesh_visibility(gui_viz_skinned_mesh_checkbox.value)\n\n        @gui_viz_skinned_mesh_opacity_slider.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            session = demo.client_sessions[client_id]\n            for motion in session.motions.values():\n                motion.character.set_skinned_mesh_opacity(gui_viz_skinned_mesh_opacity_slider.value)\n\n        @gui_camera_fov_slider.on_update\n        def _(_) -> None:\n            if not demo.client_active(client_id):\n                return\n            client.camera.fov = np.deg2rad(gui_camera_fov_slider.value)\n\n            #\n\n    # Instructions tab\n    #\n    with tab_group.add_tab(\"Instructions\", viser.Icon.INFO_CIRCLE):\n        client.gui.add_markdown(DEMO_UI_INSTRUCTIONS_TAB_MD)\n\n    #\n    # Keyboard events\n    #\n    space_pressed = [False]\n\n    @client.scene.on_keyboard_event(\"keydown\", debounce_ms=100)\n    def handle_key(event: viser.KeyboardEvent) -> None:\n        # Check if client session still exists\n        if client_id not in demo.client_sessions:\n            return\n\n        session = demo.client_sessions[client_id]\n\n        if event.event_type == \"keyup\":\n            if event.key == \" \":\n                space_pressed[0] = False\n            return\n\n        # Space bar: only toggle on FIRST press\n        if event.key == \" \":\n            if not space_pressed[0]:\n                space_pressed[0] = True\n                play_pause_button_callback(session)\n            return\n\n        # Handle arrow keys: frame navigation (fast OS repeat with 50ms debounce).\n        elif event.key == \"ArrowLeft\":\n            prev_frame_callback(session)\n        elif event.key == \"ArrowRight\":\n            next_frame_callback(session)\n\n    gui_elements = GuiElements(\n        gui_play_pause_button=gui_play_pause_button,\n        gui_next_frame_button=gui_next_frame_button,\n        gui_prev_frame_button=gui_prev_frame_button,\n        gui_generate_button=gui_generate_button,\n        gui_model_fps=gui_model_fps,\n        gui_timeline=gui_timeline,\n        gui_viz_skeleton_checkbox=gui_viz_skeleton_checkbox,\n        gui_viz_foot_contacts_checkbox=gui_viz_foot_contacts_checkbox,\n        gui_viz_skinned_mesh_checkbox=gui_viz_skinned_mesh_checkbox,\n        gui_viz_skinned_mesh_opacity_slider=gui_viz_skinned_mesh_opacity_slider,\n        gui_camera_fov_slider=gui_camera_fov_slider,\n        gui_duration_slider=gui_duration_slider,\n        gui_num_samples_slider=gui_num_samples_slider,\n        gui_cfg_checkbox=gui_cfg_checkbox,\n        gui_cfg_text_weight_slider=gui_cfg_text_weight_slider,\n        gui_cfg_constraint_weight_slider=gui_cfg_constraint_weight_slider,\n        gui_diffusion_steps_slider=gui_diffusion_steps_slider,\n        gui_seed=gui_seed,\n        gui_postprocess_checkbox=gui_postprocess_checkbox,\n        gui_root_margin=gui_root_margin,\n        gui_real_robot_rotations_checkbox=gui_real_robot_rotations_checkbox,\n        gui_dark_mode_checkbox=gui_dark_mode_checkbox,\n        gui_use_soma_layer_checkbox=gui_use_soma_layer_checkbox,\n    )\n    return (\n        gui_elements,\n        timeline_tracks,\n        example_dict,\n        gui_examples_dropdown,\n        gui_save_example_path_text,\n        gui_model_selector,\n    )\n"
  },
  {
    "path": "kimodo/exports/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Export utilities: MuJoCo, BVH, SMPLX/AMASS, and motion I/O helpers.\"\"\"\n\nfrom .bvh import bvh_to_kimodo_motion, motion_to_bvh_bytes, read_bvh_frame_time_seconds, save_motion_bvh\nfrom .motion_convert_lib import convert_motion_files\nfrom .motion_formats import (\n    infer_npz_kind,\n    infer_source_format_from_path,\n    infer_target_format_from_path,\n    resolve_source_fps,\n)\nfrom .motion_io import (\n    KIMODO_CONVERT_TARGET_FPS,\n    amass_npz_to_bytes,\n    complete_motion_dict,\n    g1_csv_to_bytes,\n    kimodo_npz_to_bytes,\n    load_amass_npz,\n    load_g1_csv,\n    load_kimodo_npz,\n    load_kimodo_npz_as_torch,\n    load_motion_file,\n    motion_dict_to_numpy,\n    save_kimodo_npz,\n    save_kimodo_npz_at_target_fps,\n)\nfrom .mujoco import MujocoQposConverter, apply_g1_real_robot_projection\nfrom .smplx import (\n    AMASSConverter,\n    amass_npz_to_kimodo_motion,\n    get_amass_parameters,\n    kimodo_y_up_to_amass_coord_rotation_matrix,\n)\n\n__all__ = [\n    \"AMASSConverter\",\n    \"KIMODO_CONVERT_TARGET_FPS\",\n    \"MujocoQposConverter\",\n    \"amass_npz_to_bytes\",\n    \"amass_npz_to_kimodo_motion\",\n    \"apply_g1_real_robot_projection\",\n    \"bvh_to_kimodo_motion\",\n    \"complete_motion_dict\",\n    \"convert_motion_files\",\n    \"g1_csv_to_bytes\",\n    \"get_amass_parameters\",\n    \"infer_npz_kind\",\n    \"infer_source_format_from_path\",\n    \"infer_target_format_from_path\",\n    \"kimodo_npz_to_bytes\",\n    \"kimodo_y_up_to_amass_coord_rotation_matrix\",\n    \"load_amass_npz\",\n    \"load_g1_csv\",\n    \"load_kimodo_npz\",\n    \"load_kimodo_npz_as_torch\",\n    \"load_motion_file\",\n    \"motion_dict_to_numpy\",\n    \"motion_to_bvh_bytes\",\n    \"read_bvh_frame_time_seconds\",\n    \"resolve_source_fps\",\n    \"save_kimodo_npz\",\n    \"save_kimodo_npz_at_target_fps\",\n    \"save_motion_bvh\",\n]\n"
  },
  {
    "path": "kimodo/exports/bvh.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Export utilities for converting internal motion representations into common file formats.\n\nThis module is intended to hold lightweight serialization / export helpers that can be reused\noutside of interactive demos.\n\"\"\"\n\nimport os\nimport tempfile\nfrom pathlib import Path\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\n\nfrom kimodo.geometry import matrix_to_quaternion as _matrix_to_quaternion\n\n\ndef _strip_end_site_blocks(bvh_text: str) -> str:\n    \"\"\"Remove all 'End Site { ... }' blocks from BVH text so output matches original format.\n\n    bvhio adds an End Site for every leaf joint when writing; we do not set EndSite on joints, so we\n    post-process the string to remove these blocks for Blender/original compatibility.\n    \"\"\"\n    lines = bvh_text.splitlines(keepends=True)\n    result = []\n    i = 0\n    while i < len(lines):\n        line = lines[i]\n        if \"End Site\" in line:\n            # Skip this line and the following block { ... }; brace-count to find closing }\n            i += 1\n            if i < len(lines) and \"{\" in lines[i]:\n                i += 1\n                depth = 1\n                while i < len(lines) and depth > 0:\n                    if \"{\" in lines[i]:\n                        depth += 1\n                    if \"}\" in lines[i]:\n                        depth -= 1\n                    i += 1\n            continue\n        result.append(line)\n        i += 1\n    return \"\".join(result)\n\n\ndef _coerce_batch(name: str, x: torch.Tensor, *, expected_ndim: int) -> torch.Tensor:\n    \"\"\"Coerce (T, ...) or (1, T, ...) into (T, ...).\"\"\"\n    if x.ndim == expected_ndim:\n        return x\n    if x.ndim == expected_ndim + 1:\n        if int(x.shape[0]) != 1:\n            raise ValueError(\n                f\"{name} has batch dimension B={int(x.shape[0])}, but BVH export \" \"only supports a single clip (B==1).\"\n            )\n        return x[0]\n    raise ValueError(f\"{name} must have shape (T, ...) or (1, T, ...); got {tuple(x.shape)}\")\n\n\ndef motion_to_bvh(\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n    *,\n    skeleton,\n    fps: float,\n    standard_tpose: bool = False,\n) -> str:\n    \"\"\"Convert local rotations and root positions to BVH format; return UTF-8 string.\n\n    Args:\n        local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.\n        root_positions: (T, 3) or (1, T, 3) root joint positions (e.g. from posed joints).\n        skeleton: Skeleton with bone_order_names, bvh_neutral_joints, etc.\n        fps: Frames per second for the motion.\n        standard_tpose: If True, export with the rest pose being the standard T-pose rather than the rest pose consistent with the BONES-SEED dataset.\n    Notes:\n        BVH is plain-text. Root is named \"Root\" with ZYX rotation order; leaf joints\n        have no End Site block.\n    \"\"\"\n    try:\n        import bvhio  # type: ignore[import-not-found]\n        import glm  # type: ignore[import-not-found]\n        from SpatialTransform import Pose  # type: ignore[import-not-found]\n    except Exception as e:  # pragma: no cover\n        raise ImportError(\n            \"BVH export requires `bvhio` (and its deps `PyGLM` + `SpatialTransform`). \"\n            \"Install with: `pip install bvhio`.\"\n        ) from e\n\n    local_rot_mats = local_rot_mats.detach()\n    root_positions = root_positions.detach()\n    # SOMA: accept either somaskel30 (convert to 77) or somaskel77 (use as-is)\n    if skeleton.name == \"somaskel30\":\n        local_rot_mats = skeleton.to_SOMASkeleton77(local_rot_mats)\n        skeleton = skeleton.somaskel77\n\n    if standard_tpose:\n        neutral = skeleton.neutral_joints.detach().cpu().numpy()\n    else:\n        # transform local rots to the original rest pose consistent with the BONES-SEED dataset\n        local_rot_mats, _ = skeleton.from_standard_tpose(local_rot_mats)\n        neutral = skeleton.bvh_neutral_joints.detach().cpu().numpy()\n\n    joint_names = list(skeleton.bone_order_names)\n    parents = skeleton.joint_parents.detach().cpu().numpy().astype(int)\n    root_idx = int(skeleton.root_idx)\n\n    local_rot_mats = _coerce_batch(\"local_rot_mats\", local_rot_mats, expected_ndim=4)\n    T, J = local_rot_mats.shape[:2]\n    q_wxyz = _matrix_to_quaternion(local_rot_mats).detach().cpu().numpy()  # [T, J, 4]\n\n    root_xyz = _coerce_batch(\"root_positions\", root_positions, expected_ndim=2)\n    root_xyz = root_xyz.cpu().numpy()  # [T, 3]\n\n    # Build BVH hierarchy: Root (wrapper at origin) -> Hips (pelvis with offset in meters) -> ...\n    # Offsets are in meters to match the original format.\n    children: dict[int, list[int]] = {i: [] for i in range(J)}\n    for i, p in enumerate(parents):\n        if p >= 0:\n            children[int(p)].append(int(i))\n\n    _ROOT_CHANNELS = [\n        \"Xposition\",\n        \"Yposition\",\n        \"Zposition\",\n        \"Zrotation\",\n        \"Yrotation\",\n        \"Xrotation\",\n    ]\n    _JOINT_CHANNELS = [\"Zrotation\", \"Yrotation\", \"Xrotation\"]\n\n    # Scale from meters to centimeters (match original SEED data BVH scale).\n    neutral = neutral * 100\n    root_xyz = root_xyz * 100\n\n    # Hips offset from Root: use skeleton neutral; if root is at origin (zeros), use a\n    # nominal pelvis height so the hierarchy is non-degenerate in Blender.\n    hips_offset = neutral[root_idx]\n    if (hips_offset == 0).all():\n        hips_offset = np.array([0.0, 100.0, 0.0], dtype=neutral.dtype)  # 1 m in cm\n\n    def _make_joint(i: int) -> \"bvhio.BvhJoint\":\n        name = joint_names[i]\n        j = bvhio.BvhJoint(name, offset=glm.vec3(0, 0, 0))\n        if i == root_idx:\n            # Hips: offset from Root (origin) in cm\n            off = hips_offset\n            j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))\n            j.Channels = _ROOT_CHANNELS.copy()\n        else:\n            p = int(parents[i])\n            off = neutral[i] - neutral[p]\n            j.Offset = glm.vec3(float(off[0]), float(off[1]), float(off[2]))\n            j.Channels = _JOINT_CHANNELS.copy()\n\n        for c in children[i]:\n            j.Children.append(_make_joint(c))\n        return j\n\n    # Wrapper Root at origin; single child is Hips (skeleton root).\n    root_wrapper = bvhio.BvhJoint(\"Root\", offset=glm.vec3(0.0, 0.0, 0.0))\n    root_wrapper.Channels = _ROOT_CHANNELS.copy()\n    root_wrapper.Children.append(_make_joint(root_idx))\n    root_joint = root_wrapper\n\n    # Populate keyframes: Root = identity/zero, Hips = root motion, others = local rotation.\n    bvh_layout = root_joint.layout()\n    name_to_id = {n: idx for idx, n in enumerate(joint_names)}\n    ordered_joint_ids = []\n    for bj, _, _ in bvh_layout:\n        if bj.Name == \"Root\":\n            ordered_joint_ids.append(None)\n        else:\n            ordered_joint_ids.append(name_to_id[bj.Name])\n\n    bvh_joints = [bj for bj, _, _ in bvh_layout]\n    for bj in bvh_joints:\n        bj.Keyframes = [None] * T  # type: ignore[list-item]\n\n    identity_quat = glm.quat(1.0, 0.0, 0.0, 0.0)\n    zero_vec = glm.vec3(0.0, 0.0, 0.0)\n    for t in range(T):\n        for bj, jid in zip(bvh_joints, ordered_joint_ids):\n            if jid is None:\n                position = zero_vec\n                rotation = identity_quat\n            elif jid == root_idx:\n                pos = root_xyz[t]\n                position = glm.vec3(float(pos[0]), float(pos[1]), float(pos[2]))\n                qw, qx, qy, qz = q_wxyz[t, jid]\n                rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))\n            else:\n                position = zero_vec\n                qw, qx, qy, qz = q_wxyz[t, jid]\n                rotation = glm.quat(float(qw), float(qx), float(qy), float(qz))\n            bj.Keyframes[t] = Pose(position, rotation)  # type: ignore[index]\n\n    container = bvhio.BvhContainer(root_joint, frameCount=T, frameTime=1.0 / float(fps))\n    with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".bvh\", delete=False, encoding=\"utf-8\") as f:\n        tmp_path = f.name\n    try:\n        bvhio.writeBvh(tmp_path, container, percision=6)\n        bvh_text = Path(tmp_path).read_text(encoding=\"utf-8\")\n        return _strip_end_site_blocks(bvh_text)\n    finally:\n        try:\n            os.remove(tmp_path)\n        except Exception:\n            pass\n\n\ndef motion_to_bvh_bytes(\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n    *,\n    skeleton,\n    fps: float,\n    standard_tpose: bool = False,\n) -> bytes:\n    \"\"\"Convert local rotations and root positions to BVH bytes (UTF-8).\n\n    Convenience wrapper around :func:`motion_to_bvh`.\n    \"\"\"\n    return motion_to_bvh(\n        local_rot_mats,\n        root_positions,\n        skeleton=skeleton,\n        fps=fps,\n        standard_tpose=standard_tpose,\n    ).encode(\"utf-8\")\n\n\ndef save_motion_bvh(\n    path: Union[str, Path],\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n    *,\n    skeleton,\n    fps: float,\n    standard_tpose: bool = False,\n) -> None:\n    \"\"\"Write local rotations and root positions to a BVH file at the given path.\"\"\"\n    Path(path).write_text(\n        motion_to_bvh(local_rot_mats, root_positions, skeleton=skeleton, fps=fps, standard_tpose=standard_tpose),\n        encoding=\"utf-8\",\n    )\n\n\ndef read_bvh_frame_time_seconds(path: Union[str, Path]) -> float:\n    \"\"\"Read ``Frame Time`` from a BVH file (seconds per frame).\"\"\"\n    with open(path, encoding=\"utf-8\") as f:\n        for line in f:\n            if \"Frame Time:\" in line:\n                parts = line.split()\n                return float(parts[-1])\n    raise ValueError(f\"Could not find 'Frame Time:' in {path}\")\n\n\ndef bvh_to_kimodo_motion(\n    path: Union[str, Path],\n    skeleton=None,\n    *,\n    standard_tpose: bool = False,\n) -> Tuple:\n    \"\"\"Load a Kimodo-style SOMA BVH into a Kimodo motion dict.\n\n    Expects the same hierarchy as :func:`save_motion_bvh` (``Root`` wrapper + SOMA77 joints).\n    The frame rate is always read from the BVH ``Frame Time`` header.  Callers\n    that need a different playback rate should resample the returned motion dict\n    (see :func:`~kimodo.exports.motion_io.resample_motion_dict_to_kimodo_fps`).\n\n    Returns:\n        ``(motion_dict, source_fps)`` where ``source_fps`` is the native BVH\n        frame rate read from the file header.\n    \"\"\"\n    from kimodo.exports.motion_io import complete_motion_dict\n    from kimodo.skeleton.bvh import parse_bvh_motion\n    from kimodo.skeleton.registry import build_skeleton\n\n    if skeleton is None:\n        skeleton = build_skeleton(77)\n    device = skeleton.neutral_joints.device\n\n    local_rot_mats, root_trans, bvh_fps = parse_bvh_motion(str(path))\n    local_rot_mats = local_rot_mats.to(device=device)\n    root_trans = root_trans.to(device=device)\n\n    if int(local_rot_mats.shape[1]) != int(skeleton.nbjoints):\n        raise ValueError(\n            f\"BVH has {local_rot_mats.shape[1]} joints but skeleton has {skeleton.nbjoints}; \"\n            \"use a Kimodo-exported SOMA BVH or matching skeleton.\"\n        )\n    if not standard_tpose:\n        local_rot_mats, _ = skeleton.to_standard_tpose(local_rot_mats)\n\n    return complete_motion_dict(local_rot_mats, root_trans, skeleton, float(bvh_fps)), bvh_fps\n"
  },
  {
    "path": "kimodo/exports/motion_convert_lib.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Library API for converting between Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV.\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\n\nimport numpy as np\n\nfrom kimodo.exports.bvh import bvh_to_kimodo_motion, save_motion_bvh\nfrom kimodo.exports.motion_formats import (\n    infer_source_format_from_path,\n    infer_target_format_from_path,\n    resolve_source_fps,\n)\nfrom kimodo.exports.motion_io import (\n    load_amass_npz,\n    load_g1_csv,\n    load_kimodo_npz_as_torch,\n    save_kimodo_npz_at_target_fps,\n)\nfrom kimodo.exports.mujoco import MujocoQposConverter\nfrom kimodo.exports.smplx import AMASSConverter\nfrom kimodo.skeleton.registry import build_skeleton\n\n\ndef convert_motion_files(\n    input_path: str,\n    output_path: str,\n    *,\n    from_fmt: str | None = None,\n    to_fmt: str | None = None,\n    source_fps: float | None = None,\n    z_up: bool = True,\n    mujoco_rest_zero: bool = False,\n    bvh_standard_tpose: bool = False,\n) -> None:\n    \"\"\"Convert a motion file between Kimodo-supported formats.\n\n    Supported pairs (hub-and-spoke through Kimodo NPZ):\n\n    - amass <-> kimodo\n    - soma-bvh <-> kimodo\n    - g1-csv <-> kimodo\n\n    Args:\n        input_path: Source file (``.npz``, ``.bvh``, or ``.csv``).\n        output_path: Destination file.\n        from_fmt: Source format; inferred from extension/contents when ``None``.\n        to_fmt: Target format; inferred from extension when ``None``.\n        source_fps: Source motion frame rate (Hz).  If provided, trusted as-is.\n            If ``None``, auto-detected from BVH ``Frame Time``, AMASS\n            ``mocap_frame_rate``, or default 30.\n        z_up: For AMASS conversions, apply the Z-up <-> Kimodo Y-up transform.\n        mujoco_rest_zero: For G1 CSV, joint angles relative to MuJoCo rest pose.\n        bvh_standard_tpose: If input or output is BVH: the BVH file uses the standard T-pose \n            as its rest pose instead of the BONES-SEED rest pose.\n    \"\"\"\n    from_fmt = from_fmt or infer_source_format_from_path(input_path)\n    to_fmt = to_fmt or infer_target_format_from_path(output_path, from_fmt)\n\n    _validate_output_extension(to_fmt, output_path)\n\n    pair = (from_fmt, to_fmt)\n\n    if pair == (\"amass\", \"kimodo\"):\n        sk = build_skeleton(22)\n        effective_source = source_fps\n        if effective_source is None:\n            with np.load(input_path, allow_pickle=True) as z:\n                effective_source = float(z[\"mocap_frame_rate\"]) if \"mocap_frame_rate\" in z.files else 30.0\n        motion = load_amass_npz(input_path, source_fps=effective_source, z_up=z_up)\n        save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)\n        return\n\n    if pair == (\"kimodo\", \"amass\"):\n        data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)\n        if J != 22:\n            raise ValueError(f\"Kimodo→AMASS requires 22 joints (SMPL-X); this file has J={J}.\")\n        sk = build_skeleton(22)\n        effective_source = resolve_source_fps(source_fps, \"kimodo\", input_path, None)\n        converter = AMASSConverter(fps=effective_source, skeleton=sk)\n        converter.convert_save_npz(data, output_path, z_up=z_up)\n        return\n\n    if pair == (\"soma-bvh\", \"kimodo\"):\n        sk = build_skeleton(77)\n        motion, bvh_fps = bvh_to_kimodo_motion(input_path, skeleton=sk, standard_tpose=bvh_standard_tpose)\n        effective_source = source_fps if source_fps is not None else bvh_fps\n        save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)\n        return\n\n    if pair == (\"kimodo\", \"soma-bvh\"):\n        data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)\n        if J == 30:\n            warnings.warn(\n                f\"Input has 30 joints (somaskel30); expanding to somaskel77 for BVH export.\",\n                UserWarning,\n                stacklevel=2,\n            )\n            sk = build_skeleton(30)\n        elif J == 77:\n            sk = build_skeleton(77)\n        else:\n            raise ValueError(f\"Kimodo→BVH requires a SOMA skeleton (30 or 77 joints); this file has J={J}.\")\n        effective_source = resolve_source_fps(source_fps, \"kimodo\", input_path, None)\n        save_motion_bvh(\n            output_path,\n            data[\"local_rot_mats\"],\n            data[\"root_positions\"],\n            skeleton=sk,\n            fps=effective_source,\n            standard_tpose=bvh_standard_tpose,\n        )\n        return\n\n    if pair == (\"g1-csv\", \"kimodo\"):\n        sk = build_skeleton(34)\n        effective_source = resolve_source_fps(source_fps, \"g1-csv\", input_path, None)\n        motion = load_g1_csv(input_path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)\n        save_kimodo_npz_at_target_fps(motion, sk, effective_source, output_path)\n        return\n\n    if pair == (\"kimodo\", \"g1-csv\"):\n        data, J = load_kimodo_npz_as_torch(input_path, ensure_complete=False)\n        if J != 34:\n            raise ValueError(f\"Kimodo→CSV requires G1 with 34 joints; this file has J={J}.\")\n        sk = build_skeleton(34)\n        effective_source = resolve_source_fps(source_fps, \"kimodo\", input_path, None)\n        converter = MujocoQposConverter(sk)\n        qpos = converter.dict_to_qpos(\n            {k: v for k, v in data.items() if k in (\"local_rot_mats\", \"root_positions\")},\n            device=str(sk.neutral_joints.device),\n            numpy=True,\n            mujoco_rest_zero=mujoco_rest_zero,\n        )\n        converter.save_csv(qpos, output_path)\n        return\n\n    raise ValueError(\n        f\"Unsupported conversion {from_fmt!r} → {to_fmt!r}. \"\n        \"Supported: amass↔kimodo (SMPL-X NPZ), soma-bvh↔kimodo, g1-csv↔kimodo.\"\n    )\n\n\ndef _validate_output_extension(to_fmt: str, output_path: str) -> None:\n    lower = output_path.lower()\n    if to_fmt == \"kimodo\" and lower.endswith(\".npz\"):\n        return\n    if to_fmt == \"amass\":\n        if not lower.endswith(\".npz\"):\n            raise ValueError(\"AMASS output must use a .npz path.\")\n    elif to_fmt == \"soma-bvh\":\n        if not lower.endswith(\".bvh\"):\n            raise ValueError(\"SOMA BVH output must use a .bvh path.\")\n    elif to_fmt == \"g1-csv\":\n        if not lower.endswith(\".csv\"):\n            raise ValueError(\"G1 CSV output must use a .csv path.\")\n"
  },
  {
    "path": "kimodo/exports/motion_formats.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Infer motion file formats from paths and NPZ contents.\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom typing import Literal\n\nimport numpy as np\n\nMotionSourceFormat = Literal[\"amass\", \"kimodo\", \"soma-bvh\", \"g1-csv\"]\nMotionTargetFormat = Literal[\"amass\", \"kimodo\", \"soma-bvh\", \"g1-csv\"]\nNpzMotionKind = Literal[\"amass\", \"kimodo\"]\n\n\ndef infer_npz_kind(path: str) -> NpzMotionKind:\n    \"\"\"Classify a ``.npz`` as AMASS SMPL-X or Kimodo from required array keys.\"\"\"\n    with np.load(path, allow_pickle=False) as z:\n        keys = set(z.files)\n    if \"trans\" in keys and \"pose_body\" in keys and \"root_orient\" in keys:\n        return \"amass\"\n    if \"local_rot_mats\" in keys or \"posed_joints\" in keys:\n        return \"kimodo\"\n    raise ValueError(\n        f\"Unrecognized NPZ {path!r}: expected AMASS keys (trans, pose_body, ...) \"\n        \"or Kimodo keys (local_rot_mats, posed_joints, ...).\"\n    )\n\n\ndef infer_source_format_from_path(path: str) -> MotionSourceFormat:\n    \"\"\"Infer converter input format from file extension and NPZ contents when needed.\"\"\"\n    ext = os.path.splitext(path)[1].lower()\n    if ext == \".bvh\":\n        return \"soma-bvh\"\n    if ext == \".csv\":\n        return \"g1-csv\"\n    if ext == \".npz\":\n        return infer_npz_kind(path)  # type: ignore[return-value]\n    raise ValueError(f\"Cannot infer format from extension of {path!r}\")\n\n\ndef infer_target_format_from_path(path: str, from_fmt: MotionSourceFormat) -> MotionTargetFormat:\n    \"\"\"Infer converter output format from destination path and source format.\"\"\"\n    ext = os.path.splitext(path)[1].lower()\n    if ext == \".bvh\":\n        return \"soma-bvh\"\n    if ext == \".csv\":\n        return \"g1-csv\"\n    if ext == \".npz\":\n        if from_fmt == \"amass\":\n            return \"kimodo\"\n        if from_fmt == \"kimodo\":\n            return \"amass\"\n        if from_fmt in (\"g1-csv\", \"soma-bvh\"):\n            return \"kimodo\"\n        raise ValueError(\n            \"Ambiguous .npz output: set --to to 'kimodo' or 'amass' when the input format is not amass/kimodo.\"\n        )\n    raise ValueError(f\"Cannot infer output format from extension of {path!r}\")\n\n\ndef resolve_source_fps(\n    fps: float | None,\n    from_kind: str,\n    input_path: str,\n    data: dict | None,\n) -> float:\n    \"\"\"Resolve source frame rate (Hz) for conversion when ``fps`` is not overridden.\"\"\"\n    if fps is not None:\n        return float(fps)\n    if data is not None and \"mocap_frame_rate\" in data:\n        return float(np.asarray(data[\"mocap_frame_rate\"]).item())\n    if from_kind == \"soma-bvh\":\n        from kimodo.exports.bvh import read_bvh_frame_time_seconds\n\n        return 1.0 / read_bvh_frame_time_seconds(input_path)\n    return 30.0\n"
  },
  {
    "path": "kimodo/exports/motion_io.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Assemble Kimodo NPZ-compatible motion dicts from local rotations + root trajectory.\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport warnings\nfrom typing import Any, Dict, Tuple\n\nimport numpy as np\nimport torch\n\nfrom kimodo.geometry import matrix_to_quaternion, quaternion_to_matrix\nfrom kimodo.motion_rep.feature_utils import compute_heading_angle, compute_vel_xyz\nfrom kimodo.motion_rep.feet import foot_detect_from_pos_and_vel\nfrom kimodo.motion_rep.smooth_root import get_smooth_root_pos\nfrom kimodo.skeleton import SkeletonBase\nfrom kimodo.skeleton.registry import build_skeleton\nfrom kimodo.tools import to_numpy\n\n# Default motion rate for Kimodo NPZ produced by format conversion (matches common model FPS).\nKIMODO_CONVERT_TARGET_FPS = 30.0\n\n\ndef _quaternion_slerp(q0: torch.Tensor, q1: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n    \"\"\"Spherical linear interpolation; ``q0``, ``q1`` (..., 4) wxyz; ``t`` broadcastable to (...,\n    1).\"\"\"\n    if t.dim() < q0.dim():\n        t = t.unsqueeze(-1)\n    dot = (q0 * q1).sum(dim=-1, keepdim=True)\n    q1 = torch.where(dot < 0, -q1, q1)\n    dot = torch.abs(dot).clamp(-1.0, 1.0)\n    theta_0 = torch.acos(dot)\n    sin_theta = torch.sin(theta_0)\n    s0 = torch.sin((1.0 - t) * theta_0) / sin_theta.clamp(min=1e-8)\n    s1 = torch.sin(t * theta_0) / sin_theta.clamp(min=1e-8)\n    q = s0 * q0 + s1 * q1\n    return q / torch.linalg.norm(q, dim=-1, keepdim=True).clamp(min=1e-8)\n\n\ndef resample_motion_dict_to_kimodo_fps(\n    motion_dict: Dict[str, torch.Tensor],\n    skeleton: SkeletonBase,\n    source_fps: float,\n    target_fps: float = KIMODO_CONVERT_TARGET_FPS,\n) -> Tuple[Dict[str, torch.Tensor], bool]:\n    \"\"\"Resample a Kimodo motion dict to ``target_fps``.\n\n    When the fps ratio is close to an integer (e.g. 120 / 30 = 4), the faster\n    stepping method is used (take every *step*-th frame).  Otherwise falls back\n    to linear interp (root) + quaternion slerp (joints).\n\n    Re-runs :func:`complete_motion_dict` at the target rate so derived channels stay consistent.\n\n    Returns:\n        The motion dict and ``True`` if time resampling was applied, else ``False`` (already at\n        ``target_fps`` with matching frame count; only re-derived via FK).\n    \"\"\"\n    local_rot_mats = motion_dict[\"local_rot_mats\"]\n    root_positions = motion_dict[\"root_positions\"]\n    local_rot_mats, root_positions = _coerce_time_local_root(local_rot_mats, root_positions)\n    t_in = int(local_rot_mats.shape[0])\n    if t_in < 1:\n        raise ValueError(\"Motion must have at least one frame.\")\n    if source_fps <= 0:\n        raise ValueError(f\"source_fps must be positive; got {source_fps}\")\n\n    t_out = max(1, int(round(t_in * target_fps / source_fps)))\n    if t_out == t_in and abs(float(source_fps) - float(target_fps)) < 1e-3:\n        return complete_motion_dict(local_rot_mats, root_positions, skeleton, float(target_fps)), False\n\n    ratio = source_fps / target_fps\n    step = round(ratio)\n    if step >= 2 and abs(ratio - step) < 0.05:\n        local_out = local_rot_mats[::step]\n        root_out = root_positions[::step]\n    else:\n        device = local_rot_mats.device\n        dtype = local_rot_mats.dtype\n        u = torch.linspace(0, t_in - 1, t_out, device=device, dtype=dtype)\n        i0 = u.floor().long().clamp(0, t_in - 1)\n        i1 = torch.minimum(i0 + 1, torch.tensor(t_in - 1, device=device))\n        tau_1d = (u - i0.float()).unsqueeze(-1)\n        rp0 = root_positions[i0]\n        rp1 = root_positions[i1]\n        root_out = (1.0 - tau_1d) * rp0 + tau_1d * rp1\n\n        quats = matrix_to_quaternion(local_rot_mats)\n        q0 = quats[i0]\n        q1 = quats[i1]\n        tau_q = (u - i0.float()).view(t_out, 1, 1)\n        quat_out = _quaternion_slerp(q0, q1, tau_q)\n        local_out = quaternion_to_matrix(quat_out)\n\n    return complete_motion_dict(local_out, root_out, skeleton, float(target_fps)), True\n\n\ndef warn_kimodo_npz_framerate(source_fps: float, t_before: int, t_after: int) -> None:\n    \"\"\"Emit a warning after time resampling for Kimodo NPZ (linear root, quaternion slerp per\n    joint).\"\"\"\n    warnings.warn(\n        f\"Resampled motion to {KIMODO_CONVERT_TARGET_FPS:.0f} Hz for Kimodo NPZ \"\n        f\"(source ~{source_fps:.4g} Hz, {t_before} input frames → {t_after} output frames). \"\n        \"Pass --source-fps if the detected source rate is wrong.\",\n        UserWarning,\n        stacklevel=3,\n    )\n\n\ndef _coerce_time_local_root(\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Normalize to shapes (T, J, 3, 3) and (T, 3).\"\"\"\n    if local_rot_mats.dim() == 5:\n        if int(local_rot_mats.shape[0]) != 1:\n            raise ValueError(f\"local_rot_mats batch size must be 1 for single clip; got {local_rot_mats.shape[0]}\")\n        local_rot_mats = local_rot_mats[0]\n    if root_positions.dim() == 3:\n        if int(root_positions.shape[0]) != 1:\n            raise ValueError(f\"root_positions batch size must be 1; got {root_positions.shape[0]}\")\n        root_positions = root_positions[0]\n    if local_rot_mats.dim() != 4:\n        raise ValueError(f\"local_rot_mats must be (T,J,3,3); got {tuple(local_rot_mats.shape)}\")\n    if root_positions.dim() != 2 or int(root_positions.shape[-1]) != 3:\n        raise ValueError(f\"root_positions must be (T,3); got {tuple(root_positions.shape)}\")\n    if int(local_rot_mats.shape[0]) != int(root_positions.shape[0]):\n        raise ValueError(\"local_rot_mats and root_positions must have the same number of frames\")\n    return local_rot_mats, root_positions\n\n\ndef complete_motion_dict(\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n    skeleton: SkeletonBase,\n    fps: float,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Build the Kimodo motion output dict from local rotations and root positions.\n\n    Matches keys written by CLI generation (see docs/source/user_guide/output_formats.md).\n\n    Args:\n        local_rot_mats: (T, J, 3, 3) or (1, T, J, 3, 3) local rotation matrices.\n        root_positions: (T, 3) or (1, T, 3) root / pelvis world positions (meters).\n        skeleton: Skeleton instance (SOMA77, G1, SMPL-X, etc.).\n        fps: Sampling rate (Hz).\n\n    Returns:\n        Dict with tensors ``posed_joints``, ``global_rot_mats``, ``local_rot_mats``,\n        ``foot_contacts``, ``smooth_root_pos``, ``root_positions``, ``global_root_heading``.\n    \"\"\"\n    device = local_rot_mats.device\n    dtype = local_rot_mats.dtype\n    local_rot_mats, root_positions = _coerce_time_local_root(\n        local_rot_mats.to(device=device, dtype=dtype),\n        root_positions.to(device=device, dtype=dtype),\n    )\n\n    global_rot_mats, posed_joints, _ = skeleton.fk(local_rot_mats, root_positions)\n\n    smooth_root_pos = get_smooth_root_pos(root_positions.unsqueeze(0)).squeeze(0)\n\n    lengths = torch.tensor([posed_joints.shape[0]], device=device)\n    velocities = compute_vel_xyz(posed_joints.unsqueeze(0), fps, lengths=lengths).squeeze(0)\n\n    heading_angle = compute_heading_angle(posed_joints.unsqueeze(0), skeleton).squeeze(0)\n    global_root_heading = torch.stack([torch.cos(heading_angle), torch.sin(heading_angle)], dim=-1)\n\n    foot_contacts = foot_detect_from_pos_and_vel(\n        posed_joints.unsqueeze(0),\n        velocities.unsqueeze(0),\n        skeleton,\n        0.15,\n        0.10,\n    ).squeeze(0)\n\n    return {\n        \"posed_joints\": posed_joints,\n        \"global_rot_mats\": global_rot_mats,\n        \"local_rot_mats\": local_rot_mats,\n        \"foot_contacts\": foot_contacts,\n        \"smooth_root_pos\": smooth_root_pos,\n        \"root_positions\": root_positions,\n        \"global_root_heading\": global_root_heading,\n    }\n\n\ndef motion_dict_to_numpy(d: Dict[str, Any]) -> Dict[str, np.ndarray]:\n    \"\"\"Convert motion dict values to numpy arrays for ``np.savez``.\"\"\"\n    out: Dict[str, np.ndarray] = {}\n    for k, v in d.items():\n        if hasattr(v, \"detach\"):\n            out[k] = to_numpy(v)\n        elif isinstance(v, np.ndarray):\n            out[k] = v\n        else:\n            out[k] = np.asarray(v)\n    return out\n\n\ndef save_kimodo_npz(path: str, motion_dict: Dict[str, Any]) -> None:\n    \"\"\"Save a Kimodo-compatible motion dict to ``.npz`` (numpy arrays).\"\"\"\n    np.savez(path, **motion_dict_to_numpy(motion_dict))\n\n\ndef load_kimodo_npz(path: str) -> Dict[str, np.ndarray]:\n    \"\"\"Load arrays from a Kimodo ``.npz`` file.\"\"\"\n    with np.load(path, allow_pickle=False) as data:\n        return {k: np.asarray(data[k]) for k in data.files}\n\n\ndef load_g1_csv(\n    path: str,\n    source_fps: float = KIMODO_CONVERT_TARGET_FPS,\n    *,\n    mujoco_rest_zero: bool = False,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Load a G1 MuJoCo ``qpos`` CSV (``(T, 36)``) into a Kimodo motion dict.\n\n    Args:\n        path: CSV path (comma-separated, no header).\n        source_fps: Source frame rate (Hz) of the CSV data.\n        mujoco_rest_zero: Must match how the CSV was written (see :class:`MujocoQposConverter`).\n    \"\"\"\n    from kimodo.exports.mujoco import MujocoQposConverter\n\n    qpos = np.loadtxt(path, delimiter=\",\")\n    if qpos.ndim != 2 or qpos.shape[-1] != 36:\n        raise ValueError(f\"Expected G1 CSV with shape (T, 36); got {qpos.shape}\")\n    sk = build_skeleton(34)\n    converter = MujocoQposConverter(sk)\n    return converter.qpos_to_motion_dict(qpos, float(source_fps), mujoco_rest_zero=mujoco_rest_zero)\n\n\ndef load_amass_npz(\n    path: str,\n    source_fps: float | None = None,\n    *,\n    z_up: bool = True,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Load an AMASS-style SMPL-X ``.npz`` into a Kimodo motion dict (22 joints).\n\n    Args:\n        path: NPZ with ``trans``, ``root_orient``, ``pose_body``, etc.\n        source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``\n            from the file when present, else 30 Hz.\n        z_up: If ``True``, apply AMASS Z-up to Kimodo Y-up transform (same as CLI).\n    \"\"\"\n    from kimodo.exports.smplx import amass_npz_to_kimodo_motion\n\n    sk = build_skeleton(22)\n    return amass_npz_to_kimodo_motion(path, sk, source_fps=source_fps, z_up=z_up)\n\n\ndef load_kimodo_npz_as_torch(\n    path: str,\n    source_fps: float = KIMODO_CONVERT_TARGET_FPS,\n    *,\n    ensure_complete: bool = True,\n) -> tuple[Dict[str, torch.Tensor], int]:\n    \"\"\"Load a Kimodo NPZ and return all arrays as torch tensors on the skeleton device.\n\n    Args:\n        path: Kimodo NPZ file path.\n        source_fps: Source frame rate (Hz) used for derived channels when\n            ``ensure_complete=True``.\n        ensure_complete: If ``True`` and the NPZ lacks derived channels\n            (``posed_joints``, ``global_rot_mats``, …), run :func:`complete_motion_dict`\n            to fill them from ``local_rot_mats`` + ``root_positions``.\n            If ``False``, load all arrays verbatim (requires ``local_rot_mats``).\n\n    Returns:\n        ``(tensor_dict, num_joints)``\n    \"\"\"\n    raw = load_kimodo_npz(path)\n    if \"local_rot_mats\" in raw:\n        j = int(raw[\"local_rot_mats\"].shape[1])\n    elif \"posed_joints\" in raw:\n        j = int(raw[\"posed_joints\"].shape[1])\n    else:\n        raise ValueError(\"Kimodo NPZ must contain 'local_rot_mats' or 'posed_joints'.\")\n    sk = build_skeleton(j)\n    device = sk.neutral_joints.device\n    dtype = torch.float32\n\n    if not ensure_complete:\n        if \"local_rot_mats\" not in raw:\n            raise ValueError(\"Kimodo NPZ must contain 'local_rot_mats' (and typically 'root_positions').\")\n        out: Dict[str, torch.Tensor] = {}\n        for k, v in raw.items():\n            out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)\n        return out, j\n\n    if \"posed_joints\" in raw and \"global_rot_mats\" in raw:\n        out = {}\n        for k, v in raw.items():\n            out[k] = torch.from_numpy(np.asarray(v)).to(device=device, dtype=dtype)\n        return out, j\n\n    if \"local_rot_mats\" not in raw or \"root_positions\" not in raw:\n        raise ValueError(\"Kimodo NPZ must contain posed_joints+global_rot_mats, or local_rot_mats+root_positions.\")\n    local = torch.from_numpy(np.asarray(raw[\"local_rot_mats\"])).to(device=device, dtype=dtype)\n    root = torch.from_numpy(np.asarray(raw[\"root_positions\"])).to(device=device, dtype=dtype)\n    return complete_motion_dict(local, root, sk, float(source_fps)), j\n\n\ndef save_kimodo_npz_at_target_fps(\n    motion: Dict[str, torch.Tensor],\n    skeleton: SkeletonBase,\n    source_fps: float,\n    output_path: str,\n    target_fps: float = KIMODO_CONVERT_TARGET_FPS,\n) -> None:\n    \"\"\"Resample a motion dict to ``target_fps`` when needed, then save Kimodo NPZ.\"\"\"\n    t_before = int(motion[\"local_rot_mats\"].shape[0])\n    motion, did_resample = resample_motion_dict_to_kimodo_fps(motion, skeleton, source_fps, target_fps)\n    t_after = int(motion[\"local_rot_mats\"].shape[0])\n    if did_resample:\n        warn_kimodo_npz_framerate(source_fps, t_before, t_after)\n    save_kimodo_npz(output_path, motion)\n\n\ndef kimodo_npz_to_bytes(motion_dict: Dict[str, Any]) -> bytes:\n    \"\"\"Serialize a Kimodo motion dict to in-memory NPZ bytes.\"\"\"\n    import io\n\n    buf = io.BytesIO()\n    np.savez(buf, **motion_dict_to_numpy(motion_dict))\n    return buf.getvalue()\n\n\ndef g1_csv_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, device: Any) -> bytes:\n    \"\"\"Convert a motion dict to G1 MuJoCo CSV bytes via :class:`MujocoQposConverter`.\"\"\"\n    import io\n\n    from kimodo.exports.mujoco import MujocoQposConverter\n\n    converter = MujocoQposConverter(skeleton)\n    qpos = converter.dict_to_qpos(\n        {k: v for k, v in motion_dict.items() if k in (\"local_rot_mats\", \"root_positions\")},\n        device,\n        numpy=True,\n    )\n    buf = io.StringIO()\n    np.savetxt(buf, qpos, delimiter=\",\")\n    return buf.getvalue().encode(\"utf-8\")\n\n\ndef amass_npz_to_bytes(motion_dict: Dict[str, Any], skeleton: SkeletonBase, fps: float) -> bytes:\n    \"\"\"Convert a motion dict to AMASS NPZ bytes via :class:`AMASSConverter`.\"\"\"\n    import io\n\n    from kimodo.exports.smplx import AMASSConverter\n\n    converter = AMASSConverter(skeleton=skeleton, fps=fps)\n    buf = io.BytesIO()\n    converter.convert_save_npz(\n        {k: v for k, v in motion_dict.items() if k in (\"local_rot_mats\", \"root_positions\")},\n        buf,\n    )\n    return buf.getvalue()\n\n\ndef _read_amass_source_fps(path: str) -> float:\n    \"\"\"Read the source frame rate from an AMASS NPZ, defaulting to 30 Hz.\"\"\"\n    with np.load(path, allow_pickle=True) as z:\n        if \"mocap_frame_rate\" in z.files:\n            return float(z[\"mocap_frame_rate\"])\n    return 30.0\n\n\ndef load_motion_file(\n    path: str,\n    source_fps: float | None = None,\n    target_fps: float | None = None,\n    *,\n    z_up: bool = True,\n    mujoco_rest_zero: bool = False,\n) -> tuple[Dict[str, torch.Tensor], int]:\n    \"\"\"Load a motion file and return a Kimodo motion dict plus joint count.\n\n    Supports SOMA BVH (``.bvh``), G1 MuJoCo CSV (``.csv``), Kimodo NPZ, and AMASS SMPL-X NPZ\n    (``.npz``).\n\n    The motion is loaded at its native (or overridden) source rate, then\n    resampled to ``target_fps`` when they differ.\n\n    Args:\n        path: Path to ``.bvh``, ``.csv``, or ``.npz``.\n        source_fps: Source frame rate (Hz).  If provided, trusted as-is.\n            If ``None``, auto-detected per format: BVH ``Frame Time`` header,\n            AMASS ``mocap_frame_rate``, or :data:`KIMODO_CONVERT_TARGET_FPS`\n            (30 Hz) for CSV / Kimodo NPZ.\n        target_fps: Desired output frame rate (Hz).  Defaults to\n            :data:`KIMODO_CONVERT_TARGET_FPS` (30 Hz).  The motion is\n            resampled when ``source_fps`` and ``target_fps`` differ.\n        z_up: AMASS NPZ only; passed to :func:`load_amass_npz`.\n        mujoco_rest_zero: G1 CSV only; passed to :func:`load_g1_csv`.\n\n    Returns:\n        ``(motion_dict, num_joints)`` with the same keys as :func:`complete_motion_dict`.\n    \"\"\"\n    from kimodo.exports.motion_formats import infer_npz_kind\n\n    if target_fps is None:\n        target_fps = KIMODO_CONVERT_TARGET_FPS\n\n    ext = os.path.splitext(path)[1].lower()\n    if ext == \".bvh\":\n        from kimodo.exports.bvh import bvh_to_kimodo_motion\n\n        motion_dict, bvh_fps = bvh_to_kimodo_motion(path)\n        effective_source = source_fps if source_fps is not None else bvh_fps\n        num_joints = int(motion_dict[\"local_rot_mats\"].shape[1])\n    elif ext == \".csv\":\n        effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS\n        motion_dict = load_g1_csv(path, source_fps=effective_source, mujoco_rest_zero=mujoco_rest_zero)\n        num_joints = 34\n    elif ext == \".npz\":\n        kind = infer_npz_kind(path)\n        if kind == \"amass\":\n            effective_source = source_fps if source_fps is not None else _read_amass_source_fps(path)\n            motion_dict = load_amass_npz(path, source_fps=effective_source, z_up=z_up)\n            num_joints = 22\n        else:\n            effective_source = source_fps if source_fps is not None else KIMODO_CONVERT_TARGET_FPS\n            motion_dict, num_joints = load_kimodo_npz_as_torch(path, source_fps=effective_source)\n    else:\n        raise ValueError(f\"Unsupported motion file {path!r}; expected .bvh, .csv, or .npz\")\n\n    if abs(effective_source - target_fps) > 0.5:\n        sk = build_skeleton(num_joints)\n        motion_dict, did_resample = resample_motion_dict_to_kimodo_fps(motion_dict, sk, effective_source, target_fps)\n        if did_resample:\n            t_out = int(motion_dict[\"local_rot_mats\"].shape[0])\n            warnings.warn(\n                f\"Resampled motion from {effective_source:.4g} Hz to \" f\"{target_fps:.0f} Hz ({t_out} frames).\",\n                UserWarning,\n                stacklevel=2,\n            )\n\n    return motion_dict, num_joints\n"
  },
  {
    "path": "kimodo/exports/mujoco.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Convert kimodo motion (y-up, z-forward) to MuJoCo qpos (z-up, x-forward) for G1 skeleton.\"\"\"\n\nimport os\nimport xml.etree.ElementTree as ET\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom scipy.spatial.transform import Rotation\n\nfrom kimodo.assets import skeleton_asset_path\nfrom kimodo.geometry import (\n    axis_angle_to_matrix,\n    matrix_to_axis_angle,\n    matrix_to_quaternion,\n    quaternion_to_matrix,\n)\nfrom kimodo.skeleton import G1Skeleton34, SkeletonBase, global_rots_to_local_rots\nfrom kimodo.tools import ensure_batched, to_numpy, to_torch\n\n# Cache so that the same (skeleton, xml_path) returns the same converter instance.\n_converter_cache: dict[tuple[int, str], \"MujocoQposConverter\"] = {}\n\n\nclass MujocoQposConverter:\n    \"\"\"Fast batch converter from our dictionary format to mujoco qpos with precomputed transforms.\n\n    In mujoco, the coordination is z up and x forward, right handed.\n\n    Features (30 joints):\n    - root (pelvis, 7 = translation + rotation) + 29 dof joints (29)\n\n    In kimodo, the coordinate system is y up and z forward, right handed.\n    Features (34 joints):\n    - root (pelvis) + (34 - 1) joints; among these joints, 4 are end-effector joints added by kimodo.\n\n    Cached by (input_skeleton id, xml_path); repeated calls with the same args return the same instance.\n    \"\"\"\n\n    def __new__(\n        cls,\n        input_skeleton: SkeletonBase,\n        xml_path: str = str(skeleton_asset_path(\"g1skel34\", \"xml\", \"g1.xml\")),\n    ):\n        key = (id(input_skeleton), xml_path)\n        if key not in _converter_cache:\n            inst = object.__new__(cls)\n            _converter_cache[key] = inst\n        return _converter_cache[key]\n\n    def __init__(\n        self,\n        input_skeleton: SkeletonBase,\n        xml_path: str = str(skeleton_asset_path(\"g1skel34\", \"xml\", \"g1.xml\")),\n    ):\n        \"\"\"Initialize converter with precomputed transforms.\n\n        Args:\n            xml_path: Path to the mujoco XML file containing joint definitions\n        \"\"\"\n        if getattr(self, \"_initialized\", False):\n            return\n        self.xml_path = xml_path\n        self.skeleton = input_skeleton\n        self._prepare_transforms()\n        self._subtree_joints = {}\n        self._initialized = True\n\n    def _prepare_transforms(self):\n        \"\"\"Precompute all necessary transforms for efficient batch processing.\"\"\"\n        # Define coordinate transformations between mujoco and kimodo space\n        # 1) R_zup_to_yup: rotation around x-axis by -90 degrees\n        # 2) x_forward_to_y_forward: rotation around z-axis by -90 degrees\n        # Combined transformation matrix: mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward\n        self.mujoco_to_kimodo_matrix = torch.tensor(\n            [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float32\n        )\n        self.kimodo_to_mujoco_matrix = self.mujoco_to_kimodo_matrix.T  # Inverse transformation: kimodo_to_mujoco\n\n        # Parse XML once and extract joint information\n        tree = ET.parse(self.xml_path)\n        root = tree.getroot()\n\n        xml_classes = [x for x in tree.findall(\".//default\") if \"class\" in x.attrib]\n        joint_axes = dict()\n        class_ranges: dict[str, tuple[float, float]] = {}\n        for xml_class in xml_classes:\n            j = xml_class.findall(\"joint\")\n            if j:\n                joint_axes[xml_class.get(\"class\")] = j[0].get(\"axis\")\n                range_str = j[0].get(\"range\")\n                if range_str:\n                    range_vals = [float(x) for x in range_str.split()]\n                    if len(range_vals) == 2:\n                        class_ranges[xml_class.get(\"class\")] = (\n                            range_vals[0],\n                            range_vals[1],\n                        )\n\n        mujoco_hinge_joints = root.find(\"worldbody\").findall(\".//joint\")  # skip the base joint\n        self._mujoco_joint_axis_values_kimodo_space = torch.zeros(\n            (len(mujoco_hinge_joints), 3), dtype=torch.float32\n        )  # mujoco order but kimodo space\n        self._mujoco_joint_axis_values_mujoco_space = torch.zeros(\n            (len(mujoco_hinge_joints), 3), dtype=torch.float32\n        )  # mujoco order but mujoco space\n\n        # for the below indices, mujoco_indices_to_kimodo_indices does not include mujoco root (30 - 1 = 29 elements),\n        # while kimodo_indices_to_mujoco_indices inclues the kimodo root (32 elements).\n        self._mujoco_indices_to_kimodo_indices = torch.zeros((len(mujoco_hinge_joints),), dtype=torch.int32)\n        self._kimodo_indices_to_mujoco_indices = (\n            torch.ones((self.skeleton.nbjoints,), dtype=torch.int32) * -1\n        )  # -1 means not in the csv skeleton\n\n        self._nb_joints_mujoco = len(mujoco_hinge_joints) + 1\n        self._nb_joints_kimodo = self.skeleton.nbjoints\n        self._mujoco_joint_including_root_parent_list = torch.full(\n            (len(mujoco_hinge_joints) + 1,), -1, dtype=torch.int32\n        )\n        self._mujoco_joint_including_root_list = [\"pelvis_skel\"]\n\n        for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):\n            joint_name_in_skeleton = joint.get(\"name\").replace(\"_joint\", \"_skel\")\n            joint_parent_name_in_skeleton = self.skeleton.bone_parents[joint_name_in_skeleton]\n\n            self._mujoco_joint_including_root_list.append(joint_name_in_skeleton)\n            self._mujoco_joint_including_root_parent_list[joint_id_in_csv + 1] = (\n                self._mujoco_joint_including_root_list.index(joint_parent_name_in_skeleton)\n            )\n\n            joint_idx_in_kimodo_skeleton = self.skeleton.bone_order_names.index(joint_name_in_skeleton)\n            axis_values = [float(x) for x in (joint.get(\"axis\") or joint_axes[joint.get(\"class\")]).split(\" \")]\n\n            # the mapped axis in kimodo skeleton space is calculated as bones_axis = mujoco_to_kimodo.apply(axis_values)\n            # [1, 0, 0] -> [0, 0, 1]; [0, 1, 0] -> [1, 0, 0]; [0, 0, 1] -> [0, 1, 0]\n            mujoco_joint_axis_mapping_kimodo_space = [\n                torch.tensor([0, 0, 1]),\n                torch.tensor([1, 0, 0]),\n                torch.tensor([0, 1, 0]),\n            ][np.argmax(axis_values)]\n\n            self._mujoco_joint_axis_values_kimodo_space[joint_id_in_csv] = mujoco_joint_axis_mapping_kimodo_space\n            self._mujoco_joint_axis_values_mujoco_space[joint_id_in_csv] = torch.tensor(axis_values)\n\n            self._mujoco_indices_to_kimodo_indices[joint_id_in_csv] = joint_idx_in_kimodo_skeleton\n            self._kimodo_indices_to_mujoco_indices[joint_idx_in_kimodo_skeleton] = (\n                joint_id_in_csv + 1\n            )  # +1 for the root\n        self._kimodo_indices_to_mujoco_indices[0] = 0  # the root joint mapping\n\n        # Joint limits (min, max) in radians for each mujoco hinge, for clamping\n        self._joint_limits_min = torch.full((len(mujoco_hinge_joints),), float(\"-inf\"), dtype=torch.float32)\n        self._joint_limits_max = torch.full((len(mujoco_hinge_joints),), float(\"inf\"), dtype=torch.float32)\n        for joint_id_in_csv, joint in enumerate(mujoco_hinge_joints):\n            range_vals = None\n            if joint.get(\"range\"):\n                range_vals = [float(x) for x in joint.get(\"range\").split()]\n            elif joint.get(\"class\") and joint.get(\"class\") in class_ranges:\n                lo, hi = class_ranges[joint.get(\"class\")]\n                range_vals = [lo, hi]\n            if range_vals is not None and len(range_vals) == 2:\n                self._joint_limits_min[joint_id_in_csv] = range_vals[0]\n                self._joint_limits_max[joint_id_in_csv] = range_vals[1]\n\n        # load the offset matrices from the xml\n        R_zup_to_yup = Rotation.from_euler(\"x\", -90, degrees=True)\n        x_forward_to_y_forward = Rotation.from_euler(\"z\", -90, degrees=True)\n        mujoco_to_kimodo = R_zup_to_yup * x_forward_to_y_forward\n\n        self._rot_offsets_q2t = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)\n        self._rot_offsets_q2t[...] = torch.eye(3)[None]\n\n        self._rot_offsets_f2q = torch.zeros(len(self._kimodo_indices_to_mujoco_indices), 3, 3, dtype=torch.float32)\n        self._rot_offsets_f2q[...] = torch.eye(3)[None]\n        parent_map = {child: parent for parent in root.iter() for child in parent}\n        for i, joint in enumerate(mujoco_hinge_joints):\n            body = parent_map[joint]\n            if \"quat\" in body.attrib:\n                rot = Rotation.from_quat(\n                    [float(x) for x in body.get(\"quat\").strip().split(\" \")],\n                    scalar_first=True,\n                )\n                idx = self._mujoco_indices_to_kimodo_indices[i]\n                self._rot_offsets_q2t[idx] = torch.from_numpy(rot.as_matrix())\n                rot = mujoco_to_kimodo * rot * mujoco_to_kimodo.inv()\n                self._rot_offsets_f2q[idx] = torch.from_numpy(rot.as_matrix().T)\n\n        # Hinge axis in f2q space so extraction uses the same frame as joint_rot_f2q.\n        # Then extract(offset) gives the angle s.t. axis_angle(angle * axis_f2q) = offset, and\n        # reconstruction R_local = offset.T @ axis_angle(angle * axis_f2q) = I when input is identity.\n        axis_kimodo = self._mujoco_joint_axis_values_kimodo_space\n        self._mujoco_joint_axis_values_f2q_space = torch.zeros_like(axis_kimodo)\n        for i in range(len(mujoco_hinge_joints)):\n            j = self._mujoco_indices_to_kimodo_indices[i].item()\n            axis_f2q = torch.mv(self._rot_offsets_f2q[j], axis_kimodo[i])\n            n = axis_f2q.norm()\n            if n > 1e-8:\n                axis_f2q = axis_f2q / n\n            self._mujoco_joint_axis_values_f2q_space[i] = axis_f2q\n\n        # Rest-pose DOFs: angle we extract when R_local = I (t-pose). MuJoCo limits are\n        # relative to joint zero (rest pose), so we must clamp in MuJoCo space: convert\n        # joint_dofs to mujoco_angle = joint_dofs - rest_dofs, clamp, then back.\n        rest_rot_f2q = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]\n        rest_rot_f2q = rest_rot_f2q.unsqueeze(0).unsqueeze(0)\n        self._rest_dofs = self._local_rots_f2q_to_joint_dofs(rest_rot_f2q).squeeze(0).squeeze(0)\n        # Axis-angle rest DOFs: angle s.t. axis_angle(angle * axis_f2q) = offset. Used in\n        # project_to_real_robot_rotations so extract+reconstruct round-trip and t-pose is preserved.\n        rest_rot_f2q_flat = self._rot_offsets_f2q[self._mujoco_indices_to_kimodo_indices]\n        full_aa = matrix_to_axis_angle(rest_rot_f2q_flat)\n        self._rest_dofs_axis_angle = (full_aa * self._mujoco_joint_axis_values_f2q_space).sum(dim=-1)\n\n    def dict_to_qpos(\n        self,\n        output: dict,\n        device: Optional[str] = None,\n        root_quat_w_first: bool = True,\n        numpy: bool = True,\n        mujoco_rest_zero: bool = False,\n    ):\n        \"\"\"Convert kimodo output dict to mujoco qpos format.\n\n        Args:\n            output: dict with keys \"local_rot_mats\" and \"root_positions\".\n            device: device to use for the output.\n            root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).\n            numpy: If True, convert the output to numpy array.\n            mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)\n                maps to q=0 in MuJoCo. If False, write raw joint_dofs.\n\n        Returns:\n            qpos: (B, T, 7+J) mujoco qpos format.\n        \"\"\"\n        local_rot_mats = to_torch(output[\"local_rot_mats\"], device)\n        root_positions = to_torch(output[\"root_positions\"], device)\n\n        qpos = self.to_qpos(\n            local_rot_mats,\n            root_positions,\n            root_quat_w_first=root_quat_w_first,\n            mujoco_rest_zero=mujoco_rest_zero,\n        )\n        if numpy:\n            qpos = to_numpy(qpos)\n        return qpos\n\n    def qpos_to_motion_dict(\n        self,\n        qpos: torch.Tensor | np.ndarray,\n        source_fps: float,\n        *,\n        root_quat_w_first: bool = True,\n        mujoco_rest_zero: bool = False,\n    ):\n        \"\"\"Inverse of :meth:`to_qpos` / :meth:`dict_to_qpos` for MuJoCo CSV ``(T, 36)`` rows.\n\n        Args:\n            qpos: Shape ``(T, 36)`` or ``(1, T, 36)`` (root xyz, root quat wxyz, 29 joint angles).\n            source_fps: Source frame rate (Hz) of the qpos data.\n            root_quat_w_first: Must match how the CSV was written (default ``True``).\n            mujoco_rest_zero: Must match :meth:`dict_to_qpos` / :meth:`to_qpos`.\n\n        Returns:\n            Kimodo motion dict (see :func:`kimodo.exports.motion_io.complete_motion_dict`).\n        \"\"\"\n        from kimodo.exports.motion_io import complete_motion_dict\n\n        qpos = to_torch(qpos, None)\n        if qpos.dim() == 2:\n            qpos = qpos.unsqueeze(0)\n        device = qpos.device\n        dtype = qpos.dtype\n        batch_size, num_frames, ncols = qpos.shape\n        if ncols != 36:\n            raise ValueError(f\"Expected qpos last dim 36; got {ncols}\")\n\n        kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)\n        mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T\n\n        root_mujoco = qpos[..., :3]\n        root_positions = torch.matmul(mujoco_to_kimodo_matrix[None, None, ...], root_mujoco[..., None]).squeeze(-1)\n\n        quat = qpos[..., 3:7]\n        if root_quat_w_first:\n            root_rot_mujoco = quaternion_to_matrix(quat)\n        else:\n            quat_wxyz = quat[..., [3, 0, 1, 2]]\n            root_rot_mujoco = quaternion_to_matrix(quat_wxyz)\n\n        O0 = self._rot_offsets_f2q[0].to(device=device, dtype=dtype)\n        # root_rot_mujoco is (..., 3, 3) after optional batch unsqueeze (e.g. (1, T, 3, 3)).\n        # Use ``...il`` so ``k`` sums with ``kl``; ``...ik`` incorrectly keeps ``k`` in the output.\n        R_f2q_root = torch.einsum(\n            \"ij,...jk,kl->...il\",\n            mujoco_to_kimodo_matrix,\n            root_rot_mujoco,\n            kimodo_to_mujoco_matrix,\n        )\n        R_kimodo_root = torch.einsum(\"ij,...jk->...ik\", O0.T, R_f2q_root)\n\n        joint_dofs = qpos[..., 7:]\n        if mujoco_rest_zero:\n            rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)\n            angles = joint_dofs + rest_dofs[None, None, :]\n            use_relative = True\n        else:\n            angles = joint_dofs\n            use_relative = False\n\n        nb_joints = self.skeleton.nbjoints\n        template = torch.eye(3, device=device, dtype=dtype).expand(batch_size, num_frames, nb_joints, 3, 3).contiguous()\n        template[:, :, 0] = R_kimodo_root\n\n        local_rot_mats = self._joint_dofs_to_local_rot_mats(\n            angles,\n            template,\n            device,\n            dtype,\n            use_relative=use_relative,\n        )\n\n        if batch_size != 1:\n            raise ValueError(f\"Only a single clip is supported; got batch_size={batch_size}\")\n\n        return complete_motion_dict(local_rot_mats[0], root_positions[0], self.skeleton, source_fps)\n\n    def save_csv(self, qpos: torch.Tensor | np.ndarray, csv_path):\n        # comment this\n        qpos = to_numpy(qpos)\n        shape = qpos.shape\n        if len(shape) == 2:\n            # only one motion: save it\n            np.savetxt(csv_path, qpos, delimiter=\",\")\n        if len(shape) == 3:\n            # batch of motions\n            if shape[0] == 1:\n                # if only one motion, just save it\n                np.savetxt(csv_path, qpos[0], delimiter=\",\")\n            else:\n                csv_path_base, ext = os.path.splitext(csv_path)\n                for i in range(shape[0]):\n                    self.save_csv(qpos[i], csv_path_base + \"_\" + str(i).zfill(2) + ext)\n\n    def _local_rots_to_joint_dofs(\n        self,\n        local_rot_mats: torch.Tensor,\n        axis_vals: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Extract per-joint single-DoF angles (radians) via Euler projection (for to_qpos/f2q).\"\"\"\n        x_joint_dof = torch.atan2(local_rot_mats[..., 2, 1], local_rot_mats[..., 2, 2])\n        y_joint_dof = torch.atan2(local_rot_mats[..., 0, 2], local_rot_mats[..., 0, 0])\n        z_joint_dof = torch.atan2(local_rot_mats[..., 1, 0], local_rot_mats[..., 1, 1])\n        xyz_joint_dofs = torch.stack([x_joint_dof, y_joint_dof, z_joint_dof], dim=-1)\n        axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)\n        joint_dofs = (xyz_joint_dofs * axis_vals[None, None, :, :]).sum(dim=-1)\n        return joint_dofs\n\n    def _local_rots_to_joint_dofs_axis_angle(\n        self,\n        local_rot_mats: torch.Tensor,\n        axis_vals: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Extract per-joint single-DoF angles (radians) via axis-angle; round-trips with\n        axis_angle_to_matrix.\n\n        Args:\n            local_rot_mats: (..., num_hinges, 3, 3) in same frame as axis_vals.\n            axis_vals: (num_hinges, 3) unit axis per hinge.\n        Returns:\n            joint_dofs: (..., num_hinges) signed angle = dot(axis_angle(R), axis).\n        \"\"\"\n        axis_vals = axis_vals.to(device=local_rot_mats.device, dtype=local_rot_mats.dtype)\n        full_aa = matrix_to_axis_angle(local_rot_mats)\n        joint_dofs = (full_aa * axis_vals).sum(dim=-1)\n        return joint_dofs\n\n    def _local_rots_f2q_to_joint_dofs(self, local_rot_mats_f2q: torch.Tensor) -> torch.Tensor:\n        \"\"\"Extract per-joint single-DoF angles from local rotations in f2q space (for to_qpos).\"\"\"\n        axis_vals = self._mujoco_joint_axis_values_f2q_space\n        return self._local_rots_to_joint_dofs(local_rot_mats_f2q, axis_vals)\n\n    def _clamp_to_limits(self, joint_dofs: torch.Tensor) -> torch.Tensor:\n        \"\"\"Clamp joint angles to XML limits (radians).\n\n        Angles are in kimodo convention (0 = rest).\n        \"\"\"\n        device = joint_dofs.device\n        lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)\n        hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)\n        return torch.clamp(joint_dofs, lo[None, None, :], hi[None, None, :])\n\n    def _clamp_joint_dofs(self, joint_dofs: torch.Tensor, rest_dofs: torch.Tensor) -> torch.Tensor:\n        \"\"\"Clamp joint angles to MuJoCo limits (radians), with rest_dofs conversion.\"\"\"\n        device = joint_dofs.device\n        rest_dofs = rest_dofs.to(device=device, dtype=joint_dofs.dtype)\n        mujoco_dofs = joint_dofs - rest_dofs[None, None, :]\n        lo = self._joint_limits_min.to(device=device, dtype=joint_dofs.dtype)\n        hi = self._joint_limits_max.to(device=device, dtype=joint_dofs.dtype)\n        mujoco_dofs = torch.clamp(mujoco_dofs, lo[None, None, :], hi[None, None, :])\n        return mujoco_dofs + rest_dofs[None, None, :]\n\n    def _joint_dofs_to_local_rot_mats(\n        self,\n        joint_dofs: torch.Tensor,\n        original_local_rot_mats: torch.Tensor,\n        device: torch.device,\n        dtype: torch.dtype,\n        use_relative: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Reconstruct full local rotation matrices from 1-DoF angles.\"\"\"\n        out = original_local_rot_mats.clone()\n        axis_kimodo = self._mujoco_joint_axis_values_kimodo_space.to(device=device, dtype=dtype)\n        for i in range(joint_dofs.shape[-1]):\n            j = self._mujoco_indices_to_kimodo_indices[i].item()\n            angle = joint_dofs[..., i]\n            axis = axis_kimodo[i]\n            if use_relative:\n                axis_angle = angle[..., None] * axis[None, None, :]\n                R_local = axis_angle_to_matrix(axis_angle)\n            else:\n                rot_offsets_f2q = self._rot_offsets_f2q.to(device=device, dtype=dtype)\n                axis_in_f2q = torch.mv(rot_offsets_f2q[j], axis)\n                axis_angle = angle[..., None] * axis_in_f2q[None, None, :]\n                R_f2q = axis_angle_to_matrix(axis_angle)\n                R_local = torch.einsum(\"ij,btjk->btik\", rot_offsets_f2q[j].T, R_f2q)\n            out[:, :, j, :, :] = R_local\n        return out\n\n    @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)\n    def project_to_real_robot_rotations(\n        self,\n        local_rot_mats: torch.Tensor,\n        root_positions: torch.Tensor,\n        clamp_to_limits: bool = True,\n        mujoco_rest_zero: bool = False,\n    ) -> dict:\n        \"\"\"Project full 3D local rotations to G1 real robot DoF and back to 3D for viz.\n\n        Joint angles are extracted along each hinge axis, optionally clamped to XML limits, then\n        reconstructed to 3D rotations. When mujoco_rest_zero=False (default), raw angles are used\n        (baked-with-quat). When True, angles are relative to rest (0 = T-pose in MuJoCo).\n        \"\"\"\n        device = local_rot_mats.device\n        dtype = local_rot_mats.dtype\n\n        # Transform to f2q frame and extract 1-DoF angles (axis-angle projection).\n        local_rot_f2q = torch.matmul(self._rot_offsets_f2q.to(device=device, dtype=dtype), local_rot_mats)\n        hinge_rots = local_rot_f2q[:, :, self._mujoco_indices_to_kimodo_indices, :, :]\n        axis_f2q = self._mujoco_joint_axis_values_f2q_space.to(device=device, dtype=dtype)\n        joint_dofs = self._local_rots_to_joint_dofs_axis_angle(hinge_rots, axis_f2q)\n\n        # Optionally express angles relative to rest (MuJoCo q=0 at T-pose).\n        if mujoco_rest_zero:\n            rest_dofs = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)\n            angles = joint_dofs - rest_dofs[None, None, :]\n            use_relative = True\n        else:\n            angles = joint_dofs\n            use_relative = False\n\n        if clamp_to_limits:\n            if mujoco_rest_zero:\n                angles = self._clamp_to_limits(angles)\n            else:\n                rest_dofs_aa = self._rest_dofs_axis_angle.to(device=device, dtype=dtype)\n                angles = self._clamp_joint_dofs(angles, rest_dofs_aa)\n\n        # Reconstruct 3D local rotations from 1-DoF angles and run FK.\n        local_rot_mats_proj = self._joint_dofs_to_local_rot_mats(\n            angles, local_rot_mats, device, dtype, use_relative=use_relative\n        )\n        global_rot_mats, posed_joints, _ = self.skeleton.fk(local_rot_mats_proj, root_positions)\n        return {\n            \"local_rot_mats\": local_rot_mats_proj,\n            \"global_rot_mats\": global_rot_mats,\n            \"posed_joints\": posed_joints,\n            \"root_positions\": root_positions,\n        }\n\n    @ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)\n    def to_qpos(\n        self,\n        local_rot_mats: torch.Tensor,\n        root_positions: torch.Tensor,\n        root_quat_w_first: bool = True,\n        mujoco_rest_zero: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Fast batch conversion from kimodo features to mujoco qpos format.\n\n        Args:\n            local_rot_mats: (B, T, J, 3, 3) local rotation matrices (kimodo convention).\n            root_positions: (B, T, 3) root positions.\n            root_quat_w_first: If True, quaternion in qpos is (w,x,y,z).\n            mujoco_rest_zero: If True, joint angles are written so that kimodo rest (t-pose)\n                maps to q=0 in MuJoCo. If False, write raw joint_dofs.\n\n        Returns:\n            torch.Tensor of shape [batch, numFrames, 36] containing mujoco qpos data:\n            - root_trans (3) + root_quat (4) + joint_dofs (29) = 36 columns\n        \"\"\"\n\n        batch_size, num_frames, nb_joints = local_rot_mats.shape[:3]\n        device, dtype = local_rot_mats.device, local_rot_mats.dtype\n\n        local_rot_mats = torch.matmul(self._rot_offsets_f2q.to(device), local_rot_mats)\n\n        batch_size, num_frames = root_positions.shape[0], root_positions.shape[1]\n\n        # Move precomputed matrices to the same device/dtype\n        kimodo_to_mujoco_matrix = self.kimodo_to_mujoco_matrix.to(device=device, dtype=dtype)\n\n        # Initialize output tensor: [batch, numFrames, 36]\n        qpos = torch.zeros((batch_size, num_frames, 36), dtype=dtype, device=device)\n\n        # Convert root translation: apply coordinate transformation\n        root_positions_mujoco = torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_positions[..., None])\n        qpos[:, :, :3] = root_positions_mujoco.view(batch_size, num_frames, 3)\n\n        # Convert root rotation: apply coordinate transformation to rotation matrix\n        root_rot = local_rot_mats[:, :, 0, :]  # [batch, numFrames, 3, 3]\n\n        # Apply coordinate transformation: R_mujoco = kimodo_to_mujoco * R_kimodo * kimodo_to_mujoco^T\n        mujoco_to_kimodo_matrix = kimodo_to_mujoco_matrix.T\n        root_rot_mujoco = torch.matmul(\n            torch.matmul(kimodo_to_mujoco_matrix[None, None, ...], root_rot),\n            mujoco_to_kimodo_matrix[None, None, ...],\n        )\n        root_rot_quat = matrix_to_quaternion(root_rot_mujoco)  # [w, x, y, z]\n        if root_quat_w_first:\n            qpos[:, :, 3:7] = root_rot_quat[:, :, [0, 1, 2, 3]]  # [w, x, y, z]\n        else:\n            qpos[:, :, 3:7] = root_rot_quat[:, :, [1, 2, 3, 0]]  # [w, x, y, z] -> [x, y, z, w]\n\n        # Joint DOFs: raw angles or relative to rest (rest = q=0 in MuJoCo).\n        joint_rot_f2q = local_rot_mats[:, :, self._mujoco_indices_to_kimodo_indices, :, :]\n        joint_dofs = self._local_rots_f2q_to_joint_dofs(joint_rot_f2q)\n        if mujoco_rest_zero:\n            rest_dofs = self._rest_dofs.to(device=device, dtype=dtype)\n            qpos[:, :, 7:] = joint_dofs - rest_dofs[None, None, :]\n        else:\n            qpos[:, :, 7:] = joint_dofs\n        return qpos\n\n\ndef apply_g1_real_robot_projection(\n    skeleton: G1Skeleton34,\n    joints_pos: torch.Tensor,\n    joints_rot: torch.Tensor,\n    clamp_to_limits: bool = True,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Project G1 motion to real robot DoF (1-DoF per joint) with optional axis limits.\n\n    Extracts a single angle per hinge along its axis (1-DoF), optionally clamps to\n    joint limits from the MuJoCo XML (when clamp_to_limits=True), then reconstructs\n    3D rotations and runs FK. T-pose (identity local rotations) is preserved.\n\n    Args:\n        skeleton: G1 skeleton instance.\n        joints_pos: (T, J, 3) or (B, T, J, 3) joint positions in global space.\n        joints_rot: (T, J, 3, 3) or (B, T, J, 3, 3) global rotation matrices.\n        clamp_to_limits: If True, clamp joint angles to XML axis limits (default True).\n\n    Returns:\n        (posed_joints, global_rot_mats) as tensors, same shape as inputs (batch preserved).\n    \"\"\"\n\n    local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton)\n    root_positions = joints_pos[..., skeleton.root_idx, :]\n\n    # Converter expects batch dim (B, T, ...); add and remove if single sequence.\n    single_sequence = local_rot_mats.dim() == 4\n    if single_sequence:\n        local_rot_mats = local_rot_mats.unsqueeze(0)\n        root_positions = root_positions.unsqueeze(0)\n\n    converter = MujocoQposConverter(skeleton)\n    projected = converter.project_to_real_robot_rotations(\n        local_rot_mats, root_positions, clamp_to_limits=clamp_to_limits\n    )\n\n    out_pos = projected[\"posed_joints\"]\n    out_rot = projected[\"global_rot_mats\"]\n    if single_sequence:\n        out_pos = out_pos.squeeze(0)\n        out_rot = out_rot.squeeze(0)\n    return out_pos, out_rot\n"
  },
  {
    "path": "kimodo/exports/smplx.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Convert kimodo motion to AMASS/SMPL-X compatible parameters (axis-angle, Y-up or Z-up).\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport einops\nimport numpy as np\nimport torch\n\nfrom kimodo.assets import skeleton_asset_path\nfrom kimodo.geometry import axis_angle_to_matrix, matrix_to_axis_angle\nfrom kimodo.tools import ensure_batched, to_numpy, to_torch\n\n\ndef kimodo_y_up_to_amass_coord_rotation_matrix() -> np.ndarray:\n    \"\"\"3x3 rotation mapping Kimodo Y-up (+Z forward) to AMASS Z-up (+Y forward).\n\n    Used by :func:`get_amass_parameters` and :func:`amass_arrays_to_kimodo_motion` (inverse).\n    \"\"\"\n    y_up_to_z_up = np.array(\n        [\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, -1.0],\n            [0.0, 1.0, 0.0],\n        ],\n        dtype=np.float32,\n    )\n    rot_z_180 = np.array(\n        [\n            [-1.0, 0.0, 0.0],\n            [0.0, -1.0, 0.0],\n            [0.0, 0.0, 1.0],\n        ],\n        dtype=np.float32,\n    )\n    return np.matmul(rot_z_180, y_up_to_z_up).astype(np.float32)\n\n\n@ensure_batched(local_rot_mats=5, root_positions=3, lengths=1)\ndef get_amass_parameters(\n    local_rot_mats,\n    root_positions,\n    skeleton,\n    z_up=True,\n):\n    \"\"\"Convert local rot mats and root positions to AMASS-style trans and pose_body; optional z_up\n    coordinate transform.\n\n    Our method generates motions with Y-up and +Z forward; if z_up=True, transform to Z-up and +Y\n    forward as in AMASS.\n    \"\"\"\n    # Our method generate motions with Y-up and +Z forward\n    # if z_up = True, we transform this to: Z-up with +Y forward, as in AMASS\n    # Remove the root offset; SMPL-X FK adds pelvis offset back.\n    pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].cpu().numpy()\n    trans = root_positions - pelvis_offset\n\n    root_rot_mats = to_numpy(local_rot_mats[:, :, 0])\n    local_rot_axis_angle = to_numpy(matrix_to_axis_angle(to_torch(local_rot_mats)))\n    pose_body = einops.rearrange(local_rot_axis_angle[:, :, 1:], \"b t j d -> b t (j d)\")\n\n    # Optionally convert from Y-up to Z-up coordinates.\n    if z_up:\n        y_up_to_z_up = kimodo_y_up_to_amass_coord_rotation_matrix()\n        root_rot_mats = np.matmul(y_up_to_z_up, root_rot_mats)\n        trans = np.matmul(trans + pelvis_offset, y_up_to_z_up.T) - pelvis_offset\n\n    root_orient = to_numpy(matrix_to_axis_angle(to_torch(root_rot_mats)))\n    return trans, root_orient, pose_body\n\n\ndef amass_arrays_to_kimodo_motion(\n    trans: np.ndarray,\n    root_orient: np.ndarray,\n    pose_body: np.ndarray,\n    skeleton,\n    source_fps: float,\n    *,\n    z_up: bool = True,\n):\n    \"\"\"Inverse of :func:`get_amass_parameters` for a single sequence (AMASS → Kimodo motion dict).\n\n    Args:\n        trans: ``(T, 3)`` AMASS root translation (same as ``trans`` in AMASS NPZ).\n        root_orient: ``(T, 3)`` axis-angle root orientation in AMASS coordinates (z-up when ``z_up``).\n        pose_body: ``(T, 63)`` body pose axis-angle (21 joints × 3).\n        skeleton: :class:`~kimodo.skeleton.definitions.SMPLXSkeleton22` instance.\n        source_fps: Source frame rate (Hz) of the AMASS recording.\n        z_up: If ``True``, invert the same Y-up↔Z-up transform as ``get_amass_parameters(..., z_up=True)``.\n\n    Returns:\n        Motion dict compatible with :func:`kimodo.exports.motion_io.save_kimodo_npz`.\n    \"\"\"\n    from kimodo.exports.motion_io import complete_motion_dict\n\n    trans = np.asarray(trans, dtype=np.float32)\n    root_orient = np.asarray(root_orient, dtype=np.float32)\n    pose_body = np.asarray(pose_body, dtype=np.float32)\n    if trans.ndim != 2 or trans.shape[-1] != 3:\n        raise ValueError(f\"trans must be (T, 3); got {trans.shape}\")\n    if root_orient.shape != trans.shape:\n        raise ValueError(f\"root_orient shape {root_orient.shape} must match trans {trans.shape}\")\n    t = trans.shape[0]\n    if pose_body.shape != (t, 63):\n        raise ValueError(f\"pose_body must be (T, 63); got {pose_body.shape}\")\n\n    pelvis_offset = skeleton.neutral_joints[skeleton.root_idx].detach().cpu().numpy().astype(np.float32)\n    device = skeleton.neutral_joints.device\n    dtype = torch.float32\n\n    Y_np = kimodo_y_up_to_amass_coord_rotation_matrix()\n    if z_up:\n        y_up_to_z_up = torch.from_numpy(Y_np).to(device=device, dtype=dtype)\n        # trans_amass = root_kimodo @ Y.T - pelvis_offset  =>  root_kimodo = (trans_amass + pelvis_offset) @ Y\n        root_positions_np = (trans + pelvis_offset) @ Y_np\n    else:\n        root_positions_np = trans + pelvis_offset\n\n    root_positions = torch.from_numpy(root_positions_np).to(device=device, dtype=dtype)\n\n    R_amass_root = axis_angle_to_matrix(torch.from_numpy(root_orient).to(device=device, dtype=dtype))\n    if z_up:\n        R_kimodo_root = torch.einsum(\"ij,tjk->tik\", y_up_to_z_up.T, R_amass_root)\n    else:\n        R_kimodo_root = R_amass_root\n\n    nb = skeleton.nbjoints\n    if nb != 22:\n        raise ValueError(f\"Expected SMPL-X body skeleton with 22 joints; got {nb}\")\n\n    local_rot_mats = torch.zeros((t, nb, 3, 3), device=device, dtype=dtype)\n    local_rot_mats[:, 0] = R_kimodo_root\n\n    pose_aa = torch.from_numpy(pose_body.reshape(t, 21, 3)).to(device=device, dtype=dtype)\n    local_rot_mats[:, 1:] = axis_angle_to_matrix(pose_aa.reshape(-1, 3)).reshape(t, 21, 3, 3)\n\n    return complete_motion_dict(local_rot_mats, root_positions, skeleton, source_fps)\n\n\ndef amass_npz_to_kimodo_motion(npz_path: str, skeleton, source_fps: Optional[float] = None, *, z_up: bool = True):\n    \"\"\"Load an AMASS-style ``.npz`` and return a Kimodo motion dict.\n\n    Args:\n        npz_path: Path to AMASS NPZ (``trans``, ``root_orient``, ``pose_body``, ...).\n        skeleton: SMPL-X skeleton instance.\n        source_fps: Source frame rate (Hz); if ``None``, uses ``mocap_frame_rate``\n            from the file when present, else ``30.0``.\n        z_up: Same meaning as :func:`amass_arrays_to_kimodo_motion`.\n    \"\"\"\n    with np.load(npz_path, allow_pickle=True) as data:\n        trans = np.asarray(data[\"trans\"], dtype=np.float32)\n        root_orient = np.asarray(data[\"root_orient\"], dtype=np.float32)\n        pose_body = np.asarray(data[\"pose_body\"], dtype=np.float32)\n        if source_fps is None:\n            source_fps = float(data[\"mocap_frame_rate\"]) if \"mocap_frame_rate\" in data.files else 30.0\n\n    return amass_arrays_to_kimodo_motion(trans, root_orient, pose_body, skeleton, source_fps, z_up=z_up)\n\n\nclass AMASSConverter:\n    def __init__(\n        self,\n        fps,\n        skeleton,\n        beta_path=str(skeleton_asset_path(\"smplx22\", \"beta.npy\")),\n        mean_hands_path=str(skeleton_asset_path(\"smplx22\", \"mean_hands.npy\")),\n    ):\n        self.fps = fps\n        self.skeleton = skeleton\n        # Load betas\n        if os.path.exists(beta_path):\n            # only use first 16 betas to match AMASS\n            betas = np.load(beta_path)[:16]\n        else:\n            betas = np.zeros(16)\n\n        # Load mean hands\n        if os.path.exists(mean_hands_path):\n            mean_hands = np.load(mean_hands_path)\n        else:\n            mean_hands = np.zeros(90)\n\n        self.default_frame_params = {\n            \"pose_jaw\": np.zeros(3),\n            \"pose_eye\": np.zeros(6),\n            \"pose_hand\": mean_hands,\n        }\n        self.output_dict_base = {\n            \"gender\": \"neutral\",\n            \"surface_model_type\": \"smplx\",\n            \"betas\": betas,\n            \"num_betas\": len(betas),\n            \"mocap_frame_rate\": float(fps),\n        }\n\n    def convert_save_npz(self, output: dict, npz_path, z_up=True):\n        trans, root_orient, pose_body = get_amass_parameters(\n            output[\"local_rot_mats\"],\n            output[\"root_positions\"],\n            self.skeleton,\n            z_up=z_up,\n        )\n        nb_frames = trans.shape[-2]\n\n        amass_output_base = self.output_dict_base.copy()\n        for key, val in self.default_frame_params.items():\n            amass_output_base[key] = einops.repeat(val, \"d -> t d\", t=nb_frames)\n\n        amass_output_base[\"mocap_time_length\"] = nb_frames / self.fps\n        self.save_npz(trans, root_orient, pose_body, amass_output_base, npz_path)\n\n    def save_npz(self, trans, root_orient, pose_body, base_output, npz_path):\n        shape = trans.shape\n        if len(shape) == 3 and shape[0] == 1:\n            # if only one motion, squeeze the data\n            trans = trans[0]\n            root_orient = root_orient[0]\n            pose_body = pose_body[0]\n            shape = trans.shape\n        if len(shape) == 2:\n            amass_output = {\n                \"trans\": trans,\n                \"root_orient\": root_orient,\n                \"pose_body\": pose_body,\n            } | base_output\n            np.savez(npz_path, **amass_output)\n\n        elif len(shape) == 3:\n            # real batch of motions\n            npz_path_base, ext = os.path.splitext(npz_path)\n            for i in range(shape[0]):\n                npz_path_i = npz_path_base + \"_\" + str(i).zfill(2) + ext\n                self.save_npz(trans[i], root_orient[i], pose_body[i], base_output, npz_path_i)\n\n\n# amass_output = {\n#     \"gender\": \"neutral\",\n#     \"surface_model_type\": \"smplx\",\n#     \"mocap_frame_rate\": float(fps),\n#     \"mocap_time_length\": len(motion) / float(fps)\n#     \"trans\": trans,\n#     \"betas\": betas,\n#     \"num_betas\": len(betas),\n#     \"root_orient\": np.array([T, 3]), # axis angle\n#     \"pose_body\": np.array([T, 63]), # 63=21*3, axis angle 21 = 22 - root\n#     \"pose_hand\": np.array([T, 90]), # 90=30*3=15*2*3 axis angle (load from mean_hands)\n#     \"pose_jaw\": np.array([T, 3]), # all zeros is fine\n#     \"pose_eye\": np.array([T, 6]), # all zeros is fine`\n# }\n"
  },
  {
    "path": "kimodo/geometry.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Rotation and representation conversions: axis-angle, quaternion, matrix, 6D continuous.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef angle_to_Y_rotation_matrix(angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Build a rotation matrix around the Y axis from a scalar angle (radians).\n\n    Shape: angle.shape + (3, 3).\n    \"\"\"\n    cos, sin = torch.cos(angle), torch.sin(angle)\n    one, zero = torch.ones_like(angle), torch.zeros_like(angle)\n    mat = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1)\n    mat = mat.reshape(angle.shape + (3, 3))\n    return mat\n\n\ndef matrix_to_cont6d(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotation matrix to 6D continuous representation (first two columns).\n\n    Shape: (..., 3, 3) -> (..., 6).\n    \"\"\"\n    cont_6d = torch.concat([matrix[..., 0], matrix[..., 1]], dim=-1)\n    return cont_6d\n\n\ndef cont6d_to_matrix(cont6d: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert 6D continuous representation to rotation matrix (Gram–Schmidt on two columns).\n\n    Last dim must be 6.\n    \"\"\"\n    assert cont6d.shape[-1] == 6, \"The last dimension must be 6\"\n    x_raw = cont6d[..., 0:3]\n    y_raw = cont6d[..., 3:6]\n\n    x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)\n    z = torch.cross(x, y_raw, dim=-1)\n    z = z / torch.norm(z, dim=-1, keepdim=True)\n\n    y = torch.cross(z, x, dim=-1)\n\n    x = x[..., None]\n    y = y[..., None]\n    z = z[..., None]\n\n    mat = torch.cat([x, y, z], dim=-1)\n    return mat\n\n\ndef axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert axis-angle to rotation matrix.\n\n    Args:\n        axis_angle: (..., 3) axis-angle vectors (angle = norm, axis = normalized)\n    Returns:\n        rotmat: (..., 3, 3) rotation matrices\n    \"\"\"\n    eps = 1e-6\n    angle = torch.norm(axis_angle, dim=-1, keepdim=True)  # (..., 1)\n    axis = axis_angle / (angle + eps)\n\n    x, y, z = axis.unbind(-1)\n\n    zero = torch.zeros_like(x)\n    K = torch.stack([zero, -z, y, z, zero, -x, -y, x, zero], dim=-1).reshape(*axis.shape[:-1], 3, 3)\n\n    eye = torch.eye(3, device=axis.device, dtype=axis.dtype)\n    eye = eye.expand(*axis.shape[:-1], 3, 3)\n\n    sin = torch.sin(angle)[..., None]\n    cos = torch.cos(angle)[..., None]\n\n    R = eye + sin * K + (1 - cos) * (K @ K)\n    return R\n\n\ndef matrix_to_axis_angle(R: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotation matrix to axis-angle via quaternions (more numerically stable).\n\n    Args:\n        R: (..., 3, 3) rotation matrices\n    Returns:\n        axis_angle: (..., 3)\n    \"\"\"\n    # Go through quaternions for numerical stability\n    quat = matrix_to_quaternion(R)  # (..., 4) with (w, x, y, z)\n    return quaternion_to_axis_angle(quat)\n\n\ndef quaternion_to_axis_angle(quat: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert quaternion to axis-angle representation.\n\n    Args:\n        quat: (..., 4) quaternions with real part first (w, x, y, z)\n    Returns:\n        axis_angle: (..., 3)\n    \"\"\"\n    eps = 1e-6\n\n    # Ensure canonical form to avoid sign ambiguity.\n    # Primary: prefer w > 0. When w ≈ 0 (angle ≈ π), prefer first nonzero xyz > 0.\n    w = quat[..., 0:1]\n    xyz = quat[..., 1:]\n\n    # Find first significant component of xyz for tie-breaking when w ≈ 0\n    first_significant = xyz[..., 0:1]  # use x component as tie-breaker\n\n    # Flip if: w < 0, OR (w ≈ 0 AND first xyz component < 0)\n    should_flip = (w < -eps) | ((w.abs() <= eps) & (first_significant < 0))\n    quat = torch.where(should_flip, -quat, quat)\n\n    w = quat[..., 0]\n    xyz = quat[..., 1:]\n\n    # sin(angle/2) = ||xyz||\n    sin_half_angle = xyz.norm(dim=-1)\n\n    # angle = 2 * atan2(sin(angle/2), cos(angle/2))\n    # This is more stable than 2 * acos(w) near angle=0\n    angle = 2.0 * torch.atan2(sin_half_angle, w)\n\n    # axis = xyz / sin(angle/2), but handle small angles\n    # For small angles: axis-angle ≈ 2 * xyz (since sin(x) ≈ x for small x)\n    small_angle = sin_half_angle.abs() < eps\n\n    # Safe division\n    scale = torch.where(\n        small_angle,\n        2.0 * torch.ones_like(angle),  # small angle: axis_angle ≈ 2 * xyz\n        angle / sin_half_angle.clamp(min=eps),\n    )\n\n    return xyz * scale.unsqueeze(-1)\n\n\ndef _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Returns torch.sqrt(torch.max(0, x)) subgradient is zero where x is 0.\"\"\"\n    return torch.sqrt(x * (x > 0).to(x.dtype))\n\n\ndef matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as rotation matrices to quaternions.\n\n    Args:\n        matrix: Rotation matrices as tensor of shape (..., 3, 3).\n    Returns:\n        quaternions with real part first, as tensor of shape (..., 4).\n    \"\"\"\n    if matrix.size(-1) != 3 or matrix.size(-2) != 3:\n        raise ValueError(f\"Invalid rotation matrix shape {matrix.shape}.\")\n\n    batch_dim = matrix.shape[:-2]\n    m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)\n\n    q_abs = _sqrt_positive_part(\n        torch.stack(\n            [\n                1.0 + m00 + m11 + m22,\n                1.0 + m00 - m11 - m22,\n                1.0 - m00 + m11 - m22,\n                1.0 - m00 - m11 + m22,\n            ],\n            dim=-1,\n        )\n    )\n\n    quat_by_rijk = torch.stack(\n        [\n            torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),\n            torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),\n            torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),\n            torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),\n        ],\n        dim=-2,\n    )\n\n    flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)\n    quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))\n\n    return (\n        (F.one_hot(q_abs.argmax(dim=-1), num_classes=4)[..., None] * quat_candidates)\n        .sum(dim=-2)\n        .reshape(batch_dim + (4,))\n    )\n\n\ndef quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:\n    \"\"\"Convert rotations given as quaternions to rotation matrices.\n\n    Args:\n        quaternions: quaternions with real part first,\n            as tensor of shape (..., 4).\n    Returns:\n        Rotation matrices as tensor of shape (..., 3, 3).\n    \"\"\"\n    r, i, j, k = torch.unbind(quaternions, -1)\n    two_s = 2.0 / (quaternions * quaternions).sum(-1)\n\n    o = torch.stack(\n        (\n            1 - two_s * (j * j + k * k),\n            two_s * (i * j - k * r),\n            two_s * (i * k + j * r),\n            two_s * (i * j + k * r),\n            1 - two_s * (i * i + k * k),\n            two_s * (j * k - i * r),\n            two_s * (i * k - j * r),\n            two_s * (j * k + i * r),\n            1 - two_s * (i * i + j * j),\n        ),\n        -1,\n    )\n    return o.reshape(quaternions.shape[:-1] + (3, 3))\n"
  },
  {
    "path": "kimodo/meta.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Parse and normalize prompt text/duration data from meta dicts.\"\"\"\n\nimport os\nfrom typing import Any, Optional\n\nfrom kimodo.tools import load_json\n\nfrom .sanitize import sanitize_text, sanitize_texts\n\n\ndef load_prompts_from_meta(meta_path: str, **kwargs):\n    \"\"\"Load prompts from a meta dict or file. If fps is provided, the durations are converted to\n    frames.\n\n    Args:\n        meta_path: Path to the meta file.\n        **kwargs: Additional arguments to pass to parse_prompts_from_meta.\n\n    Returns:\n        texts: List of texts.\n        durations: List of durations in seconds or frames.\n    \"\"\"\n    if not os.path.exists(meta_path):\n        raise FileNotFoundError(f\"meta.json not found in input folder: {meta_path}\")\n\n    meta = load_json(meta_path)\n    return parse_prompts_from_meta(meta, **kwargs)\n\n\ndef parse_prompts_from_meta(\n    meta: dict[str, Any],\n    fps: Optional[float] = None,\n    sanitize: bool = False,\n) -> tuple[list[str], list[float]]:\n    \"\"\"Parse prompt texts and durations from a meta dict into normalized lists. If fps is provided,\n    the durations are converted to frames.\n\n    Accepts either:\n    - Single prompt: \"text\" (str) and \"duration\" (float) in seconds.\n    - Multiple prompts: \"texts\" (list of str) and \"durations\" (list of float) in seconds.\n\n    Returns:\n        (texts, durations): texts as list of str, durations as list of float (seconds or frames).\n        Lengths of both lists are equal.\n\n    Raises:\n        ValueError: If meta does not contain a recognized format.\n    \"\"\"\n    # Single prompt\n    if \"text\" in meta and \"duration\" in meta:\n        text = meta[\"text\"]\n        duration = float(meta[\"duration\"])\n        if fps is not None:\n            duration = int(duration * fps)\n        if isinstance(text, list):\n            raise ValueError(\"meta has 'text' but it is a list; use 'texts' for multiple prompts\")\n\n        if sanitize:\n            text = sanitize_text(text)\n        return ([text], [duration])\n\n    # Multiple prompts\n    if \"texts\" in meta and \"durations\" in meta:\n        texts = meta[\"texts\"]\n        durations = meta[\"durations\"]\n        if not isinstance(texts, list) or not isinstance(durations, list):\n            raise ValueError(\"meta 'texts' and 'durations' must be lists\")\n        if len(texts) != len(durations):\n            raise ValueError(f\"meta 'texts' and 'durations' length mismatch: {len(texts)} vs {len(durations)}\")\n        durations = [float(d) for d in durations]\n        if fps is not None:\n            durations = [int(d * fps) for d in durations]\n\n        if sanitize:\n            texts = sanitize_texts(texts)\n        return texts, durations\n\n    raise ValueError(\"meta must contain either 'text' and 'duration', or 'texts' and 'durations'.\")\n"
  },
  {
    "path": "kimodo/metrics/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Evaluation metrics for motion quality (foot skate, contact consistency, constraint following).\"\"\"\n\nfrom .base import (\n    Metric,\n    aggregate_metrics,\n    clear_metrics,\n    compute_metrics,\n)\nfrom .constraints import ContraintFollow\nfrom .foot_skate import (\n    FootContactConsistency,\n    FootSkateFromContacts,\n    FootSkateFromHeight,\n    FootSkateRatio,\n)\nfrom .tmr import (\n    TMR_EmbeddingMetric,\n    TMR_Metric,\n    compute_tmr_per_sample_retrieval,\n    compute_tmr_retrieval_metrics,\n)\n\n__all__ = [\n    \"Metric\",\n    \"ContraintFollow\",\n    \"FootContactConsistency\",\n    \"FootSkateFromContacts\",\n    \"FootSkateFromHeight\",\n    \"FootSkateRatio\",\n    \"TMR_EmbeddingMetric\",\n    \"TMR_Metric\",\n    \"aggregate_metrics\",\n    \"clear_metrics\",\n    \"compute_metrics\",\n    \"compute_tmr_per_sample_retrieval\",\n    \"compute_tmr_retrieval_metrics\",\n]\n"
  },
  {
    "path": "kimodo/metrics/base.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Base metric class and batch/aggregate helpers.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import Dict, List\n\nimport torch\n\n\nclass Metric:\n    \"\"\"Base class for metrics that accumulate results over multiple __call__ and expose\n    aggregate().\"\"\"\n\n    def __init__(self, **kwargs):\n        self.clear()\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"Compute metric for current batch, append to saved_metrics, and return the batch\n        result.\"\"\"\n        metrics = self._compute(*args, **kwargs)\n        for key, val in metrics.items():\n            self.saved_metrics[key].append(val.detach().cpu().float())\n        return metrics\n\n    def _compute(self, **kwargs):\n        \"\"\"Subclasses implement this to compute metric dict from batch inputs.\"\"\"\n        raise NotImplementedError()\n\n    def clear(self):\n        \"\"\"Reset all accumulated metric values.\"\"\"\n        self.saved_metrics = defaultdict(list)\n\n    def aggregate(self):\n        \"\"\"Return a dict of concatenated/stacked tensors over all accumulated batches.\"\"\"\n        output = {}\n        for key, lst in self.saved_metrics.items():\n            try:\n                output[key] = torch.cat(lst)\n            except RuntimeError:\n                output[key] = torch.stack(lst)\n        return output\n\n\ndef compute_metrics(metrics_list: List[Metric], metrics_in: Dict) -> Dict:\n    \"\"\"Run each metric on metrics_in and return the combined dict of batch results.\"\"\"\n    metrics_out = {}\n    for metric in metrics_list:\n        metrics_out.update(metric(**metrics_in))\n    return metrics_out\n\n\ndef aggregate_metrics(metrics_list: List[Metric]) -> Dict:\n    \"\"\"Return combined aggregated results (concatenated over batches) for all metrics.\"\"\"\n    metrics_out = {}\n    for metric in metrics_list:\n        metrics_out.update(metric.aggregate())\n    return metrics_out\n\n\ndef clear_metrics(metrics_list: List[Metric]) -> None:\n    \"\"\"Clear accumulated values for all metrics in the list.\"\"\"\n    for metric in metrics_list:\n        metric.clear()\n"
  },
  {
    "path": "kimodo/metrics/constraints.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Constraint-following metrics.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom kimodo.constraints import (\n    EndEffectorConstraintSet,\n    FullBodyConstraintSet,\n    Root2DConstraintSet,\n)\nfrom kimodo.tools import ensure_batched\n\nfrom .base import Metric\n\n\nclass ContraintFollow(Metric):\n    \"\"\"Constraint-following metric dispatcher for kimodo constraint sets.\"\"\"\n\n    def __init__(\n        self,\n        skeleton,\n        root_threshold: float = 0.10,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.skeleton = skeleton\n        self.root_threshold = root_threshold\n\n    @ensure_batched(posed_joints=4, constraints_lst=2, lengths=1)\n    def _compute(\n        self,\n        posed_joints: Tensor,\n        constraints_lst: Optional[List],\n        lengths: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Dict:\n        if not constraints_lst:\n            return {}\n\n        root_idx = self.skeleton.root_idx\n        output = defaultdict(list)\n\n        for posed_joints_s, constraint_lst_s, lengths_s in zip(posed_joints, constraints_lst, lengths):\n            output_seq = defaultdict(list)\n            for constraint in constraint_lst_s:\n                frame_idx = constraint.frame_indices.to(device=posed_joints_s.device, dtype=torch.long)\n                assert frame_idx.max() < lengths_s, \"The constraint is defined outsite the lenght of the motion.\"\n                if frame_idx.numel() == 0:\n                    continue\n\n                if isinstance(constraint, Root2DConstraintSet):\n                    pred_root2d = posed_joints_s[frame_idx, root_idx][:, [0, 2]]\n                    target = constraint.smooth_root_2d.to(posed_joints_s.device)\n\n                    dist = torch.norm(pred_root2d - target, dim=-1)\n                    output_seq[\"constraint_root2d_err\"].append(dist)\n                    hit = (dist <= self.root_threshold).float()\n                    output_seq[\"constraint_root2d_acc\"].append(hit)\n\n                elif isinstance(constraint, FullBodyConstraintSet):\n                    pred = posed_joints_s[frame_idx]\n                    target = constraint.global_joints_positions.to(posed_joints_s.device)\n                    err = torch.norm(pred - target, dim=-1)\n                    output_seq[\"constraint_fullbody_keyframe\"].append(err)\n\n                elif isinstance(constraint, EndEffectorConstraintSet):\n                    pos_idx = constraint.pos_indices.to(device=posed_joints_s.device, dtype=torch.long)\n                    pred = posed_joints_s[frame_idx].index_select(1, pos_idx)\n                    target = constraint.global_joints_positions.to(posed_joints_s.device).index_select(1, pos_idx)\n                    err = torch.norm(pred - target, dim=-1)\n                    output_seq[\"constraint_end_effector\"].append(err)\n\n            # in case we have several same constraints in the list\n            for key, val in output_seq.items():\n                output[key].append(torch.cat(val).mean())\n\n        reduced = {}\n        for key, vals in output.items():\n            reduced[key] = torch.stack(vals, dim=0)\n        return reduced\n"
  },
  {
    "path": "kimodo/metrics/foot_skate.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Foot skate and contact consistency metrics.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Dict, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom kimodo.motion_rep.feature_utils import compute_vel_xyz\nfrom kimodo.motion_rep.feet import foot_detect_from_pos_and_vel\nfrom kimodo.skeleton import SkeletonBase\nfrom kimodo.tools import ensure_batched\n\nfrom .base import Metric\n\n\ndef get_four_contacts(fidx: list):\n    if len(fidx) == 4:\n        return fidx\n    if len(fidx) == 6:\n        # For soma77\n        # remove \"LeftToeEnd\" and \"RightToeEnd\"\n        fidx = fidx[:2] + fidx[3:5]\n        return fidx\n    raise ValueError(\"Expects 4 or 6 foot joints (heel/toe per foot)\")\n\n\nclass FootSkateFromHeight(Metric):\n    \"\"\"When toe joint is near the floor, measures mean velocity of the toes.\"\"\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        fps: float,\n        height_thresh: float = 0.05,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.height_thresh = height_thresh\n        self.skeleton = skeleton\n        self.fps = fps\n\n    @ensure_batched(posed_joints=4, lengths=1)\n    def _compute(\n        self,\n        posed_joints: Tensor,\n        lengths: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Dict:\n        fidx = self.skeleton.foot_joint_idx\n        fidx = get_four_contacts(fidx)\n\n        feet_pos = posed_joints[:, :, fidx]\n        toe_pos = feet_pos[:, :, [1, 3]]\n\n        toe_on_floor = (toe_pos[..., 1] < self.height_thresh)[:, :-1]  # y-up [B, T, 2] where [left right]\n\n        dt = 1.0 / self.fps\n        toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt  # [B, nframes-1, 2]\n\n        # compute err\n        contact_toe_vel = toe_vel * toe_on_floor  # vel when corresponding toe is on ground\n\n        # account for generated length\n        # since they are velocities use length-1 to avoid inaccurate vel going one frame past len\n        device = toe_on_floor.device\n        len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (\n            lengths[:, None, None] - 1\n        )\n        toe_on_floor = toe_on_floor * len_mask\n        contact_toe_vel = contact_toe_vel * len_mask\n\n        mean_vel = torch.sum(contact_toe_vel, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)\n        return {\"foot_skate_from_height\": mean_vel}\n\n\nclass FootSkateFromContacts(Metric):\n    \"\"\"Measures velocity of the toes and ankles when predicted to be in contact.\"\"\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        fps: float,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.skeleton = skeleton\n        self.fps = fps\n\n    @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)\n    def _compute(\n        self,\n        posed_joints: Tensor,\n        foot_contacts: Tensor,\n        lengths: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Dict:\n        fidx = self.skeleton.foot_joint_idx\n        fidx = get_four_contacts(fidx)\n\n        feet_pos = posed_joints[:, :, fidx]\n        dt = 1.0 / self.fps\n        foot_vel = torch.norm(feet_pos[:, 1:] - feet_pos[:, :-1], dim=-1) / dt\n\n        if foot_contacts.shape[-1] == 6:\n            # For soma77\n            # remove \"LeftToeEnd\" and \"RightToeEnd\"\n            foot_contacts = foot_contacts[..., [0, 1, 3, 4]]\n\n        foot_contacts = foot_contacts[:, :-1]\n        vel_err = foot_vel * foot_contacts\n\n        # account for generated length\n        # since they are velocities use length-1 to avoid inaccurate vel going one frame past len\n        device = foot_contacts.device\n        len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (\n            lengths[:, None, None] - 1\n        )\n        foot_contacts = foot_contacts * len_mask\n        vel_err = vel_err * len_mask\n\n        mean_vel = torch.sum(vel_err, (1, 2)) / (torch.sum(foot_contacts, (1, 2)) + 1e-6)  # mean over contacting frames\n\n        # Compute max velocity error across all feet and frames (per batch)\n        max_vel = vel_err.amax(dim=(1, 2))  # [B]\n\n        return {\n            \"foot_skate_from_pred_contacts\": mean_vel,\n            \"foot_skate_max_vel\": max_vel,\n        }\n\n\nclass FootSkateRatio(Metric):\n    \"\"\"Compute fraction of frames where the foot skates when it is on the ground.\n\n    Inspired by GMD: https://github.com/korrawe/guided-motion-diffusion/blob/main/data_loaders/humanml/utils/metrics.py#L204\n    \"\"\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        fps: float,\n        height_thresh=0.05,\n        vel_thresh=0.2,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.height_thresh = height_thresh\n        self.vel_thresh = vel_thresh\n\n        self.skeleton = skeleton\n        self.fps = fps\n\n    @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)\n    def _compute(\n        self,\n        posed_joints: Tensor,\n        foot_contacts: Tensor,\n        lengths: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Dict:\n        fidx = self.skeleton.foot_joint_idx\n        fidx = get_four_contacts(fidx)\n\n        feet_pos = posed_joints[:, :, fidx]\n        toe_pos = feet_pos[:, :, [1, 3]]\n\n        toe_on_floor = toe_pos[..., 1] < self.height_thresh  # y-up [B, T, 2] where [left right]\n        # current and next frame on floor to consider it in contact\n        toe_on_floor = torch.logical_and(toe_on_floor[:, :-1], toe_on_floor[:, 1:])  # [B, T-1, 2]\n\n        dt = 1.0 / self.fps\n        toe_vel = torch.norm(toe_pos[:, 1:] - toe_pos[:, :-1], dim=-1) / dt  # [B, nframes-1, 2]\n\n        # compute err\n        contact_toe_vel = toe_vel * toe_on_floor  # vel when corresponding toe is on ground\n\n        # account for generated length\n        # since they are velocities use length-1 to avoid inaccurate vel going one frame past len\n        device = toe_on_floor.device\n        len_mask = torch.arange(toe_on_floor.shape[1], device=device)[None, :, None].expand(toe_on_floor.shape) < (\n            lengths[:, None, None] - 1\n        )\n        toe_on_floor = toe_on_floor * len_mask\n        contact_toe_vel = contact_toe_vel * len_mask\n\n        # skating if velocity during contact > thresh\n        toe_skate = contact_toe_vel > self.vel_thresh\n        skate_ratio = torch.sum(toe_skate, (1, 2)) / (torch.sum(toe_on_floor, (1, 2)) + 1e-6)\n        return {\"foot_skate_ratio\": skate_ratio}\n\n\nclass FootContactConsistency(Metric):\n    \"\"\"Measures consistency between heuristic detected foot contacts (from height and velocity) and\n    predicted foot contacts.\n\n    i.e. accuracy of how well predicted matches heuristic.\n    \"\"\"\n\n    def __init__(\n        self,\n        skeleton: SkeletonBase,\n        fps: float,\n        vel_thresh: float = 0.15,\n        height_thresh: float = 0.10,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.vel_thresh = vel_thresh\n        self.height_thresh = height_thresh\n\n        self.skeleton = skeleton\n        self.fps = fps\n\n    @ensure_batched(posed_joints=4, foot_contacts=3, lengths=1)\n    def _compute(\n        self,\n        posed_joints: Tensor,\n        foot_contacts: Tensor,\n        lengths: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Dict:\n        velocity = compute_vel_xyz(posed_joints, float(self.fps), lengths=lengths)\n        heuristic_contacts = foot_detect_from_pos_and_vel(\n            posed_joints,\n            velocity,\n            self.skeleton,\n            self.vel_thresh,\n            self.height_thresh,\n        )\n\n        if foot_contacts.shape[-1] == 6:\n            # For soma77\n            # remove \"LeftToeEnd\" and \"RightToeEnd\"\n            foot_contacts = foot_contacts[..., [0, 1, 3, 4]]\n\n        num_contacts = foot_contacts.shape[-1]\n        incorrect = torch.logical_xor(heuristic_contacts, foot_contacts)\n        # account for generated length\n        # since they are velocities, use length-1 to avoid inaccurate vel going one frame past len\n        device = foot_contacts.device\n        len_mask = torch.arange(foot_contacts.shape[1], device=device)[None, :, None].expand(foot_contacts.shape) < (\n            lengths[:, None, None] - 1\n        )\n        incorrect = incorrect * len_mask\n\n        incorrect_ratio = torch.sum(incorrect, (1, 2)) / (num_contacts * (lengths - 1))\n        accuracy = 1 - incorrect_ratio\n\n        return {\"foot_contact_consistency\": accuracy}\n"
  },
  {
    "path": "kimodo/metrics/tmr.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"TMR evaluation metrics: text-motion retrieval, R-Precision, and related scores.\"\"\"\n\nfrom __future__ import annotations\n\nfrom collections import defaultdict\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nimport torch\nfrom scipy import linalg\nfrom torch import Tensor\n\nfrom kimodo.model.tmr import TMR\n\nfrom .base import Metric\n\n\n# Scores are between 0 and 1\ndef get_score_matrix_unit(x, y):\n    sim_matrix = np.einsum(\"b i, c i -> b c\", x, y)\n    scores = sim_matrix / 2 + 0.5\n    return scores\n\n\ndef get_scores_unit(x, y):\n    similarity = np.einsum(\"... i, ... i\", x, y)\n    scores = similarity / 2 + 0.5\n    return scores\n\n\ndef compute_tmr_per_sample_retrieval(\n    motion_emb: np.ndarray,\n    text_emb: np.ndarray,\n    sample_ids: List[str],\n    texts: List[str],\n    top_k: int = 5,\n) -> List[Dict[str, Any]]:\n    \"\"\"For each sample (text query i), compute t2m rank of motion i and top-k retrieved motions with\n    ids and texts.\n\n    Returns list of dicts: [{\"rank\": int, \"top_k\": [{\"id\": str, \"text\": str}, ...]}, ...].\n    \"\"\"\n    motion_emb = np.asarray(motion_emb).squeeze()\n    text_emb = np.asarray(text_emb).squeeze()\n    if motion_emb.ndim == 1:\n        motion_emb = motion_emb[np.newaxis, :]\n    if text_emb.ndim == 1:\n        text_emb = text_emb[np.newaxis, :]\n    n = motion_emb.shape[0]\n    assert text_emb.shape[0] == n and len(sample_ids) == n and len(texts) == n\n    scores = get_score_matrix_unit(text_emb, motion_emb)\n    out: List[Dict[str, Any]] = []\n    for i in range(n):\n        row = np.asarray(scores[i])\n        order = np.argsort(row)[::-1]\n        rank = int(np.where(order == i)[0][0]) + 1\n        top_indices = order[:top_k]\n        top_k_list = [{\"id\": sample_ids[j], \"text\": texts[j]} for j in top_indices]\n        out.append({\"rank\": rank, \"top_k\": top_k_list})\n    return out\n\n\nclass TMR_Metric(Metric):\n    def __init__(\n        self,\n        tmr_model: TMR,\n        ranks: List = [1, 2, 3, 5, 10],\n        ranks_rounding=2,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.tmr_model = tmr_model\n        self.ranks = ranks\n        self.ranks_rounding = ranks_rounding\n\n    def clear(self):\n        self.saved_metrics = defaultdict(list)\n        self.saved_text_latents = []\n        self.saved_motion_gen_latents = []\n        self.saved_motion_gt_latents = []\n\n    def _compute(\n        self,\n        motion_rep,\n        pred_joints_output: Dict,\n        gt_joints_output: Dict,\n        text_x_dict: Dict,\n        lengths: Tensor,\n        **kwargs,\n    ) -> Dict:\n        pred_posed_joints = pred_joints_output[\"posed_joints\"]\n        original_skeleton = motion_rep.skeleton if motion_rep is not None else None\n        latents_motion = self.tmr_model.encode_motion(\n            pred_posed_joints,\n            lengths=lengths,\n            original_skeleton=original_skeleton,\n            unit_vector=True,\n        )\n        latents_motion = latents_motion.cpu().numpy()\n\n        if isinstance(text_x_dict, dict) and \"texts\" in text_x_dict:\n            latents_text = self.tmr_model.encode_raw_text(text_x_dict[\"texts\"], unit_vector=True)\n        else:\n            latents_text = self.tmr_model.encode_text(text_x_dict, unit_vector=True)\n        if latents_text.dim() == 1:\n            latents_text = latents_text.unsqueeze(0)\n        latents_text = latents_text.cpu().numpy()\n\n        self.saved_text_latents.append(latents_text)\n        self.saved_motion_gen_latents.append(latents_motion)\n\n        scores_text = get_scores_unit(latents_motion, latents_text)\n        output = {\"TMR/t2m_sim\": scores_text}\n\n        if gt_joints_output is not None and \"posed_joints\" in gt_joints_output:\n            gt_posed_joints = gt_joints_output[\"posed_joints\"]\n            gt_latents_motion = self.tmr_model.encode_motion(\n                gt_posed_joints,\n                lengths=lengths,\n                original_skeleton=original_skeleton,\n                unit_vector=True,\n            )\n            gt_latents_motion = gt_latents_motion.cpu().numpy()\n            self.saved_motion_gt_latents.append(gt_latents_motion)\n\n            gt_scores_text = get_scores_unit(gt_latents_motion, latents_text)\n            scores_motion = get_scores_unit(latents_motion, gt_latents_motion)\n\n            output[\"TMR/t2m_gt_sim\"] = gt_scores_text\n            output[\"TMR/m2m_sim\"] = scores_motion\n\n        # pytorch tensors\n        for key, val in output.items():\n            output[key] = torch.tensor(val)\n        return output\n\n    def aggregate(self):\n        output = {}\n        for key, lst in self.saved_metrics.items():\n            output[key] = np.concatenate(lst)\n\n        assert self.saved_text_latents, \"Should call the metric at least once.\"\n\n        text_latents = np.concatenate(self.saved_text_latents)\n        motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)\n\n        batch_size = len(text_latents)\n        assert text_latents.shape == motion_gen_latents.shape\n\n        scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)\n        scores_t2t = get_score_matrix_unit(text_latents, text_latents)\n\n        t2m_metrics = contrastive_metrics(\n            scores=scores_t2m,\n            scores_t2t=scores_t2t,\n            threshold=0.99,\n            rounding=2,\n        )\n\n        for key, val in t2m_metrics.items():\n            output[\"TMR/t2m_R/\" + key] = val\n\n        mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)\n        mu_text, cov_text = calculate_activation_statistics(text_latents)\n\n        fid_gen_text = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)\n        output[\"TMR/FID/gen_text\"] = fid_gen_text\n\n        if self.saved_motion_gt_latents:\n            motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)\n            assert motion_gt_latents.shape == motion_gen_latents.shape\n\n            scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)\n            scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)\n\n            m2gm_metrics = contrastive_metrics(\n                scores=scores_m2gm,\n                scores_t2t=scores_t2t,\n                threshold=0.99,\n                rounding=2,\n            )\n            for key, val in m2gm_metrics.items():\n                output[\"TMR/m2m_R/\" + key] = val\n\n            t2gm_metrics = contrastive_metrics(\n                scores=scores_t2gm,\n                scores_t2t=scores_t2t,\n                threshold=0.99,\n                rounding=2,\n            )\n            for key, val in t2gm_metrics.items():\n                output[\"TMR/t2m_gt_R/\" + key] = val\n\n            mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)\n            fid_gen_motion = calculate_frechet_distance(\n                mu_gen,\n                cov_gen,\n                mu_gt_motion,\n                cov_gt_motion,\n            )\n            output[\"TMR/FID/gen_gt\"] = fid_gen_motion\n\n            fid_gt_text = calculate_frechet_distance(\n                mu_gt_motion,\n                cov_gt_motion,\n                mu_text,\n                cov_text,\n            )\n            output[\"TMR/FID/gt_text\"] = fid_gt_text\n\n        for key, val in output.items():\n            if isinstance(val, (int, float, np.integer, np.floating)):\n                val = torch.tensor([val for _ in range(batch_size)])\n\n            if isinstance(val, np.ndarray):\n                val = torch.from_numpy(val)\n\n            output[key] = val.cpu().float()\n        return output\n\n\nclass TMR_EmbeddingMetric(Metric):\n    \"\"\"TMR metrics from precomputed motion and text embeddings (no model load).\n\n    Use in the loop: pass motion_emb and text_emb per sample; aggregate() computes retrieval metrics.\n    \"\"\"\n\n    def __init__(self, ranks_rounding: int = 2, **kwargs):\n        super().__init__(**kwargs)\n        self.ranks_rounding = ranks_rounding\n\n    def clear(self):\n        self.saved_metrics = defaultdict(list)\n        self.saved_text_latents = []\n        self.saved_motion_gen_latents = []\n        self.saved_motion_gt_latents = []\n\n    def _compute(\n        self,\n        motion_emb=None,\n        text_emb=None,\n        gt_motion_emb=None,\n        **kwargs,\n    ) -> Dict:\n        if motion_emb is None or text_emb is None:\n            return {}\n        motion_emb = np.asarray(motion_emb)\n        text_emb = np.asarray(text_emb)\n        if motion_emb.ndim == 1:\n            motion_emb = motion_emb[np.newaxis, :]\n        if text_emb.ndim == 1:\n            text_emb = text_emb[np.newaxis, :]\n        self.saved_text_latents.append(text_emb)\n        self.saved_motion_gen_latents.append(motion_emb)\n        if gt_motion_emb is not None:\n            gt_motion_emb = np.asarray(gt_motion_emb)\n            if gt_motion_emb.ndim == 1:\n                gt_motion_emb = gt_motion_emb[np.newaxis, :]\n            self.saved_motion_gt_latents.append(gt_motion_emb)\n        scores = get_scores_unit(motion_emb, text_emb)\n        return {\"TMR/t2m_sim\": torch.tensor(scores, dtype=torch.float32)}\n\n    def aggregate(self):\n        output = {}\n        for key, lst in self.saved_metrics.items():\n            output[key] = np.concatenate(lst)\n        if not self.saved_text_latents:\n            return output\n        text_latents = np.concatenate(self.saved_text_latents)\n        motion_gen_latents = np.concatenate(self.saved_motion_gen_latents)\n        batch_size = len(text_latents)\n        assert text_latents.shape == motion_gen_latents.shape\n        scores_t2m = get_score_matrix_unit(text_latents, motion_gen_latents)\n        scores_t2t = get_score_matrix_unit(text_latents, text_latents)\n        t2m_metrics = contrastive_metrics(\n            scores=scores_t2m,\n            scores_t2t=scores_t2t,\n            threshold=0.99,\n            rounding=self.ranks_rounding,\n        )\n        for key, val in t2m_metrics.items():\n            output[\"TMR/t2m_R/\" + key] = val\n        if batch_size >= 2:\n            mu_gen, cov_gen = calculate_activation_statistics(motion_gen_latents)\n            mu_text, cov_text = calculate_activation_statistics(text_latents)\n            output[\"TMR/FID/gen_text\"] = calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text)\n        else:\n            output[\"TMR/FID/gen_text\"] = float(\"nan\")\n        if self.saved_motion_gt_latents:\n            motion_gt_latents = np.concatenate(self.saved_motion_gt_latents)\n            assert motion_gt_latents.shape == motion_gen_latents.shape\n            scores_m2gm = get_score_matrix_unit(motion_gen_latents, motion_gt_latents)\n            scores_t2gm = get_score_matrix_unit(text_latents, motion_gt_latents)\n            m2gm_metrics = contrastive_metrics(\n                scores=scores_m2gm,\n                scores_t2t=scores_t2t,\n                threshold=0.99,\n                rounding=self.ranks_rounding,\n            )\n            for key, val in m2gm_metrics.items():\n                output[\"TMR/m2m_R/\" + key] = val\n            t2gm_metrics = contrastive_metrics(\n                scores=scores_t2gm,\n                scores_t2t=scores_t2t,\n                threshold=0.99,\n                rounding=self.ranks_rounding,\n            )\n            for key, val in t2gm_metrics.items():\n                output[\"TMR/t2m_gt_R/\" + key] = val\n            if batch_size >= 2:\n                mu_gt_motion, cov_gt_motion = calculate_activation_statistics(motion_gt_latents)\n                output[\"TMR/FID/gen_gt\"] = calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion)\n                output[\"TMR/FID/gt_text\"] = calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text)\n            else:\n                output[\"TMR/FID/gen_gt\"] = float(\"nan\")\n                output[\"TMR/FID/gt_text\"] = float(\"nan\")\n        for key, val in output.items():\n            if isinstance(val, (int, float, np.integer, np.floating)):\n                val = torch.tensor([val for _ in range(batch_size)])\n            if isinstance(val, np.ndarray):\n                val = torch.from_numpy(val)\n            output[key] = val.cpu().float()\n        return output\n\n\ndef compute_tmr_retrieval_metrics(\n    motion_emb: np.ndarray,\n    text_emb: np.ndarray,\n    gt_motion_emb: Optional[np.ndarray] = None,\n    rounding: int = 2,\n) -> Dict[str, float]:\n    \"\"\"Compute TMR retrieval metrics from precomputed embeddings.\"\"\"\n    if motion_emb.shape != text_emb.shape:\n        raise ValueError(f\"Expected same shape for motion/text embeddings, got {motion_emb.shape} vs {text_emb.shape}\")\n\n    scores_t2m = get_score_matrix_unit(text_emb, motion_emb)\n    scores_t2t = get_score_matrix_unit(text_emb, text_emb)\n\n    output: Dict[str, float] = {}\n    t2m_metrics = contrastive_metrics(\n        scores=scores_t2m,\n        scores_t2t=scores_t2t,\n        threshold=0.99,\n        rounding=rounding,\n    )\n    for key, val in t2m_metrics.items():\n        output[f\"TMR/t2m_R/{key}\"] = float(val)\n\n    n_samples = len(motion_emb)\n    if n_samples >= 2:\n        mu_gen, cov_gen = calculate_activation_statistics(motion_emb)\n        mu_text, cov_text = calculate_activation_statistics(text_emb)\n        output[\"TMR/FID/gen_text\"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_text, cov_text))\n    else:\n        output[\"TMR/FID/gen_text\"] = float(\"nan\")\n\n    if gt_motion_emb is not None:\n        if gt_motion_emb.shape != motion_emb.shape:\n            raise ValueError(f\"Expected gt motion embeddings shape {motion_emb.shape}, got {gt_motion_emb.shape}\")\n\n        scores_m2gm = get_score_matrix_unit(motion_emb, gt_motion_emb)\n        scores_t2gm = get_score_matrix_unit(text_emb, gt_motion_emb)\n\n        m2gm_metrics = contrastive_metrics(\n            scores=scores_m2gm,\n            scores_t2t=scores_t2t,\n            threshold=0.99,\n            rounding=rounding,\n        )\n        for key, val in m2gm_metrics.items():\n            output[f\"TMR/m2m_R/{key}\"] = float(val)\n\n        t2gm_metrics = contrastive_metrics(\n            scores=scores_t2gm,\n            scores_t2t=scores_t2t,\n            threshold=0.99,\n            rounding=rounding,\n        )\n        for key, val in t2gm_metrics.items():\n            output[f\"TMR/t2m_gt_R/{key}\"] = float(val)\n\n        if n_samples >= 2:\n            mu_gt_motion, cov_gt_motion = calculate_activation_statistics(gt_motion_emb)\n            output[\"TMR/FID/gen_gt\"] = float(calculate_frechet_distance(mu_gen, cov_gen, mu_gt_motion, cov_gt_motion))\n            output[\"TMR/FID/gt_text\"] = float(calculate_frechet_distance(mu_gt_motion, cov_gt_motion, mu_text, cov_text))\n        else:\n            output[\"TMR/FID/gen_gt\"] = float(\"nan\")\n            output[\"TMR/FID/gt_text\"] = float(\"nan\")\n\n    return output\n\n\ndef all_contrastive_metrics(sims, emb=None, threshold=None, rounding=2, return_cols=False):\n    text_selfsim = None\n    if emb is not None:\n        text_selfsim = emb @ emb.T\n\n    t2m_m, t2m_cols = contrastive_metrics(sims, text_selfsim, threshold, return_cols=True, rounding=rounding)\n    m2t_m, m2t_cols = contrastive_metrics(sims.T, text_selfsim, threshold, return_cols=True, rounding=rounding)\n\n    all_m = {}\n    for key in t2m_m:\n        all_m[f\"t2m/{key}\"] = t2m_m[key]\n        all_m[f\"m2t/{key}\"] = m2t_m[key]\n\n    all_m[\"t2m/len\"] = float(len(sims))\n    all_m[\"m2t/len\"] = float(len(sims[0]))\n    if return_cols:\n        return all_m, t2m_cols, m2t_cols\n    return all_m\n\n\ndef contrastive_metrics(\n    scores,\n    scores_t2t=None,\n    threshold=None,\n    rounding=2,\n):\n    n, m = scores.shape\n    assert n == m\n    num_queries = n\n\n    dists = -scores\n    sorted_dists = np.sort(dists, axis=1)\n    # GT is in the diagonal\n    gt_dists = np.diag(dists)[:, None]\n\n    if scores_t2t is not None and threshold is not None:\n        real_threshold = 2 * threshold - 1\n        idx = np.argwhere(scores_t2t > real_threshold)\n        partition = np.unique(idx[:, 0], return_index=True)[1]\n        # take as GT the minimum score of similar values\n        gt_dists = np.minimum.reduceat(dists[tuple(idx.T)], partition)\n        gt_dists = gt_dists[:, None]\n\n    rows, cols = np.where((sorted_dists - gt_dists) == 0)  # find column position of GT\n\n    # if there are ties\n    if rows.size > num_queries:\n        assert np.unique(rows).size == num_queries, \"issue in metric evaluation\"\n        avg_cols = break_ties_average(sorted_dists, gt_dists)\n        cols = avg_cols\n\n    msg = \"expected ranks to match queries ({} vs {}) \"\n    assert cols.size == num_queries, msg\n\n    metrics = {}\n    vals = [str(x).zfill(2) for x in [1, 2, 3, 5, 10]]\n    for val in vals:\n        metrics[f\"R{val}\"] = 100 * float(np.sum(cols < int(val))) / num_queries\n\n    metrics[\"MedR\"] = float(np.median(cols) + 1)\n    metrics[\"len\"] = num_queries\n\n    if rounding is not None:\n        for key in metrics:\n            metrics[key] = round(metrics[key], rounding)\n    return metrics\n\n\ndef break_ties_average(sorted_dists, gt_dists):\n    # fast implementation, based on this code:\n    # https://stackoverflow.com/a/49239335\n    locs = np.argwhere((sorted_dists - gt_dists) == 0)\n\n    # Find the split indices\n    steps = np.diff(locs[:, 0])\n    splits = np.nonzero(steps)[0] + 1\n    splits = np.insert(splits, 0, 0)\n\n    # Compute the result columns\n    summed_cols = np.add.reduceat(locs[:, 1], splits)\n    counts = np.diff(np.append(splits, locs.shape[0]))\n    avg_cols = summed_cols / counts\n    return avg_cols\n\n\ndef calculate_activation_statistics(activations):\n    \"\"\"\n    Params:\n    -- activation: num_samples x dim_feat\n    Returns:\n    -- mu: dim_feat\n    -- sigma: dim_feat x dim_feat\n    \"\"\"\n    mu = np.mean(activations, axis=0)\n    cov = np.cov(activations, rowvar=False)\n    return mu, cov\n\n\ndef calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n    \"\"\"Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate\n    Gaussians X_1 ~ N(mu_1, C_1)\n\n    and X_2 ~ N(mu_2, C_2) is\n            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n    Stable version by Dougal J. Sutherland.\n    Params:\n    -- mu1   : Numpy array containing the activations of a layer of the\n               inception net (like returned by the function 'get_predictions')\n               for generated samples.\n    -- mu2   : The sample mean over activations, precalculated on an\n               representative dataset set.\n    -- sigma1: The covariance matrix over activations for generated samples.\n    -- sigma2: The covariance matrix over activations, precalculated on an\n               representative dataset set.\n    Returns:\n    --   : The Frechet Distance.\n    \"\"\"\n\n    mu1 = np.atleast_1d(mu1)\n    mu2 = np.atleast_1d(mu2)\n\n    sigma1 = np.atleast_2d(sigma1)\n    sigma2 = np.atleast_2d(sigma2)\n\n    assert mu1.shape == mu2.shape, \"Training and test mean vectors have different lengths\"\n    assert sigma1.shape == sigma2.shape, \"Training and test covariances have different dimensions\"\n\n    diff = mu1 - mu2\n\n    # Product might be almost singular\n    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n    if not np.isfinite(covmean).all():\n        msg = (\"fid calculation produces singular product; \" \"adding %s to diagonal of cov estimates\") % eps\n        print(msg)\n        offset = np.eye(sigma1.shape[0]) * eps\n        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n    # Numerical error might give slight imaginary component\n    if np.iscomplexobj(covmean):\n        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n            # try again with diagonal %s\n            offset = np.eye(sigma1.shape[0]) * eps\n            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n                m = np.max(np.abs(covmean.imag))\n                raise ValueError(\"Imaginary component {}\".format(m))\n        covmean = covmean.real\n\n    tr_covmean = np.trace(covmean)\n\n    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n"
  },
  {
    "path": "kimodo/model/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Kimodo model package: main model class, text encoders, and loading utilities.\"\"\"\n\nfrom .common import resolve_target\nfrom .kimodo_model import Kimodo\nfrom .llm2vec import LLM2VecEncoder\nfrom .load_model import load_model\nfrom .loading import (\n    AVAILABLE_MODELS,\n    DEFAULT_MODEL,\n    DEFAULT_TEXT_ENCODER_URL,\n    MODEL_NAMES,\n    load_checkpoint_state_dict,\n)\nfrom .tmr import TMR\nfrom .twostage_denoiser import TwostageDenoiser\n\n__all__ = [\n    \"Kimodo\",\n    \"LLM2VecEncoder\",\n    \"TMR\",\n    \"TwostageDenoiser\",\n    \"load_model\",\n    \"load_checkpoint_state_dict\",\n    \"resolve_target\",\n    \"AVAILABLE_MODELS\",\n    \"DEFAULT_MODEL\",\n    \"DEFAULT_TEXT_ENCODER_URL\",\n    \"MODEL_NAMES\",\n]\n"
  },
  {
    "path": "kimodo/model/backbone.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Transformer backbone: padding, masking, and encoder stack for the denoiser.\"\"\"\n\nimport logging\nfrom typing import Optional, Union\n\nimport torch\nfrom omegaconf import ListConfig\nfrom pydantic.dataclasses import dataclass\nfrom torch import Tensor, nn\nfrom torch.nn import TransformerEncoder, TransformerEncoderLayer\n\nfrom kimodo.tools import validate\n\nlog = logging.getLogger(__name__)\n\n\ndef pad_x_and_mask_to_fixed_size(x: Tensor, mask: Tensor, size: int):\n    \"\"\"Pad a feature vector x and the mask to always have the same size.\n\n    Args:\n        x (torch.Tensor): [B, T, D]\n        mask (torch.Tensor): [B, T]\n        size (int)\n    Returns:\n        torch.Tensor: [B, size, D]\n        torch.Tensor: [B, size]\n    \"\"\"\n\n    batch_size, cur_max_size, dim = x.shape[0], x.shape[1], x.shape[2]\n\n    if cur_max_size == size:\n        # already padded to this size, probably in the collate function\n        return x, mask\n\n    if cur_max_size > size:\n        # This issue should have been handled in the collate function\n        # usefull as a check for test time\n        log.warn(\"The size of the tensor is larger than the maximum size. Cropping the input..\")\n        cur_max_size = size\n\n    new_x = torch.zeros(\n        (batch_size, size, dim),\n        dtype=x.dtype,\n        device=x.device,\n    )\n    new_x[:, :cur_max_size] = x\n\n    # same for the mask\n    new_mask = torch.zeros(\n        (batch_size, size),\n        dtype=mask.dtype,\n        device=mask.device,\n    )\n    new_mask[:, :cur_max_size] = mask\n    return new_x, new_mask\n\n\n@dataclass(frozen=True, config=dict(extra=\"forbid\", arbitrary_types_allowed=True))\nclass TransformerEncoderBlockConfig:\n    \"\"\"Configuration for the transformer encoder backbone.\"\"\"\n\n    # input features dimension\n    input_dim: int\n    # output features dimension\n    output_dim: int\n\n    # skeleton object\n    skeleton: object\n\n    # dimension of the text embeddings\n    llm_shape: Union[list[int], ListConfig]\n\n    # mask the text or not\n    use_text_mask: bool\n\n    # latent dimension of the model\n    latent_dim: int\n    # dimension of the feedforward network in transformer\n    ff_size: int\n    # num layers in transformer\n    num_layers: int\n    # num heads in transformer\n    num_heads: int\n    # activation in transformer\n    activation: str\n    # dropout rate for the transformer\n    dropout: float\n    # dropout rate for the positional embeddings\n    pe_dropout: float\n    # use norm first or not\n    norm_first: bool = False\n    # artificially extend the number of text tokens\n    num_text_tokens_override: Optional[int] = None\n\n    # Input first heading angle\n    input_first_heading_angle: bool = False\n\n\nclass TransformerEncoderBlock(nn.Module):\n    @validate(TransformerEncoderBlockConfig, save_args=True, super_init=True)\n    def __init__(self, conf):\n        self.nbjoints = self.skeleton.nbjoints\n        llm_dim = self.llm_shape[-1]\n        self.embed_text = nn.Linear(llm_dim, self.latent_dim)\n\n        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.pe_dropout)\n\n        # maximum number of tokens\n        self.num_text_tokens = self.llm_shape[0]\n        if self.num_text_tokens_override is not None:\n            self.num_text_tokens = self.num_text_tokens_override\n\n        self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)\n\n        self.input_linear = nn.Linear(self.input_dim, self.latent_dim)\n        self.output_linear = nn.Linear(self.latent_dim, self.output_dim)\n        self.linear_first_heading_angle = nn.Linear(2, self.latent_dim)\n\n        trans_enc_layer = TransformerEncoderLayer(\n            d_model=self.latent_dim,\n            nhead=self.num_heads,\n            dim_feedforward=self.ff_size,\n            dropout=self.dropout,\n            activation=self.activation,\n            batch_first=True,\n            norm_first=self.norm_first,\n        )\n        self.seqTransEncoder = TransformerEncoder(\n            trans_enc_layer,\n            num_layers=self.num_layers,\n            enable_nested_tensor=False,\n        )\n\n    def forward(\n        self,\n        x: Tensor,\n        x_pad_mask: torch.Tensor,\n        text_feat: torch.Tensor,\n        text_feat_pad_mask: torch.Tensor,\n        timesteps: Tensor,\n        first_heading_angle: Optional[Tensor] = None,\n    ) -> Tensor:\n        \"\"\"\n        Args:\n            x (torch.Tensor): [B, T, dim_motion] current noisy motion\n            x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not\n            text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts\n            text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not\n            timesteps (torch.Tensor): [B,] current denoising step\n\n        Returns:\n            torch.Tensor: [B, T, output_dim]\n        \"\"\"\n        batch_size = len(x)\n        x = self.input_linear(x)  # [B, T, D]\n\n        # Pad the text tokens + mask to always have the same size == self.num_text_tokens\n        # done here if it was not done in the collate function\n        if self.num_text_tokens is not None:\n            text_feat, text_feat_pad_mask = pad_x_and_mask_to_fixed_size(\n                text_feat,\n                text_feat_pad_mask,\n                self.num_text_tokens,\n            )\n\n        # Encode the text features and the time information\n        emb_text = self.embed_text(text_feat)  # [B, max_text_len, D]\n        emb_time = self.embed_timestep(timesteps)  # [B, 1, D]\n\n        # Create mask for the time information\n        time_mask = torch.ones((batch_size, 1), dtype=bool, device=x.device)\n\n        # Create the prefix features (text, time, etc): [B, max_text_len + 1 + etc]\n        prefix_feats = torch.cat((emb_text, emb_time), axis=1)\n\n        # Behavior from old code: not use text mask -> True for all the tokens\n        if not self.use_text_mask:\n            text_feat_pad_mask = torch.ones(\n                (batch_size, emb_text.shape[1]),\n                dtype=torch.bool,\n                device=x.device,\n            )\n\n        prefix_mask = torch.cat((text_feat_pad_mask, time_mask), axis=1)\n\n        # add the input first heading angle\n        if self.input_first_heading_angle:\n            assert first_heading_angle is not None, \"The first heading angle is mandatory for this model\"\n            # cos(angle) / sin(angle)\n            first_heading_angle_feats = torch.stack(\n                [\n                    torch.cos(first_heading_angle),\n                    torch.sin(first_heading_angle),\n                ],\n                axis=-1,\n            )\n\n            first_heading_angle_feats = self.linear_first_heading_angle(first_heading_angle_feats)\n            first_heading_angle_feats = first_heading_angle_feats[:, None]  # for cat\n            first_heading_angle_mask = torch.ones(\n                (batch_size, 1),\n                dtype=bool,\n                device=x.device,\n            )\n            prefix_feats = torch.cat((prefix_feats, first_heading_angle_feats), axis=1)\n            prefix_mask = torch.cat((prefix_mask, first_heading_angle_mask), axis=1)\n\n        # compute the number of prefix features\n        pose_start_ind = prefix_feats.shape[1]\n\n        # Concatenate prefix and x: [B, len(prefix) + T, D]\n        xseq = torch.cat((prefix_feats, x), axis=1)\n\n        # Concatenate the masks and negate them: [B, len(prefix) + T]\n        src_key_padding_mask = ~torch.cat((prefix_mask, x_pad_mask), axis=1)\n\n        # Add positional encoding\n        xseq = self.sequence_pos_encoder(xseq)\n\n        # Input to the transformer and keep the motion indexes\n        if isinstance(self.seqTransEncoder, nn.TransformerEncoder):\n            assert not self.seqTransEncoder.use_nested_tensor, \"Flash attention should be disabled due to bug!\"\n\n        output = self.seqTransEncoder(\n            xseq,\n            src_key_padding_mask=src_key_padding_mask,\n        )\n        output = output[:, pose_start_ind:]  # [B, T, D]\n        output = self.output_linear(output)  # [B, T, OD]\n        return output\n\n\nclass PositionalEncoding(nn.Module):\n    \"\"\"Non-learned positional encoding.\"\"\"\n\n    def __init__(\n        self,\n        d_model: int,\n        dropout: Optional[float] = 0.1,\n        max_len: Optional[int] = 5000,\n    ):\n        \"\"\"\n        Args:\n            d_model (int): input dim\n            dropout (Optional[float] = 0.1): dropout probability on output\n            max_len (Optional[int] = 5000): maximum sequence length\n        \"\"\"\n        super(PositionalEncoding, self).__init__()\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n\n        # Note: have to replace torch.exp() and math.log() with torch.pow()\n        # due to MKL exp() and ln() throws floating point exceptions on certain CPUs\n        # see corresponding commit and MR\n        div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)\n        # div_term = torch.exp(\n        #     torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)\n        # )\n\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0)  # [1, T, D]\n\n        self.register_buffer(\"pe\", pe, persistent=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Apply positional encoding to input sequence.\n\n        Args:\n            x (torch.Tensor): [B, T, D] input motion sequence\n\n        Returns:\n            torch.Tensor: [B, T, D] input motion with PE added to it (and optionally dropout)\n        \"\"\"\n        x = x + self.pe[:, : x.shape[1], :]\n        return self.dropout(x)\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"Encoder for diffusion step.\"\"\"\n\n    def __init__(self, latent_dim: int, sequence_pos_encoder: PositionalEncoding):\n        \"\"\"\n        Args:\n            latent_dim (int): dim to encode to\n            sequence_pos_encoder (PositionalEncoding): the PE to use on timesteps\n        \"\"\"\n        super().__init__()\n        self.latent_dim = latent_dim\n        self.sequence_pos_encoder = sequence_pos_encoder\n\n        time_embed_dim = self.latent_dim\n        self.time_embed = nn.Sequential(\n            nn.Linear(self.latent_dim, time_embed_dim),\n            nn.SiLU(),\n            nn.Linear(time_embed_dim, time_embed_dim),\n        )\n\n    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:\n        \"\"\"Embed timesteps by adding PE then going through linear layers.\n\n        Args:\n            timesteps (torch.Tensor): [B]\n\n        Returns:\n            torch.Tensor: [B, 1, D]\n        \"\"\"\n        return self.time_embed(self.sequence_pos_encoder.pe.transpose(0, 1)[timesteps])\n"
  },
  {
    "path": "kimodo/model/cfg.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Classifier-free guidance wrapper for the denoiser at sampling time.\"\"\"\n\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nCFG_TYPES = [\"nocfg\", \"regular\", \"separated\"]\n\n\nclass ClassifierFreeGuidedModel(nn.Module):\n    \"\"\"Wrapper around denoiser to use classifier-free guidance at sampling time.\"\"\"\n\n    def __init__(self, model: nn.Module, cfg_type: Optional[str] = \"separated\"):\n        \"\"\"Wrap the denoiser for classifier-free guidance; cfg_type in CFG_TYPES (e.g. 'regular',\n        'nocfg').\"\"\"\n        super().__init__()\n        self.model = model\n        assert cfg_type in CFG_TYPES, f\"Invalid cfg_type: {cfg_type}\"\n        self.cfg_type_default = cfg_type\n\n    def forward(\n        self,\n        cfg_weight: Union[float, Tuple[float, float]],\n        x: torch.Tensor,\n        x_pad_mask: torch.Tensor,\n        text_feat: torch.Tensor,\n        text_feat_pad_mask: torch.Tensor,\n        timesteps: torch.Tensor,\n        first_heading_angle: Optional[torch.Tensor] = None,\n        motion_mask: Optional[torch.Tensor] = None,\n        observed_motion: Optional[torch.Tensor] = None,\n        cfg_type: Optional[str] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            cfg_weight (float): guidance weight float or tuple of floats with (text, constraint) weights if using separated cfg\n            x (torch.Tensor): [B, T, dim_motion] current noisy motion\n            x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not\n            text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts\n            text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not\n            timesteps (torch.Tensor): [B,] current denoising step\n            motion_mask\n            observed_motion\n            neutral_joints (torch.Tensor): [B, nbjoints] The neutral joints of the motions\n\n        Returns:\n            torch.Tensor: same size as input x\n        \"\"\"\n\n        if cfg_type is None:\n            cfg_type = self.cfg_type_default\n\n        assert cfg_type in CFG_TYPES, f\"Invalid cfg_type: {cfg_type}\"\n\n        # batched conditional and uncond pass together\n        if cfg_type == \"nocfg\":\n            return self.model(\n                x,\n                x_pad_mask,\n                text_feat,\n                text_feat_pad_mask,\n                timesteps,\n                first_heading_angle=first_heading_angle,\n                motion_mask=motion_mask,\n                observed_motion=observed_motion,\n            )\n        elif cfg_type == \"regular\":\n            assert isinstance(cfg_weight, (float, int)), \"cfg_weight must be a single float for regular CFG\"\n            # out_uncond + w * (out_text_and_constraint - out_uncond)\n            text_feat = torch.concatenate([text_feat, 0 * text_feat], dim=0)\n            if motion_mask is not None:\n                motion_mask = torch.concatenate([motion_mask, 0 * motion_mask], dim=0)\n            if observed_motion is not None:\n                observed_motion = torch.concatenate([observed_motion, observed_motion], dim=0)\n            if first_heading_angle is not None:\n                first_heading_angle = torch.concatenate([first_heading_angle, first_heading_angle], dim=0)\n\n            out_cond_uncond = self.model(\n                torch.concatenate([x, x], dim=0),\n                torch.concatenate([x_pad_mask, x_pad_mask], dim=0),\n                text_feat,\n                torch.concatenate([text_feat_pad_mask, False * text_feat_pad_mask], dim=0),\n                torch.concatenate([timesteps, timesteps], dim=0),\n                first_heading_angle=first_heading_angle,\n                motion_mask=motion_mask,\n                observed_motion=observed_motion,\n            )\n\n            out, out_uncond = torch.chunk(out_cond_uncond, 2)\n            out_new = out_uncond + (cfg_weight * (out - out_uncond))\n        elif cfg_type == \"separated\":\n            assert len(cfg_weight) == 2, \"cfg_weight must be a tuple of two floats for separated CFG\"\n            # out_uncond + w_text * (out_text - out_uncond) + w_constraint * (out_constraint - out_uncond)\n            text_feat = torch.concatenate([text_feat, 0 * text_feat, 0 * text_feat], dim=0)\n            if motion_mask is not None:\n                motion_mask = torch.concatenate([0 * motion_mask, motion_mask, 0 * motion_mask], dim=0)\n            if observed_motion is not None:\n                observed_motion = torch.concatenate([observed_motion, observed_motion, observed_motion], dim=0)\n            if first_heading_angle is not None:\n                first_heading_angle = torch.concatenate(\n                    [first_heading_angle, first_heading_angle, first_heading_angle],\n                    dim=0,\n                )\n\n            out_cond_uncond = self.model(\n                torch.concatenate([x, x, x], dim=0),\n                torch.concatenate([x_pad_mask, x_pad_mask, x_pad_mask], dim=0),\n                text_feat,\n                torch.concatenate(\n                    [\n                        text_feat_pad_mask,\n                        False * text_feat_pad_mask,\n                        False * text_feat_pad_mask,\n                    ],\n                    dim=0,\n                ),\n                torch.concatenate([timesteps, timesteps, timesteps], dim=0),\n                first_heading_angle=first_heading_angle,\n                motion_mask=motion_mask,\n                observed_motion=observed_motion,\n            )\n\n            out_text, out_constraint, out_uncond = torch.chunk(out_cond_uncond, 3)\n            out_new = (\n                out_uncond + (cfg_weight[0] * (out_text - out_uncond)) + (cfg_weight[1] * (out_constraint - out_uncond))\n            )\n        else:\n            raise ValueError(f\"Invalid cfg_type: {cfg_type}\")\n\n        return out_new\n"
  },
  {
    "path": "kimodo/model/common.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Config hydration: env vars, _target_ resolution, and recursive instantiation.\"\"\"\n\nimport importlib\nimport os\n\n\ndef get_env_var(name: str, default=None):\n    \"\"\"Read env var by name and by lowercased name; return default if neither set.\"\"\"\n    return os.getenv(name, os.getenv(name.lower(), default))\n\n\ndef resolve_target(target: str):\n    \"\"\"Import module and return the attribute named by a dotted path (e.g. 'pkg.mod.Class').\"\"\"\n    module_name, attr_name = target.rsplit(\".\", 1)\n    module = importlib.import_module(module_name)\n    return getattr(module, attr_name)\n\n\ndef materialize_value(value):\n    \"\"\"Recursively turn dicts with '_target_' into instances; lists/dicts traversed; leaves\n    unchanged.\"\"\"\n    if isinstance(value, dict):\n        if \"_target_\" in value:\n            return instantiate_from_dict(value)\n        return {k: materialize_value(v) for k, v in value.items()}\n    if isinstance(value, list):\n        return [materialize_value(v) for v in value]\n    return value\n\n\ndef instantiate_from_dict(node, overrides=None):\n    \"\"\"Build an instance from a config dict: '_target_' gives the class, other keys are kwargs; overrides merged in.\"\"\"\n    if not isinstance(node, dict) or \"_target_\" not in node:\n        raise ValueError(\"Config node must be a dict with a '_target_' key.\")\n\n    target = resolve_target(node[\"_target_\"])\n    kwargs = {}\n    for key, value in node.items():\n        if key == \"_target_\":\n            continue\n        kwargs[key] = materialize_value(value)\n\n    if overrides:\n        kwargs.update({k: v for k, v in overrides.items() if v is not None})\n\n    return target(**kwargs)\n"
  },
  {
    "path": "kimodo/model/diffusion.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Diffusion process and DDIM sampling for motion generation.\"\"\"\n\nimport math\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import nn\n\n\ndef get_beta_schedule(\n    num_diffusion_timesteps: int,\n    max_beta: Optional[float] = 0.999,\n) -> torch.Tensor:\n    \"\"\"Get cosine beta schedule.\"\"\"\n\n    def alpha_bar(t):\n        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float)\n\n\nclass Diffusion(torch.nn.Module):\n    \"\"\"Cosine-schedule diffusion process: betas, alphas, and DDIM step mapping.\"\"\"\n\n    def __init__(self, num_base_steps: int):\n        \"\"\"Set up cosine beta schedule and precompute diffusion variables for num_base_steps.\"\"\"\n        super().__init__()\n        self.num_base_steps = num_base_steps\n        betas_base = get_beta_schedule(self.num_base_steps)\n        self.register_buffer(\"betas_base\", betas_base, persistent=False)\n        alphas_cumprod_base = torch.cumprod(1.0 - self.betas_base, dim=0)\n        self.register_buffer(\"alphas_cumprod_base\", alphas_cumprod_base, persistent=False)\n        use_timesteps, _ = self.space_timesteps(self.num_base_steps)\n        self.calc_diffusion_vars(use_timesteps)\n\n    def extra_repr(self) -> str:\n        return f\"num_base_steps={self.num_base_steps}\"\n\n    @property\n    def device(self):\n        return self.betas_base.device\n\n    def space_timesteps(self, num_denoising_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Return (use_timesteps, map_tensor) for a subsampled denoising schedule of\n        num_denoising_steps.\"\"\"\n        nsteps_train = self.num_base_steps\n        frac_stride = (nsteps_train - 1) / max(1, num_denoising_steps - 1)\n        use_timesteps = torch.round(torch.arange(nsteps_train, device=self.device) * frac_stride).to(torch.long)\n        use_timesteps = torch.clamp(use_timesteps, max=nsteps_train - 1)\n        map_tensor = torch.arange(nsteps_train, device=self.device, dtype=torch.long)[use_timesteps]\n        return use_timesteps, map_tensor\n\n    def calc_diffusion_vars(self, use_timesteps: torch.Tensor) -> None:\n        \"\"\"Update buffers (betas, alphas, alphas_cumprod, etc.) for the given subsampled\n        timesteps.\"\"\"\n        alphas_cumprod = self.alphas_cumprod_base[use_timesteps]\n        last_alpha_cumprod = torch.cat([torch.tensor([1.0]).to(alphas_cumprod), alphas_cumprod[:-1]])\n        betas = 1.0 - alphas_cumprod / last_alpha_cumprod\n        self.register_buffer(\"betas\", betas, persistent=False)\n\n        alphas = 1.0 - self.betas\n        self.register_buffer(\"alphas\", alphas, persistent=False)\n        alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n        alphas_cumprod = torch.clamp(alphas_cumprod, min=1e-9)\n        self.register_buffer(\"alphas_cumprod\", alphas_cumprod, persistent=False)\n\n        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]).to(self.alphas_cumprod), self.alphas_cumprod[:-1]])\n        self.register_buffer(\"alphas_cumprod_prev\", alphas_cumprod_prev, persistent=False)\n\n        sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)\n        self.register_buffer(\"sqrt_recip_alphas_cumprod\", sqrt_recip_alphas_cumprod, persistent=False)\n\n        sqrt_recipm1_alphas_cumprod = torch.rsqrt(self.alphas_cumprod / (1.0 - self.alphas_cumprod))\n        self.register_buffer(\"sqrt_recipm1_alphas_cumprod\", sqrt_recipm1_alphas_cumprod, persistent=False)\n\n        posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)\n        self.register_buffer(\"posterior_variance\", posterior_variance, persistent=False)\n\n        sqrt_alphas_cumprod = torch.rsqrt(1.0 / self.alphas_cumprod)\n        self.register_buffer(\"sqrt_alphas_cumprod\", sqrt_alphas_cumprod, persistent=False)\n\n        sqrt_one_minus_alphas_cumprod = torch.rsqrt(1.0 / (1.0 - self.alphas_cumprod))\n        self.register_buffer(\n            \"sqrt_one_minus_alphas_cumprod\",\n            sqrt_one_minus_alphas_cumprod,\n            persistent=False,\n        )\n\n    def q_sample(\n        self,\n        x_start: torch.Tensor,\n        t: torch.Tensor,\n        noise: torch.Tensor = None,\n    ):\n        if noise is None:\n            noise = torch.randn_like(x_start)\n        assert noise.shape == x_start.shape\n\n        xt = (\n            self.sqrt_alphas_cumprod[t, None, None] * x_start\n            + self.sqrt_one_minus_alphas_cumprod[t, None, None] * noise\n        )\n        return xt\n\n\nclass DDIMSampler(nn.Module):\n    \"\"\"Deterministic DDIM sampler (eta = 0).\"\"\"\n\n    def __init__(self, diffusion: Diffusion):\n        super().__init__()\n        self.diffusion = diffusion\n\n    def __call__(\n        self,\n        use_timesteps: torch.Tensor,\n        x_t: torch.Tensor,\n        pred_xstart: torch.Tensor,\n        t: torch.Tensor,\n    ) -> torch.Tensor:\n        self.diffusion.calc_diffusion_vars(use_timesteps)\n        eps = (\n            self.diffusion.sqrt_recip_alphas_cumprod[t, None, None] * x_t - pred_xstart\n        ) / self.diffusion.sqrt_recipm1_alphas_cumprod[t, None, None]\n        alpha_bar_prev = self.diffusion.alphas_cumprod_prev[t, None, None]\n        x = pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev) * eps\n        return x\n"
  },
  {
    "path": "kimodo/model/kimodo_model.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Kimodo model: denoiser, text encoder, diffusion sampling, and post-processing.\"\"\"\n\nimport logging\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom tqdm.auto import tqdm\n\nfrom kimodo.constraints import EndEffectorConstraintSet, FullBodyConstraintSet\nfrom kimodo.motion_rep.feature_utils import compute_heading_angle, length_to_mask\nfrom kimodo.postprocess import post_process_motion\nfrom kimodo.sanitize import sanitize_texts\nfrom kimodo.skeleton import SOMASkeleton30\nfrom kimodo.tools import to_numpy\n\nfrom .cfg import ClassifierFreeGuidedModel\nfrom .diffusion import DDIMSampler, Diffusion\n\nlog = logging.getLogger(__name__)\n\n\nclass Kimodo(nn.Module):\n    \"\"\"Helper class for test time.\"\"\"\n\n    def __init__(\n        self,\n        denoiser: nn.Module,\n        text_encoder: nn.Module,\n        num_base_steps: int,\n        device: Optional[Union[str, torch.device]] = None,\n        cfg_type: Optional[str] = \"separated\",\n    ):\n        super().__init__()\n\n        self.denoiser = denoiser.eval()\n\n        if cfg_type is None:\n            cfg_type = \"nocfg\"\n\n        # Add Classifier-free guidance to the model if needed\n        self.denoiser = ClassifierFreeGuidedModel(self.denoiser, cfg_type=cfg_type)\n\n        self.motion_rep = denoiser.motion_rep\n        self.skeleton = self.motion_rep.skeleton\n\n        self.fps = denoiser.motion_rep.fps\n\n        self.diffusion = Diffusion(num_base_steps=num_base_steps)\n        self.sampler = DDIMSampler(self.diffusion)\n        self.text_encoder = text_encoder\n\n        self.device = device\n        # for classifier-free guidance\n\n        self.to(device)\n\n    @property\n    def output_skeleton(self):\n        \"\"\"Skeleton used for model output (somaskel77 for SOMA, else unchanged).\"\"\"\n        if isinstance(self.skeleton, SOMASkeleton30):\n            return self.skeleton.somaskel77\n        return self.skeleton\n\n    def train(self, mode: bool):\n        self.denoiser.train(mode)\n        return self\n\n    def eval(self):\n        self.denoiser.eval()\n        return self\n\n    def denoising_step(\n        self,\n        motion: torch.Tensor,\n        pad_mask: torch.Tensor,\n        text_feat: torch.Tensor,\n        text_pad_mask: torch.Tensor,\n        t: torch.Tensor,\n        first_heading_angle: Optional[torch.Tensor],\n        motion_mask: torch.Tensor,\n        observed_motion: torch.Tensor,\n        num_denoising_steps: torch.Tensor,\n        cfg_weight: Union[float, Tuple[float, float]],\n        guide_masks: Optional[Dict] = None,\n        cfg_type: Optional[str] = None,\n    ) -> torch.Tensor:\n        \"\"\"Single denoising step.\n\n        Returns:\n            torch.Tensor: [B, T, D] noisy motion input to t-1\n        \"\"\"\n        # subsample timesteps\n        #   NOTE: do this at every step due to ONNX export, i.e. num_samp_stepsmay change dynamically when\n        #       running onnx version so need to account for that.\n        num_denoising_steps = num_denoising_steps[0]\n        use_timesteps, map_tensor = self.diffusion.space_timesteps(num_denoising_steps)\n        self.diffusion.calc_diffusion_vars(use_timesteps)\n\n        # first compute initial clean prediction from denoiser\n        t_map = map_tensor[t]\n\n        with torch.inference_mode():\n            pred_clean = self.denoiser(\n                cfg_weight,\n                motion,\n                pad_mask,\n                text_feat,\n                text_pad_mask,\n                t_map,\n                first_heading_angle,\n                motion_mask,\n                observed_motion,\n                cfg_type=cfg_type,\n            )\n\n        # sampler computes next step noisy motion\n        x_tm1 = self.sampler(use_timesteps, motion, pred_clean, t)\n        return x_tm1\n\n    def _multiprompt(\n        self,\n        prompts: list[str],\n        num_frames: int | list[int],\n        num_denoising_steps: int,\n        constraint_lst: Optional[list] = [],\n        cfg_weight: Optional[float] = [2.0, 2.0],\n        num_samples: Optional[int] = None,\n        cfg_type: Optional[str] = None,\n        return_numpy: bool = False,\n        first_heading_angle: Optional[torch.Tensor] = None,\n        # for transitioning\n        num_transition_frames: int = 5,\n        # for postprocess\n        post_processing: bool = False,\n        root_margin: float = 0.04,\n        # progress bar\n        progress_bar=tqdm,\n    ) -> torch.Tensor:\n        device = self.device\n\n        bs = num_samples\n        texts = sanitize_texts(prompts)\n\n        if isinstance(num_frames, int):\n            # same duration for all the segments\n            num_frames = [num_frames for _ in range(num_samples)]\n\n        tosqueeze = False\n        if num_samples is None:\n            num_samples = 1\n            tosqueeze = True\n\n        if constraint_lst is None:\n            constraint_lst = []\n\n        # Generate one chunck at a time\n        current_frame = 0\n        generated_motions = []\n\n        for idx, (text, num_frame) in enumerate(zip(texts, num_frames)):\n            texts_bs = [text for _ in range(num_samples)]\n\n            lengths = torch.tensor(\n                [num_frame for _ in range(num_samples)],\n                device=device,\n            )\n\n            is_first_motion = not generated_motions\n\n            observed_motion, motion_mask = None, None\n\n            # filter the constraint_lst to only keep the relevent ones\n            constraint_lst_base = [\n                constraint.crop_move(current_frame, current_frame + num_frame) for constraint in constraint_lst\n            ]  # this move temporally but not spatially\n\n            observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(\n                constraint_lst_base,\n                lengths,\n                to_normalize=False,  # don't normalize yet, it needs to be moved around\n                device=device,\n            )\n\n            if not is_first_motion:\n                nb_transition_frames = num_transition_frames\n\n                if nb_transition_frames < 1:\n                    raise ValueError(f\"num_transition_frames must be at least 1, got {nb_transition_frames}\")\n\n                latest_motions = generated_motions.pop()\n                # remove the transition part of A (will be put back afterward)\n                generated_motions.append(latest_motions[:, :-nb_transition_frames])\n                latest_frames = latest_motions[:, -nb_transition_frames:]\n\n                last_output = self.motion_rep.inverse(\n                    latest_frames,\n                    is_normalized=False,\n                    return_numpy=False,\n                )\n                smooth_root_2d = last_output[\"smooth_root_pos\"][..., [0, 2]]\n\n                # add constraints at the begining to allow natural transitions\n                constraint_lst_transition = []\n                for batch_id in range(bs):\n                    new_constraint = FullBodyConstraintSet(\n                        self.skeleton,\n                        torch.arange(num_transition_frames),\n                        last_output[\"posed_joints\"][batch_id, :num_transition_frames],\n                        last_output[\"global_rot_mats\"][batch_id, :num_transition_frames],\n                        smooth_root_2d[batch_id, :num_transition_frames],\n                    )\n                    # separate end-effector constraint to capture hand/feet rotations\n                    new_ee_constraint = EndEffectorConstraintSet(\n                        self.skeleton,\n                        torch.arange(num_transition_frames),\n                        last_output[\"posed_joints\"][batch_id, :num_transition_frames],\n                        last_output[\"global_rot_mats\"][batch_id, :num_transition_frames],\n                        smooth_root_2d[batch_id, :num_transition_frames],\n                        joint_names=[\"LeftHand\", \"RightHand\", \"LeftFoot\", \"RightFoot\"],\n                    )\n\n                    constraint_lst_transition.append([new_constraint, new_ee_constraint])\n\n                transition_lengths = torch.tensor(\n                    [nb_transition_frames for _ in range(num_samples)],\n                    device=device,\n                )\n\n                observed_motion_transition, motion_mask_transition = (\n                    self.motion_rep.create_conditions_from_constraints_batched(\n                        constraint_lst_transition,\n                        transition_lengths,\n                        to_normalize=False,  # don't normalize yet\n                        device=device,\n                    )\n                )\n\n                # concatenate the obversed motion / motion mask\n                observed_motion = torch.cat([observed_motion_transition, observed_motion], axis=1)\n                motion_mask = torch.cat([motion_mask_transition, motion_mask], axis=1)\n\n                # we need to move each observed motion in the batch to the new starting points\n                last_smooth_root_2d = smooth_root_2d[:, 0]\n                observed_motion = self.motion_rep.translate_2d(\n                    observed_motion, -last_smooth_root_2d\n                )  # equivalent to:  self.motion_rep.translate_2d_to_zero(observed_motion)\n\n                # remove dummy values after moving\n                observed_motion = observed_motion * motion_mask\n\n                lengths = lengths + transition_lengths\n                first_heading_angle = compute_heading_angle(last_output[\"posed_joints\"], self.skeleton)[:, 0]\n            else:\n                if first_heading_angle is None:\n                    # Start at 0 angle, but this will change afterward\n                    first_heading_angle = torch.tensor([0.0] * bs, device=device)\n                else:\n                    first_heading_angle = torch.as_tensor(first_heading_angle, device=device)\n                    if first_heading_angle.numel() == 1:\n                        first_heading_angle = first_heading_angle.repeat(bs)\n\n            observed_motion = self.motion_rep.normalize(observed_motion)\n\n            max_frames = max(lengths)\n            motion_pad_mask = length_to_mask(lengths)\n\n            motion = self._generate(\n                texts_bs,\n                max_frames,\n                num_denoising_steps=num_denoising_steps,\n                pad_mask=motion_pad_mask,\n                first_heading_angle=first_heading_angle,\n                motion_mask=motion_mask,\n                observed_motion=observed_motion,\n                cfg_weight=cfg_weight,\n                cfg_type=cfg_type,\n            )\n\n            motion = self.motion_rep.unnormalize(motion)\n\n            if not is_first_motion:\n                motion_with_transition = self.motion_rep.translate_2d(\n                    motion,\n                    last_smooth_root_2d,\n                )\n\n                if post_processing:\n                    # Per-segment postprocessing: inverse, postprocess, re-encode.\n                    # The full transition+segment is postprocessed together so the\n                    # transition constraints keep the junction smooth.\n                    seg_output = self.motion_rep.inverse(\n                        motion_with_transition, is_normalized=False, return_numpy=False,\n                    )\n                    seg_constraints = [list(cl) for cl in constraint_lst_transition]\n                    for bi in range(bs):\n                        seg_constraints[bi].extend(\n                            [c.crop_move(current_frame - nb_transition_frames,\n                                         current_frame - nb_transition_frames + num_frame + nb_transition_frames)\n                             for c in constraint_lst]\n                        )\n                    corrected = post_process_motion(\n                        seg_output[\"local_rot_mats\"],\n                        seg_output[\"root_positions\"],\n                        seg_output[\"foot_contacts\"],\n                        self.skeleton,\n                        seg_constraints,\n                        root_margin=root_margin,\n                    )\n                    seg_output.update(corrected)\n                    motion = self.motion_rep(\n                        seg_output[\"local_rot_mats\"],\n                        seg_output[\"root_positions\"],\n                        to_normalize=False,\n                        lengths=lengths,\n                    )\n                else:\n                    motion = motion_with_transition[:, num_transition_frames:]\n                    transition_frames = motion_with_transition[:, :num_transition_frames]\n\n                    # linearly combine the previously generated transitions with the newly generated ones\n                    alpha = torch.linspace(1, 0, num_transition_frames, device=device)[:, None]\n                    new_transition_frames = (\n                        latest_frames[:, :num_transition_frames] * alpha + (1 - alpha) * transition_frames\n                    )\n\n                    # add new transitions frames for A (merging with B prediction of the history)\n                    generated_motions.append(new_transition_frames)\n\n            elif post_processing:\n                # First segment: postprocess immediately\n                seg_output = self.motion_rep.inverse(\n                    motion, is_normalized=False, return_numpy=False,\n                )\n                seg_constraints = constraint_lst_base if constraint_lst_base else []\n                corrected = post_process_motion(\n                    seg_output[\"local_rot_mats\"],\n                    seg_output[\"root_positions\"],\n                    seg_output[\"foot_contacts\"],\n                    self.skeleton,\n                    seg_constraints,\n                    root_margin=root_margin,\n                )\n                seg_output.update(corrected)\n                motion = self.motion_rep(\n                    seg_output[\"local_rot_mats\"],\n                    seg_output[\"root_positions\"],\n                    to_normalize=False,\n                    lengths=lengths,\n                )\n\n            generated_motions.append(motion)\n            current_frame += num_frame\n\n        generated_motions = torch.cat(generated_motions, axis=1)  # temporal axis (b, t, d)\n\n        if tosqueeze:\n            generated_motions = generated_motions[0]\n\n        output = self.motion_rep.inverse(\n            generated_motions,\n            is_normalized=False,\n            return_numpy=False,\n        )\n\n        # Post-processing: already applied per-segment inside the loop above,\n        # so no additional post-processing pass is needed here.\n\n        # Convert SOMA output to somaskel77 for external API\n        if isinstance(self.skeleton, SOMASkeleton30):\n            output = self.skeleton.output_to_SOMASkeleton77(output)\n\n        # Convert to numpy if requested\n        if return_numpy:\n            output = to_numpy(output)\n        return output\n\n    def __call__(\n        self,\n        prompts: str | list[str],\n        num_frames: int | list[int],\n        num_denoising_steps: int,\n        multi_prompt: bool = False,\n        constraint_lst: Optional[list] = [],\n        cfg_weight: Optional[float] = [2.0, 2.0],\n        num_samples: Optional[int] = None,\n        cfg_type: Optional[str] = None,\n        return_numpy: bool = False,\n        first_heading_angle: Optional[torch.Tensor] = None,\n        # for transitioning\n        num_transition_frames: int = 5,\n        # for postprocess\n        post_processing: bool = False,\n        root_margin: float = 0.04,\n        # progress bar\n        progress_bar=tqdm,\n    ) -> dict:\n        \"\"\"Generate motion from text prompts and optional kinematic constraints.\n\n        When a single prompt/num_frames pair is given, one motion is generated.\n        Passing lists of prompts and/or num_frames produces a batch of\n        independent motions. With ``multi_prompt=True``, the prompts are\n        treated as sequential segments that are generated and stitched together\n        with smooth transitions.\n\n        Args:\n            prompts: One or more text descriptions of the desired motion.\n                A single string generates one sample; a list generates a batch\n                (or sequential segments when ``multi_prompt=True``).\n            num_frames: Duration of the generated motion in frames.  Can be a\n                single int applied to every prompt or a per-prompt list.\n            num_denoising_steps: Number of DDIM denoising steps.  More steps\n                generally improve quality at the cost of speed.\n            multi_prompt: If ``True``, treat ``prompts`` as an ordered sequence\n                of segments and concatenate them with transitions.\n            constraint_lst: Per-sample list of kinematic constraints (e.g.\n                keyframe poses, end-effector targets, 2-D paths).  Pass an\n                empty list for unconstrained generation.\n            cfg_weight: Classifier-free guidance scale(s).  A two-element list\n                ``[text_cfg, constraint_cfg]`` controls text and constraint\n                guidance independently.\n            num_samples: Number of samples to generate.\n            cfg_type: Override the default CFG strategy set at init\n                (e.g. ``\"separated\"``).\n            return_numpy: If ``True``, convert all output tensors to numpy\n                arrays.\n            first_heading_angle: Initial body heading in radians.  Shape\n                ``(B,)`` or scalar.  Defaults to ``0`` (facing +Z).\n            num_transition_frames: Number of overlapping frames used to blend\n                consecutive segments in multi-prompt mode.\n            post_processing: If ``True``, apply post-processing\n                (foot-skate cleanup and constraint enforcement).\n            root_margin: Horizontal margin (in meters) used by the post-processor\n                to determine when to correct root motion. When root deviates more than\n                margin from the constraint, the post-processor will correct it.\n            progress_bar: Callable wrapping an iterable to display progress\n                (default: ``tqdm``).  Pass a no-op to silence output.\n\n        Returns:\n            dict: A dictionary of motion tensors (or numpy arrays if\n            ``return_numpy=True``) with the following keys:\n\n            - ``local_rot_mats`` – Local joint rotations as rotation matrices.\n            - ``global_rot_mats`` – Global joint rotations as rotation matrices.\n            - ``posed_joints`` – Joint positions in world space.\n            - ``root_positions`` – Root joint positions.\n            - ``smooth_root_pos`` – Smoothed root trajectory.\n            - ``foot_contacts`` – Boolean foot-contact labels [left heel, left toe, right heel, right toe].\n            - ``global_root_heading`` – Root heading angle over time.\n        \"\"\"\n        device = self.device\n\n        if multi_prompt:\n            # multi prompt generation\n            return self._multiprompt(\n                prompts,\n                num_frames,\n                num_denoising_steps,\n                constraint_lst,\n                cfg_weight,\n                num_samples,\n                cfg_type,\n                return_numpy,\n                first_heading_angle,\n                num_transition_frames,\n                post_processing,\n                root_margin,\n                progress_bar,\n            )\n\n        # Input checking\n        tosqueeze = False\n        if isinstance(prompts, list) and isinstance(num_frames, list):\n            assert len(prompts) == len(num_frames), \"The number of prompts should match the number of num_frames.\"\n            num_samples = len(prompts)\n        elif isinstance(prompts, list):\n            num_samples = len(prompts)\n            num_frames = [num_frames for _ in range(num_samples)]\n        elif isinstance(num_frames, list):\n            num_samples = len(num_frames)\n            prompts = [prompts for _ in range(num_samples)]\n        else:\n            if num_samples is None:\n                tosqueeze = True\n                num_samples = 1\n            prompts = [prompts for _ in range(num_samples)]\n            num_frames = [num_frames for _ in range(num_samples)]\n\n        bs = num_samples\n        texts = sanitize_texts(prompts)\n\n        lengths = torch.tensor(\n            num_frames,\n            device=device,\n        )\n        max_frames = max(lengths)\n        motion_pad_mask = length_to_mask(lengths)\n\n        if first_heading_angle is None:\n            # Start at 0 angle\n            first_heading_angle = torch.tensor([0.0] * bs, device=device)\n        else:\n            first_heading_angle = torch.as_tensor(first_heading_angle, device=device)\n            if first_heading_angle.numel() == 1:\n                first_heading_angle = first_heading_angle.repeat(bs)\n\n        observed_motion, motion_mask = None, None\n        if constraint_lst:\n            observed_motion, motion_mask = self.motion_rep.create_conditions_from_constraints_batched(\n                constraint_lst,\n                lengths,\n                to_normalize=True,\n                device=device,\n            )\n\n        motion = self._generate(\n            texts,\n            max_frames,\n            num_denoising_steps=num_denoising_steps,\n            pad_mask=motion_pad_mask,\n            first_heading_angle=first_heading_angle,\n            motion_mask=motion_mask,\n            observed_motion=observed_motion,\n            cfg_weight=cfg_weight,\n            cfg_type=cfg_type,\n            progress_bar=progress_bar,\n        )\n\n        if tosqueeze:\n            motion = motion[0]\n\n        output = self.motion_rep.inverse(\n            motion,\n            is_normalized=True,\n            return_numpy=False,  # Keep as tensor for potential post-processing\n        )\n\n        # Apply post-processing if requested\n        if post_processing:\n            corrected = post_process_motion(\n                output[\"local_rot_mats\"],\n                output[\"root_positions\"],\n                output[\"foot_contacts\"],\n                self.skeleton,\n                constraint_lst,\n                root_margin=root_margin,\n            )\n            # key frame outputs / foot contacts are not changed\n            output.update(corrected)\n\n        # Convert SOMA output to somaskel77 for external API\n        if isinstance(self.skeleton, SOMASkeleton30):\n            output = self.skeleton.output_to_SOMASkeleton77(output)\n\n        # Convert to numpy if requested\n        if return_numpy:\n            output = to_numpy(output)\n        return output\n\n    def _generate(\n        self,\n        texts: List[str],\n        max_frames: int,\n        num_denoising_steps: int,\n        pad_mask: torch.Tensor,\n        first_heading_angle: Optional[torch.Tensor],\n        motion_mask: torch.Tensor,\n        observed_motion: torch.Tensor,\n        cfg_weight: Optional[float] = 2.0,\n        text_feat: Optional[torch.Tensor] = None,\n        text_pad_mask: Optional[torch.Tensor] = None,\n        guide_masks: Optional[Dict] = None,\n        cfg_type: Optional[str] = None,\n        progress_bar=tqdm,\n    ) -> torch.Tensor:\n        \"\"\"Sample full denoising loop.\n\n        Args:\n            texts (List[str]): batch of text prompts to use for sampling (if text_feat is not passed in)\n        \"\"\"\n\n        device = self.device\n        if text_feat is None:\n            assert text_pad_mask is None\n            log.info(\"Encoding text...\")\n            text_feat, text_length = self.text_encoder(texts)\n            text_feat = text_feat.to(device)\n\n            # handle empty string (set to zero)\n            empty_text_mask = [len(text.strip()) == 0 for text in texts]\n            text_feat[empty_text_mask] = 0\n\n            # Create the pad mask for the text\n            batch_size, maxlen = text_feat.shape[:2]\n            tensor_text_length = torch.tensor(text_length, device=device)\n            tensor_text_length[empty_text_mask] = 0\n            text_pad_mask = torch.arange(maxlen, device=device).expand(batch_size, maxlen) < tensor_text_length[:, None]\n\n        if motion_mask is not None:\n            if motion_mask.dtype == torch.bool:\n                motion_mask = 1 * motion_mask\n\n        batch_size = text_feat.shape[0]\n\n        # sample loop\n        indices = list(range(num_denoising_steps))[::-1]\n        shape = (batch_size, max_frames, self.motion_rep.motion_rep_dim)\n        cur_mot = torch.randn(shape, device=self.device)\n        num_denoising_steps = torch.tensor(\n            [num_denoising_steps], device=self.device\n        )  # this and t need to be tensor for onnx export\n        # init diffusion with correct num steps before looping\n        use_timesteps = self.diffusion.space_timesteps(num_denoising_steps[0])[0]\n        self.diffusion.calc_diffusion_vars(use_timesteps)\n        for i in progress_bar(indices):\n            t = torch.tensor([i] * cur_mot.size(0), device=self.device)\n            with torch.inference_mode():\n                cur_mot = self.denoising_step(\n                    cur_mot,\n                    pad_mask,\n                    text_feat,\n                    text_pad_mask,\n                    t,\n                    first_heading_angle,\n                    motion_mask,\n                    observed_motion,\n                    num_denoising_steps,\n                    cfg_weight,\n                    guide_masks=guide_masks,\n                    cfg_type=cfg_type,\n                )\n        return cur_mot\n"
  },
  {
    "path": "kimodo/model/llm2vec/README.md",
    "content": "This is a patched version of the original [LLM2Vec](https://github.com/McGill-NLP/llm2vec) codebase so that `McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised` works with `transformers==5.0.0rc3`.\n"
  },
  {
    "path": "kimodo/model/llm2vec/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"LLM2Vec text encoder and wrapper for Kimodo.\"\"\"\n\nfrom .llm2vec import LLM2Vec\nfrom .llm2vec_wrapper import LLM2VecEncoder\n\n__all__ = [\n    \"LLM2Vec\",\n    \"LLM2VecEncoder\",\n]\n"
  },
  {
    "path": "kimodo/model/llm2vec/llm2vec.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP\n# SPDX-License-Identifier: MIT\n#\n# Permission is hereby granted, free of charge, to any person obtaining a\n# copy of this software and associated documentation files (the \"Software\"),\n# to deal in the Software without restriction, including without limitation\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n# and/or sell copies of the Software, and to permit persons to whom the\n# Software is furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n# DEALINGS IN THE SOFTWARE.\n\n\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nfrom functools import partial\nfrom typing import Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nfrom peft import PeftModel\nfrom torch import Tensor, device, nn\nfrom tqdm.autonotebook import tqdm, trange\nfrom transformers import (\n    AutoConfig,\n    AutoModel,\n    AutoTokenizer,\n    GemmaConfig,\n    LlamaConfig,\n    MistralConfig,\n    PretrainedConfig,\n    Qwen2Config,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef batch_to_device(batch, target_device: device):\n    \"\"\"Send a pytorch batch to a device (CPU/GPU)\"\"\"\n    for key in batch:\n        if isinstance(batch[key], Tensor):\n            batch[key] = batch[key].to(target_device)\n    return batch\n\n\nclass LLM2Vec(nn.Module):\n    def __init__(\n        self,\n        model: AutoModel,\n        tokenizer: AutoTokenizer,\n        pooling_mode: str = \"mean\",\n        max_length: int = 512,\n        doc_max_length: int = 400,\n        skip_instruction: bool = True,\n    ):\n        super().__init__()\n        self.model = model\n        self.tokenizer = tokenizer\n        self.pooling_mode = pooling_mode\n        self.skip_instruction = skip_instruction\n        self.max_length = max_length\n        self.doc_max_length = doc_max_length\n        self.config = model.config\n\n    @classmethod\n    def _get_model_class(cls, config_class_name, enable_bidirectional):\n        if not enable_bidirectional:\n            return AutoModel\n        if config_class_name == \"MistralConfig\":\n            from .models.bidirectional_mistral import MistralBiModel\n\n            return MistralBiModel\n        elif config_class_name == \"LlamaConfig\":\n            from .models.bidirectional_llama import LlamaBiModel\n\n            return LlamaBiModel\n        elif config_class_name == \"GemmaConfig\":\n            from .models.bidirectional_gemma import GemmaBiModel\n\n            return GemmaBiModel\n        elif config_class_name == \"Qwen2Config\":\n            from .models.bidirectional_qwen2 import Qwen2BiModel\n\n            return Qwen2BiModel\n        else:\n            raise ValueError(f\"{config_class_name} is not supported yet with bidirectional models.\")\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        base_model_name_or_path,\n        peft_model_name_or_path=None,\n        merge_peft=False,\n        enable_bidirectional=True,\n        **kwargs,\n    ):\n        # pop out encoder args\n        keys = [\"pooling_mode\", \"max_length\", \"doc_max_length\", \"skip_instruction\"]\n        encoder_args = {key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None}\n\n        tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)\n        tokenizer.pad_token = tokenizer.eos_token\n        tokenizer.padding_side = \"left\"\n\n        config = AutoConfig.from_pretrained(base_model_name_or_path)\n        config_class_name = config.__class__.__name__\n\n        model_class = cls._get_model_class(config_class_name, enable_bidirectional=enable_bidirectional)\n\n        model = model_class.from_pretrained(base_model_name_or_path, **kwargs)\n\n        if os.path.isdir(base_model_name_or_path) and os.path.exists(f\"{base_model_name_or_path}/config.json\"):\n            with open(f\"{base_model_name_or_path}/config.json\", \"r\") as fIn:\n                config_dict = json.load(fIn)\n            config = PretrainedConfig.from_dict(config_dict)\n            model.config._name_or_path = config._name_or_path\n\n        # For special case where config.json and adapter weights are in the same directory\n        if hasattr(model, \"peft_config\"):\n            model = PeftModel.from_pretrained(\n                model,\n                base_model_name_or_path,\n            )\n            model = model.merge_and_unload()\n\n        if peft_model_name_or_path is not None:\n            model = PeftModel.from_pretrained(\n                model,\n                peft_model_name_or_path,\n            )\n            if merge_peft:\n                model = model.merge_and_unload()\n\n        config = {}\n        config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path\n        if os.path.exists(f\"{config_addr}/llm2vec_config.json\"):\n            with open(f\"{config_addr}/llm2vec_config.json\", \"r\") as fIn:\n                llm2vec_config = json.load(fIn)\n            config.update(llm2vec_config)\n\n        for key, value in encoder_args.items():\n            config[key] = value\n\n        return cls(model=model, tokenizer=tokenizer, **config)\n\n    def prepare_for_tokenization(self, text):\n        if self.model.config._name_or_path == \"meta-llama/Meta-Llama-3-8B-Instruct\":\n            text = \"<|start_header_id|>user<|end_header_id|>\\n\\n\" + text.strip() + \"<|eot_id|>\"\n            return text\n        if self.model.config._name_or_path in [\n            \"mistralai/Mistral-7B-Instruct-v0.2\",\n            \"meta-llama/Llama-2-7b-chat-hf\",\n        ]:\n            text = \"[INST] \" + text.strip() + \" [/INST]\"\n        if self.model.config._name_or_path in [\n            \"google/gemma-2-9b-it\",\n        ]:\n            text = \"<bos><start_of_turn>user\\n\" + text.strip() + \"<end_of_turn>\"\n        if self.model.config._name_or_path in [\n            \"Qwen/Qwen2-1.5B-Instruct\",\n            \"Qwen/Qwen2-7B-Instruct\",\n        ]:\n            text = \"<|im_start|>user\\n\" + text.strip() + \"<|im_end|>\"\n        if self.pooling_mode == \"eos_token\":\n            if self.model.config._name_or_path == \"meta-llama/Meta-Llama-3-8B\":\n                text = text.strip() + \"<|end_of_text|>\"\n            elif isinstance(self.model.config, LlamaConfig) or isinstance(self.model.config, MistralConfig):\n                text = text.strip() + \" </s>\"\n            elif isinstance(self.model.config, GemmaConfig):\n                text = text.strip() + \"<eos>\"\n            elif isinstance(self.model.config, Qwen2Config):\n                text = text.strip() + \"<|endoftext|>\"\n        return text\n\n    def tokenize(self, texts):\n        texts_2 = []\n        original_texts = []\n        for text in texts:\n            t = text.split(\"!@#$%^&*()\")\n            texts_2.append(t[1] if len(t) > 1 else \"\")\n            original_texts.append(\"\".join(t))\n\n        original = self.tokenizer(\n            original_texts,\n            return_tensors=\"pt\",\n            padding=True,\n            truncation=True,\n            max_length=self.max_length,\n        )\n        embed_mask = None\n        for t_i, t in enumerate(texts_2):\n            ids = self.tokenizer(\n                [t],\n                return_tensors=\"pt\",\n                padding=True,\n                truncation=True,\n                max_length=self.max_length,\n                add_special_tokens=False,\n            )\n            if embed_mask is None:\n                e_m = torch.zeros_like(original[\"attention_mask\"][t_i])\n                if len(ids[\"input_ids\"][0]) > 0:\n                    e_m[-len(ids[\"input_ids\"][0]) :] = torch.ones(len(ids[\"input_ids\"][0]))\n                embed_mask = e_m.unsqueeze(0)\n            else:\n                e_m = torch.zeros_like(original[\"attention_mask\"][t_i])\n                if len(ids[\"input_ids\"][0]) > 0:\n                    e_m[-len(ids[\"input_ids\"][0]) :] = torch.ones(len(ids[\"input_ids\"][0]))\n                embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)\n\n        original[\"embed_mask\"] = embed_mask\n        return original\n\n    def _skip_instruction(self, sentence_feature):\n        assert sentence_feature[\"attention_mask\"].shape == sentence_feature[\"embed_mask\"].shape\n        sentence_feature[\"attention_mask\"] = sentence_feature[\"embed_mask\"]\n\n    def forward(self, sentence_feature: Dict[str, Tensor]):\n        embed_mask = None\n        if \"embed_mask\" in sentence_feature:\n            embed_mask = sentence_feature.pop(\"embed_mask\")\n        reps = self.model(**sentence_feature)\n        sentence_feature[\"embed_mask\"] = embed_mask\n\n        return self.get_pooling(sentence_feature, reps.last_hidden_state)\n\n    def get_pooling(self, features, last_hidden_states):  # All models padded from left\n        assert self.tokenizer.padding_side == \"left\", \"Pooling modes are implemented for padding from left.\"\n        if self.skip_instruction:\n            self._skip_instruction(features)\n        seq_lengths = features[\"attention_mask\"].sum(dim=-1)\n        if self.pooling_mode == \"mean\":\n            return torch.stack(\n                [last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths)],\n                dim=0,\n            )\n        elif self.pooling_mode == \"weighted_mean\":\n            bs, l, _ = last_hidden_states.shape\n            complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)\n            for i, seq_l in enumerate(seq_lengths):\n                if seq_l > 0:\n                    complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1\n                    complete_weights[i] /= torch.clamp(complete_weights[i].sum(), min=1e-9)\n            return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)\n        elif self.pooling_mode == \"eos_token\" or self.pooling_mode == \"last_token\":\n            return last_hidden_states[:, -1]\n        elif self.pooling_mode == \"bos_token\":\n            return last_hidden_states[features[\"input_ids\"] == self.tokenizer.bos_token_id]\n        else:\n            raise ValueError(f\"{self.pooling_mode} is not implemented yet.\")\n\n    def _convert_to_str(self, instruction, text):\n        tokenized_q = self.tokenizer(\n            text,\n            return_tensors=\"pt\",\n            padding=True,\n            truncation=True,\n            max_length=self.max_length,\n            add_special_tokens=False,\n        )\n        tokenized_q_length = len(tokenized_q[\"input_ids\"][0])\n\n        while tokenized_q_length > self.doc_max_length:\n            reduction_ratio = self.doc_max_length / tokenized_q_length\n            reduced_length = int(len(text.split()) * reduction_ratio)\n            text = \" \".join(text.split()[:reduced_length])\n            tokenized_q = self.tokenizer(\n                text,\n                return_tensors=\"pt\",\n                padding=True,\n                truncation=True,\n                max_length=self.max_length,\n                add_special_tokens=False,\n            )\n            tokenized_q_length = len(tokenized_q[\"input_ids\"][0])\n\n        return f\"{instruction.strip()} !@#$%^&*(){text}\" if instruction else f\"!@#$%^&*(){text}\"\n\n    def encode(\n        self,\n        sentences: Union[str, List[str]],\n        batch_size: int = 32,\n        show_progress_bar: bool = True,\n        convert_to_numpy: bool = False,\n        convert_to_tensor: bool = False,\n        device: Optional[str] = None,\n    ):\n        \"\"\"\n        Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.\n        Args:\n            sentences: sentence or sentences to encode.\n            batch_size: batch size for turning sentence tokens into embeddings.\n            show_progress_bar: whether to show progress bars during encoding steps.\n            convert_to_numpy: If true, return numpy arrays instead of torch tensors.\n            convert_to_tensor: If true, return torch tensors (default).\n            device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,\n            the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports\n            multiprocessing as currently implemented.\n\n        Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).\n\n        \"\"\"\n        if isinstance(sentences[0], str) and isinstance(sentences[-1], int):\n            sentences = [sentences]\n        # required for MEDI version of MTEB\n        if isinstance(sentences[0], str):\n            sentences = [[\"\"] + [sentence] for sentence in sentences]\n\n        if device is None:\n            device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n        concatenated_input_texts = []\n        for sentence in sentences:\n            assert isinstance(sentence[0], str)\n            assert isinstance(sentence[1], str)\n            concatenated_input_texts.append(self._convert_to_str(sentence[0], sentence[1]))\n        sentences = concatenated_input_texts\n\n        self.eval()\n\n        if convert_to_tensor:\n            convert_to_numpy = False\n\n        length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])\n        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]\n        all_embeddings = []\n\n        if torch.cuda.device_count() <= 1:\n            # This branch also support mps devices\n            self.to(device)\n            for start_index in trange(\n                0,\n                len(sentences),\n                batch_size,\n                desc=\"Batches\",\n                disable=not show_progress_bar,\n            ):\n                sentences_batch = sentences_sorted[start_index : start_index + batch_size]\n                embeddings = self._encode(sentences_batch, device=device, convert_to_numpy=convert_to_numpy)\n                all_embeddings.append(embeddings)\n        else:\n            num_proc = torch.cuda.device_count()\n            cuda_compatible_multiprocess = mp.get_context(\"spawn\")\n            with cuda_compatible_multiprocess.Pool(num_proc) as p:\n                sentences_batches = [\n                    sentences_sorted[start_index : start_index + batch_size]\n                    for start_index in range(0, len(sentences), batch_size)\n                ]\n\n                progress_bar = tqdm(\n                    total=len(sentences_batches),\n                    desc=\"Batches\",\n                    disable=not show_progress_bar,\n                )\n                results = []\n\n                def update(*args):\n                    progress_bar.update()\n\n                for batch in sentences_batches:\n                    results.append(\n                        p.apply_async(\n                            self._encode,\n                            args=(batch, None, convert_to_numpy, True),\n                            callback=update,\n                        )\n                    )\n\n                all_embeddings = [result.get() for result in results]\n                progress_bar.close()\n\n        all_embeddings = torch.cat(all_embeddings, dim=0)\n        all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]\n        all_embeddings = all_embeddings.to(torch.float32)\n        if convert_to_numpy:\n            all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])\n        return all_embeddings\n\n    def save(self, output_path, merge_before_save=False, save_config=True):\n        if merge_before_save and isinstance(self.model, PeftModel):\n            self.model = self.model.merge_and_unload()\n            # Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1\n            if hasattr(self.model, \"_hf_peft_config_loaded\"):\n                self.model._hf_peft_config_loaded = False\n\n        self.model.save_pretrained(output_path)\n        self.tokenizer.save_pretrained(output_path)\n\n        llm2vec_config = {\n            \"pooling_mode\": self.pooling_mode,\n            \"max_length\": self.max_length,\n            \"doc_max_length\": self.doc_max_length,\n            \"skip_instruction\": self.skip_instruction,\n        }\n\n        if save_config:\n            os.makedirs(output_path, exist_ok=True)\n            with open(f\"{output_path}/llm2vec_config.json\", \"w\") as fOut:\n                json.dump(llm2vec_config, fOut, indent=4)\n\n    def _encode(\n        self,\n        sentences_batch,\n        device: Optional[str] = None,\n        convert_to_numpy: bool = False,\n        multiprocessing=False,\n    ):\n        if multiprocessing:\n            # multiprocessing only supports CUDA devices at this time, so we ignore the value of device\n            # and use cuda:rank for the device\n            rank = mp.current_process()._identity[0]\n            if device is None and torch.cuda.is_available():\n                device = f\"cuda:{rank % torch.cuda.device_count()}\"\n\n        self.to(device)\n        features = self.tokenize([self.prepare_for_tokenization(sentence) for sentence in sentences_batch])\n        features = batch_to_device(features, device)\n\n        with torch.no_grad():\n            embeddings = self.forward(features)\n            embeddings = embeddings.detach()\n            embeddings = embeddings.cpu()\n\n        return embeddings\n\n    def _text_length(self, text: Union[List[int], List[List[int]]]):\n        \"\"\"Help function to get the length for the input text.\n\n        Text can be either a string (which means a single text) a list of ints (which means a single\n        tokenized text), or a tuple of list of ints (representing several text inputs to the model).\n        \"\"\"\n        if (\n            isinstance(text, str) or (isinstance(text, list) and isinstance(text[0], int)) or len(text) == 0\n        ):  # Single text, list of ints, or empty\n            return len(text)\n        if isinstance(text, dict):  # {key: value} case\n            return len(next(iter(text.values())))\n        elif not hasattr(text, \"__len__\"):  # Object has no len() method\n            return 1\n        else:\n            return sum([len(t) for t in text])\n\n    def resize_token_embeddings(\n        self,\n        new_num_tokens: Optional[int] = None,\n        pad_to_multiple_of: Optional[int] = None,\n    ) -> nn.Embedding:\n        return self.model.resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)\n\n    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):\n        self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)\n"
  },
  {
    "path": "kimodo/model/llm2vec/llm2vec_wrapper.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"LLM2Vec encoder wrapper for Kimodo text conditioning.\"\"\"\n\nimport os\n\nimport numpy as np\nimport torch\n\nfrom .llm2vec import LLM2Vec\n\n\nclass LLM2VecEncoder:\n    \"\"\"LLM2Vec text embeddings.\"\"\"\n\n    def __init__(\n        self,\n        base_model_name_or_path: str,\n        peft_model_name_or_path: str,\n        dtype: str,\n        llm_dim: int,\n        device: str = \"auto\",\n    ) -> None:\n        torch_dtype = getattr(torch, dtype)\n        self.llm_dim = llm_dim\n\n        cache_dir = os.environ.get(\"HUGGINGFACE_CACHE_DIR\")\n\n        if \"TEXT_ENCODERS_DIR\" in os.environ:\n            base_model_name_or_path = os.path.join(os.environ[\"TEXT_ENCODERS_DIR\"], base_model_name_or_path)\n            peft_model_name_or_path = os.path.join(os.environ[\"TEXT_ENCODERS_DIR\"], peft_model_name_or_path)\n\n        self.model = LLM2Vec.from_pretrained(\n            base_model_name_or_path=base_model_name_or_path,\n            peft_model_name_or_path=peft_model_name_or_path,\n            torch_dtype=torch_dtype,\n            cache_dir=cache_dir,\n        )\n\n        env_device = os.environ.get(\"TEXT_ENCODER_DEVICE\")\n        if env_device:\n            device = env_device\n        if device == \"auto\":\n            device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        self._device = device\n        if device is not None:\n            self.model = self.model.to(device)\n\n        self.model.eval()\n        for p in self.model.parameters():\n            p.requires_grad = False\n\n    def to(self, device: torch.device):\n        self.model = self.model.to(device)\n        self._device = str(device) if not isinstance(device, str) else device\n        return self\n\n    def eval(self):\n        self.model.eval()\n        return self\n\n    def get_device(self):\n        return self.model.model.device\n\n    def __call__(self, text: list[str] | str):\n        is_string = False\n        if isinstance(text, str):\n            text = [text]\n            is_string = True\n\n        with torch.no_grad():\n            encoded_text = self.model.encode(\n                text,\n                # IMPORTANT: different batch sizes unexpectedly change the output embeddings, so we always set it to 1\n                #            here for repeatability no matter how many texts are being encoded. This\n                #            is a fundamental issue with transformers, and is especially bad at lower\n                #            precisions (https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)\n                #            note: this is an internal batch size used by llm2vec - the text list can still be of arbitrary length.\n                batch_size=1,\n                show_progress_bar=False,\n                device=self._device,\n            )\n\n        assert len(encoded_text.shape)\n        assert self.llm_dim == encoded_text.shape[-1]\n\n        encoded_text = encoded_text[:, None]\n        lengths = np.ones(len(encoded_text), dtype=int).tolist()\n\n        if is_string:\n            encoded_text = encoded_text[0]\n            lengths = lengths[0]\n\n        encoded_text = torch.tensor(encoded_text).to(self._device)\n        return encoded_text, lengths\n"
  },
  {
    "path": "kimodo/model/llm2vec/models/__init__.py",
    "content": "# from .bidirectional_gemma import GemmaBiForMNTP, GemmaBiModel\n# from .bidirectional_llama import LlamaBiForMNTP, LlamaBiModel\n# from .bidirectional_mistral import MistralBiForMNTP, MistralBiModel\n# from .bidirectional_qwen2 import Qwen2BiForMNTP, Qwen2BiModel\n"
  },
  {
    "path": "kimodo/model/llm2vec/models/attn_mask_utils.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP\n# SPDX-License-Identifier: MIT\n#\n# Permission is hereby granted, free of charge, to any person obtaining a\n# copy of this software and associated documentation files (the \"Software\"),\n# to deal in the Software without restriction, including without limitation\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n# and/or sell copies of the Software, and to permit persons to whom the\n# Software is furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n# DEALINGS IN THE SOFTWARE.\n\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\n\n\ndef _prepare_4d_causal_attention_mask(\n    attention_mask: Optional[torch.Tensor],\n    input_shape: Union[torch.Size, Tuple, List],\n    inputs_embeds: torch.Tensor,\n    past_key_values_length: int,\n    sliding_window: Optional[int] = None,\n):\n    \"\"\"Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D\n    mask of shape `(batch_size, key_value_length)`\n\n    Args:\n        attention_mask (`torch.Tensor` or `None`):\n            A 2D attention mask of shape `(batch_size, key_value_length)`\n        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):\n            The input shape should be a tuple that defines `(batch_size, query_length)`.\n        inputs_embeds (`torch.Tensor`):\n            The embedded inputs as a torch Tensor.\n        past_key_values_length (`int`):\n            The length of the key value cache.\n        sliding_window (`int`, *optional*):\n            If the model uses windowed attention, a sliding window should be passed.\n    \"\"\"\n    attn_mask_converter = AttentionMaskConverter(\n        is_causal=False, sliding_window=sliding_window\n    )  # is_causal=True in original implementation\n\n    key_value_length = input_shape[-1] + past_key_values_length\n\n    # 4d mask is passed through the layers\n    if attention_mask is not None and len(attention_mask.shape) == 2:\n        attention_mask = attn_mask_converter.to_4d(\n            attention_mask,\n            input_shape[-1],\n            key_value_length=key_value_length,\n            dtype=inputs_embeds.dtype,\n        )\n    elif attention_mask is not None and len(attention_mask.shape) == 4:\n        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)\n        if tuple(attention_mask.shape) != expected_shape:\n            raise ValueError(\n                f\"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}.\"\n            )\n        else:\n            # if the 4D mask has correct shape - invert it and fill with negative infinity\n            inverted_mask = 1.0 - attention_mask\n            attention_mask = inverted_mask.masked_fill(\n                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min\n            )\n    else:\n        attention_mask = attn_mask_converter.to_causal_4d(\n            input_shape[0],\n            input_shape[-1],\n            key_value_length,\n            dtype=inputs_embeds.dtype,\n            device=inputs_embeds.device,\n        )\n\n    return attention_mask\n\n\n# Adapted from _prepare_4d_causal_attention_mask\ndef _prepare_4d_causal_attention_mask_for_sdpa(\n    attention_mask: Optional[torch.Tensor],\n    input_shape: Union[torch.Size, Tuple, List],\n    inputs_embeds: torch.Tensor,\n    past_key_values_length: int,\n    sliding_window: Optional[int] = None,\n):\n    \"\"\"Prepares the correct `attn_mask` argument to be used by\n    `torch.nn.functional.scaled_dot_product_attention`.\n\n    In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and\n    `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,\n    allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).\n    \"\"\"\n    attn_mask_converter = AttentionMaskConverter(\n        is_causal=False, sliding_window=sliding_window\n    )  # is_causal=True in original implementation\n\n    key_value_length = input_shape[-1] + past_key_values_length\n    batch_size, query_length = input_shape\n\n    # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`\n    # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.\n    # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).\n    is_tracing = (\n        torch.jit.is_tracing()\n        or isinstance(inputs_embeds, torch.fx.Proxy)\n        or (hasattr(torch, \"_dynamo\") and torch._dynamo.is_compiling())\n    )\n\n    if attention_mask is not None:\n        # 4d mask is passed through\n        if len(attention_mask.shape) == 4:\n            expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)\n            if tuple(attention_mask.shape) != expected_shape:\n                raise ValueError(\n                    f\"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}.\"\n                )\n            else:\n                # if the 4D mask has correct shape - invert it and fill with negative infinity\n                inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)\n                attention_mask = inverted_mask.masked_fill(\n                    inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min\n                )\n                return attention_mask\n\n        elif not is_tracing and torch.all(attention_mask == 1):\n            if query_length == 1:\n                # For query_length == 1, causal attention and bi-directional attention are the same.\n                attention_mask = None\n            elif key_value_length == query_length:\n                attention_mask = None\n            else:\n                # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation\n                # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.\n                # Reference: https://github.com/pytorch/pytorch/issues/108108\n                pass\n    elif query_length > 1 and key_value_length != query_length:\n        # See the comment above (https://github.com/pytorch/pytorch/issues/108108).\n        # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.\n        attention_mask = True\n    elif is_tracing:\n        raise ValueError(\n            'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation=\"eager\"` or pass an attention_mask input when tracing the model.'\n        )\n\n    if attention_mask is None:\n        expanded_4d_mask = None\n    elif attention_mask is True:\n        expanded_4d_mask = attn_mask_converter.to_causal_4d(\n            input_shape[0],\n            input_shape[-1],\n            key_value_length,\n            dtype=inputs_embeds.dtype,\n            device=inputs_embeds.device,\n        )\n    else:\n        expanded_4d_mask = attn_mask_converter.to_4d(\n            attention_mask,\n            input_shape[-1],\n            dtype=inputs_embeds.dtype,\n            key_value_length=key_value_length,\n        )\n\n        # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when\n        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n        # Details: https://github.com/pytorch/pytorch/issues/110213\n        if not is_tracing and expanded_4d_mask.device.type == \"cuda\":\n            expanded_4d_mask = AttentionMaskConverter._unmask_unattended(\n                expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min\n            )\n\n    return expanded_4d_mask\n"
  },
  {
    "path": "kimodo/model/llm2vec/models/bidirectional_llama.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP\n# SPDX-License-Identifier: MIT\n#\n# Permission is hereby granted, free of charge, to any person obtaining a\n# copy of this software and associated documentation files (the \"Software\"),\n# to deal in the Software without restriction, including without limitation\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n# and/or sell copies of the Software, and to permit persons to whom the\n# Software is furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n# DEALINGS IN THE SOFTWARE.\n\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom peft import PeftModel\nfrom torch import nn\nfrom transformers import LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaPreTrainedModel\nfrom transformers.cache_utils import Cache, StaticCache\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.models.llama.modeling_llama import (\n    LlamaAttention,\n    LlamaDecoderLayer,\n    # LlamaFlashAttention2,\n    LlamaMLP,\n    LlamaRMSNorm,\n    LlamaRotaryEmbedding,\n    # LlamaSdpaAttention,\n)\nfrom transformers.utils import logging\n\nfrom .utils import is_transformers_attn_greater_or_equal_4_43_1\n\nlogger = logging.get_logger(__name__)\n\n\nclass ModifiedLlamaAttention(LlamaAttention):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.is_causal = False\n\n\n# class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):\n#     def __init__(self, *args, **kwargs):\n#         super().__init__(*args, **kwargs)\n#         self.is_causal = False\n\n\n# class ModifiedLlamaSdpaAttention(LlamaSdpaAttention):\n#     def __init__(self, *args, **kwargs):\n#         super().__init__(*args, **kwargs)\n#         self.is_causal = False\n\n\n# LLAMA_ATTENTION_CLASSES = {\n#     \"eager\": ModifiedLlamaAttention,\n#     \"flash_attention_2\": ModifiedLlamaFlashAttention2,\n#     \"sdpa\": ModifiedLlamaSdpaAttention,\n# }\n\n\nclass ModifiedLlamaDecoderLayer(LlamaDecoderLayer):\n    def __init__(self, config: LlamaConfig, layer_idx: int):\n        nn.Module.__init__(self)\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)\n        # self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](\n        # config=config, layer_idx=layer_idx\n        # )\n\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n\nclass LlamaBiModel(LlamaModel):\n    _no_split_modules = [\"ModifiedLlamaDecoderLayer\"]\n\n    def __init__(self, config: LlamaConfig):\n        if not is_transformers_attn_greater_or_equal_4_43_1():\n            raise ValueError(\n                \"The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.43.1\"\n            )\n        LlamaPreTrainedModel.__init__(self, config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = LlamaRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def _update_causal_mask(\n        self,\n        attention_mask,\n        input_tensor,\n        cache_position,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        # if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n        #     if AttentionMaskConverter._ignore_causal_mask_sdpa(\n        #         attention_mask,\n        #         inputs_embeds=input_tensor,\n        #         past_key_values_length=past_seen_tokens,\n        #         is_training=self.training,\n        #     ):\n        #         return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        causal_mask = torch.zeros(\n            (sequence_length, target_length), dtype=dtype, device=device\n        )  # in original implementation - torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)\n        # Commenting out next 2 lines to disable causal masking\n        # if sequence_length != 1:\n        #     causal_mask = torch.triu(causal_mask, diagonal=1)\n        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n        if attention_mask is not None:\n            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n            if attention_mask.dim() == 2:\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)\n                causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)\n            elif attention_mask.dim() == 4:\n                # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with\n                # cache. In that case, the 4D attention mask attends to the newest tokens only.\n                if attention_mask.shape[-2] < cache_position[0] + sequence_length:\n                    offset = cache_position[0]\n                else:\n                    offset = 0\n                mask_shape = attention_mask.shape\n                mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype\n                causal_mask[\n                    : mask_shape[0],\n                    : mask_shape[1],\n                    offset : mask_shape[2] + offset,\n                    : mask_shape[3],\n                ] = mask_slice\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass LlamaBiForMNTP(LlamaForCausalLM):\n    def __init__(self, config):\n        LlamaPreTrainedModel.__init__(self, config)\n        self.model = LlamaBiModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    # getter for PEFT model\n    def get_model_for_peft(self):\n        return self.model\n\n    # setter for PEFT model\n    def set_model_for_peft(self, model: PeftModel):\n        self.model = model\n\n    # save the PEFT model\n    def save_peft_model(self, path):\n        self.model.save_pretrained(path)\n"
  },
  {
    "path": "kimodo/model/llm2vec/models/utils.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2024 McGill NLP\n# SPDX-License-Identifier: MIT\n#\n# Permission is hereby granted, free of charge, to any person obtaining a\n# copy of this software and associated documentation files (the \"Software\"),\n# to deal in the Software without restriction, including without limitation\n# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n# and/or sell copies of the Software, and to permit persons to whom the\n# Software is furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n# DEALINGS IN THE SOFTWARE.\n\nimport importlib.metadata\n\nfrom packaging import version\nfrom transformers.utils.import_utils import _is_package_available\n\n\ndef is_transformers_attn_greater_or_equal_4_43_1():\n    if not _is_package_available(\"transformers\"):\n        return False\n\n    return version.parse(importlib.metadata.version(\"transformers\")) >= version.parse(\"4.43.1\")\n"
  },
  {
    "path": "kimodo/model/load_model.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Load Kimodo diffusion models from local checkpoints or Hugging Face.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom huggingface_hub import snapshot_download\nfrom omegaconf import OmegaConf\n\nfrom .loading import (\n    AVAILABLE_MODELS,\n    DEFAULT_MODEL,\n    DEFAULT_TEXT_ENCODER_URL,\n    MODEL_NAMES,\n    TMR_MODELS,\n    get_env_var,\n    instantiate_from_dict,\n)\nfrom .registry import get_model_info, resolve_model_name\n\nDEFAULT_TEXT_ENCODER = \"llm2vec\"\nTEXT_ENCODER_PRESETS = {\n    \"llm2vec\": {\n        \"target\": \"kimodo.model.LLM2VecEncoder\",\n        \"kwargs\": {\n            \"base_model_name_or_path\": \"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp\",\n            \"peft_model_name_or_path\": \"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised\",\n            \"dtype\": \"bfloat16\",\n            \"llm_dim\": 4096,\n            \"device\": \"auto\",\n        },\n    }\n}\n\n\ndef _resolve_hf_model_path(modelname: str) -> Path:\n    \"\"\"Resolve model name to a local path, using Hugging Face cache or CHECKPOINT_DIR.\"\"\"\n    try:\n        repo_id = MODEL_NAMES[modelname]\n    except KeyError:\n        raise ValueError(f\"Model '{modelname}' not found. Available models: {MODEL_NAMES.keys()}\")\n\n    local_cache = get_env_var(\"LOCAL_CACHE\", \"False\").lower() == \"true\"\n    if not local_cache:\n        snapshot_dir = snapshot_download(repo_id=repo_id)  # will check online no matter what\n        return Path(snapshot_dir)\n\n    try:\n        snapshot_dir = snapshot_download(repo_id=repo_id, local_files_only=True)  # will check local cache only\n        return Path(snapshot_dir)\n    except Exception:\n        # if local cache is not found, download from online\n        try:\n            snapshot_dir = snapshot_download(repo_id=repo_id)\n            return Path(snapshot_dir)\n        except Exception:\n            raise RuntimeError(f\"Could not resolve model '{modelname}' from Hugging Face (repo: {repo_id}). \") from None\n\n\ndef _build_api_text_encoder_conf(text_encoder_url: str) -> dict:\n    return {\n        \"_target_\": \"kimodo.model.text_encoder_api.TextEncoderAPI\",\n        \"url\": text_encoder_url,\n    }\n\n\ndef _build_local_text_encoder_conf(text_encoder_fp32: bool = False) -> dict:\n    text_encoder_name = get_env_var(\"TEXT_ENCODER\", DEFAULT_TEXT_ENCODER)\n    if text_encoder_name not in TEXT_ENCODER_PRESETS:\n        available = \", \".join(sorted(TEXT_ENCODER_PRESETS))\n        raise ValueError(f\"Unknown TEXT_ENCODER='{text_encoder_name}'. Available: {available}\")\n\n    preset = TEXT_ENCODER_PRESETS[text_encoder_name]\n    if text_encoder_fp32:\n        preset[\"kwargs\"][\"dtype\"] = \"float32\"\n    return {\n        \"_target_\": preset[\"target\"],\n        **preset[\"kwargs\"],\n    }\n\n\ndef _select_text_encoder_conf(text_encoder_url: str, text_encoder_fp32: bool = False) -> dict:\n    # TEXT_ENCODER_MODE options:\n    # - \"api\": force TextEncoderAPI\n    # - \"local\": force local LLM2VecEncoder\n    # - \"auto\": try API first, fallback to local if unreachable\n    mode = get_env_var(\"TEXT_ENCODER_MODE\", \"auto\").lower()\n    if mode == \"local\":\n        return _build_local_text_encoder_conf(text_encoder_fp32)\n    if mode == \"api\":\n        return _build_api_text_encoder_conf(text_encoder_url)\n\n    api_conf = _build_api_text_encoder_conf(text_encoder_url)\n    try:\n        text_encoder = instantiate_from_dict(api_conf)\n        # Probe availability early so inference doesn't fail later.\n        text_encoder([\"healthcheck\"])\n        return api_conf\n    except Exception as error:\n        print(\n            \"Text encoder service is unreachable, falling back to local LLM2Vec \"\n            f\"encoder. ({type(error).__name__}: {error})\"\n        )\n        return _build_local_text_encoder_conf(text_encoder_fp32)\n\n\ndef load_model(\n    modelname=None,\n    device=None,\n    eval_mode: bool = True,\n    default_family: Optional[str] = \"Kimodo\",\n    return_resolved_name: bool = False,\n    text_encoder=None,\n    text_encoder_fp32: bool = False,\n):\n    \"\"\"Load a kimodo model by name (e.g. 'g1', 'soma').\n\n    Resolution of partial/full names (e.g. Kimodo-SOMA-RP-v1, SOMA) is done\n    inside this function using default_family when the name is not a known\n    short key.\n\n    Args:\n        modelname: Model identifier; uses DEFAULT_MODEL if None. Can be a short key,\n            a full name (e.g. Kimodo-SOMA-RP-v1), or a partial name; unknown names\n            are resolved via resolve_model_name using default_family.\n        device: Target device for the model (e.g. 'cuda', 'cpu').\n        eval_mode: If True, set model to eval mode.\n        default_family: Used when modelname is not in AVAILABLE_MODELS to resolve\n            partial names (\"Kimodo\" for demo/generation, \"TMR\" for embed script).\n            Default \"Kimodo\".\n        return_resolved_name: If True, return (model, resolved_short_key). If False,\n            return only the model.\n        text_encoder: Pre-built text encoder to reuse. When provided, skips\n            text encoder selection/instantiation entirely.\n        text_encoder_fp32: If True, uses fp32 for the text encoder rather than default bfloat16.\n\n    Returns:\n        Loaded model in eval mode, or (model, resolved short key) if\n        return_resolved_name is True.\n\n    Raises:\n        ValueError: If modelname is not in AVAILABLE_MODELS and cannot be resolved.\n        FileNotFoundError: If config.yaml is missing in the checkpoint folder.\n    \"\"\"\n    if modelname is None:\n        modelname = DEFAULT_MODEL\n    if modelname not in AVAILABLE_MODELS:\n        if default_family is not None:\n            modelname = resolve_model_name(modelname, default_family)\n        else:\n            raise ValueError(\n                f\"\"\"The model is not recognized.\n            Please choose between: {AVAILABLE_MODELS}\"\"\"\n            )\n\n    resolved_modelname = modelname\n\n    # In case, we specify a custom checkpoint directory\n    configured_checkpoint_dir = get_env_var(\"CHECKPOINT_DIR\")\n    if configured_checkpoint_dir:\n        print(f\"CHECKPOINT_DIR is set to {configured_checkpoint_dir}, checking the local cache...\")\n        # Checkpoint folders are named by display name (e.g. Kimodo-SOMA-RP-v1)\n        info = get_model_info(modelname)\n        checkpoint_folder_name = info.display_name if info is not None else modelname\n        model_path = Path(configured_checkpoint_dir) / checkpoint_folder_name\n        if not model_path.exists() and modelname != checkpoint_folder_name:\n            # Fallback: try short_key for backward compatibility\n            model_path = Path(configured_checkpoint_dir) / modelname\n        if not model_path.exists():\n            print(f\"Model folder not found at '{model_path}', downloading it from Hugging Face...\")\n            model_path = _resolve_hf_model_path(modelname)\n    else:\n        # Otherwise, we load the model from the local cache or download it from Hugging Face.\n        model_path = _resolve_hf_model_path(modelname)\n\n    model_config_path = model_path / \"config.yaml\"\n    if not model_config_path.exists():\n        raise FileNotFoundError(f\"The model checkpoint folder exists but config.yaml is missing: {model_config_path}\")\n\n    model_conf = OmegaConf.load(model_config_path)\n\n    if modelname in TMR_MODELS:\n        # Same process at the moment for TMR and Kimodo\n        pass\n\n    if text_encoder is not None:\n        runtime_conf = OmegaConf.create({\"checkpoint_dir\": str(model_path)})\n    else:\n        text_encoder_url = get_env_var(\"TEXT_ENCODER_URL\", DEFAULT_TEXT_ENCODER_URL)\n        runtime_conf = OmegaConf.create(\n            {\n                \"checkpoint_dir\": str(model_path),\n                \"text_encoder\": _select_text_encoder_conf(text_encoder_url, text_encoder_fp32),\n            }\n        )\n\n    model_cfg = OmegaConf.to_container(OmegaConf.merge(model_conf, runtime_conf), resolve=True)\n    model_cfg.pop(\"checkpoint_dir\", None)\n\n    if text_encoder is not None:\n        # Prevent Hydra from instantiating a new text encoder; pass None so\n        # Kimodo.__init__ receives a placeholder we replace immediately after.\n        model_cfg[\"text_encoder\"] = None\n\n    model = instantiate_from_dict(model_cfg, overrides={\"device\": device})\n\n    if text_encoder is not None:\n        model.text_encoder = text_encoder\n\n    if eval_mode:\n        model = model.eval()\n    if return_resolved_name:\n        return model, resolved_modelname\n    return model\n"
  },
  {
    "path": "kimodo/model/loading.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Model loading utilities: checkpoints, registry, env, and Hydra-based instantiation.\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom hydra.utils import instantiate\nfrom omegaconf import OmegaConf\nfrom safetensors.torch import load_file as load_safetensors\n\nfrom .registry import (\n    AVAILABLE_MODELS,\n    DEFAULT_MODEL,\n    DEFAULT_TEXT_ENCODER_URL,\n    KIMODO_MODELS,\n    MODEL_NAMES,\n    TMR_MODELS,\n)\n\n\ndef get_env_var(name: str, default: Optional[str] = None) -> Optional[str]:\n    \"\"\"Return environment variable value, or default if unset/empty.\"\"\"\n    return os.environ.get(name) or default\n\n\ndef instantiate_from_dict(\n    cfg: Dict[str, Any],\n    overrides: Optional[Dict[str, Any]] = None,\n):\n    \"\"\"Instantiate an object from a config dict (e.g. from OmegaConf.to_container).\n\n    The dict must contain _target_ with a fully qualified class path. Nested configs are\n    instantiated recursively.\n    \"\"\"\n    if overrides:\n        cfg = {**cfg, **overrides}\n    conf = OmegaConf.create(cfg)\n    return instantiate(conf)\n\n\ndef load_checkpoint_state_dict(ckpt_path: Union[str, Path]) -> dict:\n    \"\"\"Load a state dict from a checkpoint file.\n\n    If the checkpoint is a dict with a 'state_dict' key (e.g. PyTorch Lightning),\n    that is returned; otherwise the whole checkpoint is treated as the state dict.\n\n    Args:\n        ckpt_path: Path to the checkpoint file.\n\n    Returns:\n        state_dict suitable for model.load_state_dict().\n    \"\"\"\n    ckpt_path = str(ckpt_path)\n\n    if ckpt_path.endswith(\".safetensors\"):\n        state_dict = load_safetensors(ckpt_path)\n    else:\n        checkpoint = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n        if isinstance(checkpoint, dict) and \"state_dict\" in checkpoint:\n            state_dict = checkpoint[\"state_dict\"]\n        elif isinstance(checkpoint, dict):\n            state_dict = checkpoint\n        else:\n            raise ValueError(f\"Unsupported checkpoint format: {ckpt_path}\")\n    return {key: val.detach().cpu() for key, val in state_dict.items()}\n\n\n__all__ = [\n    \"get_env_var\",\n    \"instantiate_from_dict\",\n    \"KIMODO_MODELS\",\n    \"TMR_MODELS\",\n    \"AVAILABLE_MODELS\",\n    \"MODEL_NAMES\",\n    \"DEFAULT_MODEL\",\n    \"DEFAULT_TEXT_ENCODER_URL\",\n    \"load_checkpoint_state_dict\",\n]\n"
  },
  {
    "path": "kimodo/model/registry.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Registry of model names and Hugging Face repo IDs for Kimodo and TMR.\n\nCanonical source of truth is the list of repo IDs. Short keys (e.g. soma-rp) and metadata (dataset,\nskeleton, version, display name) are derived by parsing.\n\"\"\"\n\nimport re\nfrom dataclasses import dataclass\nfrom typing import Optional\n\n# Canonical list: repo IDs in the same syntax as Hugging Face (org/Model-Name-v1).\n# Parser expects: org/Family-SKELETON-DATASET-version (e.g. Kimodo-SOMA-RP-v1).\nKIMODO_REPO_IDS = [\n    \"nvidia/Kimodo-SOMA-RP-v1\",\n    \"nvidia/Kimodo-SOMA-RP-v1.1\",\n    \"nvidia/Kimodo-SMPLX-RP-v1\",\n    \"nvidia/Kimodo-G1-RP-v1\",\n    \"nvidia/Kimodo-SOMA-SEED-v1\",\n    \"nvidia/Kimodo-SOMA-SEED-v1.1\",\n    \"nvidia/Kimodo-G1-SEED-v1\",\n]\nTMR_REPO_IDS = [\n    \"nvidia/TMR-SOMA-RP-v1\",\n]\n\n# Repo ID without org, for display (e.g. Kimodo-SOMA-RP-v1).\n_REPO_NAME_PATTERN = re.compile(r\"^(Kimodo|TMR)-([A-Za-z0-9]+)-(RP|SEED)-v(\\d+(?:\\.\\d+)*)$\")\n\n\n@dataclass\nclass ModelInfo:\n    \"\"\"Structured metadata for one model, derived from its repo ID.\"\"\"\n\n    repo_id: str\n    short_key: str\n    family: str\n    skeleton: str\n    dataset: str\n    version: str\n    display_name: str\n\n    @property\n    def dataset_ui_label(self) -> str:\n        return \"Rigplay\" if self.dataset == \"RP\" else \"SEED\"\n\n\ndef _parse_repo_id(repo_id: str) -> Optional[ModelInfo]:\n    \"\"\"Parse a repo ID into ModelInfo.\n\n    Returns None if format is unrecognized.\n    \"\"\"\n    # repo_id is \"org/Model-Name-v1\"\n    if \"/\" in repo_id:\n        _, name = repo_id.split(\"/\", 1)\n    else:\n        name = repo_id\n    m = _REPO_NAME_PATTERN.match(name)\n    if not m:\n        return None\n    family, skeleton, dataset, ver = m.groups()\n    # Normalize skeleton for display (as is for now)\n    skeleton_display = skeleton\n    # Include family so Kimodo-SOMA-RP and TMR-SOMA-RP have distinct keys.\n    short_key = f\"{family.lower()}-{skeleton.lower()}-{dataset.lower()}\"\n    return ModelInfo(\n        repo_id=repo_id,\n        short_key=short_key,\n        family=family,\n        skeleton=skeleton_display,\n        dataset=dataset,\n        version=f\"v{ver}\",\n        display_name=name,\n    )\n\n\ndef _version_tuple(v: str) -> tuple[int, ...]:\n    \"\"\"Parse 'vN' or 'vN.M' into a comparable tuple of ints.\"\"\"\n    if v.startswith(\"v\"):\n        parts = v[1:].split(\".\")\n        if all(p.isdigit() for p in parts):\n            return tuple(int(p) for p in parts)\n    return (0,)\n\n\ndef _version_key(info: ModelInfo) -> tuple[int, ...]:\n    return _version_tuple(info.version)\n\n\ndef _build_registry() -> tuple[list[ModelInfo], dict[str, str], list[str]]:\n    \"\"\"Build model infos, short_key -> repo_id map, and list of short keys.\n\n    When multiple versions exist for the same (family, skeleton, dataset), each ModelInfo gets a\n    version-specific short_key (e.g. kimodo-soma-rp-v1, kimodo-soma-rp-v2) and a versionless alias\n    (kimodo-soma-rp) is added to MODEL_NAMES pointing to the latest version.  When only one version\n    exists, the short_key stays versionless (e.g. kimodo-smplx-rp).\n    \"\"\"\n    all_repos = KIMODO_REPO_IDS + TMR_REPO_IDS\n    infos: list[ModelInfo] = []\n    for repo_id in all_repos:\n        info = _parse_repo_id(repo_id)\n        if info is None:\n            raise ValueError(f\"Registry repo ID does not match expected pattern: {repo_id}\")\n        infos.append(info)\n\n    # Group by base short_key to detect multi-version families.\n    base_groups: dict[str, list[ModelInfo]] = {}\n    for info in infos:\n        base_groups.setdefault(info.short_key, []).append(info)\n\n    # For groups with multiple versions, make each short_key version-specific.\n    for base_key, group in base_groups.items():\n        if len(group) > 1:\n            for info in group:\n                info.short_key = f\"{base_key}-{info.version}\"\n\n    # Map each (now unique) short_key to its repo_id.\n    model_names: dict[str, str] = {}\n    for info in infos:\n        model_names[info.short_key] = info.repo_id\n\n    # Add versionless aliases for multi-version groups, pointing to the latest.\n    for base_key, group in base_groups.items():\n        if len(group) > 1:\n            latest = max(group, key=_version_key)\n            model_names[base_key] = latest.repo_id\n\n    return infos, model_names, list(model_names.keys())\n\n\nMODEL_INFOS, MODEL_NAMES, _SHORT_KEYS = _build_registry()\nAVAILABLE_MODELS = _SHORT_KEYS\n\n# Short-key lists for Kimodo vs TMR (load_model uses TMR_MODELS to branch).\nKIMODO_MODELS = [info.short_key for info in MODEL_INFOS if info.family == \"Kimodo\"]\nTMR_MODELS = [info.short_key for info in MODEL_INFOS if info.family == \"TMR\"]\n\n# Backward compatibility: FRIENDLY_NAMES for any code that still expects it.\n# Includes versioned short_keys and versionless aliases (latest display name).\nFRIENDLY_NAMES = {info.short_key: info.display_name for info in MODEL_INFOS}\nfor _key, _repo_id in MODEL_NAMES.items():\n    if _key not in FRIENDLY_NAMES:\n        for _info in MODEL_INFOS:\n            if _info.repo_id == _repo_id:\n                FRIENDLY_NAMES[_key] = _info.display_name\n                break\n\nDEFAULT_MODEL = \"kimodo-soma-rp\"\nDEFAULT_TEXT_ENCODER_URL = \"http://127.0.0.1:9550/\"\n\n# Friendly names for skeleton dropdown (key -> label).\nSKELETON_DISPLAY_NAMES = {\n    \"SOMA\": \"SOMA Human Body\",\n    \"SMPLX\": \"SMPLX Human Body\",\n    \"G1\": \"Unitree G1 Humanoid Robot\",\n}\n\n# Order for skeleton dropdown: SOMA, SMPLX, G1.\nSKELETON_ORDER = (\"SOMA\", \"SMPLX\", \"G1\")\n\n\ndef get_skeleton_display_name(skeleton_key: str) -> str:\n    \"\"\"Return the UI label for a skeleton key (e.g. SOMA -> SOMA Human Body).\"\"\"\n    return SKELETON_DISPLAY_NAMES.get(skeleton_key, skeleton_key)\n\n\ndef get_skeleton_key_from_display_name(display_name: str) -> Optional[str]:\n    \"\"\"Return the skeleton key for a UI label, or None.\"\"\"\n    for key, label in SKELETON_DISPLAY_NAMES.items():\n        if label == display_name:\n            return key\n    return None\n\n\ndef get_skeleton_display_names_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:\n    \"\"\"Return skeleton UI labels for the given dataset.\n\n    If family is set (e.g. \"Kimodo\"), only skeletons with a model of that family are included.\n    \"\"\"\n    keys = get_skeletons_for_dataset(dataset_ui_label, family=family)\n    return [get_skeleton_display_name(k) for k in keys]\n\n\ndef get_short_key(repo_id: str) -> Optional[str]:\n    \"\"\"Return the short key for a repo ID, or None if not in registry.\"\"\"\n    for info in MODEL_INFOS:\n        if info.repo_id == repo_id:\n            return info.short_key\n    return None\n\n\ndef get_model_info(short_key: str) -> Optional[ModelInfo]:\n    \"\"\"Return ModelInfo for a short key, or None if not found.\n\n    When multiple versions share the same short_key, returns the one used for loading (the latest\n    version), so CHECKPOINT_DIR and HF use the same version.\n    \"\"\"\n    repo_id = MODEL_NAMES.get(short_key)\n    if repo_id is None:\n        return None\n    for info in MODEL_INFOS:\n        if info.repo_id == repo_id:\n            return info\n    return None\n\n\ndef get_short_key_from_display_name(display_name: str) -> Optional[str]:\n    \"\"\"Return short_key for a display name (e.g. Kimodo-SOMA-RP-v1), or None.\"\"\"\n    for info in MODEL_INFOS:\n        if info.display_name == display_name:\n            return info.short_key\n    return None\n\n\ndef get_models_for_demo() -> list[ModelInfo]:\n    \"\"\"Return all model infos in registry order (for demo model list).\"\"\"\n    return list(MODEL_INFOS)\n\n\ndef get_datasets(family: Optional[str] = None) -> list[str]:\n    \"\"\"Return unique dataset UI labels (Rigplay, SEED) present in registry.\n\n    If family is set (e.g. \"Kimodo\"), only datasets that have a model of that family are included.\n    \"\"\"\n    infos = MODEL_INFOS\n    if family is not None:\n        infos = [i for i in infos if i.family == family]\n    labels = set()\n    for info in infos:\n        labels.add(info.dataset_ui_label)\n    return sorted(labels)\n\n\ndef get_skeletons_for_dataset(dataset_ui_label: str, family: Optional[str] = None) -> list[str]:\n    \"\"\"Return skeleton names that have a model for the given dataset.\n\n    Order: SOMA, SMPLX, G1 (only those present for the dataset).\n    If family is set (e.g. \"Kimodo\"), only skeletons with a model of that\n    family are included.\n    \"\"\"\n    dataset = \"RP\" if dataset_ui_label == \"Rigplay\" else \"SEED\"\n    infos = MODEL_INFOS\n    if family is not None:\n        infos = [i for i in infos if i.family == family]\n    skeletons = set()\n    for info in infos:\n        if info.dataset == dataset:\n            skeletons.add(info.skeleton)\n    return [s for s in SKELETON_ORDER if s in skeletons]\n\n\ndef get_versions_for_dataset_skeleton(dataset_ui_label: str, skeleton: str) -> list[str]:\n    \"\"\"Return version strings (e.g. v1) for the given dataset/skeleton.\n\n    Sorted by version number so the last element is the highest (e.g. v1, v2).\n    \"\"\"\n    dataset = \"RP\" if dataset_ui_label == \"Rigplay\" else \"SEED\"\n    versions = []\n    for info in MODEL_INFOS:\n        if info.dataset == dataset and info.skeleton == skeleton:\n            versions.append(info.version)\n\n    return sorted(set(versions), key=_version_tuple)\n\n\ndef get_models_for_dataset_skeleton(\n    dataset_ui_label: str, skeleton: str, family: Optional[str] = None\n) -> list[ModelInfo]:\n    \"\"\"Return model infos for the given dataset/skeleton, sorted by version (max first).\n\n    Used to build the Version dropdown (options = full display names, one per model). If family is\n    set (e.g. \"Kimodo\"), only models of that family are returned.\n    \"\"\"\n    dataset = \"RP\" if dataset_ui_label == \"Rigplay\" else \"SEED\"\n    infos = [info for info in MODEL_INFOS if info.dataset == dataset and info.skeleton == skeleton]\n    if family is not None:\n        infos = [i for i in infos if i.family == family]\n\n    return sorted(infos, key=_version_key, reverse=True)\n\n\ndef resolve_to_short_key(dataset_ui_label: str, skeleton: str, version: str) -> Optional[str]:\n    \"\"\"Return the short key for (dataset, skeleton, version), or None.\"\"\"\n    for info in MODEL_INFOS:\n        if info.dataset_ui_label == dataset_ui_label and info.skeleton == skeleton and info.version == version:\n            return info.short_key\n    return None\n\n\n# -----------------------------------------------------------------------------\n# Flexible model name resolution (partial names, case-insensitive, defaults)\n# -----------------------------------------------------------------------------\n\n_FAMILY_ALIASES = {\"kimodo\": \"Kimodo\", \"tmr\": \"TMR\"}\n_DATASET_ALIASES = {\"rp\": \"RP\", \"rigplay\": \"RP\", \"seed\": \"SEED\"}\n_SKELETON_ALIASES = {\n    \"soma\": \"SOMA\",\n    \"smplx\": \"SMPLX\",\n    \"g1\": \"G1\",\n}\n\n\ndef _normalize_family(s: str) -> Optional[str]:\n    \"\"\"Return canonical family (Kimodo/TMR) or None if unknown.\"\"\"\n    return _FAMILY_ALIASES.get(s.strip().lower())\n\n\ndef _normalize_dataset(s: str) -> Optional[str]:\n    \"\"\"Return canonical dataset (RP/SEED) or None if unknown.\"\"\"\n    return _DATASET_ALIASES.get(s.strip().lower())\n\n\ndef _normalize_skeleton(s: str) -> Optional[str]:\n    \"\"\"Return canonical skeleton (SOMA/SMPLX/G1) or None if unknown.\"\"\"\n    return _SKELETON_ALIASES.get(s.strip().lower())\n\n\ndef _get_latest_for_family_skeleton_dataset(family: str, skeleton: str, dataset: str) -> Optional[ModelInfo]:\n    \"\"\"Return the model info with the highest version for (family, skeleton, dataset).\"\"\"\n    candidates = [\n        info for info in MODEL_INFOS if info.family == family and info.skeleton == skeleton and info.dataset == dataset\n    ]\n    if not candidates:\n        return None\n    return max(candidates, key=_version_key)\n\n\ndef kimodo_short_key_for_skeleton_dataset(skeleton: str, dataset: str) -> Optional[str]:\n    \"\"\"Return the latest Kimodo model short_key for ``skeleton`` and ``dataset`` (RP/SEED), or\n    None.\"\"\"\n    info = _get_latest_for_family_skeleton_dataset(\"Kimodo\", skeleton, dataset)\n    return info.short_key if info is not None else None\n\n\ndef registry_skeleton_for_joint_count(nb_joints: int) -> str:\n    \"\"\"Map motion joint count to registry skeleton key (SOMA / SMPLX / G1).\"\"\"\n    if nb_joints == 34:\n        return \"G1\"\n    if nb_joints == 22:\n        return \"SMPLX\"\n    if nb_joints in (77, 30):\n        return \"SOMA\"\n    raise ValueError(f\"No Kimodo model registered for motion with J={nb_joints}\")\n\n\n# Optional version: Family-Skeleton-Dataset-vN or Family-Skeleton-Dataset\n_RESOLVE_FULL_PATTERN = re.compile(\n    r\"^(Kimodo|TMR|kimodo|tmr)[\\-_]\" r\"([A-Za-z0-9]+)[\\-_]\" r\"(RP|SEED|rp|seed)\" r\"(?:[\\-_]v(\\d+(?:\\.\\d+)*))?$\",\n    re.IGNORECASE,\n)\n# Partial: Skeleton-Dataset or Skeleton or Dataset (no family)\n_RESOLVE_PARTIAL_PATTERN = re.compile(\n    r\"^([A-Za-z0-9]+)(?:[\\-_](RP|SEED|rp|seed))?(?:[\\-_]v(\\d+(?:\\.\\d+)*))?$\",\n    re.IGNORECASE,\n)\n\n\ndef resolve_model_name(name: Optional[str], default_family: Optional[str] = None) -> str:\n    \"\"\"Resolve a user-facing model name to a short_key.\n\n    Accepts full names (e.g. Kimodo-SOMA-RP-v1), case-insensitive matching,\n    and partial names with defaults: dataset=RP, skeleton=SOMA, family from\n    default_family (Kimodo for demo/generation, TMR for embed script).\n    Omitted version resolves to the latest for that model.\n\n    Args:\n        name: User-provided name (can be None or empty).\n        default_family: \"Kimodo\" or \"TMR\" when name is empty or omits family.\n\n    Returns:\n        Short key (e.g. kimodo-soma-rp) for use with load_model / MODEL_NAMES.\n\n    Raises:\n        ValueError: If name cannot be resolved or default_family is missing when needed.\n    \"\"\"\n    if name is not None:\n        name = name.strip()\n    if not name:\n        if default_family is None:\n            raise ValueError('Model name is empty; provide a name or set default_family (\"Kimodo\" or \"TMR\").')\n        fam = _normalize_family(default_family)\n        if fam is None:\n            raise ValueError(f\"default_family must be 'Kimodo' or 'TMR', got {default_family!r}\")\n        info = _get_latest_for_family_skeleton_dataset(fam, \"SOMA\", \"RP\")\n        if info is None:\n            raise ValueError(f\"No model found for {fam}-SOMA-RP. Available: {list(MODEL_NAMES.keys())}\")\n        return info.short_key\n\n    # Exact short_key\n    if name in MODEL_NAMES:\n        return name\n\n    # Case-insensitive match against short_key or display_name\n    name_lower = name.lower()\n    matches = []\n    for info in MODEL_INFOS:\n        if name_lower == info.short_key.lower():\n            matches.append(info)\n        disp = info.display_name.lower()\n        if name_lower == disp or name_lower == (\"nvidia/\" + disp):\n            matches.append(info)\n    if len(matches) == 1:\n        return matches[0].short_key\n    if len(matches) > 1:\n        return matches[0].short_key\n\n    # Parsed full form: Family-Skeleton-Dataset or Family-Skeleton-Dataset-vN\n    m = _RESOLVE_FULL_PATTERN.match(name)\n    if m:\n        fam_raw, skel_raw, ds_raw, ver_num = m.groups()\n        fam = _normalize_family(fam_raw)\n        skel = _normalize_skeleton(skel_raw)\n        ds = _normalize_dataset(ds_raw)\n        if fam is not None and skel is not None and ds is not None:\n            if ver_num is not None:\n                version = f\"v{ver_num}\"\n                for info in MODEL_INFOS:\n                    if info.family == fam and info.skeleton == skel and info.dataset == ds and info.version == version:\n                        return info.short_key\n            else:\n                info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)\n                if info is not None:\n                    return info.short_key\n\n    # Parsed partial: Skeleton-Dataset, Skeleton, or Dataset (use default_family)\n    if default_family is not None:\n        m = _RESOLVE_PARTIAL_PATTERN.match(name)\n        if m:\n            tok1, ds_raw, ver_num = m.groups()\n            fam = _normalize_family(default_family)\n            if fam is not None:\n                skel = _normalize_skeleton(tok1)\n                ds_candidate = _normalize_dataset(ds_raw) if ds_raw else None\n                if skel is not None and ds_candidate is not None:\n                    ds = ds_candidate\n                elif skel is not None:\n                    ds = \"RP\"\n                else:\n                    skel = \"SOMA\"\n                    ds = _normalize_dataset(tok1) if tok1 else \"RP\"\n                    if ds is None:\n                        ds = \"RP\"\n                if ver_num is not None:\n                    version = f\"v{ver_num}\"\n                    for info in MODEL_INFOS:\n                        if (\n                            info.family == fam\n                            and info.skeleton == skel\n                            and info.dataset == ds\n                            and info.version == version\n                        ):\n                            return info.short_key\n                else:\n                    info = _get_latest_for_family_skeleton_dataset(fam, skel, ds)\n                    if info is not None:\n                        return info.short_key\n\n        # Single token: skeleton or dataset\n        fam = _normalize_family(default_family)\n        if fam is not None:\n            skel = _normalize_skeleton(name)\n            if skel is not None:\n                info = _get_latest_for_family_skeleton_dataset(fam, skel, \"RP\")\n                if info is not None:\n                    return info.short_key\n            ds = _normalize_dataset(name)\n            if ds is not None:\n                info = _get_latest_for_family_skeleton_dataset(fam, \"SOMA\", ds)\n                if info is not None:\n                    return info.short_key\n\n    raise ValueError(\n        f\"Model name {name!r} could not be resolved. \"\n        f\"Use a short key (e.g. {list(MODEL_NAMES.keys())[:3]}...), \"\n        \"a full name (e.g. Kimodo-SOMA-RP-v1), or a partial (e.g. SOMA-RP, SOMA) \"\n        \"with default_family set.\"\n    )\n"
  },
  {
    "path": "kimodo/model/text_encoder_api.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Remote text encoder API client (Gradio) for motion generation.\"\"\"\n\nimport logging\n\nimport numpy as np\nimport torch\nfrom gradio_client import Client\n\n# Suppress the [httpx] logs (GET requests)\nlogging.getLogger(\"httpx\").setLevel(logging.WARNING)\n\n# Suppress internal gradio_client logs\nlogging.getLogger(\"gradio_client\").setLevel(logging.WARNING)\n\n\nclass TextEncoderAPI:\n    \"\"\"Text encoder API client for motion generation.\"\"\"\n\n    def __init__(self, url: str):\n        self.client = Client(url, verbose=False)\n        self.device = \"cpu\"\n        self.dtype = torch.float\n\n    def _create_np_random_name(self):\n        import uuid\n\n        return str(uuid.uuid4()) + \".npy\"\n\n    def to(self, device=None, dtype=None):\n        if device is not None:\n            self.device = device\n        if dtype is not None:\n            self.dtype = dtype\n        return self\n\n    def __call__(self, texts):\n        \"\"\"Encode text prompts into tensors.\n\n        Args:\n            texts (str | list[str]): text prompts to encode\n\n        Returns:\n            tuple[torch.Tensor, list[int]]: encoded text tensors and their lengths\n        \"\"\"\n        if isinstance(texts, str):\n            texts = [texts]\n\n        tensors = []\n        lengths = []\n        for text in texts:\n            filename = self._create_np_random_name()\n\n            result = self.client.predict(\n                text=text,\n                filename=filename,\n                api_name=\"/DemoWrapper\",\n            )\n            path = result[0][\"value\"]\n            tensor = np.load(path)\n            length = tensor.shape[0]\n\n            tensors.append(tensor)\n            lengths.append(length)\n\n        padded_tensor = np.zeros((len(lengths), max(lengths), tensors[0].shape[-1]), dtype=tensors[0].dtype)\n        for idx, (tensor, length) in enumerate(zip(tensors, lengths)):\n            padded_tensor[idx, :length] = tensor\n\n        padded_tensor = torch.from_numpy(padded_tensor)\n        padded_tensor = padded_tensor.to(device=self.device, dtype=self.dtype)\n        return padded_tensor, lengths\n"
  },
  {
    "path": "kimodo/model/tmr.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"TMR model: encoder, and text-to-motion retrieval head.\"\"\"\n\nimport contextlib\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom einops import repeat\nfrom torch import Tensor\n\nfrom kimodo.model import load_checkpoint_state_dict\nfrom kimodo.motion_rep.feature_utils import length_to_mask\nfrom kimodo.sanitize import sanitize_texts\nfrom kimodo.skeleton import SkeletonBase, build_skeleton\nfrom kimodo.tools import ensure_batched\n\n\nclass PositionalEncoding(nn.Module):\n    \"\"\"Sinusoidal positional encoding for sequences (batch_first optional).\"\"\"\n\n    def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False) -> None:\n        super().__init__()\n        self.batch_first = batch_first\n\n        self.dropout = nn.Dropout(p=dropout)\n\n        pe = torch.zeros(max_len, d_model)\n        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n        # Note: have to replace torch.exp() and math.log() with torch.pow()\n        # due to MKL exp() and ln() throws floating point exceptions on certain CPUs\n        div_term = torch.pow(10000.0, -torch.arange(0, d_model, 2).float() / d_model)\n        # div_term = torch.exp(\n        #     torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)\n        # )\n\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n        pe = pe.unsqueeze(0).transpose(0, 1)\n        self.register_buffer(\"pe\", pe, persistent=False)\n\n    def forward(self, x: Tensor) -> Tensor:\n        if self.batch_first:\n            x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]\n        else:\n            x = x + self.pe[: x.shape[0], :]\n        return self.dropout(x)\n\n\ndef load_ckpt(self, ckpt_path):\n    \"\"\"Load model weights from checkpoint path.\"\"\"\n    state_dict = load_checkpoint_state_dict(ckpt_path)\n    self.load_state_dict(state_dict)\n\n\nclass ACTORStyleEncoder(nn.Module):\n    \"\"\"Motion encoder in ACTOR style: optional motion_rep projection, VAE/MLP tokens, transformer.\"\"\"\n\n    def __init__(\n        self,\n        motion_rep: Optional[nn.Module],\n        llm_shape: Optional[Tuple],\n        vae: bool,\n        latent_dim: int = 256,\n        ff_size: int = 1024,\n        num_layers: int = 4,\n        num_heads: int = 4,\n        dropout: float = 0.1,\n        activation: str = \"gelu\",\n        ckpt_path: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n\n        self.motion_rep = motion_rep\n        if motion_rep is not None and llm_shape is None:\n            nfeats = motion_rep.motion_rep_dim\n        elif motion_rep is None and llm_shape is not None:\n            nfeats = llm_shape[-1]\n        else:\n            raise ValueError\n\n        self.nfeats = nfeats\n        self.projection = nn.Linear(nfeats, latent_dim)\n\n        self.vae = vae\n        self.nbtokens = 2 if vae else 1\n        self.tokens = nn.Parameter(torch.randn(self.nbtokens, latent_dim))\n\n        self.sequence_pos_encoding = PositionalEncoding(latent_dim, dropout=dropout, batch_first=True)\n\n        seq_trans_encoder_layer = nn.TransformerEncoderLayer(\n            d_model=latent_dim,\n            nhead=num_heads,\n            dim_feedforward=ff_size,\n            dropout=dropout,\n            activation=activation,\n            batch_first=True,\n        )\n\n        self.seqTransEncoder = nn.TransformerEncoder(\n            seq_trans_encoder_layer,\n            num_layers=num_layers,\n            enable_nested_tensor=False,\n        )\n\n        if ckpt_path is not None:\n            load_ckpt(self, ckpt_path)\n\n    def forward(self, x_dict: Dict) -> Tensor:\n        x = x_dict[\"x\"]\n        mask = x_dict[\"mask\"]\n\n        x = self.projection(x)\n\n        device = x.device\n        bs = len(x)\n\n        tokens = repeat(self.tokens, \"nbtoken dim -> bs nbtoken dim\", bs=bs)\n        xseq = torch.cat((tokens, x), 1)\n\n        token_mask = torch.ones((bs, self.nbtokens), dtype=bool, device=device)\n        aug_mask = torch.cat((token_mask, mask), 1)\n\n        # add positional encoding\n        xseq = self.sequence_pos_encoding(xseq)\n        final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)\n        return final[:, : self.nbtokens]\n\n\nclass TMR(nn.Module):\n    r\"\"\"TMR: Text-to-Motion Retrieval inference code (no decoder)\n    Find more information about the model on the following website:\n    https://mathis.petrovich.fr/tmr\n    \"\"\"\n\n    @classmethod\n    def from_args(\n        cls,\n        motion_rep: nn.Module,\n        llm_shape: tuple | list,\n        vae: bool,\n        latent_dim: int = 256,\n        ff_size: int = 1024,\n        num_layers: int = 4,\n        num_heads: int = 4,\n        dropout: float = 0.1,\n        activation: str = \"gelu\",\n        ckpt_folder: Optional[str] = None,\n        device: Optional[str] = None,\n        **kwargs,\n    ):\n        motion_encoder, top_text_encoder = None, None\n\n        motion_encoder = ACTORStyleEncoder(\n            motion_rep=motion_rep,\n            llm_shape=None,\n            vae=vae,\n            latent_dim=latent_dim,\n            ff_size=ff_size,\n            num_layers=num_layers,\n            num_heads=num_heads,\n            dropout=dropout,\n            activation=activation,\n            ckpt_path=Path(ckpt_folder) / \"motion_encoder.pt\",\n        ).to(device)\n\n        top_text_encoder = ACTORStyleEncoder(\n            motion_rep=None,\n            llm_shape=llm_shape,\n            vae=vae,\n            latent_dim=latent_dim,\n            ff_size=ff_size,\n            num_layers=num_layers,\n            num_heads=num_heads,\n            dropout=dropout,\n            activation=activation,\n            ckpt_path=Path(ckpt_folder) / \"text_encoder.pt\",\n        ).to(device)\n        return cls(\n            motion_encoder,\n            top_text_encoder,\n            vae,\n            device=device,\n            **kwargs,\n        )\n\n    def __init__(\n        self,\n        motion_encoder: nn.Module,\n        top_text_encoder: nn.Module,\n        vae: bool,\n        text_encoder: Optional = None,\n        fact: Optional[float] = None,\n        sample_mean: Optional[bool] = True,\n        unit_vector: Optional[bool] = False,\n        compute_grads: bool = False,\n        device: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n\n        self.motion_encoder = motion_encoder\n        self.text_encoder = top_text_encoder\n        self.raw_text_encoder = text_encoder\n\n        self.motion_rep = None\n        self.skeleton = None\n        if self.motion_encoder is not None:\n            self.motion_rep = self.motion_encoder.motion_rep\n        if self.motion_rep is not None:\n            self.skeleton = self.motion_rep.skeleton\n\n        self.compute_grads = compute_grads\n\n        self.device = device\n\n        # sampling parameters\n        self.vae = vae\n        self.fact = fact if fact is not None else 1.0\n        self.sample_mean = sample_mean\n        self.unit_vector = unit_vector\n\n    def full_text_encoder(self, texts: list[str]):\n        assert isinstance(texts, list), \"The input should be batched.\"\n        # sanitize the texts first\n        # then encode the text, and then use the top text encoder\n        texts = sanitize_texts(texts)\n        text_feat, text_length = self.raw_text_encoder(texts)\n        if isinstance(text_length, list):\n            text_length = torch.tensor(text_length, device=self.device)\n        else:\n            text_length = text_length.to(self.device)\n        inputs = {\n            \"x\": text_feat.to(self.device),\n            \"mask\": length_to_mask(text_length, device=self.device),\n        }\n        return self.text_encoder(inputs)\n\n    def _find_encoder(self, inputs, modality):\n        assert modality in [\"text\", \"motion\", \"raw_text\", \"auto\"]\n\n        if modality == \"text\":\n            return self.text_encoder\n        elif modality == \"motion\":\n            return self.motion_encoder\n        elif modality == \"raw_text\":\n            return self.full_text_encoder\n\n        if isinstance(inputs[0], str):\n            return self.full_text_encoder\n\n        m_nfeats = self.motion_encoder.nfeats\n        t_nfeats = self.text_encoder.nfeats\n\n        if m_nfeats == t_nfeats:\n            raise ValueError(\"Cannot automatically find the encoder, as they share the same input space.\")\n\n        nfeats = inputs[\"x\"].shape[-1]\n        if nfeats == m_nfeats:\n            return self.motion_encoder\n        elif nfeats == t_nfeats:\n            return self.text_encoder\n        else:\n            raise ValueError(\"The inputs is not recognized.\")\n\n    def _encode(\n        self,\n        inputs,\n        modality: str = \"auto\",\n        sample_mean: Optional[bool] = None,\n        fact: Optional[float] = None,\n        return_distribution: bool = False,\n        unit_vector: Optional[bool] = None,\n    ):\n        sample_mean = self.sample_mean if sample_mean is None else sample_mean\n        fact = self.fact if fact is None else fact\n        unit_vector = self.unit_vector if unit_vector is None else unit_vector\n\n        # Encode the inputs\n        encoder = self._find_encoder(inputs, modality)\n        encoded = encoder(inputs)\n\n        # Sampling\n        if self.vae:\n            dists = encoded.unbind(1)\n            mu, logvar = dists\n            if sample_mean:\n                latent_vectors = mu\n            else:\n                # Reparameterization trick\n                std = logvar.exp().pow(0.5)\n                eps = std.data.new(std.size()).normal_()\n                latent_vectors = mu + fact * eps * std\n        else:\n            dists = None\n            (latent_vectors,) = encoded.unbind(1)\n\n        if unit_vector:\n            latent_vectors = torch.nn.functional.normalize(latent_vectors, dim=-1)\n\n        if return_distribution:\n            return latent_vectors, dists\n\n        return latent_vectors\n\n    @ensure_batched(posed_joints=4, lengths=1)\n    def encode_motion(\n        self,\n        posed_joints: torch.Tensor,\n        original_skeleton: Optional[SkeletonBase] = None,\n        lengths: Optional[torch.Tensor] = None,\n        unit_vector: Optional[bool] = None,\n    ):\n        # TODO here.\n        convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()\n\n        if original_skeleton is None:\n            original_skeleton = build_skeleton(posed_joints.shape[-2])\n\n        if lengths is None:\n            nbatch, nbframes = posed_joints.shape[:2]\n            device = posed_joints.device\n            assert nbatch == 1, \"If lenghts is not provided, the input should not be batched.\"\n            lengths = torch.tensor([nbframes], device=device)\n\n        # slice the posed joints if we use less joints\n        skel_slice = self.motion_rep.skeleton.get_skel_slice(original_skeleton)\n        posed_joints = posed_joints[..., skel_slice, :]\n\n        with convert_ctx:\n            features = self.motion_rep(\n                posed_joints=posed_joints,\n                to_canonicalize=True,\n                to_normalize=True,\n                lengths=lengths,\n            )\n            mask = length_to_mask(lengths, device=features.device)\n            x_dict = {\"x\": features, \"mask\": mask}\n            latent_vectors = self._encode(\n                x_dict,\n                modality=\"motion\",\n                unit_vector=unit_vector,\n            )\n        return latent_vectors\n\n    def encode_text(\n        self,\n        x_dict: Dict,\n        unit_vector: Optional[bool] = None,\n    ):\n        # TODO: make it ensure batched\n        convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()\n\n        with convert_ctx:\n            latent_vectors = self._encode(\n                x_dict,\n                modality=\"text\",\n                unit_vector=unit_vector,\n            )\n        return latent_vectors\n\n    def encode_raw_text(\n        self,\n        texts: List[str],\n        unit_vector: Optional[bool] = None,\n    ):\n        is_batched = True\n        if isinstance(texts, str):\n            is_batched = False\n            texts = [texts]\n\n        convert_ctx = torch.no_grad() if not self.compute_grads else contextlib.nullcontext()\n\n        with convert_ctx:\n            latent_vectors = self._encode(\n                texts,\n                modality=\"raw_text\",\n                unit_vector=unit_vector,\n            )\n        if not is_batched:\n            latent_vectors = latent_vectors[0]\n        return latent_vectors\n"
  },
  {
    "path": "kimodo/model/twostage_denoiser.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Two-stage transformer denoiser: root stage then body stage for motion diffusion.\"\"\"\n\nimport contextlib\nfrom typing import Optional\n\nimport torch\nfrom torch import nn\n\nfrom .backbone import TransformerEncoderBlock\nfrom .loading import load_checkpoint_state_dict\n\n\nclass TwostageDenoiser(nn.Module):\n    \"\"\"Two-stage denoiser: first predicts global root features, then body features conditioned on local root.\"\"\"\n\n    def __init__(\n        self,\n        motion_rep,\n        motion_mask_mode,\n        ckpt_path: Optional[str] = None,\n        **kwargs,\n    ):\n        \"\"\"Build root and body transformer blocks; optionally load checkpoint from ckpt_path.\"\"\"\n        super().__init__()\n        self.motion_rep = motion_rep\n        self.motion_mask_mode = motion_mask_mode\n\n        # it should be a dual motion_rep\n        # and be global by default\n        # global motion_rep as inpnut\n        input_dim = motion_rep.motion_rep_dim\n        will_concatenate = motion_mask_mode == \"concat\"\n\n        # stage 1: root only\n        root_input_dim = input_dim * 2 if will_concatenate else input_dim\n        root_output_dim = motion_rep.global_root_dim\n\n        self.root_model = TransformerEncoderBlock(\n            input_dim=root_input_dim,\n            output_dim=root_output_dim,\n            skeleton=self.motion_rep.skeleton,\n            **kwargs,\n        )\n\n        # replace the global root by the local root\n        local_motion_rep_dim = input_dim - motion_rep.global_root_dim + motion_rep.local_root_dim\n\n        # stage 2: local body\n        body_input_dim = local_motion_rep_dim + (\n            input_dim if will_concatenate else 0\n        )  # body stage always takes in local root info for motion (but still the global mask)\n\n        body_output_dim = input_dim - motion_rep.global_root_dim\n        self.body_model = TransformerEncoderBlock(\n            input_dim=body_input_dim,\n            output_dim=body_output_dim,\n            skeleton=self.motion_rep.skeleton,\n            **kwargs,\n        )\n\n        if ckpt_path:\n            self.load_ckpt(ckpt_path)\n\n    def load_ckpt(self, ckpt_path: str) -> None:\n        \"\"\"Load checkpoint from path; state dict keys are stripped of 'denoiser.backbone.'\n        prefix.\"\"\"\n        state_dict = load_checkpoint_state_dict(ckpt_path)\n        state_dict = {key.replace(\"denoiser.backbone.\", \"\"): val for key, val in state_dict.items()}\n        self.load_state_dict(state_dict)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        x_pad_mask: torch.Tensor,\n        text_feat: torch.Tensor,\n        text_feat_pad_mask: torch.Tensor,\n        timesteps: torch.Tensor,\n        first_heading_angle: Optional[torch.Tensor] = None,\n        motion_mask: Optional[torch.Tensor] = None,\n        observed_motion: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Args:\n            x (torch.Tensor): [B, T, dim_motion] current noisy motion\n            x_pad_mask (torch.Tensor): [B, T] attention mask, positions with True are allowed to attend, False are not\n            text_feat (torch.Tensor): [B, max_text_len, llm_dim] embedded text prompts\n            text_feat_pad_mask (torch.Tensor): [B, max_text_len] attention mask, positions with True are allowed to attend, False are not\n            timesteps (torch.Tensor): [B,] current denoising step\n            motion_mask\n            observed_motion\n\n        Returns:\n            torch.Tensor: same size as input x\n        \"\"\"\n\n        if self.motion_mask_mode == \"concat\":\n            if motion_mask is None or observed_motion is None:\n                motion_mask = torch.zeros_like(x)\n                observed_motion = torch.zeros_like(x)\n            x = x * (1 - motion_mask) + observed_motion * motion_mask\n            x_extended = torch.cat([x, motion_mask], axis=-1)\n        else:\n            x_extended = x\n\n        # Stage 1: predict root motion in global\n        root_motion_pred = self.root_model(\n            x_extended,\n            x_pad_mask,\n            text_feat,\n            text_feat_pad_mask,\n            timesteps,\n            first_heading_angle,\n        )  # [B, T, 5]\n\n        # Maybe pass this as argument instead of recomputing it\n        lengths = x_pad_mask.sum(-1)\n\n        # Convert root pred to local rep\n        # At test-time want to allow gradient through for guidance\n        convert_ctx = torch.no_grad() if self.training else contextlib.nullcontext()\n        with convert_ctx:\n            root_motion_local = self.motion_rep.global_root_to_local_root(\n                root_motion_pred,\n                normalized=True,\n                lengths=lengths,\n            )\n        if self.training:\n            root_motion_local = root_motion_local.detach()\n\n        # concatenate the predicted local root with the body motion\n        body_x = x[..., self.motion_rep.body_slice]\n        x_new = torch.cat([root_motion_local, body_x], axis=-1)\n\n        if self.motion_mask_mode == \"concat\":\n            x_new_extended = torch.cat([x_new, motion_mask], axis=-1)\n        else:\n            x_new_extended = x_new\n\n        # Stage 2: predict local body motion based on local root\n        predicted_body = self.body_model(\n            x_new_extended,\n            x_pad_mask,\n            text_feat,\n            text_feat_pad_mask,\n            timesteps,\n            first_heading_angle,\n        )\n\n        # concatenate the predicted local body with the predicted root\n        output = torch.cat([root_motion_pred, predicted_body], axis=-1)\n        return output\n"
  },
  {
    "path": "kimodo/motion_rep/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Motion representation utilities.\"\"\"\n\nfrom .reps import KimodoMotionRep, MotionRepBase, TMRMotionRep\n\n__all__ = [\n    \"MotionRepBase\",\n    \"KimodoMotionRep\",\n    \"TMRMotionRep\",\n]\n"
  },
  {
    "path": "kimodo/motion_rep/conditioning.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Constraint conditioning: build index and data dicts from constraint sets for the denoiser.\"\"\"\n\nfrom collections import defaultdict\n\nimport torch\n\n\ndef build_condition_dicts(constraints_lst: list):\n    index_dict = defaultdict(list)\n    data_dict = defaultdict(list)\n    for constraint in constraints_lst:\n        constraint.update_constraints(data_dict, index_dict)\n    return index_dict, data_dict\n\n\ndef get_unique_index_and_data(indices_lst, data):\n    # unique + sort them by t\n    indices_unique, inverse = torch.unique(indices_lst, dim=0, return_inverse=True)\n    # pick first value for each unique (t, j)\n    first_idx = torch.zeros(indices_unique.size(0), dtype=torch.long, device=inverse.device)\n    first_idx.scatter_(0, inverse, torch.arange(len(inverse), device=inverse.device))\n    assert (indices_lst[first_idx] == indices_unique).all()\n    # get the data\n    indices_lst = indices_lst[first_idx]\n    data = data[first_idx]\n    return indices_lst, data\n"
  },
  {
    "path": "kimodo/motion_rep/feature_utils.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Motion representation helpers: velocity, heading, masks, and rotation of features.\"\"\"\n\nfrom typing import List, Optional, Union\n\nimport einops\nimport torch\n\nfrom kimodo.geometry import cont6d_to_matrix, matrix_to_cont6d\nfrom kimodo.skeleton import SkeletonBase\nfrom kimodo.tools import ensure_batched\n\n\ndef diff_angles(angles: torch.Tensor, fps: float) -> torch.Tensor:\n    \"\"\"Compute frame-to-frame angular differences in radians, scaled by fps.\n\n    Args:\n        angles: [..., T] batched sequences of rotation angles in radians.\n        fps: Sampling rate used to convert frame differences to per-second rate.\n\n    Returns:\n        [..., T-1] difference between consecutive angles (rad/s).\n    \"\"\"\n\n    cos = torch.cos(angles)\n    sin = torch.sin(angles)\n\n    cos_diff = cos[..., 1:] * cos[..., :-1] + sin[..., 1:] * sin[..., :-1]\n    sin_diff = sin[..., 1:] * cos[..., :-1] - cos[..., 1:] * sin[..., :-1]\n\n    # should be close to angles.diff() but more robust\n    # multiply by fps = 1 / dt\n    angles_diff = fps * torch.arctan2(sin_diff, cos_diff)\n    return angles_diff\n\n\n@ensure_batched(positions=4, lengths=1)\ndef compute_vel_xyz(\n    positions: torch.Tensor,\n    fps: float,\n    lengths: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"Compute the velocities from positions: dx/dt. Works with batches. The last velocity is duplicated to keep the same size.\n\n    Args:\n        positions (torch.Tensor): [..., T, J, 3] xyz positions of a human skeleton\n        fps (float): frame per seconds\n        lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, positions should not be batched\n\n    Returns:\n        velocity (torch.Tensor): [..., T, J, 3] velocities computed from the positions\n    \"\"\"\n    device = positions.device\n\n    if lengths is None:\n        assert positions.shape[0] == 1, \"If lengths is not provided, the input should not be batched.\"\n        lengths = torch.tensor([len(positions)], device=device)\n\n    # useful for indexing\n    range_len = torch.arange(len(lengths))\n\n    # compute velocities with fps\n    velocity = fps * (positions[:, 1:] - positions[:, :-1])\n    # pading the velocity vector\n    vel_pad = torch.zeros_like(velocity[:, 0])\n    velocity, _ = einops.pack([velocity, vel_pad], \"batch * nbjoints dim\")\n\n    # repeat the last velocities\n    # with special care for different lengths with batches\n    velocity[(range_len, lengths - 1)] = velocity[(range_len, lengths - 2)]\n    return velocity\n\n\n@ensure_batched(root_rot_angles=2, lengths=1)\ndef compute_vel_angle(\n    root_rot_angles: torch.Tensor,\n    fps: float,\n    lengths: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"Compute the local root rotation velocity: dtheta/dt.\n\n    Args:\n        root_rot_angles (torch.Tensor): [..., T] rotation angle (in radian)\n        fps (float): frame per seconds\n        lengths (Optional[torch.Tensor]): [...] size of each input batched. If not provided, root_rot_angles should not be batched\n\n    Returns:\n        local_root_rot_vel (torch.Tensor): [..., T] local root rotation velocity (in radian/s)\n    \"\"\"\n    device = root_rot_angles.device\n    if lengths is None:\n        assert root_rot_angles.shape[0] == 1, \"If lengths is not provided, the input should not be batched.\"\n        lengths = torch.tensor([len(root_rot_angles)], device=device)\n\n    # useful for indexing\n    range_len = torch.arange(len(lengths))\n\n    local_root_rot_vel = diff_angles(root_rot_angles, fps)\n    pad_rot_vel_angles = torch.zeros_like(root_rot_angles[:, 0])\n    local_root_rot_vel, _ = einops.pack(\n        [local_root_rot_vel, pad_rot_vel_angles],\n        \"batch *\",\n    )\n    # repeat the last rotation angle\n    # with special care for different lengths with batches\n    local_root_rot_vel[(range_len, lengths - 1)] = local_root_rot_vel[(range_len, lengths - 2)]\n    return local_root_rot_vel\n\n\n@ensure_batched(posed_joints=4)\ndef compute_heading_angle(posed_joints: torch.Tensor, skeleton: SkeletonBase) -> torch.Tensor:\n    \"\"\"Compute the heading direction from joint positions using the hip vector.\n\n    Args:\n        posed_joints: [B, T, J, 3] global joint positions.\n        skeleton: Skeleton instance used to get hip joint indices.\n\n    Returns:\n        [B] heading angle in radians.\n    \"\"\"\n    # compute root heading for the sequence from hip positions\n    r_hip, l_hip = skeleton.hip_joint_idx\n    diff = posed_joints[:, :, r_hip] - posed_joints[:, :, l_hip]\n    heading_angle = torch.atan2(diff[..., 2], -diff[..., 0])\n    return heading_angle\n\n\ndef length_to_mask(\n    length: Union[torch.Tensor, List],\n    max_len: Optional[int] = None,\n    device=None,\n) -> torch.Tensor:\n    \"\"\"Convert sequence lengths to a boolean validity mask.\n\n    Args:\n        length: Sequence lengths, either a tensor ``[B]`` or a Python list.\n        max_len: Optional mask width. If omitted, uses ``max(length)``.\n        device: Optional device. When ``length`` is a list, this controls where\n            the new tensor is created.\n\n    Returns:\n        A boolean tensor of shape ``[B, max_len]`` where ``True`` marks valid\n        timesteps.\n    \"\"\"\n    if isinstance(length, list):\n        if device is None:\n            device = \"cpu\"\n        length = torch.tensor(length, device=device)\n\n    # Use requested device for output; move length if needed so mask and length match\n    if device is not None:\n        target = torch.device(device)\n        if length.device != target:\n            length = length.to(target)\n    device = length.device\n\n    if max_len is None:\n        max_len = max(length)\n\n    mask = torch.arange(max_len, device=device).expand(len(length), max_len) < length.unsqueeze(1)\n    return mask\n\n\nclass RotateFeatures:\n    \"\"\"Helper that applies a global heading rotation to motion features.\"\"\"\n\n    def __init__(self, angle: torch.Tensor):\n        \"\"\"Precompute 2D and 3D rotation matrices for a batch of angles.\n\n        Args:\n            angle: Rotation angle(s) in radians, shaped ``[B]``.\n        \"\"\"\n        self.angle = angle\n\n        ## Create the necessary rotations matrices\n        cos, sin = torch.cos(angle), torch.sin(angle)\n        one, zero = torch.ones_like(angle), torch.zeros_like(angle)\n\n        # 2D rotation transposed (sin are -sin)\n        self.corrective_mat_2d_T = torch.stack((cos, sin, -sin, cos), -1).reshape(angle.shape + (2, 2))\n        # 3D rotation on Y axis\n        self.corrective_mat_Y = torch.stack((cos, zero, sin, zero, one, zero, -sin, zero, cos), -1).reshape(\n            angle.shape + (3, 3)\n        )\n        self.corrective_mat_Y_T = self.corrective_mat_Y.transpose(-2, -1).contiguous()\n\n    def rotate_positions(self, positions: torch.Tensor):\n        \"\"\"Rotate 3D positions around the Y axis.\"\"\"\n        return positions @ self.corrective_mat_Y_T\n\n    def rotate_2d_positions(self, positions_2d: torch.Tensor):\n        \"\"\"Rotate 2D ``(x, z)`` vectors in the ground plane.\"\"\"\n        return positions_2d @ self.corrective_mat_2d_T\n\n    def rotate_rotations(self, rotations: torch.Tensor):\n        \"\"\"Left-multiply global rotation matrices by the heading correction.\"\"\"\n        # \"Rotate\" the global rotations\n        # which means add an extra Y rotation after the transform\n        # so at the left R' = R_y R\n        # (since we use the convention x' = R x)\n        # \"bik,btdkj->btdij\"\n\n        B, T, J = rotations.shape[:3]\n        BTJ = B * T * J\n        return (\n            self.corrective_mat_Y[:, None, None].expand(B, T, J, 3, 3).reshape(BTJ, 3, 3) @ rotations.reshape(BTJ, 3, 3)\n        ).reshape(B, T, J, 3, 3)\n\n    def rotate_6d_rotations(self, rotations_6d: torch.Tensor):\n        \"\"\"Rotate 6D rotation features via matrix conversion.\"\"\"\n        return matrix_to_cont6d(self.rotate_rotations(cont6d_to_matrix(rotations_6d)))\n"
  },
  {
    "path": "kimodo/motion_rep/feet.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Foot contact detection from joint positions and velocities.\"\"\"\n\nimport torch\n\nfrom ..tools import ensure_batched\n\n\n@ensure_batched(positions=4, velocity=4)\ndef foot_detect_from_pos_and_vel(\n    positions: torch.Tensor,\n    velocity: torch.Tensor,\n    skeleton,\n    vel_thres: float,\n    height_thresh: float,\n) -> torch.Tensor:\n    \"\"\"Compute foot contact labels using heuristics combining joint height and velocities.\n\n    Args:\n        positions (torch.Tensor): [X, T, J, 3] global joint positions\n        velocity (torch.Tensor): [X, T, J, 3] velocities (already padded correctly), already multiplied by 1 / dt\n        vel_thres (float): threshold for joint velocity\n        height_thresh (float): threshold for joint height\n\n    Returns:\n        torch.Tensor: [X, T, 4] contact labels for left and right foot joints\n        (heel/toe order follows the skeleton joint index definition), where\n        ``1`` denotes contact.\n    \"\"\"\n\n    device = positions.device\n    # Use at most 2 foot joints per side (ankle + toe); SOMA77 defines a\n    # third end-effector (ToeEnd) that SOMA30 and other skeletons omit.\n    fid_l = skeleton.left_foot_joint_idx[:2]\n    fid_r = skeleton.right_foot_joint_idx[:2]\n\n    velfactor, heightfactor = (\n        torch.tensor([vel_thres, vel_thres], device=device),\n        torch.tensor([height_thresh, height_thresh], device=device),\n    )\n\n    feet_l_v = torch.linalg.norm(velocity[:, :, fid_l], axis=-1)\n    feet_l_h = positions[:, :, fid_l, 1]\n\n    feet_l = torch.logical_and(\n        feet_l_v < velfactor,\n        feet_l_h < heightfactor,\n    ).to(positions.dtype)\n\n    feet_r_v = torch.linalg.norm(velocity[:, :, fid_r], axis=-1)\n    feet_r_h = positions[:, :, fid_r, 1]\n\n    feet_r = torch.logical_and(\n        feet_r_v < velfactor,\n        feet_r_h < heightfactor,\n    ).to(positions.dtype)\n\n    foot_contacts = torch.cat((feet_l, feet_r), axis=-1)\n    return foot_contacts\n"
  },
  {
    "path": "kimodo/motion_rep/reps/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Motion representation implementations: base, Kimodo, and TMR.\"\"\"\n\nfrom .base import MotionRepBase\nfrom .kimodo_motionrep import KimodoMotionRep\nfrom .tmr_motionrep import TMRMotionRep\n\n__all__ = [\n    \"MotionRepBase\",\n    \"KimodoMotionRep\",\n    \"TMRMotionRep\",\n]\n"
  },
  {
    "path": "kimodo/motion_rep/reps/base.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Base motion representation: feature layout, normalization, and conditioning helpers.\"\"\"\n\nimport os\nfrom typing import Optional\n\nimport einops\nimport numpy as np\nimport torch\nfrom einops import repeat\n\nfrom ...tools import ensure_batched\nfrom ..conditioning import build_condition_dicts\nfrom ..feature_utils import compute_vel_angle, compute_vel_xyz\nfrom ..stats import Stats\n\n\ndef _require_split_stats_layout(stats_path: str) -> None:\n    \"\"\"Raise if stats_path does not contain the required global_root, local_root, body subdirs.\"\"\"\n    subdirs = (\"global_root\", \"local_root\", \"body\")\n    missing = []\n    for name in subdirs:\n        subpath = os.path.join(stats_path, name)\n        mean_path = os.path.join(subpath, \"mean.npy\")\n        if not os.path.isfile(mean_path):\n            missing.append(f\"{subpath}/ (mean.npy)\")\n    if missing:\n        raise FileNotFoundError(\n            f\"Checkpoint stats must use the split layout with subfolders \"\n            f\"global_root/, local_root/, and body/ under '{stats_path}'. \"\n            f\"Missing or incomplete: {', '.join(missing)}. \"\n        )\n\n\nclass MotionRepBase:\n    \"\"\"Base class for motion representations used in generation and conditioning.\n\n    Subclasses define:\n    - ``size_dict``: feature blocks and their shapes,\n    - ``last_root_feature``: last entry of the root block,\n    - ``local_root_size_dict``: local-root feature layout,\n    and implement transform-specific methods such as ``__call__``, ``inverse``,\n    ``rotate``, ``translate_2d`` and ``create_conditions``.\n    \"\"\"\n\n    def __init__(\n        self,\n        skeleton,\n        fps,\n        stats_path: Optional[str] = None,\n    ):\n        \"\"\"Initialize feature slicing metadata and optional normalization stats.\"\"\"\n\n        self.skeleton = skeleton\n        self.fps = fps\n        self.nbjoints = skeleton.nbjoints\n\n        self.feature_names = list(self.size_dict.keys())\n        self.ps = list(self.size_dict.values())\n        self.nfeats_dict = {key: val.numel() for key, val in self.size_dict.items()}\n        feats_cumsum = np.cumsum([0] + list(self.nfeats_dict.values())).tolist()\n        self.slice_dict = {key: slice(feats_cumsum[i], feats_cumsum[i + 1]) for i, key in enumerate(self.feature_names)}\n\n        self.motion_rep_dim = sum(self.nfeats_dict.values())\n        self.root_slice = slice(0, self.slice_dict[self.last_root_feature].stop)\n        self.body_slice = slice(self.root_slice.stop, self.motion_rep_dim)\n        self.body_dim = self.body_slice.stop - self.body_slice.start\n        self.global_root_dim = self.root_slice.stop\n        self.local_root_dim = sum(val.numel() for val in self.local_root_size_dict.values())\n\n        if stats_path:\n            _require_split_stats_layout(stats_path)\n            self.global_root_stats = Stats(os.path.join(stats_path, \"global_root\"))\n            self.local_root_stats = Stats(os.path.join(stats_path, \"local_root\"))\n            self.body_stats = Stats(os.path.join(stats_path, \"body\"))\n\n            # Global stats\n            mean = torch.cat([self.global_root_stats.mean, self.body_stats.mean])\n            std = torch.cat([self.global_root_stats.std, self.body_stats.std])\n            assert len(mean) == len(std) == self.motion_rep_dim, \"There is an stat issue.\"\n            self.stats = Stats()\n            self.stats.register_from_tensors(mean, std)\n\n    def get_root_pos(self, features: torch.Tensor, fallback_to_smooth: bool = True):\n        \"\"\"Extract root positions from a feature tensor.\n\n        Supports both ``root_pos`` and ``smooth_root_pos`` representations.\n        \"\"\"\n        if \"root_pos\" in self.slice_dict:\n            return features[..., self.slice_dict[\"root_pos\"]]\n\n        if \"smooth_root_pos\" not in self.slice_dict:\n            raise TypeError(\"This motion rep should have either a root_pos or smooth_root_pos field\")\n\n        if fallback_to_smooth:\n            return features[:, :, self.slice_dict[\"smooth_root_pos\"]]\n\n        # else compute the root pos from the smooth root and local joints offset\n        smooth_root_pos = features[:, :, self.slice_dict[\"smooth_root_pos\"]].clone()\n        local_joints_positions_flatten = features[..., self.slice_dict[\"local_joints_positions\"]]\n        hips_offset = local_joints_positions_flatten[..., self.skeleton.root_idx : self.skeleton.root_idx + 3]\n        root_pos = torch.stack(\n            [\n                smooth_root_pos[..., 0] + hips_offset[..., 0],\n                smooth_root_pos[..., 1],\n                smooth_root_pos[..., 2] + hips_offset[..., 2],\n            ],\n            axis=-1,\n        )\n        return root_pos\n\n    @ensure_batched(root_features=3, lengths=1)\n    def global_root_to_local_root(\n        self,\n        root_features: torch.Tensor,\n        normalized: bool,\n        lengths: Optional[torch.Tensor],\n    ):\n        \"\"\"Convert global root features to local-root motion features.\n\n        Args:\n            root_features: Root feature tensor containing root position and\n                global heading, shaped ``[B, T, D_root]``.\n            normalized: Whether ``root_features`` are normalized.\n            lengths: Optional valid lengths per sequence.\n\n        Returns:\n            Tensor ``[B, T, 4]`` with local root rotational velocity, planar\n            velocity, and global root height.\n        \"\"\"\n        if normalized:\n            root_features = self.global_root_stats.unnormalize(root_features)\n\n        [root_pos, global_root_heading] = einops.unpack(root_features, self.ps[:2], \"batch time *\")\n        cos, sin = global_root_heading.unbind(-1)\n        heading_angle = torch.arctan2(sin, cos)\n\n        local_root_rot_vel = compute_vel_angle(heading_angle, self.fps, lengths=lengths)\n        local_root_vel = compute_vel_xyz(\n            root_pos[..., None, :],\n            self.fps,\n            lengths=lengths,\n        )[..., 0, [0, 2]]\n        global_root_y = root_pos[..., 1]\n        local_root_motion = torch.cat(\n            [\n                local_root_rot_vel[..., None],\n                local_root_vel,\n                global_root_y[..., None],\n            ],\n            axis=-1,\n        )\n\n        if normalized:\n            local_root_motion = self.local_root_stats.normalize(local_root_motion)\n        return local_root_motion\n\n    def get_root_heading_angle(self, features: torch.Tensor) -> torch.Tensor:\n        \"\"\"Compute root heading angle from cosine/sine heading features.\"\"\"\n        global_root_heading = features[:, :, self.slice_dict[\"global_root_heading\"]]\n        cos, sin = global_root_heading.unbind(-1)\n        return torch.arctan2(sin, cos)\n\n    @ensure_batched(features=3)\n    def rotate_to(\n        self,\n        features: torch.Tensor,\n        target_angle: torch.Tensor,\n        return_delta_angle=False,\n    ):\n        \"\"\"Rotate each sequence so frame-0 heading matches ``target_angle``.\"\"\"\n        # rotate so that the first frame angle is the target\n        # it put the motion_rep to the angle\n        current_first_angle = self.get_root_heading_angle(features)[:, 0]\n        delta_angle = target_angle - current_first_angle\n        rotated_features = self.rotate(features, delta_angle)\n        if return_delta_angle:\n            return rotated_features, delta_angle\n        return rotated_features\n\n    @ensure_batched(features=3)\n    def rotate_to_zero(\n        self,\n        features: torch.Tensor,\n        return_delta_angle=False,\n    ):\n        \"\"\"Rotate each sequence so frame-0 heading becomes zero.\"\"\"\n        target_angle = torch.zeros(len(features), device=features.device)\n        return self.rotate_to(features, target_angle, return_delta_angle=return_delta_angle)\n\n    @ensure_batched(features=3)\n    def randomize_first_heading(\n        self,\n        features: torch.Tensor,\n        return_delta_angle=False,\n    ) -> torch.Tensor:\n        \"\"\"Rotate each sequence to a random frame-0 heading.\"\"\"\n        target_heading_angle = torch.rand(features.shape[0]) * 2 * np.pi\n        return self.rotate_to(\n            features,\n            target_heading_angle,\n            return_delta_angle=return_delta_angle,\n        )\n\n    @ensure_batched(features=3, target_2d_pos=2)\n    def translate_2d_to(\n        self,\n        features: torch.Tensor,\n        target_2d_pos: torch.Tensor,\n        return_delta_pos: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Translate each sequence so frame-0 root ``(x, z)`` matches a target.\"\"\"\n        root_pos = self.get_root_pos(features)\n        current_first_2d_pos = root_pos[:, 0, [0, 2]].clone()\n        delta_2d_pos = target_2d_pos - current_first_2d_pos\n        translated_features = self.translate_2d(features, delta_2d_pos)\n        if return_delta_pos:\n            return translated_features, delta_2d_pos\n        return translated_features\n\n    @ensure_batched(features=3)\n    def translate_2d_to_zero(\n        self,\n        features: torch.Tensor,\n        return_delta_pos: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Translate each sequence so frame-0 root ``(x, z)`` is at the origin.\"\"\"\n        target_2d_pos = torch.zeros(len(features), 2, device=features.device)\n        return self.translate_2d_to(features, target_2d_pos, return_delta_pos=return_delta_pos)\n\n    @ensure_batched(features=3)\n    def canonicalize(self, features: torch.Tensor, normalized: bool = False):\n        \"\"\"Canonicalize heading and planar position at frame 0.\"\"\"\n        if normalized:\n            features = self.unnormalize(features)\n        rotated_features = self.rotate_to_zero(features)\n        canonicalized_features = self.translate_2d_to_zero(rotated_features)\n        if normalized:\n            canonicalized_features = self.normalize(canonicalized_features)\n        return canonicalized_features\n\n    def normalize(self, features):\n        \"\"\"Normalize features.\"\"\"\n        return self.stats.normalize(features)\n\n    def unnormalize(self, features):\n        \"\"\"Undo feature normalization.\"\"\"\n        return self.stats.unnormalize(features)\n\n    def create_conditions_from_constraints(\n        self,\n        constraints_lst: list,\n        length: int,\n        to_normalize: bool,\n        device: str,\n    ):\n        \"\"\"Create a conditioning tensor and mask from constraint objects.\"\"\"\n        index_dict, data_dict = build_condition_dicts(constraints_lst)\n        return self.create_conditions(index_dict, data_dict, length, to_normalize, device)\n\n    def create_conditions_from_constraints_batched(\n        self,\n        constraints_lst: list | list[list],\n        lengths: torch.Tensor,\n        to_normalize: bool,\n        device: str,\n    ):\n        \"\"\"Batched version of ``create_conditions_from_constraints``.\n\n        Supports either one shared constraint list for all batch elements, or a per-sample list of\n        constraint lists.\n        \"\"\"\n        num_samples = len(lengths)\n        if not constraints_lst or not isinstance(constraints_lst[0], list):\n            # If no constraints, or constraints are shared across the batch,\n            # build once and repeat.\n            observed_motion, motion_mask = self.create_conditions_from_constraints(\n                constraints_lst, int(lengths.max()), to_normalize, device\n            )\n            observed_motion = repeat(observed_motion, \"t d -> b t d\", b=num_samples)\n            motion_mask = repeat(motion_mask, \"t d -> b t d\", b=num_samples)\n            return observed_motion, motion_mask\n\n        length = int(lengths.max())\n        observed_motion_lst = []\n        motion_mask_lst = []\n        for constraints_lst_el in constraints_lst:\n            observed_motion, motion_mask = self.create_conditions_from_constraints(\n                constraints_lst_el,\n                length,\n                to_normalize,\n                device,\n            )\n            observed_motion_lst.append(observed_motion)\n            motion_mask_lst.append(motion_mask)\n        observed_motion = torch.stack(observed_motion_lst, axis=0)\n        motion_mask = torch.stack(motion_mask_lst, axis=0)\n        return observed_motion, motion_mask\n"
  },
  {
    "path": "kimodo/motion_rep/reps/kimodo_motionrep.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nfrom typing import Optional\n\nimport einops\nimport torch\nfrom torch import Tensor\n\nfrom kimodo.tools import to_numpy\n\nfrom ...geometry import cont6d_to_matrix, matrix_to_cont6d\nfrom ...skeleton.kinematics import fk\nfrom ...skeleton.transforms import global_rots_to_local_rots\nfrom ...tools import ensure_batched\nfrom ..conditioning import get_unique_index_and_data\nfrom ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz\nfrom ..feet import foot_detect_from_pos_and_vel\nfrom ..smooth_root import get_smooth_root_pos\nfrom .base import MotionRepBase\n\n\nclass KimodoMotionRep(MotionRepBase):\n    \"\"\"Global root / global joints rotations representation, relative to a smooth root.\"\"\"\n\n    def __init__(\n        self,\n        skeleton,\n        fps,\n        stats_path: Optional[str] = None,\n    ):\n        nbjoints = skeleton.nbjoints\n\n        self.size_dict = {\n            \"smooth_root_pos\": torch.Size([3]),\n            \"global_root_heading\": torch.Size([2]),\n            \"local_joints_positions\": torch.Size([nbjoints, 3]),\n            \"global_rot_data\": torch.Size([nbjoints, 6]),\n            \"velocities\": torch.Size([nbjoints, 3]),\n            \"foot_contacts\": torch.Size([4]),\n        }\n        self.last_root_feature = \"global_root_heading\"\n        self.local_root_size_dict = {\n            \"local_root_rot_vel\": torch.Size([1]),\n            \"local_root_vel\": torch.Size([2]),\n            \"global_root_y\": torch.Size([1]),\n        }\n        super().__init__(skeleton, fps, stats_path)\n\n    @ensure_batched(local_joint_rots=5, root_positions=3, lengths=1)\n    def __call__(\n        self,\n        local_joint_rots: torch.Tensor,\n        root_positions: torch.Tensor,\n        to_normalize: bool,\n        to_canonicalize: bool = False,\n        lengths: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"Convert local rotations and root trajectory into smooth-root features.\n\n        Args:\n            local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.\n            root_positions: Root positions ``[B, T, 3]``.\n            to_normalize: Whether to normalize output features.\n            to_canonicalize: Whether to canonicalize output features (False by default).\n            lengths: Optional valid lengths for variable-length batches.\n\n        Returns:\n            Motion features with shape ``[B, T, motion_rep_dim]``.\n        \"\"\"\n        device = local_joint_rots.device\n        if lengths is None:\n            assert local_joint_rots.shape[0] == 1, \"If lenghts is not provided, the input should not be batched.\"\n            lengths = torch.tensor([local_joint_rots.shape[1]], device=device)\n\n        (\n            global_joints_rots,\n            global_joints_positions,\n            local_joints_positions_origin_is_pelvis,\n        ) = fk(local_joint_rots, root_positions, self.skeleton)\n\n        root_heading_angle = compute_heading_angle(global_joints_positions, self.skeleton)\n        global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)\n\n        smooth_root_pos = get_smooth_root_pos(root_positions)\n        hips_offset = root_positions - smooth_root_pos\n        hips_offset[..., 1] = root_positions[..., 1]\n        local_joints_positions = local_joints_positions_origin_is_pelvis + hips_offset[:, :, None]\n\n        velocities = compute_vel_xyz(global_joints_positions, self.fps, lengths=lengths)\n        foot_contacts = foot_detect_from_pos_and_vel(global_joints_positions, velocities, self.skeleton, 0.15, 0.10)\n        global_rot_data = matrix_to_cont6d(global_joints_rots)\n\n        features, _ = einops.pack(\n            [\n                smooth_root_pos,\n                global_root_heading,\n                local_joints_positions,\n                global_rot_data,\n                velocities,\n                foot_contacts,\n            ],\n            \"batch time *\",\n        )\n\n        if to_canonicalize:\n            features = self.canonicalize(features, normalized=False)\n\n        if to_normalize:\n            features = self.normalize(features)\n        return features\n\n    @ensure_batched(features=3, angle=1)\n    def rotate(self, features: torch.Tensor, angle: torch.Tensor):\n        \"\"\"Rotate root/joint positional and rotational features by heading.\"\"\"\n        # assume it is not normalized\n        bs = features.shape[0]\n        device = features.device\n        [\n            smooth_root_pos,\n            global_root_heading,\n            local_joints_positions,\n            global_rot_data,\n            velocities,\n            foot_contacts,\n        ] = einops.unpack(features, self.ps, \"batch time *\")\n\n        if not isinstance(angle, torch.Tensor):\n            angle = torch.tensor(angle, device=device)\n        if len(angle.shape) == 0:\n            angle = angle.repeat(bs)\n\n        RF = RotateFeatures(angle)\n        new_features, _ = einops.pack(\n            [\n                RF.rotate_positions(smooth_root_pos),\n                RF.rotate_2d_positions(global_root_heading),\n                RF.rotate_positions(local_joints_positions),\n                RF.rotate_6d_rotations(global_rot_data),\n                RF.rotate_positions(velocities),\n                foot_contacts,\n            ],\n            \"batch time *\",\n        )\n        return new_features\n\n    @ensure_batched(features=3, translation_2d=2)\n    def translate_2d(\n        self,\n        features: torch.Tensor,\n        translation_2d: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Translate smooth root planar position by ``(dx, dz)``.\"\"\"\n        # only move on the ground\n        # If we need a translate_3D function, we should not forget to move the local_joints_positions as well\n        bs = features.shape[0]\n        if len(translation_2d.shape) == 1:\n            translation_2d = translation_2d.repeat(bs, 1)\n\n        new_features = features.clone()\n        new_smooth_root_pos = new_features[:, :, self.slice_dict[\"smooth_root_pos\"]]\n        new_smooth_root_pos[:, :, 0] += translation_2d[:, [0]]\n        new_smooth_root_pos[:, :, 2] += translation_2d[:, [1]]\n        return new_features\n\n    @ensure_batched(features=3)\n    def inverse(\n        self,\n        features: torch.Tensor,\n        is_normalized: bool,\n        posed_joints_from=\"rotations\",\n        return_numpy: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Decode smooth-root features into motion tensors.\"\"\"\n        assert posed_joints_from in [\n            \"rotations\",\n            \"positions\",\n        ], \"posed_joints_from should 'rotations' or 'positions'\"\n\n        if is_normalized:\n            features = self.unnormalize(features)\n\n        [\n            smooth_root_pos,\n            global_root_heading,\n            local_joints_positions,\n            global_rot_data,\n            velocities,\n            foot_contacts,\n        ] = einops.unpack(features, self.ps, \"batch time *\")\n\n        global_rot_mats = cont6d_to_matrix(global_rot_data)\n        local_rot_mats = global_rots_to_local_rots(global_rot_mats, self.skeleton)\n\n        posed_joints_from_pos = local_joints_positions.clone()\n        posed_joints_from_pos[..., 0] += smooth_root_pos[..., None, 0]\n        posed_joints_from_pos[..., 2] += smooth_root_pos[..., None, 2]\n        root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]\n        foot_contacts = foot_contacts > 0.5\n\n        if posed_joints_from == \"rotations\":\n            _, posed_joints, _ = self.skeleton.fk(\n                local_rot_mats,\n                root_positions,\n            )\n        else:\n            posed_joints = posed_joints_from_pos\n\n        output_tensor_dict = {\n            \"local_rot_mats\": local_rot_mats,\n            \"global_rot_mats\": global_rot_mats,\n            \"posed_joints\": posed_joints,\n            \"root_positions\": root_positions,\n            \"smooth_root_pos\": smooth_root_pos,\n            \"foot_contacts\": foot_contacts,\n            \"global_root_heading\": global_root_heading,\n        }\n        if return_numpy:\n            return to_numpy(output_tensor_dict)\n        return output_tensor_dict\n\n    def create_conditions(\n        self,\n        index_dict: dict[Tensor],\n        data_dict: dict[Tensor],\n        length: int,\n        to_normalize: bool,\n        device: str,\n    ):\n        \"\"\"Build sparse conditioning tensors for smooth-root representation.\"\"\"\n        # create empty features and mask to be filled in\n        observed_motion = torch.zeros(length, self.motion_rep_dim, device=device)\n        motion_mask = torch.zeros(length, self.motion_rep_dim, dtype=bool, device=device)\n\n        def _cat_indices(indices_list: list[Tensor]) -> Tensor:\n            indices = torch.cat([torch.tensor(x) if not isinstance(x, Tensor) else x for x in indices_list])\n            return indices.to(device=device, dtype=torch.long)\n\n        def _match_obs_dtype(tensor: Tensor) -> Tensor:\n            return tensor.to(device=device, dtype=observed_motion.dtype)\n\n        if (fname := \"smooth_root_2d\") in index_dict and index_dict[fname]:\n            indices = _cat_indices(index_dict[fname])\n            indices, smooth_root_2d = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))\n            smooth_root_2d = _match_obs_dtype(smooth_root_2d)\n            f_sliced = observed_motion[:, self.slice_dict[\"smooth_root_pos\"]]\n            f_sliced[indices, 0] = smooth_root_2d[:, 0]\n            f_sliced[indices, 2] = smooth_root_2d[:, 1]\n            m_sliced = motion_mask[:, self.slice_dict[\"smooth_root_pos\"]]\n            m_sliced[indices, 0] = True\n            m_sliced[indices, 2] = True\n\n        if (fname := \"root_y_pos\") in index_dict and index_dict[fname]:\n            indices = _cat_indices(index_dict[fname])\n            indices, root_pos_Y = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))\n            root_pos_Y = _match_obs_dtype(root_pos_Y)\n            f_sliced = observed_motion[:, self.slice_dict[\"smooth_root_pos\"]]\n            f_sliced[indices, 1] = root_pos_Y\n            m_sliced = motion_mask[:, self.slice_dict[\"smooth_root_pos\"]]\n            m_sliced[indices, 1] = True\n\n        if (fname := \"global_root_heading\") in index_dict and index_dict[fname]:\n            indices = _cat_indices(index_dict[fname])\n            indices, global_root_heading = get_unique_index_and_data(indices, torch.cat(data_dict[fname]))\n            global_root_heading = _match_obs_dtype(global_root_heading)\n            f_sliced = observed_motion[:, self.slice_dict[fname]]\n            f_sliced[indices] = global_root_heading\n            m_sliced = motion_mask[:, self.slice_dict[fname]]\n            m_sliced[indices] = True\n\n        if (fname := \"global_joints_rots\") in index_dict and index_dict[fname]:\n            indices_lst = _cat_indices(index_dict[fname])\n            indices_lst, global_joints_rots = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))\n            global_joints_rots = _match_obs_dtype(global_joints_rots)\n            global_rot_data = matrix_to_cont6d(global_joints_rots)\n            f_sliced = observed_motion[:, self.slice_dict[\"global_rot_data\"]]\n            masking = torch.zeros(len(f_sliced) * self.nbjoints, 6, device=device, dtype=bool)\n            masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True\n            masking = masking.reshape(len(f_sliced), self.nbjoints * 6)\n            f_sliced[masking] = global_rot_data.flatten()\n            m_sliced = motion_mask[:, self.slice_dict[\"global_rot_data\"]]\n            m_sliced[masking] = True\n\n        if (fname := \"global_joints_positions\") in index_dict and index_dict[fname]:\n            indices_lst = _cat_indices(index_dict[fname])\n            indices_lst, global_joints_positions = get_unique_index_and_data(indices_lst, torch.cat(data_dict[fname]))\n            global_joints_positions = _match_obs_dtype(global_joints_positions)\n            T_indices = indices_lst[:, 0].contiguous()\n            _test = motion_mask[T_indices, self.slice_dict[\"smooth_root_pos\"]]\n            if not _test[:, [0, 2]].all():\n                raise ValueError(\"For constraining global positions, the smooth root should also be constrained.\")\n            smooth_root_pos = observed_motion[T_indices, self.slice_dict[\"smooth_root_pos\"]].clone()\n            local_reference = smooth_root_pos.clone()\n            local_reference[..., 1] = 0.0\n            local_joints_positions = global_joints_positions - local_reference\n            f_sliced = observed_motion[:, self.slice_dict[\"local_joints_positions\"]]\n            masking = torch.zeros(len(f_sliced) * self.nbjoints, 3, device=device, dtype=bool)\n            masking[indices_lst.T[0] * self.nbjoints + indices_lst.T[1]] = True\n            masking = masking.reshape(len(f_sliced), self.nbjoints * 3)\n            f_sliced[masking] = local_joints_positions.flatten()\n            m_sliced = motion_mask[:, self.slice_dict[\"local_joints_positions\"]]\n            m_sliced[masking] = True\n\n        if to_normalize:\n            observed_motion = self.normalize(observed_motion)\n        return observed_motion, motion_mask\n"
  },
  {
    "path": "kimodo/motion_rep/reps/tmr_motionrep.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"TMR motion representation: global root, global joints, velocities, and foot contacts.\"\"\"\n\nfrom typing import Optional\n\nimport einops\nimport torch\n\nfrom ...skeleton.kinematics import fk\nfrom ...tools import ensure_batched, to_numpy\nfrom ..feature_utils import RotateFeatures, compute_heading_angle, compute_vel_xyz\nfrom ..feet import foot_detect_from_pos_and_vel\nfrom .base import MotionRepBase\n\n\nclass TMRMotionRep(MotionRepBase):\n    \"\"\"Motion representation with global root and local joint positions.\n    The local joint positions are rotation invariant (they all face z+)\n\n    Feature layout:\n    - root position ``(x, y, z)``\n    - root heading as ``(cos(theta), sin(theta))``\n    - local joint positions (root and rotation removed)\n    - local joint velocities (rotation removed)\n    - binary foot contacts\n    \"\"\"\n\n    def __init__(\n        self,\n        skeleton,\n        fps,\n        stats_path: Optional[str] = None,\n    ):\n        nbjoints = skeleton.nbjoints\n\n        self.size_dict = {\n            \"root_pos\": torch.Size([3]),\n            \"global_root_heading\": torch.Size([2]),\n            \"local_joints_positions\": torch.Size([nbjoints - 1, 3]),\n            \"velocities\": torch.Size([nbjoints, 3]),\n            \"foot_contacts\": torch.Size([4]),\n        }\n        self.last_root_feature = \"global_root_heading\"\n        self.local_root_size_dict = {\n            \"local_root_rot_vel\": torch.Size([1]),\n            \"local_root_vel\": torch.Size([2]),\n            \"global_root_y\": torch.Size([1]),\n        }\n        super().__init__(skeleton, fps, stats_path)\n\n    @ensure_batched(local_joint_rots=5, root_positions=3, posed_joints=4, lengths=1)\n    def __call__(\n        self,\n        local_joint_rots: Optional[torch.Tensor] = None,\n        root_positions: Optional[torch.Tensor] = None,\n        posed_joints: Optional[torch.Tensor] = None,\n        *,\n        to_normalize: bool,\n        to_canonicalize: bool = False,\n        lengths: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"Convert motion inputs to this feature representation.\n\n        Args:\n            local_joint_rots: Local joint rotation matrices ``[B, T, J, 3, 3]``.\n                Required when ``posed_joints`` is not provided.\n            root_positions: Root translations ``[B, T, 3]``. Required when\n                ``posed_joints`` is not provided.\n            posed_joints: Optional precomputed global joint positions\n                ``[B, T, J, 3]``. If passed, FK is skipped.\n            to_normalize: Whether to normalize output features.\n            to_canonicalize: Whether to canonicalize output features (False by default).\n            lengths: Optional valid lengths for variable-length batches.\n\n        Returns:\n            Motion features with shape ``[B, T, motion_rep_dim]``.\n        \"\"\"\n        if posed_joints is not None:\n            device = posed_joints.device\n            nbatch, nbframes, nbjoints = posed_joints.shape[:3]\n        else:\n            device = local_joint_rots.device\n            nbatch, nbframes, nbjoints = local_joint_rots.shape[:3]\n\n        if lengths is None:\n            assert nbatch == 1, \"If lenghts is not provided, the input should not be batched.\"\n            lengths = torch.tensor([nbframes], device=device)\n\n        if posed_joints is None:\n            _, global_positions, local_joints_positions_origin_is_pelvis = fk(\n                local_joint_rots, root_positions, self.skeleton\n            )\n        else:\n            global_positions = posed_joints\n            root_positions = posed_joints[:, :, 0]\n            local_joints_positions_origin_is_pelvis = posed_joints - root_positions[:, :, None]\n\n        root_heading_angle = compute_heading_angle(global_positions, self.skeleton)\n        global_root_heading = torch.stack([torch.cos(root_heading_angle), torch.sin(root_heading_angle)], dim=-1)\n\n        ground_offset = 0 * root_positions\n        ground_offset[..., 1] = root_positions[..., 1]\n\n        local_joints_positions = local_joints_positions_origin_is_pelvis[:, :, 1:] + ground_offset[:, :, None]\n        velocities = compute_vel_xyz(global_positions, self.fps, lengths=lengths)\n\n        # Remove the heading angle for each frame\n        RF = RotateFeatures(-root_heading_angle)\n        local_joints_positions = RF.rotate_positions(local_joints_positions)\n        velocities = RF.rotate_positions(velocities)\n\n        foot_contacts = foot_detect_from_pos_and_vel(global_positions, velocities, self.skeleton, 0.15, 0.10)\n        features, _ = einops.pack(\n            [\n                root_positions,\n                global_root_heading,\n                local_joints_positions,\n                velocities,\n                foot_contacts,\n            ],\n            \"batch time *\",\n        )\n\n        if to_canonicalize:\n            features = self.canonicalize(features, normalized=False)\n\n        if to_normalize:\n            features = self.normalize(features)\n        return features\n\n    @ensure_batched(features=3, angle=1)\n    def rotate(self, features: torch.Tensor, angle: torch.Tensor):\n        \"\"\"Rotate all spatial features by a heading delta (radians).\"\"\"\n        # rotate by the angle\n        # it add the angle to the current features\n        # assume it is not normalized\n        bs = features.shape[0]\n        device = features.device\n        [\n            root_pos,\n            global_root_heading,\n            local_joints_positions,\n            velocities,\n            foot_contacts,\n        ] = einops.unpack(features, self.ps, \"batch time *\")\n\n        if not isinstance(angle, torch.Tensor):\n            angle = torch.tensor(angle, device=device)\n        if len(angle.shape) == 0:\n            angle = angle.repeat(bs)\n\n        RF = RotateFeatures(angle)\n        new_features, _ = einops.pack(\n            [\n                RF.rotate_positions(root_pos),\n                RF.rotate_2d_positions(global_root_heading),\n                local_joints_positions,  # already rotation invariant\n                velocities,  # already rotation invariant\n                foot_contacts,\n            ],\n            \"batch time *\",\n        )\n        return new_features\n\n    @ensure_batched(features=3, translation_2d=2)\n    def translate_2d(\n        self,\n        features: torch.Tensor,\n        translation_2d: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"Translate root planar position by ``(dx, dz)``.\"\"\"\n        # only move on the ground\n        # For 3D, we should not forget to move the local_joints_positions as well\n        bs = features.shape[0]\n        if len(translation_2d.shape) == 1:\n            translation_2d = translation_2d.repeat(bs, 1)\n\n        new_features = features.clone()\n        new_root_pos = new_features[:, :, self.slice_dict[\"root_pos\"]]\n        new_root_pos[:, :, 0] += translation_2d[:, 0]\n        new_root_pos[:, :, 2] += translation_2d[:, 1]\n        return new_features\n\n    @ensure_batched(features=3)\n    def inverse(\n        self,\n        features: torch.Tensor,\n        is_normalized: bool,\n        posed_joints_from=\"positions\",\n        return_numpy: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Decode features back to a motion dictionary.\n\n        Args:\n            features: Feature tensor ``[B, T, D]``.\n            is_normalized: Whether input features are normalized.\n            posed_joints_from: Must be ``\"positions\"`` for this representation.\n            return_numpy: Whether to convert tensors to numpy arrays.\n\n        Returns:\n            Dictionary containing reconstructed positions and auxiliary data.\n        \"\"\"\n        assert posed_joints_from == \"positions\"\n        if is_normalized:\n            features = self.unnormalize(features)\n\n        [\n            root_positions,\n            global_root_heading,\n            local_joints_positions,\n            velocities,\n            foot_contacts,\n        ] = einops.unpack(features, self.ps, \"batch time *\")\n\n        dummy_root = 0 * local_joints_positions[:, :, [0]]\n        posed_joints_from_pos = torch.stack([dummy_root, local_joints_positions], axis=2)\n        posed_joints_from_pos[..., 0] += root_positions[..., None, 0]\n        posed_joints_from_pos[..., 2] += root_positions[..., None, 2]\n        root_positions = posed_joints_from_pos[..., self.skeleton.root_idx, :]\n        foot_contacts = foot_contacts > 0.5\n        posed_joints = posed_joints_from_pos\n\n        output_tensor_dict = {\n            \"local_rot_mats\": None,\n            \"global_rot_mats\": None,\n            \"posed_joints\": posed_joints,\n            \"root_positions\": root_positions,\n            \"foot_contacts\": foot_contacts,\n            \"global_root_heading\": global_root_heading,\n        }\n        if return_numpy:\n            return to_numpy(output_tensor_dict)\n        return output_tensor_dict\n"
  },
  {
    "path": "kimodo/motion_rep/smooth_root.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Smooth root trajectory: ADMM-based smoother with margin constraints and get_smooth_root_pos helper.\"\"\"\n\nimport math\n\nimport numpy as np\nimport torch\nfrom scipy import sparse\nfrom scipy.sparse.linalg import splu\n\nfrom kimodo.tools import ensure_batched\n\n\nclass TrajectorySmoother:\n    \"\"\"Modify trajectories to hit target values while respecting soft constraints.\n\n    This smoother keeps the trajectory close to the original positions while minimizing\n    accelerations. Targets are enforced at specified frames via soft constraints.\n    \"\"\"\n\n    def __init__(\n        self,\n        margins,\n        pos_weight=0.0,\n        loop=False,\n        admm_iters=100,\n        alpha_overrelax=1.0,\n        circle_project=False,\n    ):\n        \"\"\"Initialize the TrajectorySmoother.\n\n        Args:\n            margins: Array of margin values for each frame.\n                    margins[i] < 0: unconstrained\n                    margins[i] == 0: pinned on this frame\n                    margins[i] > 0: can deviate within the margin\n            pos_weight: Weight for position preservation\n            loop: Whether the trajectory should loop\n            admm_iters: Number of ADMM iterations\n        \"\"\"\n        self.pos_weight = pos_weight\n        self.admm_iters = admm_iters\n        self.alpha_overrelax = alpha_overrelax\n        self.circle_project = circle_project\n        N = len(margins)\n\n        # Store margin information as numpy arrays\n        self.margin_vals = margins\n\n        # Build acceleration matrix A\n        a_data = []\n        a_rows = []\n        a_cols = []\n\n        for i in range(1, N - 1):\n            scale = 1.0\n            a_data.extend([-scale, 2.0 * scale, -scale])\n            a_rows.extend([i, i, i])\n            a_cols.extend([i - 1, i, i + 1])\n\n        if loop:\n            # Add periodic accelerations\n            scale = 1.0\n            a_data.extend([-scale, 2.0 * scale, -scale])\n            a_rows.extend([0, 0, 0])\n            a_cols.extend([N - 1, 0, 1])\n\n            scale = 1.0\n            a_data.extend([-scale, 2.0 * scale, -scale])\n            a_rows.extend([N - 1, N - 1, N - 1])\n            a_cols.extend([N - 2, N - 1, 0])\n\n        A = sparse.csr_matrix((a_data, (a_rows, a_cols)), shape=(N, N))\n\n        # Build identity matrix\n        identity_matrix = sparse.eye(N)\n\n        # Build system matrix M\n        M = pos_weight * identity_matrix + A.T @ A\n\n        # Calculate ADMM step size\n        diag_max = max(abs(M.diagonal()))\n        self.admm_stepsize = 0.25 * np.sqrt(diag_max)\n\n        M = M + self.admm_stepsize * identity_matrix\n        self.system_lu = splu(M.tocsc())\n\n    def smooth(self, targets, x0):\n        \"\"\"Interpolate between reference positions while satisfying constraints.\n\n        Args:\n            observations: Target positions for constrained frames (numpy array)\n            ref_positions: Reference positions defining original shape\n                         (numpy array)\n\n        Returns:\n            Interpolated positions (numpy array)\n        \"\"\"\n        x_target = targets.copy()\n        x = x0.copy()\n        z = np.zeros_like(x)\n        u = np.zeros_like(x)\n\n        for _ in range(self.admm_iters):\n            self.z_update(z, x, x_target, u)\n            self.u_update(u, x, z)\n            self.x_update(x, z, u, x_target)\n\n        return x\n\n    def x_update(self, x, z, u, x_t):\n        \"\"\"Update x in the ADMM iteration.\"\"\"\n\n        # x = (wp * I + A^T A + p I)^-1 (wp * x_orig + p (z - u))\n        r = self.pos_weight * x_t + self.admm_stepsize * (z - u)\n        x[:] = self.system_lu.solve(r)\n\n    def z_update(self, z, x, z_t, u):\n        \"\"\"Update z in the ADMM iteration using vectorized operations.\"\"\"\n        # Compute the difference from target for all margin locations at once\n        z[:] = x + u - z_t\n\n        # Check if we need to project back to margin\n        z_diff_norms = np.linalg.norm(z, axis=1)\n        mask = z_diff_norms > self.margin_vals\n        if np.any(mask):\n            scale_factors = self.margin_vals[mask] / z_diff_norms[mask]\n            z[mask] *= scale_factors[:, np.newaxis]\n\n        # Add back the target\n        z[:] += z_t\n\n        if self.circle_project:\n            z[:] = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1.0e-6)\n\n    def u_update(self, u, x, z):\n        \"\"\"Update u in the ADMM iteration using vectorized operations.\"\"\"\n        u[:] += self.alpha_overrelax * (x - z)\n\n\ndef smooth_signal(x, margins, pos_weight=0, alpha_overrelax=1.8, admm_iters=500, circle_project=False):\n    \"\"\"Multigrid trajectory smoothing with margin constraints.\n\n    Args:\n        x: Input trajectory ``[T, D]`` as a NumPy array.\n        margins: Allowed radius around each target frame ``[T]``.\n        pos_weight: Weight for staying close to the original signal.\n        alpha_overrelax: ADMM over-relaxation coefficient.\n        admm_iters: ADMM iterations per multigrid level.\n        circle_project: If ``True``, project each vector to the unit sphere.\n\n    Returns:\n        Smoothed trajectory of shape ``[T, D]``.\n    \"\"\"\n    x_smoothed = x.copy()\n    x_smoothed[:] = x.mean(axis=0, keepdims=True)\n\n    # smooth the signal, multigrid style by starting out coarse,\n    # doubling the resolution and repeating until we're at the full\n    # resolution, using the previous result as the initial guess.\n    levels = int(math.floor(math.log2(len(x))))\n    levels = max(levels - 4, 1)\n\n    stepsize = 2**levels\n    while True:\n        # smooth signals at this level:\n        num_steps = len(x_smoothed[::stepsize])\n        smoother = TrajectorySmoother(\n            margins=margins[::stepsize],\n            pos_weight=pos_weight,\n            alpha_overrelax=alpha_overrelax,\n            admm_iters=admm_iters,\n            circle_project=circle_project,\n        )\n        x_smoothed[::stepsize] = smoother.smooth(x[::stepsize], x_smoothed[::stepsize])\n\n        # interpolate to next level:\n        next_stepsize = stepsize // 2\n        num_interleaved = len(x_smoothed[next_stepsize::stepsize])\n        if num_interleaved == num_steps:\n            # linearly extrapolate the last value if we have to:\n            x_smoothed[next_stepsize::stepsize][-1] = (\n                x_smoothed[::stepsize][-1] + (x_smoothed[::stepsize][-1] - x_smoothed[::stepsize][-2]) / 2\n            )\n            num_interleaved = num_interleaved - 1\n\n        # linearly interpolate the remaining values:\n        x_smoothed[next_stepsize::stepsize][:num_interleaved] = (\n            x_smoothed[::stepsize][:-1] + x_smoothed[::stepsize][1:]\n        ) / 2\n\n        if stepsize == 1:\n            break\n\n        stepsize //= 2\n\n    return x_smoothed\n\n\n@ensure_batched(hip_translations=3)\ndef get_smooth_root_pos(hip_translations):\n    \"\"\"Smooth root trajectory in the ground plane while preserving height.\n\n    Args:\n        hip_translations: Root translations ``[B, T, 3]``.\n\n    Returns:\n        Smoothed root translations ``[B, T, 3]`` where ``x/z`` are smoothed and\n        ``y`` remains unchanged.\n    \"\"\"\n    root_translations_xz = hip_translations[..., [0, 2]]\n    root_translations_y = hip_translations[..., [1]]\n\n    batch_size, nframes = root_translations_xz.shape[:2]\n    margins = np.full(root_translations_xz.shape[1], 0.06)\n\n    root_translations_smoothed_xz = []\n    for batch in range(batch_size):\n        root_translations_smoothed_xz.append(\n            smooth_signal(root_translations_xz[batch].detach().cpu().numpy(), margins)[None]\n        )\n\n    root_translations_smoothed_xz = torch.tensor(np.concatenate(root_translations_smoothed_xz))\n\n    root_translations = torch.cat(\n        [\n            root_translations_smoothed_xz.to(root_translations_y.device),\n            root_translations_y,\n        ],\n        dim=-1,\n    )[..., [0, 2, 1]]\n\n    return root_translations\n"
  },
  {
    "path": "kimodo/motion_rep/stats.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Feature normalization statistics (mean/std) for motion representations.\"\"\"\n\nimport logging\nimport os\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nlog = logging.getLogger(__name__)\n\n\nclass Stats(torch.nn.Module):\n    \"\"\"Utility module for feature normalization statistics.\n\n    Normalization follows:\n    ``(data - mean) / sqrt(std**2 + eps)``\n    \"\"\"\n\n    def __init__(\n        self,\n        folder: Optional[str] = None,\n        load: bool = True,\n        eps=1e-05,\n    ):\n        super().__init__()\n        self.folder = folder\n        self.eps = eps\n        if folder is not None and load:\n            self.load()\n\n    def sliced(self, indices):\n        \"\"\"Return a new ``Stats`` object containing selected feature indices.\"\"\"\n        new_stats = Stats(folder=self.folder, load=False, eps=self.eps)\n        new_stats.register_from_tensors(\n            self.mean[..., indices].clone(),\n            self.std[..., indices].clone(),\n        )\n        return new_stats\n\n    def load(self):\n        \"\"\"Load ``mean.npy`` and ``std.npy`` from ``self.folder``.\"\"\"\n        mean_path = os.path.join(self.folder, \"mean.npy\")\n        std_path = os.path.join(self.folder, \"std.npy\")\n        if not os.path.exists(mean_path) or not os.path.exists(std_path):\n            raise FileNotFoundError(\n                f\"Missing stats files in '{self.folder}'. Expected:\\n\"\n                f\"  - {mean_path}\\n\"\n                f\"  - {std_path}\\n\\n\"\n                \"Make sure the checkpoint/stats have been downloaded and are mounted into the container.\\n\"\n                \"If you're using Docker Compose, run it from the repo root so `./:/workspace` mounts the correct directory.\"\n            )\n\n        mean = torch.from_numpy(np.load(mean_path))\n        std = torch.from_numpy(np.load(std_path))\n        self.register_from_tensors(mean, std)\n\n    def register_from_tensors(self, mean: torch.Tensor, std: torch.Tensor):\n        \"\"\"Register mean/std tensors as non-persistent buffers.\"\"\"\n        self.register_buffer(\"mean\", mean, persistent=False)\n        self.register_buffer(\"std\", std, persistent=False)\n\n    def normalize(self, data: torch.Tensor) -> torch.Tensor:\n        \"\"\"Normalize data using the stored statistics.\"\"\"\n        mean = self.mean.to(device=data.device, dtype=data.dtype)\n        std = self.std.to(device=data.device, dtype=data.dtype)\n        # adjust std with eps\n        return (data - mean) / torch.sqrt(std**2 + self.eps)\n\n    def unnormalize(self, data: torch.Tensor) -> torch.Tensor:\n        \"\"\"Undo normalization using the stored statistics.\"\"\"\n        mean = self.mean.to(device=data.device, dtype=data.dtype)\n        std = self.std.to(device=data.device, dtype=data.dtype)\n        # adjust std with eps\n        return data * torch.sqrt(std**2 + self.eps) + mean\n\n    def is_loaded(self):\n        \"\"\"Return whether statistics are currently available.\"\"\"\n        return hasattr(self, \"mean\")\n\n    def get_dim(self):\n        \"\"\"Return feature dimensionality.\"\"\"\n        return self.mean.shape[0]\n\n    def save(\n        self,\n        folder: Optional[str] = None,\n        mean: Optional[torch.Tensor] = None,\n        std: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Save statistics to ``folder`` as ``mean.npy`` and ``std.npy``.\"\"\"\n        if folder is None:\n            folder = self.folder\n            if folder is None:\n                raise ValueError(\"No folder to save stats\")\n\n        if mean is None and std is None:\n            try:\n                mean = self.mean.cpu().numpy()\n                std = self.std.cpu().numpy()\n            except AttributeError:\n                raise ValueError(\"Stats were not loaded\")\n\n        # don't override stats folder\n        os.makedirs(folder, exist_ok=False)\n\n        np.save(os.path.join(folder, \"mean.npy\"), mean)\n        np.save(os.path.join(folder, \"std.npy\"), std)\n\n    def __eq__(self, other):\n        return (self.mean.cpu() == other.mean.cpu()).all() and (self.std.cpu() == other.std.cpu()).all()\n\n    # should define a hash value for pytorch, as we defined __eq__\n    def __hash__(self):\n        # Convert mean and std to bytes for a consistent hash value\n        mean_hash = hash(self.mean.detach().cpu().numpy().tobytes())\n        std_hash = hash(self.std.detach().cpu().numpy().tobytes())\n        return hash((mean_hash, std_hash))\n\n    def __repr__(self):\n        return f'Stats(folder=\"{self.folder}\")'\n"
  },
  {
    "path": "kimodo/postprocess.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Post-processing utilities for motion generation output.\"\"\"\n\nfrom types import SimpleNamespace\nfrom typing import Dict, List, Optional, Tuple\n\nimport numpy as np\nimport torch\n\nfrom .constraints import (\n    EndEffectorConstraintSet,\n    FullBodyConstraintSet,\n    Root2DConstraintSet,\n)\nfrom .geometry import matrix_to_quaternion, quaternion_to_matrix\nfrom .skeleton import (\n    G1Skeleton34,\n    SkeletonBase,\n    SMPLXSkeleton22,\n    SOMASkeleton30,\n    SOMASkeleton77,\n    fk,\n)\n\n\ndef extract_input_motion_from_constraints(\n    constraint_lst: List,\n    skeleton: SkeletonBase,\n    num_frames: int,\n    num_joints: int,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Extract hip translations and local rotations from constraints for postprocessing.\n\n    Args:\n        constraint_lst: List of constraints (FullBodyConstraintSet, EndEffectorConstraintSet, etc.)\n        skeleton: Skeleton instance\n        num_frames: Total number of frames in the motion\n        num_joints: Number of joints\n\n    Returns:\n        Tuple of (hip_translations_input, rotations_input):\n            - hip_translations_input: Hip translations, shape (T, 3)\n            - rotations_input: Local joint rotations as quaternions, shape (T, J, 4)\n    \"\"\"\n    # Initialize with zeros for all frames\n    hip_translations_input = torch.zeros(num_frames, 3)\n    rotations_input = torch.zeros(num_frames, num_joints, 4)\n    rotations_input[..., 0] = 1.0  # Initialize as identity quaternions (w=1, x=y=z=0)\n\n    def _match_hip_dtype(tensor: torch.Tensor) -> torch.Tensor:\n        return tensor.to(device=hip_translations_input.device, dtype=hip_translations_input.dtype)\n\n    def _match_rot_dtype(tensor: torch.Tensor) -> torch.Tensor:\n        return tensor.to(device=rotations_input.device, dtype=rotations_input.dtype)\n\n    if not constraint_lst:\n        return hip_translations_input, rotations_input\n\n    # Sort constraints to ensure FullBodyConstraintSet is processed last\n    #   This ensures it will get the last say on whether hip translations need to be exact root or smoothed root\n    sorted_constraints = sorted(constraint_lst, key=lambda c: isinstance(c, FullBodyConstraintSet))\n    for constraint in sorted_constraints:\n        frame_indices = constraint.frame_indices\n        if isinstance(frame_indices, torch.Tensor):\n            valid_mask = frame_indices < num_frames\n            if valid_mask.sum() == 0:\n                continue\n            frame_indices = frame_indices[valid_mask]\n        else:\n            valid_positions = [i for i, idx in enumerate(frame_indices) if idx < num_frames]\n            if not valid_positions:\n                continue\n            frame_indices = [frame_indices[i] for i in valid_positions]\n\n        # Handle Root2DConstraintSet separately - only assign smooth_root_2d at xz dimensions\n        if isinstance(constraint, Root2DConstraintSet):\n            smooth_root_2d = constraint.smooth_root_2d  # (K, 2) where K = len(frame_indices)\n            if isinstance(frame_indices, torch.Tensor):\n                smooth_root_2d = smooth_root_2d[valid_mask]\n            else:\n                smooth_root_2d = smooth_root_2d[valid_positions]\n            smooth_root_2d = _match_hip_dtype(smooth_root_2d)\n            hip_translations_input[frame_indices, 0] = smooth_root_2d[:, 0]  # x\n            hip_translations_input[frame_indices, 2] = smooth_root_2d[:, 1]  # z\n            continue\n        elif isinstance(constraint, FullBodyConstraintSet) or isinstance(constraint, EndEffectorConstraintSet):\n            global_rots = constraint.global_joints_rots  # (K, J, 3, 3) where K = len(frame_indices)\n            global_positions = constraint.global_joints_positions  # (K, J, 3)\n            if isinstance(frame_indices, torch.Tensor):\n                global_rots = global_rots[valid_mask]\n                global_positions = global_positions[valid_mask]\n                smooth_root_2d = constraint.smooth_root_2d[valid_mask]\n            else:\n                global_rots = global_rots[valid_positions]\n                global_positions = global_positions[valid_positions]\n                smooth_root_2d = constraint.smooth_root_2d[valid_positions]\n\n            root_positions = global_positions[:, skeleton.root_idx]  # (K, 3)\n            # replace xz with smooth_root_2d values for EE constraints that do not include Hips\n            #    since the hips themselves are not actually constrained in the model conditioning\n            if isinstance(constraint, EndEffectorConstraintSet) and \"Hips\" not in constraint.joint_names:\n                root_positions[:, 0] = smooth_root_2d[:, 0]  # x\n                root_positions[:, 2] = smooth_root_2d[:, 1]  # z\n\n            local_rot_mats = skeleton.global_rots_to_local_rots(global_rots)  # (K, J, 3, 3)\n            local_rot_quats = matrix_to_quaternion(local_rot_mats)  # (K, J, 4)\n\n            hip_translations_input[frame_indices] = _match_hip_dtype(root_positions)\n            rotations_input[frame_indices] = _match_rot_dtype(local_rot_quats)\n        else:\n            NotImplementedError(f\"Constraint {constraint.name} is not supported\")\n\n    return hip_translations_input, rotations_input\n\n\ndef create_working_rig_from_skeleton(\n    skeleton: SkeletonBase, above_ground_offset: float = 0.007\n) -> List[SimpleNamespace]:\n    \"\"\"Create the working rig as a list of SimpleNamespace objects from skeleton.\n\n    Args:\n        skeleton: SkeletonBase instance with bone_order_names, neutral_joints, joint_parents\n        above_ground_offset: Additional offset to position the rig slightly above ground\n    Returns:\n        List of SimpleNamespace objects representing the working rig\n    \"\"\"\n    working_rig_joints = []\n\n    joint_names = skeleton.bone_order_names\n    neutral_positions = skeleton.neutral_joints.cpu().numpy()\n    parent_indices = skeleton.joint_parents.cpu().numpy()\n\n    if isinstance(skeleton, (G1Skeleton34, SMPLXSkeleton22)):\n        retarget_map = {\n            skeleton.bone_order_names[skeleton.root_idx]: \"Hips\",\n            skeleton.left_hand_joint_names[0]: \"LeftHand\",\n            skeleton.right_hand_joint_names[0]: \"RightHand\",\n            skeleton.left_foot_joint_names[0]: \"LeftFoot\",\n            skeleton.right_foot_joint_names[0]: \"RightFoot\",\n        }\n    else:\n        # works for SOMA\n        retarget_map = {\n            \"Hips\": \"Hips\",\n            \"Head\": \"Head\",\n            \"LeftHand\": \"LeftHand\",\n            \"RightHand\": \"RightHand\",\n            \"LeftFoot\": \"LeftFoot\",\n            \"RightFoot\": \"RightFoot\",\n        }\n\n    for i, joint_name in enumerate(joint_names):\n        parent_name = None if parent_indices[i] == -1 else joint_names[parent_indices[i]]\n\n        # Calculate local translation relative to parent\n        if parent_indices[i] == -1:\n            # Move the rig so that the lowest point (toe) is at ground level (y=0),\n            # plus a small offset to position the rig slightly above ground\n            toe_height = neutral_positions[:, 1].min()  # lowest y-coordinate (toe)\n            local_translation = (\n                neutral_positions[i] + np.array([0.0, -toe_height + above_ground_offset, 0.0])\n            ).tolist()\n        else:\n            parent_idx = parent_indices[i]\n            parent_position = neutral_positions[parent_idx]\n            joint_position = neutral_positions[i]\n            local_translation = (joint_position - parent_position).tolist()\n\n        # Default rotation (identity quaternion: x=0, y=0, z=0, w=1)\n        default_rotation = [0.0, 0.0, 0.0, 1.0]\n\n        joint_info = SimpleNamespace(\n            name=joint_name,\n            parent=parent_name,\n            t_pose_rotation=default_rotation,\n            t_pose_translation=local_translation,\n            retarget_tag=retarget_map.get(joint_name),\n        )\n\n        working_rig_joints.append(joint_info)\n\n    return working_rig_joints\n\n\ndef post_process_motion(\n    local_rot_mats: torch.Tensor,\n    root_positions: torch.Tensor,\n    contacts: torch.Tensor,\n    skeleton: SkeletonBase,\n    constraint_lst: Optional[List] = None,\n    contact_threshold: float = 0.5,\n    root_margin: float = 0.04,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Post-process generated motion to reduce foot skating and improve quality.\n\n    Args:\n        local_rot_mats: Local joint rotation matrices, shape (B, T, J, 3, 3)\n        root_positions: Root joint positions, shape (B, T, 3)\n        contacts: Foot contact labels, shape (B, T, num_contacts)\n        skeleton: Skeleton instance\n        constraint_lst: Optional list of constraints (or list of lists of constraints for batched inference)(FullBodyConstraintSet, etc.)\n        contact_threshold: Threshold for foot contact detection\n        root_margin: Margin for root position correction\n\n    Returns:\n        Dictionary with corrected motion data:\n            - local_rot_mats: Corrected local rotation matrices (B, T, J, 3, 3)\n            - root_positions: Corrected root positions (B, T, 3)\n            - posed_joints: Corrected global joint positions (B, T, J, 3)\n            - global_rot_mats: Corrected global rotation matrices (B, T, J, 3, 3)\n    \"\"\"\n    # Ensure batch dimension\n    assert local_rot_mats.dim() == 5, \"local_rot_mats should be 5D, make sure to include the batch dimension\"\n\n    batch_size, num_frames, num_joints = local_rot_mats.shape[:3]\n\n    def _build_constraint_masks_dict(constraints: List) -> Dict[str, torch.Tensor]:\n        out = {\n            key: torch.zeros(num_frames, dtype=torch.float32)\n            for key in [\n                \"FullBody\",\n                \"LeftFoot\",\n                \"RightFoot\",\n                \"LeftHand\",\n                \"RightHand\",\n                \"Root\",\n            ]\n        }\n        for constraint in constraints:\n            frame_indices = constraint.frame_indices\n            if isinstance(frame_indices, torch.Tensor):\n                frame_indices = frame_indices[frame_indices < num_frames]\n                if frame_indices.numel() == 0:\n                    continue\n            else:\n                frame_indices = [idx for idx in frame_indices if idx < num_frames]\n                if not frame_indices:\n                    continue\n            if constraint.name == \"fullbody\":\n                out[\"FullBody\"][frame_indices] = 1.0\n            elif constraint.name == \"left-foot\":\n                out[\"LeftFoot\"][frame_indices] = 1.0\n            elif constraint.name == \"right-foot\":\n                out[\"RightFoot\"][frame_indices] = 1.0\n            elif constraint.name == \"left-hand\":\n                out[\"LeftHand\"][frame_indices] = 1.0\n            elif constraint.name == \"right-hand\":\n                out[\"RightHand\"][frame_indices] = 1.0\n            elif constraint.name == \"root2d\":\n                out[\"Root\"][frame_indices] = 1.0\n        return out\n\n    # Create constraint masks from constraint_lst (one dict per batch item when batched)\n    batched_constraints = bool(constraint_lst) and isinstance(constraint_lst[0], list)\n    if batched_constraints:\n        constraint_masks_dict_lst = [_build_constraint_masks_dict(constraint_lst[b]) for b in range(batch_size)]\n    else:\n        constraint_masks_dict = (\n            _build_constraint_masks_dict(constraint_lst)\n            if constraint_lst\n            else {\n                key: torch.zeros(num_frames, dtype=torch.float32)\n                for key in [\n                    \"FullBody\",\n                    \"LeftFoot\",\n                    \"RightFoot\",\n                    \"LeftHand\",\n                    \"RightHand\",\n                    \"Root\",\n                ]\n            }\n        )\n\n    # Create working rig\n    above_ground_offset = 0.02 if isinstance(skeleton, (SOMASkeleton30, SOMASkeleton77)) else 0.007\n    # larger offset for SOMA since model tends to generate lower to the ground\n    working_rig = create_working_rig_from_skeleton(skeleton, above_ground_offset=above_ground_offset)\n    has_double_ankle_joints = isinstance(skeleton, G1Skeleton34)\n\n    # Prepare input tensors. The generated motion will be modified in place. Clone first.\n    neutral_joints_pelvis_offset = skeleton.neutral_joints[0].cpu().clone()\n    hip_translations_corrected = root_positions.cpu().clone()\n    rotations_corrected = matrix_to_quaternion(local_rot_mats).cpu().clone()  # (B, T, J, 4)\n    contacts = contacts.cpu()\n\n    # Extract input motion (target keyframes) from constraints for each batch\n    # For constrained keyframes, use the original motion from constraints\n    # For non-constrained frames, zeros are used\n    hip_translations_input = torch.zeros(batch_size, num_frames, 3)\n    rotations_input = torch.zeros(batch_size, num_frames, num_joints, 4)\n    rotations_input[..., 0] = 1.0  # Initialize as identity quaternions (w=1, x=y=z=0)\n\n    if constraint_lst:\n        for b in range(batch_size):\n            # Get constraints for this batch item (if batched) or use the same list\n            constraints_lst_el = (\n                constraint_lst[b]\n                if isinstance(\n                    constraint_lst[0], list\n                )  # when the constraint_list is in batch format, each item in a list is a constraintlist for one sample\n                else constraint_lst  # single constraint list shared for all samples in the batch\n            )\n            hip_translations_input[b], rotations_input[b] = extract_input_motion_from_constraints(\n                constraints_lst_el,\n                skeleton,\n                num_frames,\n                num_joints,\n            )\n\n    # Call the motion correction for each batch (optional package)\n    try:\n        from motion_correction import motion_postprocess\n    except ImportError as e:\n        raise RuntimeError(\n            \"Motion correction is required for this postprocessing path but the \"\n            \"motion_correction package is not installed. Install with: pip install -e .\"\n        ) from e\n    for b in range(batch_size):\n        masks_b = constraint_masks_dict_lst[b] if batched_constraints else constraint_masks_dict\n        motion_postprocess.correct_motion(\n            hip_translations_corrected[b : b + 1],\n            rotations_corrected[b : b + 1],\n            contacts[b : b + 1],\n            hip_translations_input[b : b + 1],\n            rotations_input[b : b + 1],\n            masks_b,\n            contact_threshold,\n            root_margin,\n            working_rig,\n            has_double_ankle_joints,\n        )\n\n    local_rot_mats_corrected = quaternion_to_matrix(rotations_corrected)\n\n    # Compute posed joints using FK\n    device = local_rot_mats.device\n    global_rot_mats, posed_joints, _ = fk(\n        local_rot_mats_corrected.to(device),\n        hip_translations_corrected.to(device),\n        skeleton,\n    )\n\n    result = {\n        \"local_rot_mats\": local_rot_mats_corrected.to(device),\n        \"root_positions\": hip_translations_corrected.to(device),\n        \"posed_joints\": posed_joints,\n        \"global_rot_mats\": global_rot_mats,\n    }\n\n    return result\n"
  },
  {
    "path": "kimodo/sanitize.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Text prompt sanitization for motion generation (whitespace, punctuation, capitalization).\"\"\"\n\n\ndef sanitize_text(text: str, paragraph: bool = True) -> str:\n    \"\"\"Sanitize a text prompt: strip, collapse spaces, capitalize, trim non-alphanumeric, add/fix final punctuation.\n\n    Args:\n        text: Input text prompt.\n        paragraph: If True, capitalize after each sentence break and normalize spacing between sentences.\n\n    Returns:\n        Sanitized text.\n    \"\"\"\n    # remove any trailing or leading whitespace\n    text = text.strip()\n\n    # https://stackoverflow.com/a/1546251\n    # replace duplicate spaces by one space\n    text = \" \".join(text.split())\n\n    if text == \"\":\n        return text\n\n    # removing leading non alpha numeric characters\n    for i, c in enumerate(text):\n        if not str.isalnum(c):\n            continue\n        break\n    text = text[i:]\n\n    # Capitalize\n    text = text.capitalize()\n\n    final_punctuations = \".!?\\\"])'\"\n    # removing trailing non alpha numeric characters\n    # expect final punctuations\n    for i, c in reversed(list(enumerate(text))):\n        if not str.isalnum(c) and c not in final_punctuations:\n            continue\n        break\n    text = text[: i + 1]\n\n    # Adding period at the end if needed\n    if text[-1] not in \".!?\":\n        text = text + \".\"\n\n    if paragraph:\n        # fix end of sentences if several sentences\n        for sentence_break in \".!?\":\n            subtexts = text.split(sentence_break)\n            text = f\"{sentence_break} \".join(  # put back a space after the break\n                [\n                    y[0].capitalize() + y[1:]  # only capitalize the first character\n                    if y\n                    else y  # y is empty at the end\n                    for x in subtexts\n                    for y in [x.strip()]  # remove extra spaces\n                ]\n            ).strip()  # remove extra space at the end\n    return text\n\n\ndef sanitize_texts(texts: list[str]) -> list[str]:\n    \"\"\"Sanitize each text prompt in the list (see sanitize_text).\n\n    Args:\n        texts: List of input text prompts.\n\n    Returns:\n        List of sanitized texts.\n    \"\"\"\n    return [sanitize_text(text) for text in texts]\n\n\nif __name__ == \"__main__\":\n    texts = [\n        \" A person is    walking.\",\n        \"someone go forward\",\n        \"jump\",\n        \"jumping!\",\n        \"jumping)\",\n        \"-go\",\n        \"blocasdji  -----\",\n        \"\",\n    ]\n\n    print(\"Old texts\")\n    print(\"\\n\".join(texts))\n    print()\n\n    new_texts = sanitize_texts(texts)\n    print(\"Sanitized texts\")\n    print(\"\\n\".join(new_texts))\n"
  },
  {
    "path": "kimodo/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "kimodo/scripts/docker-entrypoint.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nHOST_UID=\"${HOST_UID:-}\"\nHOST_GID=\"${HOST_GID:-}\"\nHOST_USER=\"${HOST_USER:-user}\"\n\nif [[ -z \"${HOST_UID}\" || -z \"${HOST_GID}\" ]]; then\n  if [[ -d /workspace ]]; then\n    HOST_UID=\"$(stat -c %u /workspace)\"\n    HOST_GID=\"$(stat -c %g /workspace)\"\n  else\n    HOST_UID=\"${HOST_UID:-1000}\"\n    HOST_GID=\"${HOST_GID:-1000}\"\n  fi\nfi\n\nif ! getent group \"${HOST_GID}\" >/dev/null 2>&1; then\n  groupadd -g \"${HOST_GID}\" \"${HOST_USER}\"\nfi\n\nif ! getent passwd \"${HOST_UID}\" >/dev/null 2>&1; then\n  useradd -m -u \"${HOST_UID}\" -g \"${HOST_GID}\" -s /bin/bash \"${HOST_USER}\"\nfi\n\nexec gosu \"${HOST_UID}:${HOST_GID}\" \"$@\"\n"
  },
  {
    "path": "kimodo/scripts/generate.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport os\nimport shutil\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom kimodo import DEFAULT_MODEL, load_model\nfrom kimodo.constraints import load_constraints_lst\nfrom kimodo.exports.motion_io import save_kimodo_npz\nfrom kimodo.meta import load_prompts_from_meta\nfrom kimodo.model.cfg import CFG_TYPES\nfrom kimodo.model.registry import get_model_info\nfrom kimodo.tools import load_json, save_json, seed_everything\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Cmd line API for generation motions with kimodo\")\n    parser.add_argument(\n        \"prompt\",\n        nargs=\"?\",\n        type=str,\n        default=None,\n        help=\"Text prompt describing the motion to generate, or several prompts separated by periods.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=DEFAULT_MODEL,\n        help=\"Name of the model (e.g. Kimodo-SOMA-RP-v1, etc).\",\n    )\n    parser.add_argument(\n        \"--duration\",\n        type=str,\n        default=\"5.0\",\n        help=\"Duration in seconds (default: 5.0). Separate by spaces in a string for different durations per prompts\",\n    )\n    parser.add_argument(\n        \"--num_samples\",\n        type=int,\n        default=1,\n        help=\"Number of samples to generate (default: 1)\",\n    )\n    parser.add_argument(\n        \"--diffusion_steps\",\n        type=int,\n        default=100,\n        help=\"Number of diffusion steps (default: 100)\",\n    )\n    parser.add_argument(\n        \"--num_transition_frames\",\n        type=int,\n        default=5,\n        help=\"Number of frames to help transitioning (default: 5)\",\n    )\n    parser.add_argument(\n        \"--constraints\",\n        type=str,\n        default=None,\n        help=\"Saved constraint list\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        default=\"output\",\n        help=\"Output stem name: with one sample writes a single file per format (e.g. test.npz, test.csv); with multiple samples creates a folder and writes test_00.npz, test_01.npz, ... inside it. Used for NPZ, AMASS NPZ, CSV, and BVH.\",\n    )\n    parser.add_argument(\n        \"--save_example_dir\",\n        action=\"store_true\",\n        help=(\n            \"Save demo-compatible example directories (each contains motion.npz, constraints.json, meta.json). \"\n            \"With one sample, writes <output>_example/. With multiple samples, writes \"\n            \"<output>_examples/<output>_example_00/, <output>_example_01/, ...\"\n        ),\n    )\n    parser.add_argument(\n        \"--bvh\",\n        action=\"store_true\",\n        help=\"Also export BVH (SOMA models only); uses the same stem as --output.\",\n    )\n    parser.add_argument(\n        \"--bvh_standard_tpose\",\n        action=\"store_true\",\n        help=\"If exporting BVH, export with the rest pose being the standard T-pose rather than the rest pose consistent with the BONES-SEED dataset.\",\n    )\n    parser.add_argument(\n        \"--no-postprocess\",\n        action=\"store_true\",\n        help=\"Don't apply motion post-processing to reduce foot skating (ignored for G1)\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=None,\n        help=\"Seed for reproducible results\",\n    )\n    parser.add_argument(\n        \"--input_folder\",\n        type=str,\n        default=None,\n        help=\"Folder containing meta.json and optional constraints.json. If set, generation settings are loaded from meta.json.\",\n    )\n    parser.add_argument(\n        \"--cfg_type\",\n        type=str,\n        default=argparse.SUPPRESS,\n        choices=CFG_TYPES,\n        help=(\n            \"Classifier-free guidance mode: nocfg (no CFG), regular (single scale on cond vs uncond), \"\n            \"or separated (custom: separate text and constraint scales). \"\n            \"Use with --cfg_weight as required by the mode.\"\n        ),\n    )\n    parser.add_argument(\n        \"--cfg_weight\",\n        type=float,\n        nargs=\"*\",\n        default=argparse.SUPPRESS,\n        help=(\n            \"CFG scale(s): one float for regular, or two floats [text_weight, constraint_weight] for separated. \"\n            \"Omit with --cfg_type nocfg. If omitted, two floats alone imply separated; one float alone implies regular.\"\n        ),\n    )\n    return parser.parse_args()\n\n\ndef get_texts_and_num_frames_from_prompt(prompt: str, duration: str, fps: float):\n    # Get the texts\n    texts = [text.strip() for text in prompt.split(\".\")]\n    texts = [text + \".\" for text in texts if text]\n\n    nb_prompts = len(texts)\n\n    # Get the durations\n    if \" \" not in duration:\n        duration_sec = float(duration)\n        # same for all the prompts\n        num_frames = [int(duration_sec * fps)] * nb_prompts\n    else:\n        durations = duration.split(\" \")\n        assert len(durations) == len(texts), \"The number of durations should match the number of prompts\"\n        num_frames = [int(float(duration.strip()) * fps) for duration in durations]\n        assert len(num_frames) == nb_prompts, \"The number of durations should be 1 or match the number of texts\"\n\n    return texts, num_frames\n\n\ndef _single_file_path(path: str, ext: str) -> str:\n    \"\"\"Return path for a single output file (no folder).\n\n    Adds ext if missing; creates parent dirs if any.\n    \"\"\"\n    if not path.endswith(ext):\n        path = path.rstrip(os.sep) + ext\n    parent = os.path.dirname(path)\n    if parent:\n        os.makedirs(parent, exist_ok=True)\n    return path\n\n\ndef _output_dir_and_path(path: str, default_base: str, ext: str):\n    \"\"\"Create output folder from path and return (dir_path, path_for_file_with_suffix, base_name).\n\n    If path has an extension, folder name is the path stem; else the path is the folder name.\n    base_name is the folder basename for _00, _01, ... when n_samples > 1.\n    \"\"\"\n    folder = os.path.splitext(path)[0] if os.path.splitext(path)[1] else path\n    os.makedirs(folder, exist_ok=True)\n    base_name = os.path.basename(folder.rstrip(os.sep))\n    return folder, os.path.join(folder, default_base + ext), base_name\n\n\ndef resolve_cfg_kwargs(args: argparse.Namespace, meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:\n    \"\"\"Resolve cfg_type / cfg_weight for model(...).\n\n    Precedence: explicit CLI (--cfg_type / --cfg_weight) overrides meta.json ``cfg``;\n    if neither applies, returns {} so the model uses its own defaults.\n    \"\"\"\n    ns = vars(args)\n    has_type = \"cfg_type\" in ns\n    has_wflag = \"cfg_weight\" in ns\n    cli_type = ns.get(\"cfg_type\")\n    cli_w = ns.get(\"cfg_weight\")\n\n    if has_wflag:\n        if cli_w is None or len(cli_w) == 0:\n            raise ValueError(\"--cfg_weight requires one float (regular) or two floats (separated).\")\n\n    if has_type and cli_type == \"nocfg\":\n        if has_wflag:\n            raise ValueError(\"--cfg_weight is not used with --cfg_type nocfg.\")\n        return {\"cfg_type\": \"nocfg\"}\n\n    if has_type or has_wflag:\n        if has_type:\n            eff_type = cli_type\n            if has_wflag:\n                if eff_type == \"regular\" and len(cli_w) != 1:\n                    raise ValueError(\"--cfg_type regular requires exactly one --cfg_weight value.\")\n                if eff_type == \"separated\" and len(cli_w) != 2:\n                    raise ValueError(\"--cfg_type separated requires exactly two --cfg_weight values.\")\n            else:\n                if eff_type == \"regular\":\n                    raise ValueError(\"--cfg_type regular requires --cfg_weight with one float.\")\n                if eff_type == \"separated\":\n                    raise ValueError(\"--cfg_type separated requires --cfg_weight with two floats.\")\n        else:\n            if len(cli_w) == 1:\n                eff_type = \"regular\"\n            elif len(cli_w) == 2:\n                eff_type = \"separated\"\n            else:\n                raise ValueError(\"--cfg_weight expects 1 float (regular) or 2 floats (separated).\")\n\n        if eff_type == \"regular\":\n            return {\"cfg_type\": \"regular\", \"cfg_weight\": float(cli_w[0])}\n        return {\"cfg_type\": \"separated\", \"cfg_weight\": [float(cli_w[0]), float(cli_w[1])]}\n\n    if meta and isinstance(meta.get(\"cfg\"), dict):\n        cfg = meta[\"cfg\"]\n        enabled = cfg.get(\"enabled\", True)\n        if not enabled:\n            return {\"cfg_type\": \"nocfg\"}\n        return {\n            \"cfg_type\": \"separated\",\n            \"cfg_weight\": [\n                float(cfg.get(\"text_weight\", 2.0)),\n                float(cfg.get(\"constraint_weight\", 2.0)),\n            ],\n        }\n\n    return {}\n\n\ndef get_generation_inputs(args, fps: float):\n    \"\"\"Get texts/num_frames and parameter overrides from either CLI or input_folder.\"\"\"\n    if args.input_folder is None:\n        if not args.prompt:\n            raise ValueError(\"Either provide 'prompt' or '--input_folder'.\")\n        texts, num_frames = get_texts_and_num_frames_from_prompt(args.prompt, args.duration, fps)\n        return {\n            \"texts\": texts,\n            \"num_frames\": num_frames,\n            \"num_samples\": args.num_samples,\n            \"diffusion_steps\": args.diffusion_steps,\n            \"seed\": args.seed,\n            \"constraints_path\": args.constraints,\n            \"meta\": None,\n        }\n\n    meta_path = os.path.join(args.input_folder, \"meta.json\")\n    meta = load_json(meta_path)\n    texts, durations_sec = load_prompts_from_meta(meta_path)\n    num_frames = [int(float(duration) * fps) for duration in durations_sec]\n\n    constraints_path = args.constraints\n    default_constraints_path = os.path.join(args.input_folder, \"constraints.json\")\n    if constraints_path is None and os.path.exists(default_constraints_path):\n        constraints_path = default_constraints_path\n\n    return {\n        \"texts\": texts,\n        \"num_frames\": num_frames,\n        \"num_samples\": meta.get(\"num_samples\", args.num_samples),\n        \"diffusion_steps\": meta.get(\"diffusion_steps\", args.diffusion_steps),\n        \"seed\": meta.get(\"seed\", args.seed),\n        \"constraints_path\": constraints_path,\n        \"meta\": meta,\n    }\n\n\ndef main():\n    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n    print(f\"Using device: {device}\")\n\n    args = parse_args()\n\n    # Load model (resolution of name done inside load_model)\n    model, resolved_model = load_model(\n        args.model,\n        device=device,\n        default_family=\"Kimodo\",\n        return_resolved_name=True,\n    )\n    info = get_model_info(resolved_model)\n    display = info.display_name if info else resolved_model\n    print(f\"Loaded model: {display} ({resolved_model})\")\n\n    # Get generation inputs\n    generation_inputs = get_generation_inputs(args, model.fps)\n    texts = generation_inputs[\"texts\"]\n    num_frames = generation_inputs[\"num_frames\"]\n    print(\"Will generate motions with the following prompts\")\n    for text, num_frame in zip(texts, num_frames):\n        print(f\"    '{text}' with {num_frame} frames\")\n\n    # Load constraints\n    constraints_path = generation_inputs[\"constraints_path\"]\n    if constraints_path:\n        constraint_lst = load_constraints_lst(constraints_path, model.skeleton)\n    else:\n        constraint_lst = []\n\n    if constraint_lst:\n        print(f\"Using {len(constraint_lst)} set of constraints\")\n        for constraint in constraint_lst:\n            print(f\"    {constraint}\")\n\n    if generation_inputs[\"seed\"] is not None:\n        seed_everything(generation_inputs[\"seed\"])\n\n    cfg_kwargs = resolve_cfg_kwargs(args, generation_inputs.get(\"meta\"))\n    if cfg_kwargs:\n        ct = cfg_kwargs.get(\"cfg_type\")\n        cw = cfg_kwargs.get(\"cfg_weight\")\n        if cw is not None:\n            print(f\"Using CFG: cfg_type={ct!r}, cfg_weight={cw!r}\")\n        else:\n            print(f\"Using CFG: cfg_type={ct!r}\")\n\n    # G1: postprocessing is disabled (does not work well for this model).\n    use_postprocess = False if \"g1\" in resolved_model else (not args.no_postprocess)\n    output = model(\n        texts,\n        num_frames,\n        constraint_lst=constraint_lst,\n        num_denoising_steps=generation_inputs[\"diffusion_steps\"],\n        num_samples=generation_inputs[\"num_samples\"],\n        multi_prompt=True,\n        num_transition_frames=args.num_transition_frames,\n        post_processing=use_postprocess,\n        return_numpy=True,\n        **cfg_kwargs,\n    )\n\n    n_samples = int(output[\"posed_joints\"].shape[0])\n    # Parse the output stem once; all formats (NPZ, AMASS NPZ, CSV, BVH) use this base name.\n    output_base = args.output\n\n    # Save the NPZ output\n    if n_samples == 1:\n        npz_path = _single_file_path(output_base, \".npz\")\n        print(f\"Saving the npz output to {npz_path}\")\n        single = {\n            k: (v[0] if hasattr(v, \"shape\") and len(v.shape) > 0 and v.shape[0] == n_samples else v)\n            for k, v in output.items()\n        }\n        save_kimodo_npz(npz_path, single)\n    else:\n        out_dir, _, base_name = _output_dir_and_path(output_base, \"motion\", \".npz\")\n        print(f\"Saving the npz output to {out_dir}/ ({base_name}_00.npz ...)\")\n        for i in range(n_samples):\n            single = {\n                k: (v[i] if hasattr(v, \"shape\") and len(v.shape) > 0 and v.shape[0] == n_samples else v)\n                for k, v in output.items()\n            }\n            save_kimodo_npz(os.path.join(out_dir, f\"{base_name}_{i:02d}.npz\"), single)\n\n    # Save the AMASS NPZ output\n    if resolved_model == \"kimodo-smplx-rp\":\n        from kimodo.exports.smplx import AMASSConverter\n\n        converter = AMASSConverter(skeleton=model.skeleton, fps=model.fps)\n        if n_samples == 1:\n            # Use distinct name so AMASS NPZ does not overwrite the main NPZ\n            amass_single_path = _single_file_path(output_base + \"_amass\", \".npz\")\n            print(f\"Saving the amass output to {amass_single_path}\")\n            converter.convert_save_npz(output, amass_single_path)\n        else:\n            out_dir, _, base_name = _output_dir_and_path(output_base, \"amass\", \".npz\")\n            print(f\"Saving the amass output to {out_dir}/ (amass_00.npz ...)\")\n            converter.convert_save_npz(output, os.path.join(out_dir, \"amass.npz\"))\n\n    # Save the CSV output\n    if resolved_model == \"kimodo-g1-rp\":\n        from kimodo.exports.mujoco import MujocoQposConverter\n\n        converter = MujocoQposConverter(model.skeleton)\n        qpos = converter.dict_to_qpos(output, device)\n        if n_samples == 1:\n            csv_path = _single_file_path(output_base, \".csv\")\n            print(f\"Saving the csv output to {csv_path}\")\n            converter.save_csv(qpos, csv_path)\n        else:\n            out_dir, _, base_name = _output_dir_and_path(output_base, \"qpos\", \".csv\")\n            print(f\"Saving the csv output to {out_dir}/ ({base_name}_00.csv ...)\")\n            converter.save_csv(qpos, os.path.join(out_dir, base_name + \".csv\"))\n\n    # Save the BVH output\n    if args.bvh:\n        skeleton = model.skeleton\n        if \"somaskel\" not in skeleton.name:\n            print(\"BVH export is only supported for SOMA skeletons. Skipping --bvh.\")\n        else:\n            from kimodo.exports.bvh import save_motion_bvh\n            from kimodo.skeleton import SOMASkeleton30, global_rots_to_local_rots\n\n            if isinstance(skeleton, SOMASkeleton30):\n                # Motion has already been converted to somaskel77 within the model for output\n                skeleton = skeleton.somaskel77.to(device)\n\n            if n_samples == 1:\n                bvh_path = _single_file_path(output_base, \".bvh\")\n                print(f\"Saving the BVH output to {bvh_path}\")\n                joints_pos = torch.from_numpy(output[\"posed_joints\"][0]).to(device)\n                joints_rot = torch.from_numpy(output[\"global_rot_mats\"][0]).to(device)\n                local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton)\n                root_positions = joints_pos[:, skeleton.root_idx, :]\n                save_motion_bvh(\n                    bvh_path,\n                    local_rot_mats,\n                    root_positions,\n                    skeleton=skeleton,\n                    fps=model.fps,\n                    standard_tpose=args.bvh_standard_tpose,\n                )\n            else:\n                out_dir, _, base_name = _output_dir_and_path(output_base, \"motion\", \".bvh\")\n                print(f\"Saving the BVH output to {out_dir}/ ({base_name}_00.bvh ...)\")\n                for i in range(n_samples):\n                    joints_pos = torch.from_numpy(output[\"posed_joints\"][i]).to(device)\n                    joints_rot = torch.from_numpy(output[\"global_rot_mats\"][i]).to(device)\n                    local_rot_mats = global_rots_to_local_rots(joints_rot, skeleton)\n                    root_positions = joints_pos[:, skeleton.root_idx, :]\n                    save_motion_bvh(\n                        os.path.join(out_dir, f\"{base_name}_{i:02d}.bvh\"),\n                        local_rot_mats,\n                        root_positions,\n                        skeleton=skeleton,\n                        fps=model.fps,\n                        standard_tpose=args.bvh_standard_tpose,\n                    )\n\n    # Save the example directory\n    if args.save_example_dir:\n        output_stem = os.path.splitext(output_base)[0].rstrip(os.sep)\n        base_name = os.path.basename(output_stem)\n\n        if n_samples == 1:\n            parent_dir = None\n            example_dirs = [output_stem + \"_example\"]\n        else:\n            parent_dir = output_stem + \"_examples\"\n            if os.path.exists(parent_dir):\n                raise FileExistsError(f\"Example directory already exists: {parent_dir}\")\n            os.makedirs(parent_dir)\n            example_dirs = [\n                os.path.join(parent_dir, f\"{base_name}_example_{i:02d}\") for i in range(n_samples)\n            ]\n\n        durations_sec = [nf / model.fps for nf in num_frames]\n        if len(texts) == 1:\n            meta_info: dict = {\"text\": texts[0], \"duration\": durations_sec[0]}\n        else:\n            meta_info = {\"texts\": texts, \"durations\": durations_sec}\n        meta_info[\"num_samples\"] = generation_inputs[\"num_samples\"]\n        if generation_inputs[\"seed\"] is not None:\n            meta_info[\"seed\"] = generation_inputs[\"seed\"]\n        meta_info[\"diffusion_steps\"] = generation_inputs[\"diffusion_steps\"]\n        if cfg_kwargs:\n            cfg_type = cfg_kwargs.get(\"cfg_type\", \"nocfg\")\n            cfg_weight = cfg_kwargs.get(\"cfg_weight\")\n            if cfg_type == \"nocfg\":\n                meta_info[\"cfg\"] = {\"enabled\": False}\n            elif cfg_type == \"separated\" and isinstance(cfg_weight, list) and len(cfg_weight) == 2:\n                meta_info[\"cfg\"] = {\n                    \"enabled\": True,\n                    \"text_weight\": cfg_weight[0],\n                    \"constraint_weight\": cfg_weight[1],\n                }\n            elif cfg_type == \"regular\" and cfg_weight is not None:\n                meta_info[\"cfg\"] = {\n                    \"enabled\": True,\n                    \"text_weight\": float(cfg_weight),\n                    \"constraint_weight\": float(cfg_weight),\n                }\n\n        for i, example_dir in enumerate(example_dirs):\n            if os.path.exists(example_dir):\n                raise FileExistsError(f\"Example directory already exists: {example_dir}\")\n            os.makedirs(example_dir)\n            sample = {\n                k: (v[i] if hasattr(v, \"shape\") and len(v.shape) > 0 and v.shape[0] == n_samples else v)\n                for k, v in output.items()\n            }\n            save_kimodo_npz(os.path.join(example_dir, \"motion.npz\"), sample)\n            if constraints_path:\n                shutil.copy2(constraints_path, os.path.join(example_dir, \"constraints.json\"))\n            save_json(os.path.join(example_dir, \"meta.json\"), meta_info)\n\n        if parent_dir is None:\n            print(f\"Saved demo example to {example_dirs[0]}\")\n        else:\n            print(f\"Saved {n_samples} demo examples to {parent_dir}/\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kimodo/scripts/gradio_theme.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport gradio as gr\n\n\ndef get_gradio_theme(remove_gradio_footer=False):\n    theme = gr.themes.Base(\n        primary_hue=\"blue\",\n        text_size=gr.themes.Size(lg=\"16px\", md=\"14px\", sm=\"12px\", xl=\"22px\", xs=\"10px\", xxl=\"35px\", xxs=\"9px\"),\n        font=[\n            gr.themes.GoogleFont(\"Source Sans Pro\"),\n            \"BlinkMacSystemFont\",\n            \"Segoe UI\",\n            \"Roboto\",\n        ],\n    ).set(\n        body_text_color=\"*neutral_900\",\n        body_text_color_subdued=\"*neutral_500\",\n        body_text_color_subdued_dark=\"*neutral_500\",\n    )\n\n    css = \"\"\"\n        @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600;700;900&display=swap');\n\n        /* Base text */\n        body, .gradio-container {\n          font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen-Sans, Ubuntu, Cantarell, 'Helvetica Neue', sans-serif !important;\n          font-size: 16px !important;\n        }\n\n        h1 {\n          // font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important;\n          font-weight: 700 !important;\n          font-size: 2.75rem !important;\n          // margin: 0px;\n          padding: 1.5rem 0px 0px 0px;\n          // line-height: 1.2;\n        }\n        h2 {\n          // font-family: 'Source Sans Pro', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif !important;\n          font-weight: 600 !important;\n          font-size: 1.5rem !important;\n        }\n    \"\"\"\n    if remove_gradio_footer:\n        css += \"\"\"\n        footer {\n        display: none !important;\n        }\n        \"\"\"\n    return theme, css\n"
  },
  {
    "path": "kimodo/scripts/lock_requirements.py",
    "content": "#!/usr/bin/env python3\n\n# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Regenerate `docker_requirements.txt` from `docker_requirements.in` using `uv`, targeting the\nDocker image runtime, and filter out `torch` + CUDA wheels so Docker doesn't try to reinstall\nPyTorch.\n\nUsage:\n  python3 kimodo/scripts/lock_requirements.py\n\nOptional args:\n  --python-version 3.10\n  --python-platform x86_64-manylinux2014\n  --in docker_requirements.in\n  --out docker_requirements.txt\n\"\"\"\n\nimport argparse\nimport shutil\nimport subprocess\nfrom pathlib import Path\nfrom typing import Iterable\n\nDEFAULT_PYTHON_VERSION = \"3.10\"\nDEFAULT_PYTHON_PLATFORM = \"x86_64-manylinux2014\"\n\n# Packages to omit from the lockfile because the Docker base image already provides torch+CUDA.\nOMIT_NAMES = {\"torch\", \"triton\", \"networkx\", \"sympy\", \"mpmath\"}\nOMIT_PREFIXES = (\"nvidia-\",)\n\n\ndef _run(cmd: list[str]) -> None:\n    print(\"+\", \" \".join(cmd))\n    subprocess.run(cmd, check=True)\n\n\ndef _ensure_uv() -> None:\n    if shutil.which(\"uv\") is None:\n        raise SystemExit(\n            \"ERROR: `uv` is not installed or not on PATH.\\n\"\n            \"Install it (one of):\\n\"\n            \"  - pipx install uv\\n\"\n            \"  - python -m pip install --user uv\\n\"\n            \"Then rerun this script.\"\n        )\n\n\ndef _parse_req_name(line: str) -> str:\n    # uv emits `name==version` lines.\n    s = line.strip()\n    if \"==\" in s:\n        return s.split(\"==\", 1)[0].strip()\n    # Fallback: treat the whole token before space as name.\n    return s.split()[0].strip()\n\n\ndef _iter_blocks(lines: list[str]) -> Iterable[list[str]]:\n    \"\"\"Split a docker_requirements.txt into blocks: [top-level req line + indented comments].\"\"\"\n    i = 0\n    n = len(lines)\n    while i < n:\n        line = lines[i]\n        # Header/comments/blank\n        if line.startswith(\"#\") or line.strip() == \"\":\n            yield [line]\n            i += 1\n            continue\n\n        # Top-level requirement line\n        if not line.startswith(\" \"):\n            block = [line]\n            i += 1\n            while i < n and (lines[i].startswith(\" \") or lines[i].strip() == \"\" or lines[i].startswith(\"#\")):\n                # Stop if we hit another top-level requirement line\n                if not lines[i].startswith(\" \") and not lines[i].startswith(\"#\") and lines[i].strip() != \"\":\n                    break\n                block.append(lines[i])\n                i += 1\n            yield block\n            continue\n\n        # Indented line without a requirement header (shouldn't happen, but keep)\n        yield [line]\n        i += 1\n\n\ndef _should_omit(req_line: str) -> bool:\n    name = _parse_req_name(req_line)\n    if name in OMIT_NAMES:\n        return True\n    for pfx in OMIT_PREFIXES:\n        if name.startswith(pfx):\n            return True\n    return False\n\n\ndef filter_lockfile(path: Path) -> None:\n    lines = path.read_text(encoding=\"utf-8\").splitlines(True)\n    out: list[str] = []\n\n    inserted_note = False\n    for block in _iter_blocks(lines):\n        first = block[0]\n\n        # After the uv header lines, insert a short note once.\n        if (not inserted_note) and first.startswith(\"# This file was autogenerated by uv\"):\n            out.extend(block)\n            out.append(\n                \"# NOTE: `torch` (and its CUDA wheels) are intentionally omitted from this lockfile.\\n\"\n                \"# The Docker base image (nvcr.io/nvidia/pytorch) already provides a tested PyTorch build.\\n\"\n                \"#\\n\"\n            )\n            inserted_note = True\n            continue\n\n        if first.startswith(\"#\") or first.strip() == \"\":\n            out.extend(block)\n            continue\n\n        if _should_omit(first):\n            continue\n\n        out.extend(block)\n\n    path.write_text(\"\".join(out), encoding=\"utf-8\")\n\n\ndef main() -> None:\n    ap = argparse.ArgumentParser()\n    ap.add_argument(\"--in\", dest=\"in_file\", default=\"docker_requirements.in\")\n    ap.add_argument(\"--out\", dest=\"out_file\", default=\"docker_requirements.txt\")\n    ap.add_argument(\"--python-version\", default=DEFAULT_PYTHON_VERSION)\n    ap.add_argument(\"--python-platform\", default=DEFAULT_PYTHON_PLATFORM)\n    args = ap.parse_args()\n\n    _ensure_uv()\n\n    in_path = Path(args.in_file)\n    out_path = Path(args.out_file)\n    if not in_path.exists():\n        raise SystemExit(f\"ERROR: missing {in_path}\")\n\n    _run(\n        [\n            \"uv\",\n            \"pip\",\n            \"compile\",\n            \"-U\",\n            str(in_path),\n            \"-o\",\n            str(out_path),\n            \"--python-version\",\n            args.python_version,\n            \"--python-platform\",\n            args.python_platform,\n        ]\n    )\n    filter_lockfile(out_path)\n    print(f\"OK: wrote {out_path} (filtered torch/CUDA wheels)\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kimodo/scripts/motion_convert.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"CLI entry-point for motion format conversion.\n\nLibrary conversion logic lives in :mod:`kimodo.exports.motion_convert_lib`.\nFormat detection utilities live in :mod:`kimodo.exports.motion_formats`.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport argparse\nimport sys\n\nfrom kimodo.exports.motion_convert_lib import convert_motion_files\n\n\ndef run_convert(\n    input_path: str,\n    output_path: str,\n    from_fmt: str | None,\n    to_fmt: str | None,\n    source_fps: float | None,\n    z_up: bool,\n    mujoco_rest_zero: bool,\n    bvh_standard_tpose: bool = False,\n) -> None:\n    \"\"\"Thin wrapper kept for backward compatibility; delegates to :func:`convert_motion_files`.\"\"\"\n    convert_motion_files(\n        input_path,\n        output_path,\n        from_fmt=from_fmt,\n        to_fmt=to_fmt,\n        source_fps=source_fps,\n        z_up=z_up,\n        mujoco_rest_zero=mujoco_rest_zero,\n        bvh_standard_tpose=bvh_standard_tpose,\n    )\n\n\ndef build_argparser() -> argparse.ArgumentParser:\n    p = argparse.ArgumentParser(\n        description=\"Convert Kimodo NPZ, AMASS NPZ, SOMA BVH, and G1 MuJoCo CSV.\",\n    )\n    p.add_argument(\"input\", help=\"Input file path\")\n    p.add_argument(\"output\", help=\"Output file path\")\n    p.add_argument(\n        \"--from\",\n        dest=\"from_fmt\",\n        choices=(\"amass\", \"kimodo\", \"soma-bvh\", \"g1-csv\"),\n        default=None,\n        help=\"Input format (default: infer from file contents/extension)\",\n    )\n    p.add_argument(\n        \"--to\",\n        dest=\"to_fmt\",\n        choices=(\"kimodo\", \"amass\", \"soma-bvh\", \"g1-csv\"),\n        default=None,\n        help=\"Output format (default: infer from output extension)\",\n    )\n    p.add_argument(\n        \"--source-fps\",\n        \"--fps\",\n        dest=\"source_fps\",\n        type=float,\n        default=None,\n        help=(\n            \"Source motion frame rate in Hz (default: auto-detected from \"\n            \"BVH Frame Time / AMASS mocap_frame_rate, or 30 Hz). \"\n            \"Kimodo NPZ output is always resampled to 30 Hz.\"\n        ),\n    )\n    p.add_argument(\n        \"--no-z-up\",\n        action=\"store_true\",\n        help=\"For AMASS paths: disable Z-up transform (treat trans/orient as already Kimodo Y-up).\",\n    )\n    p.add_argument(\n        \"--mujoco-rest-zero\",\n        action=\"store_true\",\n        default=False,\n        help=\"For G1 CSV: joint angles relative to MuJoCo rest (must match export).\",\n    )\n    p.add_argument(\n        \"--bvh_standard_tpose\",\n        action=\"store_true\",\n        default=False,\n        help=\"If input or output is BVH: the BVH file uses the standard T-pose as its rest pose instead of the BONES-SEED rest pose.\",\n    )\n    return p\n\n\ndef main(argv: list[str] | None = None) -> int:\n    args = build_argparser().parse_args(argv)\n    try:\n        convert_motion_files(\n            args.input,\n            args.output,\n            from_fmt=args.from_fmt,\n            to_fmt=args.to_fmt,\n            source_fps=args.source_fps,\n            z_up=not args.no_z_up,\n            mujoco_rest_zero=args.mujoco_rest_zero,\n            bvh_standard_tpose=args.bvh_standard_tpose,\n        )\n    except Exception as e:\n        print(f\"Error: {e}\", file=sys.stderr)\n        return 1\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "kimodo/scripts/mujoco_load.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport time\n\nimport mujoco\nimport mujoco.viewer\nimport numpy as np\n\nfrom kimodo.assets import skeleton_asset_path\n\nqpos = np.loadtxt(\"motion.csv\", delimiter=\",\")\nmodel = mujoco.MjModel.from_xml_path(str(skeleton_asset_path(\"g1skel34\", \"xml\", \"g1.xml\")))\ndata = mujoco.MjData(model)\n\nfps = 30  # adjust to your intended playback rate\n\nwith mujoco.viewer.launch_passive(model, data) as viewer:\n    # loop the motion\n    while viewer.is_running():\n        for frame in qpos:\n            data.qpos[:] = frame\n            mujoco.mj_forward(model, data)\n            viewer.sync()\n            time.sleep(1.0 / fps)\n"
  },
  {
    "path": "kimodo/scripts/run_text_encoder_server.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport argparse\nimport os\n\nimport gradio as gr\nimport numpy as np\n\nfrom kimodo.model import resolve_target\n\nfrom .gradio_theme import get_gradio_theme\n\nos.environ[\"HF_ENABLE_PARALLEL_LOADING\"] = \"YES\"\nDEFAULT_TEXT = \"A person walks and falls to the ground.\"\nDEFAULT_SERVER_NAME = \"0.0.0.0\"\nDEFAULT_SERVER_PORT = 9550\nDEFAULT_TMP_FOLDER = \"/tmp/text_encoder/\"\nDEFAULT_TEXT_ENCODER = \"llm2vec\"\nTEXT_ENCODER_PRESETS = {\n    \"llm2vec\": {\n        \"target\": \"kimodo.model.LLM2VecEncoder\",\n        \"kwargs\": {\n            \"base_model_name_or_path\": \"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp\",\n            \"peft_model_name_or_path\": \"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised\",\n            \"dtype\": \"bfloat16\",\n            \"llm_dim\": 4096,\n            \"device\": \"auto\",\n        },\n        \"display_name\": \"LLM2Vec\",\n    }\n}\n\n\nclass DemoWrapper:\n    def __init__(self, text_encoder, tmp_folder):\n        self.text_encoder = text_encoder\n        self.tmp_folder = tmp_folder\n\n    def __call__(self, text, filename, progress=gr.Progress()):\n        # Compute text embedding\n        tensor, length = self.text_encoder(text)\n        embedding = tensor[:length]\n        embedding = embedding.cpu().numpy()\n\n        # Save text embedding\n        path = os.path.join(self.tmp_folder, filename)\n        np.save(path, embedding)\n\n        output_title = gr.Markdown(visible=True)\n        output_text = gr.Markdown(visible=True, value=f\"Text: {text}\")\n        download = gr.DownloadButton(visible=True, value=path)\n        return download, output_title, output_text\n\n\ndef _get_env(name: str, default):\n    return os.getenv(name, default)\n\n\ndef _build_text_encoder(name: str, fp32: bool = False):\n    if name not in TEXT_ENCODER_PRESETS:\n        available = \", \".join(sorted(TEXT_ENCODER_PRESETS))\n        raise ValueError(f\"Unknown TEXT_ENCODER='{name}'. Available: {available}\")\n    preset = TEXT_ENCODER_PRESETS[name]\n    target_cls = resolve_target(preset[\"target\"])\n    if fp32:\n        preset[\"kwargs\"][\"dtype\"] = \"float32\"\n    return target_cls(**preset[\"kwargs\"])\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run text encoder Gradio server.\")\n    parser.add_argument(\n        \"--text-encoder\",\n        default=_get_env(\"TEXT_ENCODER\", DEFAULT_TEXT_ENCODER),\n        choices=sorted(TEXT_ENCODER_PRESETS.keys()),\n        help=\"Text encoder preset.\",\n    )\n    parser.add_argument(\n        \"--tmp-folder\",\n        default=_get_env(\"TEXT_ENCODER_TMP_FOLDER\", DEFAULT_TMP_FOLDER),\n    )\n    parser.add_argument(\n        \"--fp32\",\n        action=\"store_true\",\n        help=\"Uses fp32 for the text encoder rather than default bfloat16.\",\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n    server_name = _get_env(\"GRADIO_SERVER_NAME\", DEFAULT_SERVER_NAME)\n    server_port = int(_get_env(\"GRADIO_SERVER_PORT\", DEFAULT_SERVER_PORT))\n    theme, css = get_gradio_theme()\n    os.makedirs(args.tmp_folder, exist_ok=True)\n    text_encoder = _build_text_encoder(args.text_encoder, args.fp32)\n    display_name = TEXT_ENCODER_PRESETS[args.text_encoder][\"display_name\"]\n    demo_wrapper_fn = DemoWrapper(text_encoder, args.tmp_folder)\n\n    with gr.Blocks(title=\"Text encoder\", css=css, theme=theme) as demo:\n        gr.Markdown(f\"# Text encoder: {display_name}\")\n        gr.Markdown(\"## Description\")\n        gr.Markdown(\"Get a embeddings from a text.\")\n\n        gr.Markdown(\"## Inputs\")\n        with gr.Row():\n            text = gr.Textbox(\n                placeholder=\"Type the motion you want to generate with a sentence\",\n                show_label=True,\n                label=\"Text prompt\",\n                value=DEFAULT_TEXT,\n                type=\"text\",\n            )\n        with gr.Row(scale=3):\n            with gr.Column(scale=1):\n                btn = gr.Button(\"Encode\", variant=\"primary\")\n            with gr.Column(scale=1):\n                clear = gr.Button(\"Clear\", variant=\"secondary\")\n            with gr.Column(scale=3):\n                pass\n\n        output_title = gr.Markdown(\"## Outputs\", visible=False)\n        output_text = gr.Markdown(\"\", visible=False)\n        with gr.Row(scale=3):\n            with gr.Column(scale=1):\n                download = gr.DownloadButton(\"Download\", variant=\"primary\", visible=False)\n            with gr.Column(scale=4):\n                pass\n\n        filename = gr.Textbox(\n            visible=False,\n            value=\"embedding.npy\",\n        )\n\n        def clear_fn():\n            return [\n                gr.DownloadButton(visible=False),\n                gr.Markdown(visible=False),\n                gr.Markdown(visible=False),\n            ]\n\n        outputs = [download, output_title, output_text]\n\n        gr.on(\n            triggers=[text.submit, btn.click],\n            fn=clear_fn,\n            inputs=None,\n            outputs=outputs,\n        ).then(\n            fn=demo_wrapper_fn,\n            inputs=[text, filename],\n            outputs=outputs,\n        )\n\n        def download_file():\n            return gr.DownloadButton()\n\n        download.click(\n            fn=download_file,\n            inputs=None,\n            outputs=[download],\n        )\n        clear.click(fn=clear_fn, inputs=None, outputs=outputs)\n\n    demo.launch(server_name=server_name, server_port=server_port)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kimodo/skeleton/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Skeleton definitions and utilities used across kimodo.\"\"\"\n\nfrom .base import SkeletonBase\nfrom .definitions import (\n    G1Skeleton34,\n    SMPLXSkeleton22,\n    SOMASkeleton30,\n    SOMASkeleton77,\n)\nfrom .kinematics import batch_rigid_transform, fk\nfrom .registry import build_skeleton\nfrom .transforms import global_rots_to_local_rots, to_standard_tpose\n\n__all__ = [\n    \"SkeletonBase\",\n    \"G1Skeleton34\",\n    \"SOMASkeleton30\",\n    \"SOMASkeleton77\",\n    \"SMPLXSkeleton22\",\n    \"batch_rigid_transform\",\n    \"fk\",\n    \"build_skeleton\",\n    \"global_rots_to_local_rots\",\n    \"to_standard_tpose\",\n]\n"
  },
  {
    "path": "kimodo/skeleton/base.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Base skeleton class: hierarchy, joint metadata, and helpers for kinematics and motion.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\n\nfrom kimodo.assets import skeleton_asset_path\n\nfrom .kinematics import fk\nfrom .transforms import (\n    from_standard_tpose,\n    global_rots_to_local_rots,\n    to_standard_tpose,\n)\n\n\nclass SkeletonBase(torch.nn.Module):\n    \"\"\"Base class that stores a skeleton hierarchy and helper metadata.\n\n    Subclasses define the static joint layout (joint names and parent links) and semantic groups\n    (feet, hands, hips). This class builds index mappings, parent tensors, and convenience helpers\n    used by kinematics, constraints, and motion conversion utilities.\n    \"\"\"\n\n    # these should be defined in the subclass\n    name = None\n    bone_order_names_with_parents = None\n    bone_order_names_no_root = None\n    root_idx = None\n    foot_joint_names = None\n    foot_joint_idx = None\n    hip_joint_names = None  # in order [right, left]\n    hip_joint_idx = None  # in order [right, left]\n\n    def __init__(\n        self,\n        folder: Optional[str] = None,\n        name: Optional[str] = None,\n        load: bool = True,\n        **kwargs,  # to catch addition args in configs\n    ):\n        \"\"\"Initialize a skeleton instance and optional neutral-pose assets.\n\n        Args:\n            folder: Folder containing serialized skeleton assets (for example\n                `joints.p` and optional `standard_t_pose_global_offsets_rots.p`).\n            name: Optional runtime name used to validate subclass compatibility.\n            load: Whether to load tensor assets from `folder`.\n            **kwargs: Unused extra config keys kept for config compatibility.\n        \"\"\"\n        super().__init__()\n\n        if name is not None:\n            # Check that the name is not too far from the actual skeleton class name\n            assert self.name in name\n            self.name = name\n\n        if folder is None:\n            # Take the skeleton asset folder of the repo from the name\n            # in case we don't override it\n            folder = str(skeleton_asset_path(self.name))\n\n        self.folder = folder\n\n        self.dim = len(self.bone_order_names_with_parents)\n\n        if load and folder is not None:\n            pfolder = Path(folder)\n            neutral_joints = torch.load(pfolder / \"joints.p\").squeeze()\n            self.register_buffer(\"neutral_joints\", neutral_joints, persistent=False)\n\n            if (pfolder / \"bvh_joints.p\").exists():\n                bvh_neutral_joints = torch.load(pfolder / \"bvh_joints.p\").squeeze()\n                self.register_buffer(\"bvh_neutral_joints\", bvh_neutral_joints, persistent=False)\n\n            global_offset_path = pfolder / \"standard_t_pose_global_offsets_rots.p\"\n            if global_offset_path.exists():\n                global_rot_offsets = torch.load(global_offset_path).squeeze()\n                self.register_buffer(\"global_rot_offsets\", global_rot_offsets, persistent=False)\n            # Usefull for g1, where the rest pose is not zero\n            baked_rest_path = pfolder / \"rest_pose_local_rot.p\"\n            if baked_rest_path.exists():\n                rest_pose_local_rot = torch.load(baked_rest_path).squeeze()\n                self.register_buffer(\"rest_pose_local_rot\", rest_pose_local_rot, persistent=False)\n\n        self.bone_order_names = [x for x, y in self.bone_order_names_with_parents]\n\n        self.bone_parents = dict(self.bone_order_names_with_parents)\n        self.bone_index = {x: idx for idx, x in enumerate(self.bone_order_names)}\n        self.bone_order_names_index = self.bone_index\n\n        # create the parents tensor on the fly\n        joint_parents = torch.tensor(\n            [-1 if (y := self.bone_parents[x]) is None else self.bone_index[y] for x in self.bone_order_names]\n        )\n        self.register_buffer(\"joint_parents\", joint_parents, persistent=False)\n\n        self.nbjoints = len(self.bone_order_names)\n\n        # check lengths\n        assert self.nbjoints == len(self.joint_parents)\n        if \"neutral_joints\" in self.__dict__:\n            assert self.nbjoints == len(self.neutral_joints)\n\n        root_indices = torch.where(joint_parents == -1)[0]\n        assert len(root_indices) == 1  # should be one root only\n        self.root_idx = root_indices[0].item()\n\n        if \"neutral_joints\" in self.__dict__:\n            assert (self.neutral_joints[0] == 0).all()\n\n        # remove the root\n        self.bone_order_names_no_root = (\n            self.bone_order_names[: self.root_idx] + self.bone_order_names[self.root_idx + 1 :]\n        )\n\n        self.foot_joint_names = self.left_foot_joint_names + self.right_foot_joint_names\n        self.foot_joint_names_index = {x: idx for idx, x in enumerate(self.foot_joint_names)}\n\n        self.left_foot_joint_idx = [\n            self.bone_order_names.index(foot_joint) for foot_joint in self.left_foot_joint_names\n        ]\n\n        self.right_foot_joint_idx = [\n            self.bone_order_names.index(foot_joint) for foot_joint in self.right_foot_joint_names\n        ]\n\n        self.foot_joint_idx = self.left_foot_joint_idx + self.right_foot_joint_idx\n\n        self.hip_joint_idx = [self.bone_order_names.index(hip_joint) for hip_joint in self.hip_joint_names]\n\n    def expand_joint_names(self, joint_names):\n        \"\"\"Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to\n        constrain position and rotations.\n\n        Args:\n            joint_names: list of list of base EE names to constrain\n\n        Returns:\n            rot_joint_names: list of list of joint names to constrain rotations\n            pos_joint_names: list of list of joint names to constrain positions\n        \"\"\"\n\n        base_ee = [\"LeftFoot\", \"RightFoot\", \"LeftHand\", \"RightHand\", \"Hips\"]\n\n        pelvis_name = self.bone_order_names[self.root_idx]\n\n        base_pos_names = [\n            self.left_foot_joint_names,\n            self.right_foot_joint_names,\n            self.left_hand_joint_names,\n            self.right_hand_joint_names,\n            [pelvis_name],\n        ]\n        # base of each chain\n        base_rot_names = [\n            self.left_foot_joint_names[:1],\n            self.right_foot_joint_names[:1],\n            self.left_hand_joint_names[:1],\n            self.right_hand_joint_names[:1],\n            [pelvis_name],\n        ]\n        rot_joint_names = []\n        pos_joint_names = []\n        # loop through each EE joint group to constrain in the current keyframe\n        for jname in joint_names:\n            idx = base_ee.index(jname)\n            rot_joint_names += base_rot_names[idx]\n            pos_joint_names += base_pos_names[idx]\n        return rot_joint_names, pos_joint_names\n\n    def expand_joint_names_batched(self, joint_names):\n        \"\"\"Expand base EE names [LeftFoot, RightFoot, LeftHand, RightHand] actual joint names to\n        constrain position and rotations.\n\n        Args:\n            joint_names: list of list of base EE names to constrain\n\n        Returns:\n            rot_joint_names: list of list of joint names to constrain rotations\n            pos_joint_names: list of list of joint names to constrain positions\n        \"\"\"\n\n        base_ee = [\"LeftFoot\", \"RightFoot\", \"LeftHand\", \"RightHand\", \"Hips\"]\n\n        pelvis_name = self.bone_order_names[self.root_idx]\n\n        base_pos_names = [\n            self.left_foot_joint_names,\n            self.right_foot_joint_names,\n            self.left_hand_joint_names,\n            self.right_hand_joint_names,\n            [pelvis_name],\n        ]\n        # base of each chain\n        base_rot_names = [\n            self.left_foot_joint_names[:1],\n            self.right_foot_joint_names[:1],\n            self.left_hand_joint_names[:1],\n            self.right_hand_joint_names[:1],\n            [pelvis_name],\n        ]\n        # loop through each keyframe\n        rot_joint_names = []\n        pos_joint_names = []\n        for key_joint_names in joint_names:\n            key_rot_names = []\n            key_pos_names = []\n            # loop through each EE joint group to constrain in the current keyframe\n            for jname in key_joint_names:\n                idx = base_ee.index(jname)\n                key_rot_names += base_rot_names[idx]\n                key_pos_names += base_pos_names[idx]\n            rot_joint_names.append(key_rot_names)\n            pos_joint_names.append(key_pos_names)\n        return rot_joint_names, pos_joint_names\n\n    def __repr__(self):\n        if self.folder is None:\n            return f\"{self.__class__.__name__}()\"\n        return f'{self.__class__.__name__}(folder=\"{self.folder}\")'\n\n    @property\n    def device(self):\n        \"\"\"Device where neutral-joint buffers are stored.\n\n        Returns 'cpu' if neutral_joints is not present.\n        \"\"\"\n        if getattr(self, \"neutral_joints\", None) is None:\n            return \"cpu\"\n        return self.neutral_joints.device\n\n    def fk(self, local_joint_rots: torch.Tensor, root_positions: torch.Tensor):\n        \"\"\"Run forward kinematics for this skeleton layout.\n\n        Args:\n            local_joint_rots: Local joint rotation matrices with shape\n                `(..., J, 3, 3)`.\n            root_positions: Root translations with shape `(..., 3)`.\n\n        Returns:\n            Tuple of `(global_joint_rots, posed_joints, posed_joints_norootpos)`.\n        \"\"\"\n        global_joint_rots, posed_joints, posed_joints_norootpos = fk(local_joint_rots, root_positions, self)\n        return global_joint_rots, posed_joints, posed_joints_norootpos\n\n    def to_standard_tpose(self, local_rot_mats: torch.Tensor):\n        \"\"\"Convert local rotations into the skeleton's standard T-pose frame.\"\"\"\n        return to_standard_tpose(local_rot_mats, self)\n\n    def from_standard_tpose(self, local_rot_mats: torch.Tensor):\n        \"\"\"Convert local rotations from the skeleton's standard T-pose frame.\"\"\"\n        return from_standard_tpose(local_rot_mats, self)\n\n    def global_rots_to_local_rots(self, global_joint_rots: torch.Tensor):\n        \"\"\"Convert global joint rotations to local rotations for this hierarchy.\"\"\"\n        return global_rots_to_local_rots(global_joint_rots, self)\n\n    def get_skel_slice(self, skeleton: \"SkeletonBase\"):\n        \"\"\"Build index mapping from another skeleton into this skeleton order.\n\n        Args:\n            skeleton: Source skeleton whose joint order is used by input tensors.\n\n        Returns:\n            A list of source indices ordered as `self.bone_order_names`.\n\n        Raises:\n            ValueError: If at least one required joint is missing from `skeleton`.\n        \"\"\"\n        try:\n            skel_slice = [skeleton.bone_index[x] for x in self.bone_order_names]\n        except KeyError:\n            raise ValueError(\"The current skeleton contain joints that are not in the input\")\n        return skel_slice\n"
  },
  {
    "path": "kimodo/skeleton/bvh.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"BVH parsing utilities and skeleton/animation conversion helpers.\"\"\"\n\nimport re\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom scipy.spatial.transform import Rotation\n\n\nclass BvhNode:\n    \"\"\"Lightweight tree node used to represent parsed BVH hierarchy lines.\"\"\"\n\n    def __init__(self, value=[], parent=None):\n        \"\"\"Create a node from tokenized BVH line values.\"\"\"\n        self.value = value\n        self.children = []\n        self.parent = parent\n        if self.parent:\n            self.parent.add_child(self)\n\n    def add_child(self, item):\n        \"\"\"Attach a child node and set its parent reference.\"\"\"\n        item.parent = self\n        self.children.append(item)\n\n    def filter(self, key):\n        \"\"\"Yield direct children whose first token matches `key`.\"\"\"\n        for child in self.children:\n            if child.value[0] == key:\n                yield child\n\n    def __iter__(self):\n        for child in self.children:\n            yield child\n\n    def __getitem__(self, key):\n        \"\"\"Return all tokens following `key` from the first matching child node.\"\"\"\n        for child in self.children:\n            for index, item in enumerate(child.value):\n                if item == key:\n                    if index + 1 >= len(child.value):\n                        return None\n                    else:\n                        return child.value[index + 1 :]\n        raise IndexError(\"key {} not found\".format(key))\n\n    def __repr__(self):\n        return str(\" \".join(self.value))\n\n    @property\n    def name(self):\n        \"\"\"Joint name for `ROOT`/`JOINT` entries.\"\"\"\n        return self.value[1]\n\n\nclass Bvh:\n    \"\"\"Parsed BVH file with hierarchy graph and per-frame channel values.\"\"\"\n\n    def __init__(self, data: str, backend: Optional[str] = \"graph\"):\n        \"\"\"\n        Args:\n            data: Raw BVH file content.\n            backend: Parsing mode. `\"graph\"` keeps list-based frame storage,\n                while `\"np\"` precomputes a NumPy array and index caches.\n        \"\"\"\n        self.data = data\n        self.root = BvhNode()\n        self.frames = []\n        self.backend = backend\n        self.tokenize()\n        if self.backend == \"np\":\n            # cache important info for quick access later\n            self.build_data_array()\n        elif self.backend == \"graph\":\n            pass\n        else:\n            raise ValueError(f\"Unknown backend for BVH loading: {backend}\")\n\n    def build_data_array(self):\n        \"\"\"Build cached channel indices and contiguous frame data for `\"np\"` backend.\"\"\"\n        joints = self.get_joints()\n        self.joint2idx = dict()\n        self.joint2channels = dict()\n        cur_idx = 0\n        for joint in joints:\n            self.joint2idx[joint.value[1]] = cur_idx\n            cur_idx += int(joint[\"CHANNELS\"][0])\n            self.joint2channels[joint.value[1]] = joint[\"CHANNELS\"][1:]\n        self.np_data_array = np.array(self.frames, dtype=np.float32)\n\n    def tokenize(self):\n        \"\"\"Tokenize BVH text and populate hierarchy plus frame values.\"\"\"\n        first_round = []\n        accumulator = \"\"\n        for char in self.data:\n            if char not in (\"\\n\", \"\\r\"):\n                accumulator += char\n            elif accumulator:\n                first_round.append(re.split(\"\\\\s+\", accumulator.strip()))\n                accumulator = \"\"\n        node_stack = [self.root]\n        frame_time_found = False\n        node = None\n        for item in first_round:\n            if frame_time_found:\n                self.frames.append(item)\n                continue\n            key = item[0]\n            if key == \"{\":\n                node_stack.append(node)\n            elif key == \"}\":\n                node_stack.pop()\n            else:\n                node = BvhNode(item)\n                # print(\"new node: \", node, \"\\nparent: \", node_stack[-1])\n                node_stack[-1].add_child(node)\n            if item[0] == \"Frame\" and item[1] == \"Time:\":\n                frame_time_found = True\n\n    def search(self, *items):\n        \"\"\"Depth-first search for nodes matching a prefix of tokens.\"\"\"\n        found_nodes = []\n\n        def check_children(node):\n            if len(node.value) >= len(items):\n                failed = False\n                for index, item in enumerate(items):\n                    if node.value[index] != item:\n                        failed = True\n                        break\n                if not failed:\n                    found_nodes.append(node)\n            for child in node:\n                check_children(child)\n\n        check_children(self.root)\n        return found_nodes\n\n    def get_joints(self):\n        \"\"\"Return all `ROOT`/`JOINT` hierarchy joints in BVH traversal order.\"\"\"\n        joints = []\n\n        def iterate_joints(joint):\n            joints.append(joint)\n            for child in joint.filter(\"JOINT\"):\n                iterate_joints(child)\n\n        iterate_joints(next(self.root.filter(\"ROOT\")))\n        return joints\n\n    def get_joints_names(self):\n        \"\"\"Return joint names in the same order as :meth:`get_joints`.\"\"\"\n        joints = []\n\n        def iterate_joints(joint):\n            joints.append(joint.value[1])\n            for child in joint.filter(\"JOINT\"):\n                iterate_joints(child)\n\n        iterate_joints(next(self.root.filter(\"ROOT\")))\n        return joints\n\n    def joint_direct_children(self, name):\n        \"\"\"Return direct child joints of the given joint name.\"\"\"\n        joint = self.get_joint(name)\n        return [child for child in joint.filter(\"JOINT\")]\n\n    def get_joint_index(self, name):\n        \"\"\"Return hierarchy index of the named joint.\"\"\"\n        return self.get_joints().index(self.get_joint(name))\n\n    def get_joint(self, name):\n        \"\"\"Return hierarchy node for a joint name.\"\"\"\n        found = self.search(\"ROOT\", name)\n        if not found:\n            found = self.search(\"JOINT\", name)\n        if found:\n            return found[0]\n        raise LookupError(\"joint not found\")\n\n    def joint_offset(self, name, idx=[0, 1, 2]):\n        \"\"\"Return selected `OFFSET` components for a joint.\"\"\"\n        joint = self.get_joint(name)\n        offset = joint[\"OFFSET\"]\n        if len(offset) < max(idx):\n            return None\n        return (float(offset[idx[0]]), float(offset[idx[1]]), float(offset[idx[2]]))\n\n    def joint_offset_rot(self, name):\n        \"\"\"Return optional rotational offset components from custom BVH files.\"\"\"\n        return self.joint_offset(name, idx=[3, 4, 5])\n\n    def joint_channels(self, name):\n        \"\"\"Return channel names declared for a joint.\"\"\"\n        if self.backend == \"np\":\n            return self.joint2channels[name]\n        else:\n            joint = self.get_joint(name)\n            return joint[\"CHANNELS\"][1:]\n\n    def get_joint_channels_index(self, joint_name):\n        \"\"\"Return the flattened starting channel index for one joint.\"\"\"\n        if self.backend == \"np\":\n            return self.joint2idx[joint_name]\n        else:\n            index = 0\n            for joint in self.get_joints():\n                if joint.value[1] == joint_name:\n                    return index\n                index += int(joint[\"CHANNELS\"][0])\n            raise LookupError(\"joint not found\")\n\n    def get_joint_channel_index(self, joint, channel):\n        \"\"\"Return per-joint channel offset for a specific channel name.\"\"\"\n        channels = self.joint_channels(joint)\n        if channel in channels:\n            channel_index = channels.index(channel)\n        else:\n            raise ValueError(f\"Channel {channel} not found in {channels}\")\n        return channel_index\n\n    def frame_joint_channel(self, frame_index, joint, channel, value=None):\n        \"\"\"Return one channel value for one joint at one frame index.\"\"\"\n        joint_index = self.get_joint_channels_index(joint)\n        channel_index = self.get_joint_channel_index(joint, channel)\n        if channel_index == -1 and value is not None:\n            return value\n        if self.backend == \"np\":\n            return self.np_data_array[frame_index, joint_index + channel_index]\n        else:\n            return float(self.frames[frame_index][joint_index + channel_index])\n\n    def frame_joint_channels(self, frame_index, joint, channels, value=None):\n        \"\"\"Get single frame data for on specific joint from multiple specific channels (e.g.\n        Xrotation, Yrotation, Zrotation).\"\"\"\n        values = []\n        joint_index = self.get_joint_channels_index(joint)\n        if self.backend == \"np\":\n            channel_idx = [self.get_joint_channel_index(joint, channel) for channel in channels]\n            channel_idx = np.array(channel_idx) + joint_index\n            values = self.np_data_array[frame_index, channel_idx]\n        else:\n            for channel in channels:\n                channel_index = self.get_joint_channel_index(joint, channel)\n                if channel_index == -1 and value is not None:\n                    values.append(value)\n                else:\n                    values.append(float(self.frames[frame_index][joint_index + channel_index]))\n        return values\n\n    def frames_joint_channels(self, joint, channels, value=None):\n        \"\"\"Get all frame data for one joint from multiple channels (e.g. Xrotation, Yrotation,\n        Zrotation).\"\"\"\n        joint_index = self.get_joint_channels_index(joint)\n        if self.backend == \"np\":\n            channel_idx = [self.get_joint_channel_index(joint, channel) for channel in channels]\n            channel_idx = np.array(channel_idx) + joint_index\n            all_frames = self.np_data_array[:, channel_idx]\n        else:\n            all_frames = []\n            for frame in self.frames:\n                values = []\n                for channel in channels:\n                    channel_index = self.get_joint_channel_index(joint, channel)\n                    if channel_index == -1 and value is not None:\n                        values.append(value)\n                    else:\n                        values.append(float(frame[joint_index + channel_index]))\n                all_frames.append(values)\n        return all_frames\n\n    def frames_joints_channels(self, joint_names, channels):\n        \"\"\"Get all frames for all specified joints with one specified set of channels.\"\"\"\n        if self.backend != \"np\":\n            raise NotImplementedError(\"Only np backend is supported for this function\")\n        joint_indices = [(joint_name, self.joint2idx[joint_name]) for joint_name in joint_names]\n        data_indices = []\n        for joint_name, joint_idx in joint_indices:\n            channel_indices = [self.get_joint_channel_index(joint_name, channel) for channel in channels]\n            data_indices.extend([joint_idx + channel_idx for channel_idx in channel_indices])\n        all_frames = self.np_data_array[:, data_indices]\n        all_frames = all_frames.reshape(-1, len(joint_names), len(channels))\n        return all_frames\n\n    def joint_parent(self, name):\n        \"\"\"Return parent joint node, or `None` for the root.\"\"\"\n        joint = self.get_joint(name)\n        if joint.parent == self.root:\n            return None\n        return joint.parent\n\n    def joint_parent_index(self, name):\n        \"\"\"Return parent joint index, or `-1` for the root.\"\"\"\n        joint = self.get_joint(name)\n        if joint.parent == self.root:\n            return -1\n        return self.get_joints().index(joint.parent)\n\n    @property\n    def nframes(self):\n        \"\"\"Number of motion frames declared in the BVH header.\"\"\"\n        try:\n            return int(next(self.root.filter(\"Frames:\")).value[1])\n        except StopIteration:\n            raise LookupError(\"number of frames not found\")\n\n    @property\n    def frame_time(self):\n        \"\"\"Frame duration in seconds declared in the BVH header.\"\"\"\n        try:\n            return float(next(self.root.filter(\"Frame\")).value[2])\n        except StopIteration:\n            raise LookupError(\"frame time not found\")\n\n\nclass Bone:\n    \"\"\"Container for one skeleton bone and its kinematic metadata.\"\"\"\n\n    def __init__(self):\n        # original bone info\n        self.id = None\n        self.name = None\n        self.orient = np.identity(3)\n        self.dof_index = []\n        self.channels = []  # bvh only\n        self.lb = []\n        self.ub = []\n        self.parent = None\n        self.child = []\n\n        # asf specific\n        self.dir = np.zeros(3)\n        self.len = 0\n        # bvh specific\n        self.offset = np.zeros(3)  # default offset for position\n        self.offset_rot = None  # rotation for custom nv bvh\n\n        # inferred info\n        self.pos = np.zeros(3)\n        self.end = np.zeros(3)\n\n    def __repr__(self):\n        return f\"{self.name}\"\n\n\nclass SkeletonBvh:\n    \"\"\"Skeleton structure reconstructed from BVH hierarchy metadata.\"\"\"\n\n    def __init__(self):\n        self.bones = []\n        self.name2bone = {}\n        self.mass_scale = 1.0\n        self.len_scale = 1.0\n        self.dof_name = [\"x\", \"y\", \"z\"]\n        self.root = None\n\n    def get_bones_names(self):\n        \"\"\"Return bone names in skeleton order.\"\"\"\n        return [x.name for x in self.bones]\n\n    def get_parent_indices(self):\n        \"\"\"Return parent index array aligned with `self.bones`.\"\"\"\n        parent_indices = [-1] * len(self.bones)\n        for bone in self.bones:\n            if bone.parent:\n                parent_indices[bone.id] = bone.parent.id\n        return parent_indices\n\n    def get_neutral_joints(self):\n        \"\"\"Return neutral/rest joint positions as a NumPy array `(J, 3)`.\"\"\"\n        joints = []\n        for bone in self.bones:\n            joints.append(bone.pos)\n        joints = np.stack(joints, axis=0)\n        return joints\n\n    def load_from_bvh(self, fname, exclude_bones=None, spec_channels=None, mocap=None):\n        \"\"\"Load skeleton hierarchy and rest offsets from a BVH file.\n\n        Args:\n            fname: Path to a BVH file (ignored when *mocap* is given).\n            exclude_bones: Bone-name substrings to ignore while constructing the\n                skeleton.\n            spec_channels: Optional per-joint channel overrides.\n            mocap: Pre-parsed :class:`Bvh` object.  When provided the file is\n                not re-read from disk.\n        \"\"\"\n        if exclude_bones is None:\n            exclude_bones = {}\n        if spec_channels is None:\n            spec_channels = dict()\n        if mocap is None:\n            with open(fname) as f:\n                mocap = Bvh(f.read())\n\n        joint_names = list(\n            filter(\n                lambda x: all([t not in x for t in exclude_bones]),\n                mocap.get_joints_names(),\n            )\n        )\n        dof_ind = {\"x\": 0, \"y\": 1, \"z\": 2}\n        self.len_scale = 1.0\n        self.root = Bone()\n        self.root.id = 0\n        self.root.name = joint_names[0]\n        self.root.channels = mocap.joint_channels(self.root.name)\n        self.root.offset = np.array(mocap.joint_offset(self.root.name)) * self.len_scale\n        self.root.offset_rot = mocap.joint_offset_rot(self.root.name)\n        if self.root.offset_rot is not None:\n            self.root.offset_rot = np.array(self.root.offset_rot)\n        # self.root.offset = np.zeros_like(self.root.offset) # TODO: remove this\n        self.name2bone[self.root.name] = self.root\n        self.bones.append(self.root)\n        for i, joint in enumerate(joint_names[1:]):\n            bone = Bone()\n            bone.id = i + 1\n            bone.name = joint\n            bone.channels = spec_channels[joint] if joint in spec_channels.keys() else mocap.joint_channels(joint)\n            bone.dof_index = [dof_ind[x[0].lower()] for x in bone.channels]\n            bone.offset = np.array(mocap.joint_offset(joint)) * self.len_scale\n            bone.offset_rot = mocap.joint_offset_rot(joint)\n            if bone.offset_rot is not None:\n                bone.offset_rot = np.array(bone.offset_rot)\n            bone.lb = [-180.0] * 3\n            bone.ub = [180.0] * 3\n            self.bones.append(bone)\n            self.name2bone[joint] = bone\n\n        # for bone in self.bones:\n        # print(bone.name, bone.channels, bone.offset)\n\n        for bone in self.bones[1:]:\n            parent_name = mocap.joint_parent(bone.name).name\n            if parent_name in self.name2bone.keys():\n                bone_p = self.name2bone[parent_name]\n                bone_p.child.append(bone)\n                bone.parent = bone_p\n\n        self.forward_bvh(self.root)\n        for bone in self.bones:\n            if len(bone.child) == 0:\n                child_vals = [str(node) for node in mocap.get_joint(bone.name).children]\n                if \"End Site\" in child_vals:\n                    end_site_idx = child_vals.index(\"End Site\")\n                    end_site_offset = mocap.get_joint(bone.name).children[end_site_idx][\"OFFSET\"]\n                    bone.end = bone.pos + np.array([float(x) for x in end_site_offset]) * self.len_scale\n                else:\n                    pass\n            else:\n                bone.end = sum([bone_c.pos for bone_c in bone.child]) / len(bone.child)\n\n    def forward_bvh(self, bone):\n        \"\"\"Recursively accumulate absolute joint positions from local offsets.\"\"\"\n        if bone.parent:\n            bone.pos = bone.parent.pos + bone.offset\n        else:\n            bone.pos = bone.offset\n        for bone_c in bone.child:\n            self.forward_bvh(bone_c)\n\n\ndef load_bvh_animation(\n    fname: str,\n    skeleton: SkeletonBvh,\n    rot_order: Optional[str] = \"native\",\n    backend: Optional[str] = \"np\",\n    return_quat: Optional[bool] = False,\n    mocap: Optional[\"Bvh\"] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Load motion channels from BVH into root translations and joint rotations.\n\n    Args:\n        fname: Full path to the BVH file (ignored when *mocap* is given).\n        skeleton: Parsed neutral skeleton built from compatible BVH hierarchy.\n        rot_order: Euler order to use for conversion (`\"native\"` keeps BVH order).\n        backend: BVH parser backend (`\"np\"` or `\"graph\"`).\n        return_quat: If `True`, return quaternions instead of rotation matrices.\n        mocap: Pre-parsed :class:`Bvh` object.  When provided the file is\n            not re-read from disk.\n\n    Returns:\n        Root translations `(T, 3)` and joint rotations `(T, J, 3, 3)` or\n        `(T, J, 4)` when `return_quat=True`.\n    \"\"\"\n    if mocap is None:\n        with open(fname) as f:\n            mocap = Bvh(f.read(), backend=backend)\n\n    # assume all joints are same ordering, load in with native ordering\n    root_channels = mocap.joint_channels(skeleton.root.name)\n    pos_channels = [channel for channel in root_channels if channel.endswith(\"position\")]\n    rot_channels = [channel for channel in root_channels if channel.endswith(\"rotation\")]\n\n    root_trans = np.array(mocap.frames_joint_channels(skeleton.root.name, pos_channels))\n\n    effective_backend = mocap.backend\n    if effective_backend == \"np\":\n        # NOTE: assumes rot channel ordering is the same for all joints\n        joint_eulers = mocap.frames_joints_channels(skeleton.get_bones_names(), rot_channels)\n        joint_eulers = np.deg2rad(joint_eulers)\n    elif effective_backend == \"graph\":\n        joint_eulers = []\n        for bone in skeleton.bones:\n            bone_channels = mocap.joint_channels(bone.name)\n            bone_rot_channels = [channel for channel in bone_channels if channel.endswith(\"rotation\")]\n            assert bone_rot_channels == rot_channels, \"Rotation channel ordering is not consistent across joints!\"\n            # use native rotation order\n            euler = np.deg2rad(np.array(mocap.frames_joint_channels(bone.name, rot_channels)))\n            joint_eulers.append(euler)\n        joint_eulers = np.stack(joint_eulers, axis=1)\n    else:\n        raise ValueError(f\"Unknown backend for BVH loading: {effective_backend}\")\n\n    if rot_order == \"native\":\n        rot_order = \"\"\n        for axis in rot_channels:\n            rot_order += axis[0]\n    else:\n        # need to reorder dims\n        ordered_joint_eulers = []\n        for axis in rot_order:\n            i = rot_channels.index(axis + \"rotation\")\n            ordered_joint_eulers.append(joint_eulers[..., i])\n        joint_eulers = np.stack(ordered_joint_eulers, axis=-1)\n\n    rotations = Rotation.from_euler(rot_order, joint_eulers.reshape(-1, 3))\n    if return_quat:\n        joint_rots = rotations.as_quat(scalar_first=True).reshape(joint_eulers.shape[:-1] + (4,))\n    else:\n        joint_rots = rotations.as_matrix().reshape(joint_eulers.shape[:-1] + (3, 3))\n\n    return root_trans, joint_rots\n\n\ndef parse_bvh_motion(file_path_input: str, parse_neutral_joints: bool = False):\n    \"\"\"Parse a BVH motion into tensors used by kimodo motion pipelines.\n\n    Args:\n        file_path_input: Path to input BVH file.\n        parse_neutral_joints: If `True`, also return neutral joints in meters.\n\n    Returns:\n        ``(local_rot_mats, root_trans, fps)`` or\n        ``(local_rot_mats, root_trans, fps, neutral_joints)`` when requested.\n    \"\"\"\n    with open(file_path_input) as f:\n        mocap = Bvh(f.read(), backend=\"np\")\n\n    fps = 1.0 / mocap.frame_time\n\n    skeletonBVH = SkeletonBvh()\n    exclude_bones = {\"Root\"}\n    skeletonBVH.load_from_bvh(file_path_input, exclude_bones=exclude_bones, mocap=mocap)\n\n    root_trans, local_rot_mats = load_bvh_animation(file_path_input, skeletonBVH, mocap=mocap)\n    root_trans *= 0.01  # unit change: cm -> m\n    root_trans = torch.tensor(root_trans)\n    local_rot_mats = torch.tensor(local_rot_mats)\n\n    # Don't parse neutral_joints here\n    # it is not actually needed right now:\n    # the skeleton is always the same, and saved in the folder\n    # carefull: the one saved in the folder it relative to the standard t_pose\n    # whereas the parsed one is not\n    if not parse_neutral_joints:\n        return local_rot_mats, root_trans, fps\n\n    neutral_joints = skeletonBVH.get_neutral_joints()\n    neutral_joints *= 0.01  # unit change: cm -> m\n    # remove the root position of the skeleton\n    # (it is already \"included\" in the root_translation)\n    root_idx = 0\n    neutral_joints = torch.tensor(neutral_joints - neutral_joints[root_idx])\n    return local_rot_mats, root_trans, fps, neutral_joints\n"
  },
  {
    "path": "kimodo/skeleton/definitions.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Concrete skeleton definitions: SOMA, G1, SMPLX with joint names and hierarchy.\"\"\"\n\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom ..tools import ensure_batched\nfrom .base import SkeletonBase\n\n\nclass SOMASkeleton77(SkeletonBase):\n    \"\"\"High-detail 77-joint SOMA skeleton with full finger and toe chains.\"\"\"\n\n    name = \"somaskel77\"\n\n    right_foot_joint_names = [\n        \"RightFoot\",\n        \"RightToeBase\",\n        \"RightToeEnd\",\n    ]  # in order of chain\n    left_foot_joint_names = [\n        \"LeftFoot\",\n        \"LeftToeBase\",\n        \"LeftToeEnd\",\n    ]  # in order of chain\n    right_hand_joint_names = [\n        \"RightHand\",\n        \"RightHandThumb1\",\n        \"RightHandThumb2\",\n        \"RightHandThumb3\",\n        \"RightHandThumbEnd\",\n        \"RightHandIndex1\",\n        \"RightHandIndex2\",\n        \"RightHandIndex3\",\n        \"RightHandIndex4\",\n        \"RightHandIndexEnd\",\n        \"RightHandMiddle1\",\n        \"RightHandMiddle2\",\n        \"RightHandMiddle3\",\n        \"RightHandMiddle4\",\n        \"RightHandMiddleEnd\",\n        \"RightHandRing1\",\n        \"RightHandRing2\",\n        \"RightHandRing3\",\n        \"RightHandRing4\",\n        \"RightHandRingEnd\",\n        \"RightHandPinky1\",\n        \"RightHandPinky2\",\n        \"RightHandPinky3\",\n        \"RightHandPinky4\",\n        \"RightHandPinkyEnd\",\n    ]  # in order of chain\n    left_hand_joint_names = [\n        \"LeftHand\",\n        \"LeftHandThumb1\",\n        \"LeftHandThumb2\",\n        \"LeftHandThumb3\",\n        \"LeftHandThumbEnd\",\n        \"LeftHandIndex1\",\n        \"LeftHandIndex2\",\n        \"LeftHandIndex3\",\n        \"LeftHandIndex4\",\n        \"LeftHandIndexEnd\",\n        \"LeftHandMiddle1\",\n        \"LeftHandMiddle2\",\n        \"LeftHandMiddle3\",\n        \"LeftHandMiddle4\",\n        \"LeftHandMiddleEnd\",\n        \"LeftHandRing1\",\n        \"LeftHandRing2\",\n        \"LeftHandRing3\",\n        \"LeftHandRing4\",\n        \"LeftHandRingEnd\",\n        \"LeftHandPinky1\",\n        \"LeftHandPinky2\",\n        \"LeftHandPinky3\",\n        \"LeftHandPinky4\",\n        \"LeftHandPinkyEnd\",\n    ]  # in order of chain\n\n    hip_joint_names = [\"RightLeg\", \"LeftLeg\"]  # in order [right, left]\n\n    bone_order_names_with_parents = [\n        (\"Hips\", None),\n        (\"Spine1\", \"Hips\"),\n        (\"Spine2\", \"Spine1\"),\n        (\"Chest\", \"Spine2\"),\n        (\"Neck1\", \"Chest\"),\n        (\"Neck2\", \"Neck1\"),\n        (\"Head\", \"Neck2\"),\n        (\"HeadEnd\", \"Head\"),\n        (\"Jaw\", \"Head\"),\n        (\"LeftEye\", \"Head\"),\n        (\"RightEye\", \"Head\"),\n        (\"LeftShoulder\", \"Chest\"),\n        (\"LeftArm\", \"LeftShoulder\"),\n        (\"LeftForeArm\", \"LeftArm\"),\n        (\"LeftHand\", \"LeftForeArm\"),\n        (\"LeftHandThumb1\", \"LeftHand\"),\n        (\"LeftHandThumb2\", \"LeftHandThumb1\"),\n        (\"LeftHandThumb3\", \"LeftHandThumb2\"),\n        (\"LeftHandThumbEnd\", \"LeftHandThumb3\"),\n        (\"LeftHandIndex1\", \"LeftHand\"),\n        (\"LeftHandIndex2\", \"LeftHandIndex1\"),\n        (\"LeftHandIndex3\", \"LeftHandIndex2\"),\n        (\"LeftHandIndex4\", \"LeftHandIndex3\"),\n        (\"LeftHandIndexEnd\", \"LeftHandIndex4\"),\n        (\"LeftHandMiddle1\", \"LeftHand\"),\n        (\"LeftHandMiddle2\", \"LeftHandMiddle1\"),\n        (\"LeftHandMiddle3\", \"LeftHandMiddle2\"),\n        (\"LeftHandMiddle4\", \"LeftHandMiddle3\"),\n        (\"LeftHandMiddleEnd\", \"LeftHandMiddle4\"),\n        (\"LeftHandRing1\", \"LeftHand\"),\n        (\"LeftHandRing2\", \"LeftHandRing1\"),\n        (\"LeftHandRing3\", \"LeftHandRing2\"),\n        (\"LeftHandRing4\", \"LeftHandRing3\"),\n        (\"LeftHandRingEnd\", \"LeftHandRing4\"),\n        (\"LeftHandPinky1\", \"LeftHand\"),\n        (\"LeftHandPinky2\", \"LeftHandPinky1\"),\n        (\"LeftHandPinky3\", \"LeftHandPinky2\"),\n        (\"LeftHandPinky4\", \"LeftHandPinky3\"),\n        (\"LeftHandPinkyEnd\", \"LeftHandPinky4\"),\n        (\"RightShoulder\", \"Chest\"),\n        (\"RightArm\", \"RightShoulder\"),\n        (\"RightForeArm\", \"RightArm\"),\n        (\"RightHand\", \"RightForeArm\"),\n        (\"RightHandThumb1\", \"RightHand\"),\n        (\"RightHandThumb2\", \"RightHandThumb1\"),\n        (\"RightHandThumb3\", \"RightHandThumb2\"),\n        (\"RightHandThumbEnd\", \"RightHandThumb3\"),\n        (\"RightHandIndex1\", \"RightHand\"),\n        (\"RightHandIndex2\", \"RightHandIndex1\"),\n        (\"RightHandIndex3\", \"RightHandIndex2\"),\n        (\"RightHandIndex4\", \"RightHandIndex3\"),\n        (\"RightHandIndexEnd\", \"RightHandIndex4\"),\n        (\"RightHandMiddle1\", \"RightHand\"),\n        (\"RightHandMiddle2\", \"RightHandMiddle1\"),\n        (\"RightHandMiddle3\", \"RightHandMiddle2\"),\n        (\"RightHandMiddle4\", \"RightHandMiddle3\"),\n        (\"RightHandMiddleEnd\", \"RightHandMiddle4\"),\n        (\"RightHandRing1\", \"RightHand\"),\n        (\"RightHandRing2\", \"RightHandRing1\"),\n        (\"RightHandRing3\", \"RightHandRing2\"),\n        (\"RightHandRing4\", \"RightHandRing3\"),\n        (\"RightHandRingEnd\", \"RightHandRing4\"),\n        (\"RightHandPinky1\", \"RightHand\"),\n        (\"RightHandPinky2\", \"RightHandPinky1\"),\n        (\"RightHandPinky3\", \"RightHandPinky2\"),\n        (\"RightHandPinky4\", \"RightHandPinky3\"),\n        (\"RightHandPinkyEnd\", \"RightHandPinky4\"),\n        (\"LeftLeg\", \"Hips\"),\n        (\"LeftShin\", \"LeftLeg\"),\n        (\"LeftFoot\", \"LeftShin\"),\n        (\"LeftToeBase\", \"LeftFoot\"),\n        (\"LeftToeEnd\", \"LeftToeBase\"),\n        (\"RightLeg\", \"Hips\"),\n        (\"RightShin\", \"RightLeg\"),\n        (\"RightFoot\", \"RightShin\"),\n        (\"RightToeBase\", \"RightFoot\"),\n        (\"RightToeEnd\", \"RightToeBase\"),\n    ]\n\n    @property\n    def relaxed_hands_rest_pose(self):\n        # lazy loading\n        if hasattr(self, \"_relaxed_hands_rest_pose\"):\n            return self._relaxed_hands_rest_pose\n\n        relaxed_hands_pose_path = Path(self.folder) / \"relaxed_hands_rest_pose.npy\"\n        relaxed_hands_rest_pose = torch.from_numpy(np.load(relaxed_hands_pose_path)).squeeze()\n        self.register_buffer(\n            \"_relaxed_hands_rest_pose\",\n            relaxed_hands_rest_pose,\n            persistent=False,\n        )\n        return self._relaxed_hands_rest_pose\n\n\nclass SOMASkeleton30(SkeletonBase):\n    \"\"\"Compact 30-joint SOMA variant with reduced hand and end-effector detail.\"\"\"\n\n    name = \"somaskel30\"\n\n    right_foot_joint_names = [\n        \"RightFoot\",\n        \"RightToeBase\",\n    ]  # in order of chain\n    left_foot_joint_names = [\n        \"LeftFoot\",\n        \"LeftToeBase\",\n    ]  # in order of chain\n    right_hand_joint_names = [\n        \"RightHand\",\n        \"RightHandMiddleEnd\",\n    ]  # in order of chain\n    left_hand_joint_names = [\n        \"LeftHand\",\n        \"LeftHandMiddleEnd\",\n    ]  # in order of chain\n\n    hip_joint_names = [\"RightLeg\", \"LeftLeg\"]  # in order [right, left]\n\n    bone_order_names_with_parents = [\n        (\"Hips\", None),\n        (\"Spine1\", \"Hips\"),\n        (\"Spine2\", \"Spine1\"),\n        (\"Chest\", \"Spine2\"),\n        (\"Neck1\", \"Chest\"),\n        (\"Neck2\", \"Neck1\"),\n        (\"Head\", \"Neck2\"),\n        (\"Jaw\", \"Head\"),\n        (\"LeftEye\", \"Head\"),\n        (\"RightEye\", \"Head\"),\n        (\"LeftShoulder\", \"Chest\"),\n        (\"LeftArm\", \"LeftShoulder\"),\n        (\"LeftForeArm\", \"LeftArm\"),\n        (\"LeftHand\", \"LeftForeArm\"),\n        (\"LeftHandThumbEnd\", \"LeftHand\"),\n        (\"LeftHandMiddleEnd\", \"LeftHand\"),\n        (\"RightShoulder\", \"Chest\"),\n        (\"RightArm\", \"RightShoulder\"),\n        (\"RightForeArm\", \"RightArm\"),\n        (\"RightHand\", \"RightForeArm\"),\n        (\"RightHandThumbEnd\", \"RightHand\"),\n        (\"RightHandMiddleEnd\", \"RightHand\"),\n        (\"LeftLeg\", \"Hips\"),\n        (\"LeftShin\", \"LeftLeg\"),\n        (\"LeftFoot\", \"LeftShin\"),\n        (\"LeftToeBase\", \"LeftFoot\"),\n        (\"RightLeg\", \"Hips\"),\n        (\"RightShin\", \"RightLeg\"),\n        (\"RightFoot\", \"RightShin\"),\n        (\"RightToeBase\", \"RightFoot\"),\n    ]\n\n    @property\n    def somaskel77(self):\n        # lazy loading\n        if not hasattr(self, \"_somaskel77\"):\n            self._somaskel77 = SOMASkeleton77()\n        return self._somaskel77\n\n    @ensure_batched(local_joint_rots_subset=4)\n    def to_SOMASkeleton77(self, local_joint_rots_subset: torch.Tensor):\n        # Converting from 30-joint to 77-joint to have relaxed hands\n\n        device = local_joint_rots_subset.device\n        nF = len(local_joint_rots_subset)\n        local_joint_rots_mats = self.somaskel77.relaxed_hands_rest_pose.clone().to(device).repeat(nF, 1, 1, 1)\n\n        skel_slice = self.get_skel_slice(self.somaskel77)\n        local_joint_rots_mats[:, skel_slice] = local_joint_rots_subset\n        return local_joint_rots_mats\n\n    @ensure_batched(local_joint_rots_full=4) # [BT, J, 3, 3]\n    def from_SOMASkeleton77(self, local_joint_rots_full: torch.Tensor) -> torch.Tensor:\n        \"\"\"Extract the 30-joint subset from 77-joint local rotation data.\"\"\"\n        skel_slice = self.get_skel_slice(self.somaskel77)\n        return local_joint_rots_full[:, skel_slice]\n\n    def output_to_SOMASkeleton77(self, output: dict) -> dict:\n        \"\"\"Convert model output dict from somaskel30 to somaskel77.\n\n        Expands local_rot_mats to 77 joints, re-runs FK for global_rot_mats and posed_joints. Foot\n        contacts are expanded from 4 channels to 6 (toe-end copies toe-base contact).\n        \"\"\"\n        local_rot_mats_77 = self.to_SOMASkeleton77(output[\"local_rot_mats\"])\n        root_positions = output[\"root_positions\"]\n        global_rot_mats_77, posed_joints_77, _ = self.somaskel77.fk(local_rot_mats_77, root_positions)\n        out_77 = dict(output)\n        out_77[\"local_rot_mats\"] = local_rot_mats_77\n        out_77[\"global_rot_mats\"] = global_rot_mats_77\n        out_77[\"posed_joints\"] = posed_joints_77\n\n        if \"foot_contacts\" in output:\n            fc = output[\"foot_contacts\"]  # [..., 4]: [L_heel, L_toe, R_heel, R_toe]\n            # -> [..., 6]: [L_heel, L_toe, L_toe_end, R_heel, R_toe, R_toe_end]\n            out_77[\"foot_contacts\"] = torch.cat([fc[..., :2], fc[..., 1:2], fc[..., 2:4], fc[..., 3:4]], dim=-1)\n\n        return out_77\n\n\nclass G1Skeleton34(SkeletonBase):\n    \"\"\"Unitree G1 skeleton with 32 articulated joints plus 2 toe endpoints.\"\"\"\n\n    name = \"g1skel34\"\n    right_foot_joint_names = [\"right_ankle_roll_skel\", \"right_toe_base\"]\n    left_foot_joint_names = [\"left_ankle_roll_skel\", \"left_toe_base\"]\n    right_hand_joint_names = [\"right_wrist_yaw_skel\", \"right_hand_roll_skel\"]\n    left_hand_joint_names = [\"left_wrist_yaw_skel\", \"left_hand_roll_skel\"]\n\n    hip_joint_names = [\n        \"right_hip_pitch_skel\",\n        \"left_hip_pitch_skel\",\n    ]  # used to calculate root orientation, only need 1 pair of hip joints\n\n    bone_order_names_with_parents = [\n        (\"pelvis_skel\", None),\n        (\"left_hip_pitch_skel\", \"pelvis_skel\"),\n        (\"left_hip_roll_skel\", \"left_hip_pitch_skel\"),\n        (\"left_hip_yaw_skel\", \"left_hip_roll_skel\"),\n        (\"left_knee_skel\", \"left_hip_yaw_skel\"),\n        (\"left_ankle_pitch_skel\", \"left_knee_skel\"),\n        (\"left_ankle_roll_skel\", \"left_ankle_pitch_skel\"),\n        (\"left_toe_base\", \"left_ankle_roll_skel\"),\n        (\"right_hip_pitch_skel\", \"pelvis_skel\"),\n        (\"right_hip_roll_skel\", \"right_hip_pitch_skel\"),\n        (\"right_hip_yaw_skel\", \"right_hip_roll_skel\"),\n        (\"right_knee_skel\", \"right_hip_yaw_skel\"),\n        (\"right_ankle_pitch_skel\", \"right_knee_skel\"),\n        (\"right_ankle_roll_skel\", \"right_ankle_pitch_skel\"),\n        (\"right_toe_base\", \"right_ankle_roll_skel\"),\n        (\"waist_yaw_skel\", \"pelvis_skel\"),\n        (\"waist_roll_skel\", \"waist_yaw_skel\"),\n        (\"waist_pitch_skel\", \"waist_roll_skel\"),\n        (\"left_shoulder_pitch_skel\", \"waist_pitch_skel\"),\n        (\"left_shoulder_roll_skel\", \"left_shoulder_pitch_skel\"),\n        (\"left_shoulder_yaw_skel\", \"left_shoulder_roll_skel\"),\n        (\"left_elbow_skel\", \"left_shoulder_yaw_skel\"),\n        (\"left_wrist_roll_skel\", \"left_elbow_skel\"),\n        (\"left_wrist_pitch_skel\", \"left_wrist_roll_skel\"),\n        (\"left_wrist_yaw_skel\", \"left_wrist_pitch_skel\"),\n        (\"left_hand_roll_skel\", \"left_wrist_yaw_skel\"),\n        (\"right_shoulder_pitch_skel\", \"waist_pitch_skel\"),\n        (\"right_shoulder_roll_skel\", \"right_shoulder_pitch_skel\"),\n        (\"right_shoulder_yaw_skel\", \"right_shoulder_roll_skel\"),\n        (\"right_elbow_skel\", \"right_shoulder_yaw_skel\"),\n        (\"right_wrist_roll_skel\", \"right_elbow_skel\"),\n        (\"right_wrist_pitch_skel\", \"right_wrist_roll_skel\"),\n        (\"right_wrist_yaw_skel\", \"right_wrist_pitch_skel\"),\n        (\"right_hand_roll_skel\", \"right_wrist_yaw_skel\"),\n    ]\n\n\nclass SMPLXSkeleton22(SkeletonBase):\n    \"\"\"SMPL-X skeleton with body-only 22 joints.\"\"\"\n\n    name = \"smplx22\"\n    right_foot_joint_names = [\"right_ankle\", \"right_foot\"]  # in order of chain\n    left_foot_joint_names = [\"left_ankle\", \"left_foot\"]  # in order of chain\n    right_hand_joint_names = [\"right_wrist\"]  # in order of chain\n    left_hand_joint_names = [\"left_wrist\"]  # in order of chain\n    hip_joint_names = [\"right_hip\", \"left_hip\"]  # in order [right, left]\n\n    bone_order_names_with_parents = [\n        (\"pelvis\", None),\n        (\"left_hip\", \"pelvis\"),\n        (\"right_hip\", \"pelvis\"),\n        (\"spine1\", \"pelvis\"),\n        (\"left_knee\", \"left_hip\"),\n        (\"right_knee\", \"right_hip\"),\n        (\"spine2\", \"spine1\"),\n        (\"left_ankle\", \"left_knee\"),\n        (\"right_ankle\", \"right_knee\"),\n        (\"spine3\", \"spine2\"),\n        (\"left_foot\", \"left_ankle\"),\n        (\"right_foot\", \"right_ankle\"),\n        (\"neck\", \"spine3\"),\n        (\"left_collar\", \"spine3\"),\n        (\"right_collar\", \"spine3\"),\n        (\"head\", \"neck\"),\n        (\"left_shoulder\", \"left_collar\"),\n        (\"right_shoulder\", \"right_collar\"),\n        (\"left_elbow\", \"left_shoulder\"),\n        (\"right_elbow\", \"right_shoulder\"),\n        (\"left_wrist\", \"left_elbow\"),\n        (\"right_wrist\", \"right_elbow\"),\n    ]\n"
  },
  {
    "path": "kimodo/skeleton/kinematics.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Forward-kinematics primitives for articulated skeletons.\"\"\"\n\nfrom typing import List\n\nimport einops\nimport torch\nimport torch.nn.functional as F\n\nfrom ..tools import ensure_batched\n\n\n@ensure_batched(local_joint_rots=4, root_positions=2)\ndef fk(\n    local_joint_rots: torch.Tensor,\n    root_positions: torch.Tensor,\n    skeleton,\n    root_positions_is_global: bool = True,\n):\n    \"\"\"Compute global joint rotations and positions from local rotations.\n\n    Args:\n        local_joint_rots: Local rotation matrices with shape `(..., J, 3, 3)`.\n        root_positions: Root translations with shape `(..., 3)`.\n        skeleton: Skeleton object exposing `neutral_joints`, `joint_parents`, and\n            `root_idx`.\n        root_positions_is_global: If `True`, neutral joints are recentered so root\n            translations are interpreted in world space.\n\n    Returns:\n        Tuple `(global_joint_rots, posed_joints, posed_joints_norootpos)`.\n    \"\"\"\n    device = local_joint_rots.device\n    dtype = local_joint_rots.dtype\n\n    # If skeleton has baked rest (e.g. from XML), identity local = baked rest pose.\n    # So training/inference local rotations are in reference to XML rest *orientations*.\n    rest_local = getattr(skeleton, \"rest_local_rots\", None)\n    if rest_local is not None:\n        rest_local = rest_local.to(device=device, dtype=dtype)\n        local_joint_rots = torch.einsum(\"jmn,...jno->...jmo\", rest_local, local_joint_rots)\n\n    # Rest positions for FK. Must be consistent with rest_local: when local = identity,\n    # FK(rest_local, neutral_joints) should equal the XML rest pose positions. So\n    # neutral_joints are not necessarily the raw XML joint positions; they are the\n    # rest layout that, when rotated by rest_local, yields the XML rest positions.\n    neutral_joints = skeleton.neutral_joints.to(device=device, dtype=dtype)\n\n    if root_positions_is_global is True:\n        # Removing the pelvis offset from the neutral joints\n        # as the root positions does not depends on the pelvis offset of the skeleton\n        pelvis_offset = neutral_joints[skeleton.root_idx]\n        neutral_joints = neutral_joints - pelvis_offset\n\n    # compute joint position and global rotations\n    joints = einops.repeat(\n        neutral_joints,\n        \"j k -> b j k\",\n        b=len(local_joint_rots),\n    )\n    posed_joints_norootpos, global_joint_rots = batch_rigid_transform(\n        local_joint_rots,\n        joints,\n        skeleton.joint_parents,\n        skeleton.root_idx,\n    )\n    # if root_positions_is_global is True:\n    # posed_joints_norootpos always start at zero\n    # otherwise it could start with the pelvis offset\n\n    posed_joints = posed_joints_norootpos + root_positions[:, None]\n    return global_joint_rots, posed_joints, posed_joints_norootpos\n\n\ndef compute_idx_levels(parents):\n    \"\"\"Group joint indices by hierarchy depth for level-wise FK updates.\n\n    Args:\n        parents: Parent index tensor of shape `(J,)` with root parent `-1`.\n\n    Returns:\n        List of index tensors, where each tensor contains joints at one depth.\n    \"\"\"\n    idx_levs = [[]]\n    lev_dicts = {0: -1}\n    for i in range(1, parents.shape[0]):\n        assert int(parents[i]) in lev_dicts\n        lev = lev_dicts[int(parents[i])] + 1\n        if lev + 1 > len(idx_levs):\n            idx_levs.append([])\n        idx_levs[lev].append(int(i))\n        lev_dicts[int(i)] = lev\n    idx_levs = [torch.tensor(x).long() for x in idx_levs]\n    return idx_levs\n\n\ndef batch_rigid_transform(rot_mats, joints, parents, root_idx):\n    \"\"\"Perform batch rigid transformation on a skeletal structure.\n\n    Args:\n        rot_mats: Local rotation matrices for each joint: (B, J, 3, 3)\n        joints: Initial joint positions: (B, J, 3)\n        parents: Tensor indicating the parent of each joint: (J,)\n        root_idx (int): index of the root\n\n    Returns:\n        Transformed joint positions after applying forward kinematics.\n    \"\"\"\n\n    # Compute the hierarchical levels of joints based on their parent relationships\n    idx_levs = compute_idx_levels(parents)\n\n    # Apply forward kinematics to transform the joints\n    return forward_kinematics(rot_mats, joints, parents, idx_levs, root_idx)\n\n\n@torch.jit.script\ndef transform_mat(R, t):\n    \"\"\"Creates a batch of transformation matrices.\n\n    Args:\n        - R: Bx3x3 array of a batch of rotation matrices\n        - t: Bx3x1 array of a batch of translation vectors\n    Returns:\n        - T: Bx4x4 Transformation matrix\n    \"\"\"\n    # No padding left or right, only add an extra row\n    return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1.0)], dim=2)\n\n\n@torch.jit.script\ndef forward_kinematics(\n    rot_mats,\n    joints,\n    parents: torch.Tensor,\n    idx_levs: List[torch.Tensor],\n    root_idx: int,\n):\n    \"\"\"Perform forward kinematics to compute posed joints and global rotation matrices.\n\n    Args:\n        rot_mats: Local rotation matrices for each joint: (B, J, 3, 3)\n        joints: Initial joint positions: (B, J, 3)\n        parents: Tensor indicating the parent of each joint: (J,)\n        idx_levs: Tensors of joint indices grouped by depth in the kinematic tree.\n        root_idx (int): index of the root\n    Returns:\n        Posed joints: (B, J, 3)\n        Global rotation matrices: (B, J, 3, 3)\n    \"\"\"\n\n    # Add an extra dimension to joints\n    joints = torch.unsqueeze(joints, dim=-1)\n\n    # Compute relative joint positions\n    rel_joints = joints.clone()\n\n    mask_no_root = torch.ones(joints.shape[1], dtype=torch.bool)\n    mask_no_root[root_idx] = False\n    rel_joints[:, mask_no_root] -= joints[:, parents[mask_no_root]].clone()\n\n    # Compute initial transformation matrices\n    # (B, J + 1, 4, 4)\n    transforms_mat = transform_mat(rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(\n        -1, joints.shape[1], 4, 4\n    )\n\n    # Initialize the root transformation matrices\n    transforms = torch.zeros_like(transforms_mat)\n    transforms[:, root_idx] = transforms_mat[:, root_idx]\n\n    # Compute global transformations level by level\n    for indices in idx_levs:\n        curr_res = torch.matmul(transforms[:, parents[indices]], transforms_mat[:, indices])\n        transforms[:, indices] = curr_res\n\n    # Extract posed joint positions from the transformation matrices\n    posed_joints = transforms[:, :, :3, 3]\n\n    # Extract global rotation matrices from the transformation matrices\n    global_rot_mat = transforms[:, :, :3, :3]\n\n    return posed_joints, global_rot_mat\n"
  },
  {
    "path": "kimodo/skeleton/registry.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Factory helpers for building predefined skeleton variants.\"\"\"\n\nfrom pathlib import Path\n\nfrom kimodo.assets import SKELETONS_ROOT\n\nfrom .definitions import (\n    G1Skeleton34,\n    SMPLXSkeleton22,\n    SOMASkeleton30,\n    SOMASkeleton77,\n)\n\n\ndef build_skeleton(nbjoints: int, assets_folder: str | Path = SKELETONS_ROOT):\n    \"\"\"Instantiate a known skeleton class from its joint count.\n\n    Supported joint counts: 30 (SOMA compact), 34 (G1), 77 (SOMA full), 22 (SMPLX).\n\n    Args:\n        nbjoints: Number of joints expected in the skeleton representation.\n        assets_folder: Base skeleton-assets directory containing per-skeleton subfolders.\n\n    Returns:\n        A configured `SkeletonBase` subclass instance.\n\n    Raises:\n        ValueError: If `nbjoints` does not match a registered skeleton.\n    \"\"\"\n    assets_folder = Path(assets_folder)\n    if nbjoints == 34:\n        return G1Skeleton34(assets_folder / \"g1skel34\")\n    elif nbjoints == 22:\n        return SMPLXSkeleton22(assets_folder / \"smplx22\")\n    elif nbjoints == 30:\n        return SOMASkeleton30(assets_folder / \"somaskel30\")\n    elif nbjoints == 77:\n        return SOMASkeleton77(assets_folder / \"somaskel77\")\n    else:\n        raise ValueError(\"This skeleton is not recognized.\")\n"
  },
  {
    "path": "kimodo/skeleton/transforms.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Rotation-space conversion utilities for skeleton motion data.\"\"\"\n\nimport einops\nimport torch\n\nfrom ..tools import ensure_batched\nfrom .kinematics import batch_rigid_transform\n\n\ndef global_rots_to_local_rots(global_joint_rots: torch.Tensor, skeleton):\n    \"\"\"Convert global rotations to local rotations using a skeleton hierarchy.\n\n    Args:\n        global_joint_rots: Global rotation matrices with shape `(..., J, 3, 3)`.\n        skeleton: Skeleton object exposing `joint_parents` and `root_idx`.\n\n    Returns:\n        Local rotation matrices with the same leading shape as the input.\n    \"\"\"\n    # Doing big batch\n    global_joint_mats, ps = einops.pack(\n        [global_joint_rots],\n        \"* nbjoints dim1 dim2\",\n    )\n\n    # obtain back the local rotations from the new global rotations\n    parent_rot_mats = global_joint_mats[:, skeleton.joint_parents]\n\n    parent_rot_mats[:, skeleton.root_idx] = torch.eye(3)  # the root joint\n    parent_rot_mats_inv = parent_rot_mats.transpose(2, 3)\n    local_rot_mats = torch.einsum(\n        \"T N m n, T N n o -> T N m o\",\n        parent_rot_mats_inv,\n        global_joint_mats,\n    )\n    [local_rot_mats] = einops.unpack(local_rot_mats, ps, \"* nbjoints dim1 dim2\")\n    return local_rot_mats\n\n\n@ensure_batched(local_rot_mats=4)\ndef change_tpose(local_rot_mats: torch.Tensor, global_rot_offsets: torch.Tensor, skeleton):\n    \"\"\"Re-express local rotations in another t_pose based on the global rotation offsets.\n\n    Args:\n        local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.\n        global_rot_offsets: Global rotation offsets with shape `(..., J, 3, 3)`.\n        skeleton: Skeleton object exposing `joint_parents`,\n            `root_idx`, and `nbjoints`.\n\n    Returns:\n        Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.\n    \"\"\"\n\n    device, dtype = local_rot_mats.device, local_rot_mats.dtype\n    global_rot_offsets = global_rot_offsets.to(device=device, dtype=dtype)\n\n    root_idx = skeleton.root_idx\n    joint_parents = skeleton.joint_parents\n    # These are dummy joint positions, will not be used\n    neutral_joints = torch.ones((len(local_rot_mats), skeleton.nbjoints, 3), device=device, dtype=dtype)\n\n    # get the old joint rotations in the same global space as the t-pose\n    #   Note: the neutral joints we use here doesn't matter, because we are only using the global rotation outputs\n    _, global_rot_mats = batch_rigid_transform(local_rot_mats, neutral_joints, joint_parents, root_idx)  # (T, N, 3, 3)\n\n    # compute the desired joint rotations in the frame of the new t-pose\n    new_global_rot_mats = torch.einsum(\"T N m n, N o n -> T N m o\", global_rot_mats, global_rot_offsets)\n    # convert back to local rotations\n    new_local_rot_mats = global_rots_to_local_rots(new_global_rot_mats, skeleton)\n    return new_local_rot_mats, new_global_rot_mats\n\n\n@ensure_batched(local_rot_mats=4)\ndef to_standard_tpose(local_rot_mats: torch.Tensor, skeleton):\n    \"\"\"Re-express local rotations in the skeleton's standard T-pose convention.\n\n    Args:\n        local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.\n        skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`,\n            `root_idx`, and `nbjoints`.\n\n    Returns:\n        Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.\n    \"\"\"\n    global_rot_offsets = skeleton.global_rot_offsets\n    return change_tpose(local_rot_mats, global_rot_offsets, skeleton)\n\n\n@ensure_batched(local_rot_mats=4)\ndef from_standard_tpose(local_rot_mats: torch.Tensor, skeleton):\n    \"\"\"Re-express local rotations from the skeleton's standard T-pose convention to the original\n    formulation.\n\n    Args:\n        local_rot_mats: Local rotation matrices with shape `(..., J, 3, 3)`.\n        skeleton: Skeleton object exposing `global_rot_offsets`, `joint_parents`,\n            `root_idx`, and `nbjoints`.\n\n    Returns:\n        Tuple `(new_local_rot_mats, new_global_rot_mats)` in the standard frame.\n    \"\"\"\n    global_rot_offsets = skeleton.global_rot_offsets\n    global_rot_offsets_T = global_rot_offsets.mT  # do the inverse transform\n    return change_tpose(local_rot_mats, global_rot_offsets_T, skeleton)\n"
  },
  {
    "path": "kimodo/tools.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Shared utilities: validation decorator, batching, JSON I/O, seeding, tensor conversion.\"\"\"\n\nimport inspect\nimport json\nimport math\nimport random\nfrom collections.abc import Mapping, Sequence\nfrom functools import wraps\nfrom math import prod\nfrom pathlib import Path\nfrom typing import Any, Callable, Mapping, Optional, ParamSpec, TypeVar, Union\n\nimport numpy as np\nimport torch\n\n\ndef validate(validator, save_args: bool = False, super_init: bool = False):\n    \"\"\"Create a decorator function for validating user inputs.\n\n    Args:\n        validator: the function to validate (pydantic dataclass)\n        save (bool): save all the attributes to the obj [args[0]]\n        super_init (bool): init parent with no arguments (useful for using save on a nn.Module)\n\n    Returns:\n        decorator: the decorator function\n    \"\"\"\n\n    def decorator(func):\n        @wraps(func)\n        def validated_func(*args, **kwargs):\n            conf = validator(**kwargs)\n\n            if save_args:\n                assert len(args) != 0\n                obj = args[0]\n\n                if super_init:\n                    # init the parent module\n                    super(type(obj), obj).__init__()\n\n                for key, val in conf.__dict__.items():\n                    setattr(obj, key, val)\n            return func(*args, conf)\n\n        return validated_func\n\n    return decorator\n\n\n# Type alias for clarity\nTensor = Any\n\nP = ParamSpec(\"P\")\nR = TypeVar(\"R\")\n\n\ndef ensure_batched(**spec: int) -> Callable[[Callable[P, R]], Callable[P, R]]:\n    \"\"\"Decorator to flatten complex batch dimensions.\n\n    Fixes included:\n    1. Handles 1D tensors (tail_ndim=0) correctly without slicing errors.\n    2. Skips .reshape() if the input is already purely flat (Optimization).\n    \"\"\"\n    if not spec:\n        raise ValueError(\"At least one argument spec must be provided.\")\n\n    def decorator(fn: Callable[P, R]) -> Callable[P, R]:\n        sig = inspect.signature(fn)\n\n        @wraps(fn)\n        def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:\n            bound = sig.bind(*args, **kwargs)\n            bound.apply_defaults()\n\n            def _sequence_shape(name: str, value: Any) -> tuple[int, ...]:\n                if not isinstance(value, (list, tuple)):\n                    return ()\n                if len(value) == 0:\n                    return (0,)\n                first_shape = _sequence_shape(name, value[0])\n                for item in value[1:]:\n                    item_shape = _sequence_shape(name, item)\n                    if item_shape != first_shape:\n                        raise ValueError(f\"'{name}' must be a rectangular nested sequence, got ragged shape.\")\n                return (len(value), *first_shape)\n\n            def _shape_and_ndim(name: str, value: Any) -> tuple[tuple[int, ...], int]:\n                if hasattr(value, \"shape\") and hasattr(value, \"ndim\"):\n                    shape = tuple(value.shape)\n                    return shape, int(value.ndim)\n                if isinstance(value, (list, tuple)):\n                    shape = _sequence_shape(name, value)\n                    return shape, len(shape)\n                raise TypeError(f\"'{name}' must be tensor-like or a nested list/tuple, got {type(value)}.\")\n\n            def _reshape_like(value: Any, shape: tuple[int, ...], name: str) -> Any:\n                if hasattr(value, \"reshape\"):\n                    return value.reshape(*shape)\n\n                if not isinstance(value, (list, tuple)):\n                    raise TypeError(f\"Cannot reshape '{name}' of type {type(value)}.\")\n\n                flat: list[Any] = []\n\n                def _flatten(x: Any) -> None:\n                    if isinstance(x, (list, tuple)):\n                        for item in x:\n                            _flatten(item)\n                    else:\n                        flat.append(x)\n\n                _flatten(value)\n                expected_size = prod(shape) if shape else 1\n                if len(flat) != expected_size:\n                    raise ValueError(f\"Cannot reshape '{name}' with {len(flat)} elements into shape {shape}.\")\n\n                def _build(index: int, dims: tuple[int, ...]) -> tuple[Any, int]:\n                    if not dims:\n                        return flat[index], index + 1\n                    items = []\n                    for _ in range(dims[0]):\n                        item, index = _build(index, dims[1:])\n                        items.append(item)\n                    return items, index\n\n                rebuilt, used = _build(0, shape)\n                if used != len(flat):\n                    raise ValueError(f\"Internal reshape error for '{name}': used {used}/{len(flat)} elements.\")\n                if isinstance(value, tuple) and isinstance(rebuilt, list):\n                    return tuple(rebuilt)\n                return rebuilt\n\n            # --- 1. CANONICAL ARGUMENT ---\n            spec_items = list(spec.items())\n            canonical_name = None\n            canonical_ndim = None\n            x0 = None\n            for name, ndim in spec_items:\n                candidate = bound.arguments.get(name, None)\n                if candidate is not None:\n                    canonical_name = name\n                    canonical_ndim = ndim\n                    x0 = candidate\n                    break\n            if canonical_name is None:\n                raise ValueError(\n                    \"All canonical candidates are None: \" + \", \".join(f\"'{name}'\" for name, _ in spec_items)\n                )\n\n            # Calculate split between Batch dims and Feature dims\n            expected_tail_dims = canonical_ndim - 1  # e.g. 3 - 1 = 2 (Sequence, Feat)\n            x0_shape, x0_ndim = _shape_and_ndim(canonical_name, x0)\n\n            # Validation\n            if x0_ndim < expected_tail_dims:\n                raise ValueError(f\"'{canonical_name}' ndim={x0_ndim} < expected {expected_tail_dims} tail dims.\")\n\n            # --- LOGIC FIX 1: Handle 0 tail dims correctly ---\n            if expected_tail_dims == 0:\n                orig_batch_shape = x0_shape\n                tail_shape = ()\n            else:\n                orig_batch_shape = x0_shape[:-expected_tail_dims]\n                tail_shape = x0_shape[-expected_tail_dims:]\n\n            # Calculate flattened batch size\n            # If orig_batch_shape is () (scalar input), size is 1.\n            B_flat = prod(orig_batch_shape) if orig_batch_shape else 1\n\n            # Determine if we added a fake batch dim (unbatched input)\n            is_unbatched_input = len(orig_batch_shape) == 0\n\n            # --- LOGIC FIX 2: Skip reshape if already flat (Optimization) ---\n            # If batch shape is already 1D (e.g. [2]), we don't need to reshape [2, 140, 5] -> [2, 140, 5]\n            is_already_flat = len(orig_batch_shape) == 1\n\n            if is_unbatched_input:\n                # (H, W) -> (1, H, W)\n                x0_batched = _reshape_like(x0, (1, *tail_shape), canonical_name)\n            elif is_already_flat:\n                # (B, H, W) -> Keep as is\n                x0_batched = x0\n            else:\n                # (B1, B2, H, W) -> (B1*B2, H, W)\n                x0_batched = _reshape_like(x0, (B_flat, *tail_shape), canonical_name)\n\n            bound.arguments[canonical_name] = x0_batched\n\n            # --- 2. OTHER ARGUMENTS ---\n            for name, target_ndim in spec_items:\n                if name == canonical_name:\n                    continue\n                val = bound.arguments.get(name, None)\n                if val is None:\n                    continue\n\n                arg_tail_dims = target_ndim - 1  # e.g. for lengths=1, tail=0\n                val_shape, val_ndim = _shape_and_ndim(name, val)\n\n                # Validate\n                if val_ndim < arg_tail_dims:\n                    raise ValueError(f\"'{name}' ndim={val_ndim} too small.\")\n\n                # --- Get Batch Shape (With 0-tail fix) ---\n                if arg_tail_dims == 0:\n                    val_batch_shape = val_shape\n                    val_tail_shape = ()\n                else:\n                    val_batch_shape = val_shape[:-arg_tail_dims]\n                    val_tail_shape = val_shape[-arg_tail_dims:]\n\n                # --- Check Mismatch ---\n                # Unbatched inputs must match unbatched canonical\n                if len(val_batch_shape) == 0:\n                    if not is_unbatched_input:\n                        raise ValueError(f\"'{name}' is unbatched but canonical is batched.\")\n                    val_batched = _reshape_like(val, (1, *val_tail_shape), name)\n                else:\n                    # Batched inputs must match canonical batch shape EXACTLY\n                    if val_batch_shape != orig_batch_shape:\n                        raise ValueError(\n                            f\"Batch dimensions mismatch! '{canonical_name}' has {orig_batch_shape}, \"\n                            f\"but '{name}' has {val_batch_shape}.\"\n                        )\n\n                    # Optimization: Don't reshape if already flat\n                    if is_already_flat:\n                        val_batched = val\n                    else:\n                        val_batched = _reshape_like(val, (B_flat, *val_tail_shape), name)\n\n                bound.arguments[name] = val_batched\n\n            # --- 3. EXECUTION ---\n            out = fn(**bound.arguments)\n\n            # --- 4. RESTORE ---\n            def restore(obj):\n                if isinstance(obj, Mapping):\n                    return {k: restore(v) for k, v in obj.items()}\n                if isinstance(obj, (list, tuple)):\n                    return type(obj)(restore(x) for x in obj)\n\n                if hasattr(obj, \"shape\"):\n                    if obj.ndim == 0:\n                        return obj\n\n                    # Verify batch dimension exists and wasn't reduced\n                    if obj.shape[0] != B_flat:\n                        return obj\n\n                    # If input was simple (B, ...), return simple (B, ...)\n                    if is_already_flat:\n                        return obj\n\n                    rest = obj.shape[1:]\n\n                    if is_unbatched_input:\n                        assert obj.shape[0] == 1, \"The batch size should be 1 for unbatched.\"\n                        return obj[0]\n\n                    return obj.reshape(*orig_batch_shape, *rest)\n                return obj\n\n            return restore(out)\n\n        return wrapper\n\n    return decorator\n\n\ndef to_numpy(obj):\n    \"\"\"Recursively convert tensors in dicts/lists/tuples to numpy arrays; leave other types\n    unchanged.\"\"\"\n    if isinstance(obj, Mapping):\n        return {k: to_numpy(v) for k, v in obj.items()}\n    if isinstance(obj, (list, tuple)):\n        return type(obj)(to_numpy(x) for x in obj)\n    if isinstance(obj, torch.Tensor):\n        return obj.cpu().numpy()\n    return obj\n\n\ndef to_torch(obj, device=None, dtype=None):\n    \"\"\"Recursively convert numpy arrays in dicts/lists/tuples to torch tensors; optionally move to\n    device/dtype.\"\"\"\n    if isinstance(obj, Mapping):\n        return {k: to_torch(v, device, dtype) for k, v in obj.items()}\n    if isinstance(obj, (list, tuple)):\n        return type(obj)(to_torch(x, device, dtype) for x in obj)\n    if isinstance(obj, np.ndarray):\n        obj = torch.from_numpy(obj)\n    if isinstance(obj, torch.Tensor):\n        if dtype is not None:\n            obj = obj.to(dtype=dtype)\n        if device is None:\n            return obj\n        return obj.to(device)\n    return obj\n\n\ndef seed_everything(seed: int, deterministic: bool = False) -> None:\n    \"\"\"Seed all random number generators.\"\"\"\n    random.seed(seed)  # for Python random module.\n    np.random.seed(seed)  # for NumPy.\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.backends.cudnn.deterministic = True  # for deterministic behavior.\n        torch.backends.cudnn.benchmark = False  # if you want to make the behavior deterministic.\n\n\ndef load_json(path: Union[str, Path]) -> Any:\n    \"\"\"Load a JSON file and return its contents.\n\n    Args:\n        path (str | Path): Path to the JSON file.\n\n    Returns:\n        Any: Parsed JSON content (dict, list, etc.).\n\n    Raises:\n        FileNotFoundError: If the file does not exist.\n        ValueError: If the file is not valid JSON.\n    \"\"\"\n    path = Path(path)\n\n    if not path.exists():\n        raise FileNotFoundError(f\"JSON file not found: {path}\")\n\n    try:\n        with path.open(\"r\", encoding=\"utf-8\") as f:\n            return json.load(f)\n    except json.JSONDecodeError as e:\n        raise ValueError(f\"Invalid JSON in file {path}: {e}\") from e\n\n\ndef save_json(path: Union[str, Path], data: Any) -> None:\n    \"\"\"Save data to a JSON file.\n\n    Args:\n        path (str | Path): Path to the JSON file.\n        data (Any): Data to save (must be JSON serializable).\n\n    Raises:\n        ValueError: If the data is not JSON serializable.\n    \"\"\"\n    path = Path(path)\n\n    # Create parent directories if they don't exist\n    path.parent.mkdir(parents=True, exist_ok=True)\n\n    try:\n        with path.open(\"w\", encoding=\"utf-8\") as f:\n            json.dump(data, f, indent=2, ensure_ascii=False)\n    except (TypeError, ValueError) as e:\n        raise ValueError(f\"Data is not JSON serializable: {e}\") from e\n"
  },
  {
    "path": "kimodo/viz/__init__.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Viser-based 3D visualization for skeletons and motion.\"\"\"\n\nfrom . import viser_utils\nfrom .viser_utils import (\n    Character,\n    CharacterMotion,\n    ConstraintSet,\n    EEJointsKeyframeSet,\n    FullbodyKeyframeSet,\n    GuiElements,\n    RootKeyframe2DSet,\n    SkeletonMesh,\n    WaypointMesh,\n    load_example_cases,\n)\n\n__all__ = [\n    \"Character\",\n    \"CharacterMotion\",\n    \"ConstraintSet\",\n    \"EEJointsKeyframeSet\",\n    \"FullbodyKeyframeSet\",\n    \"GuiElements\",\n    \"RootKeyframe2DSet\",\n    \"SkeletonMesh\",\n    \"WaypointMesh\",\n    \"load_example_cases\",\n    \"viser_utils\",\n]\n"
  },
  {
    "path": "kimodo/viz/constraint_ui.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Constraint visualization and frame indexing for the viz UI.\"\"\"\n\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\n\nimport viser\nimport viser.transforms as tf\nfrom kimodo.motion_rep.smooth_root import get_smooth_root_pos\nfrom kimodo.skeleton import SkeletonBase\nfrom kimodo.tools import to_numpy, to_torch\n\nfrom .scene import SkeletonMesh, WaypointMesh\n\n\ndef update_interval(interval_start, interval_end, start_frame_idx, end_frame_idx):\n    \"\"\"Updates an interval after removing the range from start_frame_idx to end_frame_idx.\"\"\"\n    # Calculate new range after removing [start_frame_idx, end_frame_idx]\n    # Case 1: Removal fully contains the interval -> delete entirely\n    if start_frame_idx <= interval_start and end_frame_idx >= interval_end:\n        return None, None  # Already removed, don't recreate\n    # Case 2: Removal is at the start of interval -> shrink from start\n    elif start_frame_idx <= interval_start and end_frame_idx < interval_end:\n        new_start = end_frame_idx + 1\n        new_end = interval_end\n    # Case 3: Removal is at the end of interval -> shrink from end\n    elif start_frame_idx > interval_start and end_frame_idx >= interval_end:\n        new_start = interval_start\n        new_end = start_frame_idx - 1\n    # Case 4: Removal is in the middle -> keep the larger portion\n    else:  # start_frame_idx > interval_start and end_frame_idx < interval_end\n        left_size = start_frame_idx - interval_start\n        right_size = interval_end - end_frame_idx\n        if left_size >= right_size:\n            new_start = interval_start\n            new_end = start_frame_idx - 1\n        else:\n            new_start = end_frame_idx + 1\n            new_end = interval_end\n    return new_start, new_end\n\n\nclass ConstraintSet:\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        skeleton: SkeletonBase,\n        display_name: Optional[str] = None,\n    ):\n        self.name = name\n        self.server = server\n        self.skeleton = skeleton\n        self.display_name = display_name if display_name is not None else name\n\n        self.keyframes = dict()  # frame_idx -> poses\n        self.frame2keyid = dict()  # frame_idx -> list of keyframe ids at this frame\n        self.scene_elements = dict()  # frame_idx -> meshes, labels, etc.\n        self.interval_labels = dict()  # (start_frame_idx, end_frame_idx) -> interval_label\n        self.labels_visible = True\n\n    def set_label_visibility(self, visible: bool) -> None:\n        \"\"\"Show or hide constraint labels without deleting them.\"\"\"\n        self.labels_visible = visible\n        for scene_data in self.scene_elements.values():\n            label = scene_data.get(\"label\")\n            if label is not None:\n                label.visible = visible\n        for interval_label in self.interval_labels.values():\n            interval_label.visible = visible\n\n    def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None:\n        \"\"\"Show all overlay elements, or only those at the given frame.\n\n        Args:\n            only_frame: If None, show all overlays. If int, show only overlays at that frame.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def add_keyframe(self, keyframe_id: str, frame_idx: int, pose_data: torch.Tensor):\n        \"\"\"Adds a single keyframe at the given frame with the given pose data.\n\n        Args:\n            keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx.\n            frame_idx: int, frame index to add the keyframe at\n            pose_data: torch.Tensor, e.g. full-body pose, EE pose, 2D root pose, etc.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def add_interval(\n        self,\n        interval_id: str,\n        start_frame_idx: int,\n        end_frame_idx: int,\n        pose_seq_data: torch.Tensor,\n    ):\n        \"\"\"Adds a keyframe interval between the given start and end frames with the given pose data.\n\n        Args:\n            interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx.\n            start_frame_idx: int, start frame index of the interval\n            end_frame_idx: int, end frame index of the interval\n            pose_seq_data: torch.Tensor, data for constrained interval, e.g. full-body poses, EE poses, 2D root poses, etc.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def _add_interval_label(self, start_frame_idx: int, end_frame_idx: int):\n        \"\"\"\n        Adds an interval label between the given start and end frames\n        Args:\n            start_frame_idx: int, start frame index of the interval\n            end_frame_idx: int, end frame index of the interval\n        \"\"\"\n        mid = int((start_frame_idx + end_frame_idx) / 2)\n        interval_label_pos = self._get_label_pos(mid)\n        interval_label = self.server.scene.add_label(\n            name=f\"/{self.name}/interval_label_{start_frame_idx}_{end_frame_idx}\",\n            text=f\"{self.display_name} @ [{start_frame_idx}, {end_frame_idx}]\",\n            position=interval_label_pos,\n            font_size_mode=\"screen\",\n            font_screen_scale=0.7,\n            anchor=\"center-center\",\n        )\n        interval_label.visible = self.labels_visible\n        self.interval_labels[(start_frame_idx, end_frame_idx)] = interval_label\n\n    def remove_keyframe(self, keyframe_id: str, frame_idx: int):\n        \"\"\"\n        Removes a keyframe at the given frame\n        Args:\n            keyframe_id: str, id for the keyframe to remove\n            frame_idx: int, frame index to remove the keyframe at\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int):\n        \"\"\"\n        Removes an interval between the given start and end frames\n        Args:\n            interval_id: str, id for the interval to remove\n            start_frame_idx: int, start frame index of the interval\n            end_frame_idx: int, end frame index of the interval\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def _get_label_pos(self, frame_idx: int):\n        \"\"\"\n        Returns the position of where to place the displayed label for the given frame index\n        Args:\n            frame_idx: int, frame index to get the label position for\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def _remove_interval_and_update_label(self, interval_id: str, start_frame_idx: int, end_frame_idx: int):\n        \"\"\"\n        Removes an interval between the given start and end frames and updates the label\n        Args:\n            start_frame_idx: int, start frame index of the interval\n            end_frame_idx: int, end frame index of the interval\n        \"\"\"\n        for frame_idx in range(start_frame_idx, end_frame_idx + 1):\n            self.remove_keyframe(interval_id, frame_idx)\n\n        # Update interval labels that overlap with the removed range\n        intervals_to_update = []\n        for (interval_start, interval_end), label in list(self.interval_labels.items()):\n            # Check if intervals overlap\n            if interval_start <= end_frame_idx and interval_end >= start_frame_idx:\n                intervals_to_update.append((interval_start, interval_end, label))\n\n        for interval_start, interval_end, label in intervals_to_update:\n            # Remove old label from scene and dict\n            self.server.scene.remove_by_name(label.name)\n            del self.interval_labels[(interval_start, interval_end)]\n\n            new_start, new_end = update_interval(interval_start, interval_end, start_frame_idx, end_frame_idx)\n\n            if new_start is None or new_end is None:\n                continue\n\n            # Create updated label with new range\n            if new_start <= new_end:\n                # Position label at midpoint - these keyframes are guaranteed to exist\n                # since the new range is outside the removal range\n                mid_frame = (new_start + new_end) // 2\n                label_pos = self._get_label_pos(mid_frame)\n                new_label = self.server.scene.add_label(\n                    name=f\"/{self.name}/interval_label_{new_start}_{new_end}\",\n                    text=f\"{self.display_name} @ [{new_start}, {new_end}]\",\n                    position=label_pos,\n                    font_size_mode=\"screen\",\n                    font_screen_scale=0.7,\n                    anchor=\"center-center\",\n                )\n                new_label.visible = self.labels_visible\n                self.interval_labels[(new_start, new_end)] = new_label\n\n    def get_constraint_info(self, device: Optional[str] = None):\n        \"\"\"Returns constraint information for generation (torch) or UI (numpy).\"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n    def get_frame_idx(self):\n        \"\"\"Returns all constrained frame indices in the set.\"\"\"\n        return [frame_idx for frame_idx in list(self.keyframes.keys())]\n\n    def clear(self, frame_idx: Optional[int] = None):\n        \"\"\"\n        Clears all keyframes and intervals from the constraint set\n        Args:\n            frame_idx: int, sing frame index to clear if given\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement this method\")\n\n\ndef build_constraint_set_table_markdown(constraint_list: List[ConstraintSet]):\n    markdown = \"| Track | Frame Num |\\n\"\n    markdown += \"|------|----------|\\n\"\n\n    # Sort constraints by frame_idx\n    for constraint in constraint_list:\n        frame_info = constraint.get_frame_idx()\n        if len(frame_info) > 0:\n            frame_info = \", \".join([str(frame) for frame in sorted(frame_info)])\n        else:\n            frame_info = \"-\"\n        markdown += f\"| {constraint.display_name} | {frame_info} |\\n\"\n\n    return markdown\n\n\nclass FullbodyKeyframeSet(ConstraintSet):\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        skeleton: SkeletonBase,\n        display_name: Optional[str] = None,\n    ):\n        super().__init__(name, server, skeleton, display_name=display_name)\n\n    def add_keyframe(\n        self,\n        keyframe_id: str,\n        frame_idx: int,\n        joints_pos: torch.Tensor | np.ndarray,\n        joints_rot: torch.Tensor | np.ndarray,\n        viz_label: bool = True,\n        exists_ok: bool = False,\n    ):\n        \"\"\"Adds a single full-body keyframe at the given frame or updates the existing one at this\n        frame. Note if a keyframe already exists at this frame, it will be updated to the given\n        pose.\n\n        Args:\n            keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx.\n            frame_idx: int, frame index to add the keyframe at\n            joints_pos: torch.Tensor, [J, 3] joints positions to add the keyframe at\n        \"\"\"\n        # create/update scene elements\n        if frame_idx in self.keyframes:\n            skeleton_mesh = self.scene_elements[frame_idx][\"skeleton_mesh\"]\n            skeleton_mesh.set_pose(to_torch(joints_pos))\n            if viz_label and \"label\" in self.scene_elements[frame_idx]:\n                label = self.scene_elements[frame_idx][\"label\"]\n                label.position = to_numpy(joints_pos)[self.skeleton.root_idx]\n                label.visible = self.labels_visible\n        else:\n            # create skeleton to visualize the full-body constraint\n            skeleton_mesh = SkeletonMesh(\n                f\"/{self.name}/skeleton_{frame_idx}\",\n                self.server,\n                self.skeleton,\n                joint_color=(255, 235, 0),\n                bone_color=(255, 0, 0),\n                starting_joints_pos=to_torch(joints_pos),\n            )\n            self.scene_elements[frame_idx] = {\n                \"skeleton_mesh\": skeleton_mesh,\n            }\n            if viz_label:\n                label = self.server.scene.add_label(\n                    name=f\"/{self.name}/label_{frame_idx}\",\n                    text=f\"{self.display_name} @ {frame_idx}\",\n                    position=to_numpy(joints_pos)[self.skeleton.root_idx],\n                    font_size_mode=\"screen\",\n                    font_screen_scale=0.7,\n                    anchor=\"center-center\",\n                )\n                label.visible = self.labels_visible\n                self.scene_elements[frame_idx][\"label\"] = label\n\n        # set/update data\n        self.keyframes[frame_idx] = {\n            \"joints_pos\": to_numpy(joints_pos),\n            \"joints_rot\": to_numpy(joints_rot),\n        }\n\n        if frame_idx not in self.frame2keyid:\n            self.frame2keyid[frame_idx] = []\n\n        if keyframe_id in self.frame2keyid[frame_idx]:\n            if not exists_ok:\n                raise AssertionError(\"keyframe_id already exists in this frame!\")\n        else:\n            self.frame2keyid[frame_idx].append(keyframe_id)\n\n    def add_interval(\n        self,\n        interval_id: str,\n        start_frame_idx: int,\n        end_frame_idx: int,\n        joints_pos: torch.Tensor,\n        joints_rot: torch.Tensor,\n    ):\n        \"\"\"Adds a full-body keyframe interval between the given start and end frames.\n\n        Args:\n            start_frame_idx: int, start frame index of the interval\n            end_frame_idx: int, end frame index of the interval\n            joints_pos: torch.Tensor, [T, J, 3] joints positions within the interval\n        \"\"\"\n        assert joints_pos.shape[0] == end_frame_idx - start_frame_idx + 1\n        for frame_idx in range(start_frame_idx, end_frame_idx + 1):\n            rel_idx = frame_idx - start_frame_idx\n            self.add_keyframe(\n                interval_id,\n                frame_idx,\n                joints_pos[rel_idx],\n                joints_rot[rel_idx],\n                viz_label=False,\n            )\n\n        # add separate interval label\n        self._add_interval_label(start_frame_idx, end_frame_idx)\n\n    def remove_keyframe(self, keyframe_id: str, frame_idx: int):\n        if frame_idx not in self.keyframes:\n            return\n        if keyframe_id not in self.frame2keyid[frame_idx]:\n            return\n        self.frame2keyid[frame_idx].remove(keyframe_id)\n        if len(self.frame2keyid[frame_idx]) == 0:\n            del self.frame2keyid[frame_idx]\n            self.clear(frame_idx)\n\n    def _get_label_pos(self, frame_idx: int):\n        return self.keyframes[frame_idx][\"joints_pos\"][self.skeleton.root_idx]\n\n    def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int):\n        self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx)\n\n    def get_constraint_info(self, device: Optional[str] = None):\n        all_joints_pos = []\n        all_joints_rot = []\n        for v in self.keyframes.values():\n            joints_pos = to_torch(v[\"joints_pos\"], device=device)\n            joints_rot = to_torch(v[\"joints_rot\"], device=device)\n            if len(joints_pos.shape) == 2:\n                all_joints_pos.append(joints_pos[None])\n            else:\n                all_joints_pos.append(joints_pos)\n            if len(joints_rot.shape) == 3:\n                all_joints_rot.append(joints_rot[None])\n            else:\n                all_joints_rot.append(joints_rot)\n\n        all_joints_pos = torch.cat(all_joints_pos, dim=0) if len(all_joints_pos) > 0 else None\n        all_joints_rot = torch.cat(all_joints_rot, dim=0) if len(all_joints_rot) > 0 else None\n\n        return {\n            \"frame_idx\": self.get_frame_idx(),\n            \"joints_pos\": all_joints_pos,\n            \"joints_rot\": all_joints_rot,\n        }\n\n    def clear(self, frame_idx: Optional[int] = None):\n        frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx]\n        for fidx in frame_idx_list:\n            self.scene_elements[fidx][\"skeleton_mesh\"].clear()\n            if \"ee_rotation_axes\" in self.scene_elements[fidx]:\n                self.server.scene.remove_by_name(self.scene_elements[fidx][\"ee_rotation_axes\"].name)\n            if \"label\" in self.scene_elements[fidx]:\n                self.server.scene.remove_by_name(self.scene_elements[fidx][\"label\"].name)\n\n            self.keyframes.pop(fidx)\n            self.scene_elements.pop(fidx)\n            self.frame2keyid.pop(fidx, None)\n\n        if frame_idx is None:\n            # clear all interval labels if clearing all keyframes\n            for interval_label in list(self.interval_labels.values()):\n                self.server.scene.remove_by_name(interval_label.name)\n            self.interval_labels.clear()\n            self.frame2keyid.clear()\n\n    def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None:\n        show_all = only_frame is None\n        for fidx, scene_data in self.scene_elements.items():\n            visible = show_all or fidx == only_frame\n            scene_data[\"skeleton_mesh\"].set_visibility(visible)\n            label = scene_data.get(\"label\")\n            if label is not None:\n                label.visible = visible and self.labels_visible\n        for interval_label in self.interval_labels.values():\n            interval_label.visible = show_all and self.labels_visible\n\n\nclass EEJointsKeyframeSet(ConstraintSet):\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        skeleton: SkeletonBase,\n        display_name: Optional[str] = None,\n    ):\n        super().__init__(name, server, skeleton, display_name=display_name)\n\n        # frame_idx -> list of (keyframe_id, joint_names) at this frame\n        self.frame2keyid = dict()\n\n    def create_scene_elements(\n        self,\n        frame_idx: int,\n        joints_pos: torch.Tensor | np.ndarray,\n        joints_rot: Optional[torch.Tensor | np.ndarray],\n        joint_names: List[str],\n        viz_label: bool = True,\n    ):\n        # create skeleton to visualize the full-body constraint\n        ee_joint_indices = []\n        ee_gizmo_indices = []\n        constrained_bone_idx = []\n        for joint_name in joint_names:\n            if joint_name == \"Hips\":\n                continue\n            elif joint_name in [\"LeftHand\", \"RightHand\", \"LeftFoot\", \"RightFoot\"]:\n                expanded_joint_names = {\n                    \"LeftHand\": self.skeleton.left_hand_joint_names,\n                    \"RightHand\": self.skeleton.right_hand_joint_names,\n                    \"LeftFoot\": self.skeleton.left_foot_joint_names,\n                    \"RightFoot\": self.skeleton.right_foot_joint_names,\n                }[joint_name]\n                ee_joint_indices.extend([self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names])\n                if len(expanded_joint_names) > 1:\n                    ee_gizmo_indices.extend(\n                        [self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names[:1]]\n                    )\n                constrained_bone_idx.extend(\n                    [self.skeleton.bone_order_names_index[joint] - 1 for joint in expanded_joint_names[1:]]\n                )\n            else:\n                raise ValueError(f\"Invalid joint name: {joint_name}\")\n\n        # de-duplicate while preserving order\n        ee_joint_indices = list(dict.fromkeys(ee_joint_indices))\n        ee_gizmo_indices = list(dict.fromkeys(ee_gizmo_indices))\n        constrained_bone_idx = list(dict.fromkeys(constrained_bone_idx))\n\n        constrained_idx = [self.skeleton.root_idx] + ee_joint_indices\n\n        constrained_idx = np.array(constrained_idx)\n        constrained_bone_idx = np.array(constrained_bone_idx)\n\n        # create skeleton to visualize the full-body constraint\n        joint_color = np.full((self.skeleton.nbjoints, 3), (220, 220, 220))\n        bone_color = np.full((self.skeleton.nbjoints - 1, 3), (220, 220, 220))\n        # color constrained joints differently\n        joint_color[constrained_idx] = (255, 0, 0)\n        bone_color[constrained_bone_idx] = (255, 0, 0)\n        skeleton_mesh = SkeletonMesh(\n            f\"/{self.name}/skeleton_{frame_idx}\",\n            self.server,\n            self.skeleton,\n            joint_color=joint_color,\n            bone_color=bone_color,\n            starting_joints_pos=to_torch(joints_pos),\n        )\n\n        self.scene_elements[frame_idx] = {\n            \"skeleton_mesh\": skeleton_mesh,\n        }\n        joints_pos_np = to_numpy(joints_pos)\n        joints_rot_np = to_numpy(joints_rot) if joints_rot is not None else None\n        if joints_rot_np is not None and len(ee_gizmo_indices) > 0:\n            ee_axes = self.server.scene.add_batched_axes(\n                f\"/{self.name}/ee_rot_axes_{frame_idx}\",\n                batched_wxyzs=tf.SO3.from_matrix(joints_rot_np[ee_gizmo_indices]).wxyz,\n                batched_positions=joints_pos_np[ee_gizmo_indices],\n                axes_length=0.07,\n                axes_radius=0.007,\n            )\n            self.scene_elements[frame_idx][\"ee_rotation_axes\"] = ee_axes\n        if viz_label:\n            label = self.server.scene.add_label(\n                name=f\"/{self.name}/label_{frame_idx}\",\n                text=f\"{self.display_name} @ {frame_idx}\",\n                position=joints_pos_np[self.skeleton.root_idx] + np.array([0.0, 0.05, 0.0]),\n                font_size_mode=\"screen\",\n                font_screen_scale=0.7,\n                anchor=\"bottom-center\",\n            )\n            label.visible = self.labels_visible\n            self.scene_elements[frame_idx][\"label\"] = label\n\n    def add_keyframe(\n        self,\n        keyframe_id: str,\n        frame_idx: int,\n        joints_pos: torch.Tensor | np.ndarray,\n        joints_rot: torch.Tensor | np.ndarray,\n        joint_names: List[str],\n        end_effector_type: str,\n        viz_label: bool = True,\n        exists_ok: bool = False,\n    ):\n        \"\"\"Adds a single EE keyframe at the given frame or updates the existing one at this frame.\n\n        Args:\n            keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx.\n            frame_idx: int, frame index to add the keyframe at\n            joints_pos: torch.Tensor, [J, 3] joints positions to add the keyframe at\n            joints_rot: torch.Tensor, [J, 3, 3] joints rotation matrices to add the keyframe at\n            joint_names: List[str], names of the joints to add the keyframe at\n        \"\"\"\n        need_create_viz = True\n        joint_names_input = joint_names\n\n        if not isinstance(end_effector_type, set):\n            end_effector_type = set([end_effector_type])\n\n        # create/update scene elements\n        if frame_idx in self.keyframes:\n            if joint_names != self.keyframes[frame_idx][\"joint_names\"]:\n                # merge together with existing constraint if needed\n                joint_names = set(joint_names)\n                joint_names.update(set(self.keyframes[frame_idx][\"joint_names\"]))\n                joint_names = list(joint_names)\n                end_effector_type.update(self.keyframes[frame_idx][\"end_effector_type\"])\n                # need to re-create viz elements\n                self.clear(frame_idx)\n            else:\n                need_create_viz = False\n                # overwrite the pose with the latest one\n                skeleton_mesh = self.scene_elements[frame_idx][\"skeleton_mesh\"]\n                skeleton_mesh.set_pose(to_torch(joints_pos))\n                if \"ee_rotation_axes\" in self.scene_elements[frame_idx]:\n                    ee_gizmo_indices = []\n                    for joint_name in joint_names:\n                        if joint_name == \"Hips\":\n                            continue\n                        elif joint_name in [\n                            \"LeftHand\",\n                            \"RightHand\",\n                            \"LeftFoot\",\n                            \"RightFoot\",\n                        ]:\n                            expanded_joint_names = {\n                                \"LeftHand\": self.skeleton.left_hand_joint_names,\n                                \"RightHand\": self.skeleton.right_hand_joint_names,\n                                \"LeftFoot\": self.skeleton.left_foot_joint_names,\n                                \"RightFoot\": self.skeleton.right_foot_joint_names,\n                            }[joint_name]\n                            if len(expanded_joint_names) > 0:\n                                ee_gizmo_indices.extend(\n                                    [self.skeleton.bone_order_names_index[joint] for joint in expanded_joint_names[:1]]\n                                    # take only the base joint of the end effector (to avoid clutter)\n                                )\n                        else:\n                            raise ValueError(f\"Invalid joint name: {joint_name}\")\n                    ee_gizmo_indices = list(dict.fromkeys(ee_gizmo_indices))\n                    if len(ee_gizmo_indices) > 0:\n                        ee_axes = self.scene_elements[frame_idx][\"ee_rotation_axes\"]\n                        joints_pos_np = to_numpy(joints_pos)\n                        joints_rot_np = to_numpy(joints_rot)\n                        ee_axes.batched_positions = joints_pos_np[ee_gizmo_indices]\n                        ee_axes.batched_wxyzs = tf.SO3.from_matrix(joints_rot_np[ee_gizmo_indices]).wxyz\n                if viz_label and \"label\" in self.scene_elements[frame_idx]:\n                    label = self.scene_elements[frame_idx][\"label\"]\n                    label.position = to_numpy(joints_pos)[self.skeleton.root_idx]\n                    label.visible = self.labels_visible\n\n        if need_create_viz:\n            self.create_scene_elements(frame_idx, joints_pos, joints_rot, joint_names, viz_label=viz_label)\n\n        # set/update data\n        self.keyframes[frame_idx] = {\n            \"joints_pos\": to_numpy(joints_pos),\n            \"joints_rot\": to_numpy(joints_rot),\n            \"joint_names\": joint_names,\n            \"end_effector_type\": end_effector_type,\n        }\n\n        if frame_idx not in self.frame2keyid:\n            self.frame2keyid[frame_idx] = []\n\n        known_keyframe_ids = {k: idx for idx, (k, _) in enumerate(self.frame2keyid[frame_idx])}\n\n        if keyframe_id in known_keyframe_ids.keys():\n            if not exists_ok:\n                raise AssertionError(\"keyframe_id already exists in this frame!\")\n            idx = known_keyframe_ids[keyframe_id]\n            # override previous exisiting keyframe\n            self.frame2keyid[frame_idx][idx] = (keyframe_id, joint_names_input)\n        else:\n            # track which subset of joints are constrained by this keyframe_id\n            self.frame2keyid[frame_idx].append((keyframe_id, joint_names_input))\n\n    def add_interval(\n        self,\n        interval_id: str,\n        start_frame_idx: int,\n        end_frame_idx: int,\n        joints_pos: torch.Tensor | np.ndarray,\n        joints_rot: torch.Tensor | np.ndarray,\n        joint_names: List[str],\n        end_effector_type: str,\n    ):\n        \"\"\"Adds an interval of EE keyframes at the given frame or updates the existing one at this\n        frame.\n\n        Args:\n            interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx.\n            start_frame_idx: int, start frame index to add the interval at\n            end_frame_idx: int, end frame index to add the interval at\n            joints_pos: torch.Tensor, [T, J, 3] joints positions to add the interval at\n            joints_rot: torch.Tensor, [T, J, 3, 3] joints rotation matrices to add the interval at\n            joint_names: List[str], names of the joints to add for the entire interval\n        \"\"\"\n        num_frames = end_frame_idx - start_frame_idx + 1\n        joints_pos_np = to_numpy(joints_pos)\n        joints_rot_np = to_numpy(joints_rot)\n        assert joints_pos_np.shape[0] == num_frames\n        assert joints_rot_np.shape[0] == num_frames\n\n        for frame_idx in range(start_frame_idx, end_frame_idx + 1):\n            rel_idx = frame_idx - start_frame_idx\n            self.add_keyframe(\n                interval_id,\n                frame_idx,\n                joints_pos_np[rel_idx],\n                joints_rot_np[rel_idx],\n                joint_names,\n                end_effector_type,\n                viz_label=False,\n            )\n        self._add_interval_label(start_frame_idx, end_frame_idx)\n\n    def remove_keyframe(self, keyframe_id: str, frame_idx: int):\n        \"\"\"Removes a keyframe at the given frame or updates the existing one at this frame by\n        removing the specified joints.\n\n        Args:\n            keyframe_id: str, id for the keyframe to remove. This determines which joints to remove.\n            frame_idx: int, frame index to remove the keyframe at\n        \"\"\"\n        if frame_idx not in self.keyframes:\n            return\n\n        remaining_joint_names = set()\n        delete_idx = None\n        for i, (keyid, joint_names) in enumerate(self.frame2keyid[frame_idx]):\n            if keyid == keyframe_id:\n                delete_idx = i\n            else:\n                remaining_joint_names.update(joint_names)\n        if delete_idx is None:\n            # this keyframe_id is not in the specified frame\n            return\n\n        self.frame2keyid[frame_idx].pop(delete_idx)\n        if len(remaining_joint_names) == 0:\n            # no more keyframes in this frame, clear the frame\n            del self.frame2keyid[frame_idx]\n            self.clear(frame_idx)\n            return\n\n        # only deleting part of keyframe (potentially some subset of joints)\n        # delete the old visualization and add a new one with the updated joint set\n        new_joint_names = list(remaining_joint_names)\n        self.clear(frame_idx, scene_elements_only=True)\n        joints_pos = self.keyframes[frame_idx][\"joints_pos\"]\n        joints_rot = self.keyframes[frame_idx][\"joints_rot\"]\n        self.create_scene_elements(frame_idx, joints_pos, joints_rot, new_joint_names)\n        self.keyframes[frame_idx][\"joint_names\"] = new_joint_names\n\n    def _get_label_pos(self, frame_idx: int):\n        return self.keyframes[frame_idx][\"joints_pos\"][self.skeleton.root_idx]\n\n    def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int):\n        self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx)\n\n    def get_constraint_info(self, device: Optional[str] = None):\n        all_joints_pos = []\n        all_joints_rot = []\n        all_joints_names = []\n        all_end_effector_type = []\n        for v in self.keyframes.values():\n            joints_pos = to_torch(v[\"joints_pos\"], device=device)\n            joints_rot = to_torch(v[\"joints_rot\"], device=device)\n            if len(joints_pos.shape) == 2:\n                all_joints_pos.append(joints_pos[None])\n            else:\n                all_joints_pos.append(joints_pos)\n            if len(joints_rot.shape) == 3:\n                all_joints_rot.append(joints_rot[None])\n            else:\n                all_joints_rot.append(joints_rot)\n            all_joints_names.append(v[\"joint_names\"])\n            all_end_effector_type.append(v[\"end_effector_type\"])\n\n        all_joints_pos = torch.cat(all_joints_pos, dim=0) if len(all_joints_pos) > 0 else None\n        all_joints_rot = torch.cat(all_joints_rot, dim=0) if len(all_joints_rot) > 0 else None\n\n        return {\n            \"frame_idx\": self.get_frame_idx(),\n            \"joints_pos\": all_joints_pos,\n            \"joints_rot\": all_joints_rot,\n            \"joint_names\": all_joints_names,\n            \"end_effector_type\": all_end_effector_type,\n        }\n\n    def clear(self, frame_idx: Optional[int] = None, scene_elements_only: bool = False):\n        frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx]\n        for fidx in frame_idx_list:\n            self.scene_elements[fidx][\"skeleton_mesh\"].clear()\n            if \"ee_rotation_axes\" in self.scene_elements[fidx]:\n                self.server.scene.remove_by_name(self.scene_elements[fidx][\"ee_rotation_axes\"].name)\n            if \"label\" in self.scene_elements[fidx]:\n                self.server.scene.remove_by_name(self.scene_elements[fidx][\"label\"].name)\n            self.scene_elements.pop(fidx)\n            if not scene_elements_only:\n                self.keyframes.pop(fidx)\n\n        if frame_idx is None:\n            # clear all interval labels if clearing all keyframes\n            for interval_label in list(self.interval_labels.values()):\n                self.server.scene.remove_by_name(interval_label.name)\n            self.interval_labels.clear()\n\n    def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None:\n        show_all = only_frame is None\n        for fidx, scene_data in self.scene_elements.items():\n            visible = show_all or fidx == only_frame\n            scene_data[\"skeleton_mesh\"].set_visibility(visible)\n            if \"ee_rotation_axes\" in scene_data:\n                scene_data[\"ee_rotation_axes\"].visible = visible\n            label = scene_data.get(\"label\")\n            if label is not None:\n                label.visible = visible and self.labels_visible\n        for interval_label in self.interval_labels.values():\n            interval_label.visible = show_all and self.labels_visible\n\n\nclass RootKeyframe2DSet(ConstraintSet):\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        skeleton: SkeletonBase,\n        display_name: Optional[str] = None,\n    ):\n        super().__init__(name, server, skeleton, display_name=display_name)\n        self.dense_path = False\n        self.smooth_path = True\n        self.line_segments = None  # visualization of dense path\n        self.interval_line_segments = {}\n\n    def add_keyframe(\n        self,\n        keyframe_id: str,\n        frame_idx: int,\n        root_pos: torch.Tensor | np.ndarray,\n        viz_label: bool = True,\n        update_path: bool = True,\n        viz_waypoint: bool = True,\n        exists_ok: bool = False,\n    ):\n        \"\"\"Adds a single 2D root keyframe at the given frame or updates the existing one at this\n        frame.\n\n        Args:\n            keyframe_id: str, id for the keyframe. Must be unique within the given frame_idx.\n            frame_idx: int, frame index to add the keyframe at\n            root_pos: torch.Tensor, [3] root position to add the keyframe at, y entry (index 1) should be 0\n            viz_label: bool, whether to visualize the label for the keyframe\n        \"\"\"\n        root_pos_np = to_numpy(root_pos)\n        if frame_idx not in self.scene_elements:\n            self.scene_elements[frame_idx] = {}\n\n        scene_data = self.scene_elements[frame_idx]\n        if frame_idx in self.keyframes:\n            waypoint = scene_data.get(\"waypoint\")\n            if waypoint is not None:\n                waypoint.update_position(root_pos_np)\n            elif viz_waypoint:\n                waypoint = WaypointMesh(\n                    f\"/{self.name}/waypoint_{frame_idx}\",\n                    self.server,\n                    position=root_pos_np,\n                )\n                scene_data[\"waypoint\"] = waypoint\n\n            label = scene_data.get(\"label\")\n            if viz_label and label is not None:\n                label.position = root_pos_np\n                label.visible = self.labels_visible\n            elif viz_label and label is None:\n                label = self.server.scene.add_label(\n                    name=f\"/{self.name}/label_{frame_idx}\",\n                    text=f\"{self.display_name} @ {frame_idx}\",\n                    position=root_pos_np,\n                    font_size_mode=\"screen\",\n                    font_screen_scale=0.7,\n                    anchor=\"bottom-left\",\n                )\n                label.visible = self.labels_visible\n                scene_data[\"label\"] = label\n        else:\n            if viz_waypoint:\n                waypoint = WaypointMesh(\n                    f\"/{self.name}/waypoint_{frame_idx}\",\n                    self.server,\n                    position=root_pos_np,\n                )\n                scene_data[\"waypoint\"] = waypoint\n            if viz_label:\n                label = self.server.scene.add_label(\n                    name=f\"/{self.name}/label_{frame_idx}\",\n                    text=f\"{self.display_name} @ {frame_idx}\",\n                    position=root_pos_np,\n                    font_size_mode=\"screen\",\n                    font_screen_scale=0.7,\n                    anchor=\"bottom-left\",\n                )\n                label.visible = self.labels_visible\n                scene_data[\"label\"] = label\n\n        # set/update data\n        self.keyframes[frame_idx] = root_pos_np\n        if frame_idx not in self.frame2keyid:\n            self.frame2keyid[frame_idx] = []\n\n        if keyframe_id in self.frame2keyid[frame_idx]:\n            if not exists_ok:\n                raise AssertionError(\"keyframe_id already exists in this frame!\")\n        else:\n            self.frame2keyid[frame_idx].append(keyframe_id)\n\n        # need to update path visualization\n        if self.line_segments is not None and update_path:\n            self.update_line_segments()\n\n    def add_interval(\n        self,\n        interval_id: str,\n        start_frame_idx: int,\n        end_frame_idx: int,\n        root_pos: torch.Tensor | np.ndarray,\n    ):\n        \"\"\"Adds an interval of 2D root keyframes between the given start and end frames.\n\n        Args:\n            interval_id: str, id for the interval. Must be unique within the given start_frame_idx and end_frame_idx.\n            start_frame_idx: int, start frame index to add the interval at\n            end_frame_idx: int, end frame index to add the interval at\n            root_pos: torch.Tensor, [T, 3] root positions to add the interval at\n        \"\"\"\n        root_pos_np = to_numpy(root_pos)\n        assert root_pos_np.shape[0] == end_frame_idx - start_frame_idx + 1\n        if root_pos_np.shape[0] >= 2:\n            points = np.zeros((root_pos_np.shape[0] - 1, 2, 3))\n            points[:, 0] = root_pos_np[:-1]\n            points[:, 1] = root_pos_np[1:]\n            if interval_id in self.interval_line_segments:\n                self.server.scene.remove_by_name(self.interval_line_segments[interval_id].name)\n            self.interval_line_segments[interval_id] = self.server.scene.add_line_segments(\n                name=f\"/{self.name}/interval_{interval_id}_line\",\n                points=points,\n                colors=(255, 0, 0),\n                line_width=5.0,\n            )\n\n        for frame_idx in range(start_frame_idx, end_frame_idx + 1):\n            rel_idx = frame_idx - start_frame_idx\n            self.add_keyframe(\n                interval_id,\n                frame_idx,\n                root_pos_np[rel_idx],\n                viz_label=False,\n                update_path=False,\n                viz_waypoint=False,\n            )\n        self._add_interval_label(start_frame_idx, end_frame_idx)\n        if self.line_segments is not None:\n            self.update_line_segments()\n\n    def set_smooth_path(self, smooth_path: bool):\n        self.smooth_path = smooth_path\n        if self.line_segments is not None:\n            self.update_line_segments()\n\n    def set_dense_path(self, dense_path: bool):\n        \"\"\"If dense_path is True, will make the path dense by interpolated between added keyframes.\n\n        Args:\n            dense_path: bool, whether to make the path dense\n        \"\"\"\n        self.dense_path = dense_path\n        if self.dense_path:\n            # visualize dense path with line segments\n            self.line_segments = self.server.scene.add_line_segments(\n                name=f\"/{self.name}/line_segments\",\n                points=np.zeros((1, 2, 3)),\n                colors=(255, 0, 0),\n                line_width=5.0,\n            )\n            self.update_line_segments()\n        else:\n            if self.line_segments is not None:\n                self.server.scene.remove_by_name(self.line_segments.name)\n                self.line_segments = None\n\n    def interpolate_path(self, t: np.ndarray):\n        \"\"\"Interpolates the path between the given frame indices.\n\n        Args:\n            t: np.ndarray, frame indices to interpolate at\n        \"\"\"\n        from scipy.interpolate import interp1d\n\n        cur_info = self._get_sparse_constraint_info()\n        frame_idx = cur_info[\"frame_idx\"]\n        all_root_pos = cur_info[\"root_pos\"]\n\n        x = all_root_pos[:, 0]\n        z = all_root_pos[:, 2]\n\n        kind = \"linear\"\n        # if self.smooth_path and len(frame_idx) >= 3:\n        # kind = \"quadratic\"\n\n        interp_x = interp1d(frame_idx, x, kind=kind)\n        interp_z = interp1d(frame_idx, z, kind=kind)\n\n        x_new = interp_x(t)\n        z_new = interp_z(t)\n\n        path3d = np.stack([x_new, np.zeros_like(x_new), z_new], axis=1)\n\n        if self.smooth_path and len(frame_idx) >= 3:\n            path3d = get_smooth_root_pos(torch.from_numpy(path3d[None]))[0].numpy()\n        return path3d\n\n    def update_line_segments(self):\n        if len(self.keyframes) < 2:\n            return\n\n        t = np.array(sorted(self.get_frame_idx()))\n        if self.smooth_path:\n            # more points for smoothed curve\n            t = np.linspace(t[0], t[-1], 100)\n\n        path3d = self.interpolate_path(t)\n\n        points = np.zeros((len(path3d) - 1, 2, 3))\n        points[:, 0] = path3d[:-1]\n        points[:, 1] = path3d[1:]\n\n        self.line_segments.points = points\n\n    def remove_keyframe(self, keyframe_id: str, frame_idx: int):\n        if frame_idx not in self.keyframes:\n            return\n        if keyframe_id not in self.frame2keyid[frame_idx]:\n            return\n        self.frame2keyid[frame_idx].remove(keyframe_id)\n        if len(self.frame2keyid[frame_idx]) == 0:\n            del self.frame2keyid[frame_idx]\n            self.clear(frame_idx)\n            if self.line_segments is not None:\n                self.update_line_segments()\n\n    def _get_label_pos(self, frame_idx: int):\n        return self.keyframes[frame_idx]\n\n    def remove_interval(self, interval_id: str, start_frame_idx: int, end_frame_idx: int):\n        if interval_id in self.interval_line_segments:\n            self.server.scene.remove_by_name(self.interval_line_segments[interval_id].name)\n            del self.interval_line_segments[interval_id]\n        self._remove_interval_and_update_label(interval_id, start_frame_idx, end_frame_idx)\n\n    def _get_sparse_constraint_info(self):\n        all_root_pos = []\n        for v in self.keyframes.values():\n            v_np = to_numpy(v)\n            if len(v_np.shape) == 1:\n                all_root_pos.append(v_np[None])\n            else:\n                all_root_pos.append(v_np)\n        if len(all_root_pos) > 0:\n            all_root_pos = np.concatenate(all_root_pos, axis=0)\n        else:\n            all_root_pos = None\n        return {\n            \"frame_idx\": self.get_frame_idx(),\n            \"root_pos\": all_root_pos,\n        }\n\n    def get_constraint_info(self, device: Optional[str] = None):\n        if not self.dense_path or len(self.keyframes) == 0:\n            info = self._get_sparse_constraint_info()\n            return {\n                \"frame_idx\": info[\"frame_idx\"],\n                \"root_pos\": to_torch(info[\"root_pos\"], device=device, dtype=torch.float32),\n            }\n        else:\n            frame_idx_list = self.get_frame_idx()\n            min_frame_idx = min(frame_idx_list)\n            max_frame_idx = max(frame_idx_list)\n            t = np.arange(min_frame_idx, max_frame_idx + 1)\n            path3d = self.interpolate_path(t)\n            return {\n                \"frame_idx\": t.tolist(),\n                \"root_pos\": to_torch(path3d, device=device, dtype=torch.float32),\n            }\n\n    def clear(self, frame_idx: Optional[int] = None):\n        frame_idx_list = list(self.keyframes.keys()) if frame_idx is None else [frame_idx]\n        for fidx in frame_idx_list:\n            scene_data = self.scene_elements.get(fidx, {})\n            waypoint = scene_data.get(\"waypoint\")\n            if waypoint is not None:\n                waypoint.clear()\n            label = scene_data.get(\"label\")\n            if label is not None:\n                self.server.scene.remove_by_name(label.name)\n\n            self.keyframes.pop(fidx)\n            self.scene_elements.pop(fidx)\n\n        if frame_idx is None:\n            # clear all interval labels if clearing all keyframes\n            for interval_label in list(self.interval_labels.values()):\n                self.server.scene.remove_by_name(interval_label.name)\n            self.interval_labels.clear()\n\n            # clear line segments if turning off dense path\n            if self.line_segments is not None:\n                self.server.scene.remove_by_name(self.line_segments.name)\n                self.line_segments = None\n\n            for interval_line in list(self.interval_line_segments.values()):\n                self.server.scene.remove_by_name(interval_line.name)\n            self.interval_line_segments.clear()\n\n    def set_overlay_visibility(self, only_frame: Optional[int] = None) -> None:\n        show_all = only_frame is None\n        for fidx, scene_data in self.scene_elements.items():\n            visible = show_all or fidx == only_frame\n            waypoint = scene_data.get(\"waypoint\")\n            if waypoint is not None:\n                waypoint.set_visible(visible)\n            label = scene_data.get(\"label\")\n            if label is not None:\n                label.visible = visible and self.labels_visible\n        if self.line_segments is not None:\n            self.line_segments.visible = show_all\n        for line_handle in self.interval_line_segments.values():\n            line_handle.visible = show_all\n        for interval_label in self.interval_labels.values():\n            interval_label.visible = show_all and self.labels_visible\n\n\n#\n# GUI Elements that need to be tracked\n"
  },
  {
    "path": "kimodo/viz/coords.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Pure numpy coordinate/rotation helpers for viz.\"\"\"\n\nimport numpy as np\n\n\ndef skew(v: np.ndarray) -> np.ndarray:\n    \"\"\"Skew-symmetric matrix for cross products: skew(v) @ x == np.cross(v, x).\"\"\"\n    vx, vy, vz = float(v[0]), float(v[1]), float(v[2])\n    return np.array([[0.0, -vz, vy], [vz, 0.0, -vx], [-vy, vx, 0.0]], dtype=np.float64)\n\n\ndef rotation_matrix_from_two_vec(v_from: np.ndarray, v_to: np.ndarray, eps: float = 1e-8) -> np.ndarray:\n    \"\"\"Return R such that R @ v_from ~= v_to (both treated as 3D vectors).\n\n    Uses a Rodrigues-style construction, with special handling for near-parallel and near-opposite\n    vectors for numerical stability.\n    \"\"\"\n    a = np.asarray(v_from, dtype=np.float64).reshape(3)\n    b = np.asarray(v_to, dtype=np.float64).reshape(3)\n    na = np.linalg.norm(a)\n    nb = np.linalg.norm(b)\n    if na < eps or nb < eps:\n        return np.eye(3, dtype=np.float64)\n    a = a / na\n    b = b / nb\n\n    c = float(np.clip(np.dot(a, b), -1.0, 1.0))  # cos(theta)\n    if c > 1.0 - eps:\n        return np.eye(3, dtype=np.float64)\n    if c < -1.0 + eps:\n        # 180 deg rotation about any axis orthogonal to a:\n        # R = -I + 2 * uu^T, where u is a unit axis orthogonal to a.\n        axis_seed = np.array([1.0, 0.0, 0.0], dtype=np.float64)\n        if abs(float(np.dot(a, axis_seed))) > 0.9:\n            axis_seed = np.array([0.0, 1.0, 0.0], dtype=np.float64)\n        u = np.cross(a, axis_seed)\n        u = u / np.linalg.norm(u).clip(min=eps)\n        return -np.eye(3, dtype=np.float64) + 2.0 * np.outer(u, u)\n\n    v = np.cross(a, b)\n    s2 = float(np.dot(v, v))  # ||v||^2 == sin^2(theta)\n    K = skew(v)\n    # R = I + K + K^2 * ((1 - c) / s^2)\n    return np.eye(3, dtype=np.float64) + K + (K @ K) * ((1.0 - c) / s2)\n"
  },
  {
    "path": "kimodo/viz/g1_rig.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"G1 robot rig: mesh loading, joint mapping, and viser scene setup for G1 skeleton.\"\"\"\n\nimport os\nimport xml.etree.ElementTree as ET\nfrom typing import Any, Optional, Tuple\n\nimport numpy as np\nimport trimesh\n\nimport viser\nimport viser.transforms as tf\nfrom kimodo.assets import skeleton_asset_path\nfrom kimodo.skeleton import G1Skeleton34\n\n# MuJoCo (z-up, x-forward) -> kimodo (y-up, z-forward)\nMUJOCO_TO_KIMODO = np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=np.float64)\n\n# MuJoCo (z-up, x-forward) -> kimodo (y-up, z-forward)\nMUJOCO_TO_KIMODO = np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=np.float64)\n\nG1_MESH_JOINT_MAP = {\n    \"pelvis_skel\": [\"pelvis.STL\", \"pelvis_contour_link.STL\"],\n    \"left_hip_pitch_skel\": [\"left_hip_pitch_link.STL\"],\n    \"left_hip_roll_skel\": [\"left_hip_roll_link.STL\"],\n    \"left_hip_yaw_skel\": [\"left_hip_yaw_link.STL\"],\n    \"left_knee_skel\": [\"left_knee_link.STL\"],\n    \"left_ankle_pitch_skel\": [\"left_ankle_pitch_link.STL\"],\n    \"left_ankle_roll_skel\": [\"left_ankle_roll_link.STL\"],\n    \"right_hip_pitch_skel\": [\"right_hip_pitch_link.STL\"],\n    \"right_hip_roll_skel\": [\"right_hip_roll_link.STL\"],\n    \"right_hip_yaw_skel\": [\"right_hip_yaw_link.STL\"],\n    \"right_knee_skel\": [\"right_knee_link.STL\"],\n    \"right_ankle_pitch_skel\": [\"right_ankle_pitch_link.STL\"],\n    \"right_ankle_roll_skel\": [\"right_ankle_roll_link.STL\"],\n    \"waist_yaw_skel\": [\"waist_yaw_link_rev_1_0.STL\", \"waist_yaw_link.STL\"],\n    \"waist_roll_skel\": [\"waist_roll_link_rev_1_0.STL\", \"waist_roll_link.STL\"],\n    \"waist_pitch_skel\": [\n        \"torso_link_rev_1_0.STL\",\n        \"torso_link.STL\",\n        \"logo_link.STL\",\n        \"head_link.STL\",\n    ],\n    \"left_shoulder_pitch_skel\": [\"left_shoulder_pitch_link.STL\"],\n    \"left_shoulder_roll_skel\": [\"left_shoulder_roll_link.STL\"],\n    \"left_shoulder_yaw_skel\": [\"left_shoulder_yaw_link.STL\"],\n    \"left_elbow_skel\": [\"left_elbow_link.STL\"],\n    \"left_wrist_roll_skel\": [\"left_wrist_roll_link.STL\"],\n    \"left_wrist_pitch_skel\": [\"left_wrist_pitch_link.STL\"],\n    \"left_wrist_yaw_skel\": [\"left_wrist_yaw_link.STL\", \"left_rubber_hand.STL\"],\n    \"right_shoulder_pitch_skel\": [\"right_shoulder_pitch_link.STL\"],\n    \"right_shoulder_roll_skel\": [\"right_shoulder_roll_link.STL\"],\n    \"right_shoulder_yaw_skel\": [\"right_shoulder_yaw_link.STL\"],\n    \"right_elbow_skel\": [\"right_elbow_link.STL\"],\n    \"right_wrist_roll_skel\": [\"right_wrist_roll_link.STL\"],\n    \"right_wrist_pitch_skel\": [\"right_wrist_pitch_link.STL\"],\n    \"right_wrist_yaw_skel\": [\"right_wrist_yaw_link.STL\", \"right_rubber_hand.STL\"],\n}\n\n# Joint axis/limits from g1.xml (used by exports, e.g. MujocoQposConverter)\n_G1_JOINT_AXIS_INDEX_CACHE: Optional[dict[str, int]] = None\n_G1_JOINT_LIMITS_CACHE: Optional[dict[str, tuple[float, float]]] = None\n\n\ndef _get_g1_joint_axis_indices() -> dict[str, int]:\n    \"\"\"Return a map from G1 joint names to a single rotation axis index.\"\"\"\n    global _G1_JOINT_AXIS_INDEX_CACHE\n    if _G1_JOINT_AXIS_INDEX_CACHE is not None:\n        return _G1_JOINT_AXIS_INDEX_CACHE\n\n    xml_path = str(skeleton_asset_path(\"g1skel34\", \"xml\", \"g1.xml\"))\n    if not os.path.exists(xml_path):\n        _G1_JOINT_AXIS_INDEX_CACHE = {}\n        return _G1_JOINT_AXIS_INDEX_CACHE\n\n    tree = ET.parse(xml_path)\n    root = tree.getroot()\n\n    joint_axes = {}\n    for xml_class in tree.findall(\".//default\"):\n        if \"class\" not in xml_class.attrib:\n            continue\n        joint_nodes = xml_class.findall(\"joint\")\n        if joint_nodes:\n            joint_axes[xml_class.get(\"class\")] = joint_nodes[0].get(\"axis\")\n\n    axis_indices_by_name: dict[str, int] = {}\n    for joint in root.find(\"worldbody\").findall(\".//joint\"):\n        axis_str = joint.get(\"axis\") or joint_axes.get(joint.get(\"class\"))\n        if axis_str is None:\n            continue\n        axis_vals = np.array([float(x) for x in axis_str.split()], dtype=np.float64)\n        if not np.any(axis_vals):\n            continue\n        axis_kimodo = MUJOCO_TO_KIMODO @ axis_vals\n        axis_idx = int(np.argmax(np.abs(axis_kimodo)))\n        axis_indices_by_name[joint.get(\"name\").replace(\"_joint\", \"_skel\")] = axis_idx\n\n    _G1_JOINT_AXIS_INDEX_CACHE = axis_indices_by_name\n    return _G1_JOINT_AXIS_INDEX_CACHE\n\n\ndef _get_g1_joint_limits() -> dict[str, tuple[float, float]]:\n    \"\"\"Return a map from G1 joint names to (min, max) angle limits in radians.\"\"\"\n    global _G1_JOINT_LIMITS_CACHE\n    if _G1_JOINT_LIMITS_CACHE is not None:\n        return _G1_JOINT_LIMITS_CACHE\n\n    xml_path = str(skeleton_asset_path(\"g1skel34\", \"xml\", \"g1.xml\"))\n    if not os.path.exists(xml_path):\n        _G1_JOINT_LIMITS_CACHE = {}\n        return _G1_JOINT_LIMITS_CACHE\n\n    tree = ET.parse(xml_path)\n    root = tree.getroot()\n\n    class_ranges: dict[str, tuple[float, float]] = {}\n    for xml_class in tree.findall(\".//default\"):\n        class_name = xml_class.get(\"class\")\n        if not class_name:\n            continue\n        joint_nodes = xml_class.findall(\"joint\")\n        if not joint_nodes:\n            continue\n        range_str = joint_nodes[0].get(\"range\")\n        if not range_str:\n            continue\n        range_vals = [float(x) for x in range_str.split()]\n        if len(range_vals) != 2:\n            continue\n        class_ranges[class_name] = (range_vals[0], range_vals[1])\n\n    joint_limits: dict[str, tuple[float, float]] = {}\n    worldbody = root.find(\"worldbody\")\n    if worldbody is None:\n        _G1_JOINT_LIMITS_CACHE = {}\n        return _G1_JOINT_LIMITS_CACHE\n\n    for joint in worldbody.findall(\".//joint\"):\n        range_str = joint.get(\"range\") or class_ranges.get(joint.get(\"class\"))\n        if range_str is None:\n            continue\n        if isinstance(range_str, tuple):\n            joint_range = range_str\n        else:\n            range_vals = [float(x) for x in range_str.split()]\n            if len(range_vals) != 2:\n                continue\n            joint_range = (range_vals[0], range_vals[1])\n        joint_name = joint.get(\"name\")\n        if not joint_name:\n            continue\n        joint_limits[joint_name.replace(\"_joint\", \"_skel\")] = joint_range\n\n    _G1_JOINT_LIMITS_CACHE = joint_limits\n    return _G1_JOINT_LIMITS_CACHE\n\n\n_G1_JOINT_F2Q_DATA_CACHE: Optional[dict[str, dict[str, Any]]] = None\n\n\ndef get_g1_joint_f2q_data(\n    skeleton: G1Skeleton34,\n) -> dict[str, dict[str, Any]]:\n    \"\"\"Return per-hinge-joint f2q data for correct 1-DoF + limits in offset space.\n\n    Each entry is for a G1 hinge joint (by name) and contains:\n      - \"offset_f2q\": (3, 3) matrix such that R_f2q = offset_f2q @ R_local (kimodo).\n      - \"axis_f2q\": (3,) unit axis in f2q space; angle = dot(axis_angle(R_f2q), axis_f2q).\n      - \"rest_dof_axis_angle\": angle (rad) at T-pose in f2q space; MuJoCo q = angle_f2q - this.\n\n    Limits from the XML apply to q = angle_f2q - rest_dof_axis_angle.\n    \"\"\"\n    global _G1_JOINT_F2Q_DATA_CACHE\n    if _G1_JOINT_F2Q_DATA_CACHE is not None:\n        return _G1_JOINT_F2Q_DATA_CACHE\n\n    from kimodo.exports.mujoco import MujocoQposConverter\n\n    converter = MujocoQposConverter(skeleton)\n    # converter: _rot_offsets_f2q[kimodo_idx], _mujoco_joint_axis_values_f2q_space[hinge_idx],\n    # _rest_dofs_axis_angle[hinge_idx], _kimodo_indices_to_mujoco_indices[kimodo_idx] = hinge_idx+1 or 0\n    out: dict[str, dict[str, Any]] = {}\n    for j in range(skeleton.nbjoints):\n        mujoco_one_based = converter._kimodo_indices_to_mujoco_indices[j].item()\n        if mujoco_one_based <= 0:\n            continue\n        hinge_idx = mujoco_one_based - 1\n        joint_name = skeleton.bone_order_names[j]\n        offset_f2q = converter._rot_offsets_f2q[j].detach().cpu().numpy().astype(np.float64)\n        axis_f2q = converter._mujoco_joint_axis_values_f2q_space[hinge_idx].detach().cpu().numpy().astype(np.float64)\n        n = np.linalg.norm(axis_f2q)\n        if n > 1e-10:\n            axis_f2q = axis_f2q / n\n        rest_dof = float(converter._rest_dofs_axis_angle[hinge_idx].detach().cpu().numpy())\n        out[joint_name] = {\n            \"offset_f2q\": offset_f2q,\n            \"axis_f2q\": axis_f2q,\n            \"rest_dof_axis_angle\": rest_dof,\n        }\n    _G1_JOINT_F2Q_DATA_CACHE = out\n    return out\n\n\n# -----------------------------------------------------------------------------\n# Mesh loading cache (shared across G1 rig instances; each rig gets its own scene meshes)\n# -----------------------------------------------------------------------------\n_G1_MESH_DATA_CACHE: dict[str, list[dict]] = {}\n\n\ndef _load_g1_mesh_data(\n    mesh_dir: str,\n    skeleton: G1Skeleton34,\n) -> list[dict]:\n    \"\"\"Load STL meshes and XML transforms once per mesh_dir; shared across rig instances.\"\"\"\n    if mesh_dir in _G1_MESH_DATA_CACHE:\n        return _G1_MESH_DATA_CACHE[mesh_dir]\n\n    mesh_geom_cache = G1MeshRig._mesh_geom_cache\n    mesh_transform_cache = G1MeshRig._mesh_transform_cache\n\n    # Load XML-derived transforms (cached inside _get_mesh_local_transforms_impl)\n    mesh_file_transforms = _get_mesh_local_transforms_impl(mesh_dir, mesh_transform_cache)\n\n    data_list: list[dict] = []\n    for joint_name, mesh_files in G1_MESH_JOINT_MAP.items():\n        if joint_name not in skeleton.bone_index:\n            continue\n        joint_idx = skeleton.bone_index[joint_name]\n        for mesh_file in mesh_files:\n            mesh_path = os.path.join(mesh_dir, mesh_file)\n            if not os.path.exists(mesh_path):\n                continue\n            vertices, faces = _get_mesh_geometry_impl(mesh_file, mesh_path, mesh_dir, mesh_geom_cache)\n            if vertices is None:\n                continue\n            geom_pos, geom_rot = mesh_file_transforms.get(\n                mesh_file,\n                (np.zeros(3, dtype=np.float64), np.eye(3, dtype=np.float64)),\n            )\n            data_list.append(\n                {\n                    \"mesh_file\": mesh_file,\n                    \"vertices\": vertices,\n                    \"faces\": faces,\n                    \"joint_idx\": joint_idx,\n                    \"geom_pos\": geom_pos.copy(),\n                    \"geom_rot\": geom_rot.copy(),\n                }\n            )\n\n    _G1_MESH_DATA_CACHE[mesh_dir] = data_list\n    return data_list\n\n\ndef _get_mesh_geometry_impl(\n    mesh_file: str,\n    mesh_path: str,\n    mesh_dir: str,\n    mesh_geom_cache: dict,\n) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:\n    \"\"\"Load one STL; result cached per mesh_dir and shared across rigs.\"\"\"\n    cached = mesh_geom_cache.get(mesh_dir)\n    if cached is not None and mesh_file in cached:\n        vertices, faces = cached[mesh_file]\n        return vertices.copy(), faces.copy()\n\n    mesh = trimesh.load_mesh(mesh_path, process=True)\n    if isinstance(mesh, trimesh.Scene):\n        mesh = trimesh.util.concatenate(mesh.dump())\n    vertices = mesh.vertices @ MUJOCO_TO_KIMODO.T\n    faces = mesh.faces\n\n    if mesh_dir not in mesh_geom_cache:\n        mesh_geom_cache[mesh_dir] = {}\n    mesh_geom_cache[mesh_dir][mesh_file] = (vertices, faces)\n    return vertices.copy(), faces.copy()\n\n\ndef _get_mesh_local_transforms_impl(\n    mesh_dir: str,\n    mesh_transform_cache: dict,\n) -> dict[str, tuple[np.ndarray, np.ndarray]]:\n    \"\"\"Parse g1.xml once per mesh_dir; result shared across G1 rig instances.\"\"\"\n    cached = mesh_transform_cache.get(mesh_dir)\n    if cached is not None:\n        return {mesh_file: (pos.copy(), rot.copy()) for mesh_file, (pos, rot) in cached.items()}\n\n    xml_path = os.path.abspath(os.path.join(mesh_dir, \"..\", \"..\", \"xml\", \"g1.xml\"))\n    if not os.path.exists(xml_path):\n        return {}\n    tree = ET.parse(xml_path)\n    root = tree.getroot()\n\n    mesh_file_to_mesh_name = {}\n    for mesh in root.findall(\".//asset/mesh\"):\n        mesh_name = mesh.get(\"name\")\n        mesh_file = mesh.get(\"file\")\n        if mesh_name and mesh_file:\n            mesh_file_to_mesh_name[mesh_file] = mesh_name\n\n    mesh_name_to_transform = {}\n    for geom in root.findall(\".//geom\"):\n        mesh_name = geom.get(\"mesh\")\n        if mesh_name is None:\n            continue\n        pos = geom.get(\"pos\")\n        quat = geom.get(\"quat\")\n        if pos is None:\n            geom_pos = np.zeros(3, dtype=np.float64)\n        else:\n            geom_pos = np.array([float(x) for x in pos.split()], dtype=np.float64)\n        if quat is None:\n            geom_rot = np.eye(3, dtype=np.float64)\n        else:\n            wxyz = np.array([float(x) for x in quat.split()], dtype=np.float64)\n            geom_rot = tf.SO3(wxyz=wxyz).as_matrix()\n        mesh_name_to_transform[mesh_name] = (geom_pos, geom_rot)\n\n    mesh_file_transforms = {}\n    for mesh_file, mesh_name in mesh_file_to_mesh_name.items():\n        geom_pos, geom_rot = mesh_name_to_transform.get(\n            mesh_name,\n            (np.zeros(3, dtype=np.float64), np.eye(3, dtype=np.float64)),\n        )\n        geom_pos = MUJOCO_TO_KIMODO @ geom_pos\n        geom_rot = MUJOCO_TO_KIMODO @ geom_rot @ MUJOCO_TO_KIMODO.T\n        mesh_file_transforms[mesh_file] = (geom_pos, geom_rot)\n\n    mesh_transform_cache[mesh_dir] = {mf: (pos.copy(), rot.copy()) for mf, (pos, rot) in mesh_file_transforms.items()}\n    return mesh_file_transforms\n\n\nclass G1MeshRig:\n    \"\"\"Rig for G1 STL meshes.\n\n    Each instance has its own scene meshes (so clear() only removes one character). Loading is\n    shared: STL files and g1.xml are cached per mesh_dir via _load_g1_mesh_data() and the class-\n    level _mesh_*_cache dicts.\n    \"\"\"\n\n    _mesh_geom_cache: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {}\n    _mesh_transform_cache: dict[str, dict[str, tuple[np.ndarray, np.ndarray]]] = {}\n\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer | viser.ClientHandle,\n        skeleton: G1Skeleton34,\n        mesh_dir: str,\n        color: Tuple[int, int, int],\n    ):\n        self.server = server\n        self.skeleton = skeleton\n        self.mesh_dir = mesh_dir\n        self.color = color\n        self.mesh_handles: list[viser.SceneHandle] = []\n        self.mesh_items: list[dict[str, object]] = []\n        self._defer_initial_visibility = True\n\n        data_list = _load_g1_mesh_data(mesh_dir, skeleton)\n\n        for item in data_list:\n            mesh_file = item[\"mesh_file\"]\n            vertices = item[\"vertices\"]\n            faces = item[\"faces\"]\n            joint_idx = item[\"joint_idx\"]\n            geom_pos = item[\"geom_pos\"]\n            geom_rot = item[\"geom_rot\"]\n\n            handle = self.server.scene.add_mesh_simple(\n                f\"/{name}/g1_mesh/{os.path.splitext(mesh_file)[0]}\",\n                vertices=vertices,\n                faces=faces,\n                opacity=None,\n                color=self.color,\n                wireframe=False,\n                visible=not self._defer_initial_visibility,\n            )\n            self.mesh_handles.append(handle)\n            self.mesh_items.append(\n                {\n                    \"handle\": handle,\n                    \"joint_idx\": joint_idx,\n                    \"geom_pos\": geom_pos,\n                    \"geom_rot\": geom_rot,\n                }\n            )\n\n        if self._defer_initial_visibility:\n            for handle in self.mesh_handles:\n                handle.visible = True\n\n    def set_visibility(self, visible: bool) -> None:\n        for handle in self.mesh_handles:\n            handle.visible = visible\n\n    def set_opacity(self, opacity: float) -> None:\n        for handle in self.mesh_handles:\n            handle.opacity = opacity\n\n    def set_wireframe(self, wireframe: bool) -> None:\n        for handle in self.mesh_handles:\n            handle.wireframe = wireframe\n\n    def set_color(self, color: Tuple[int, int, int]) -> None:\n        self.color = color\n        for handle in self.mesh_handles:\n            handle.color = color\n\n    def set_pose(self, joints_pos: np.ndarray, joints_rot: np.ndarray) -> None:\n        for item in self.mesh_items:\n            handle = item[\"handle\"]\n            joint_idx = item[\"joint_idx\"]\n            geom_pos = item[\"geom_pos\"]\n            geom_rot = item[\"geom_rot\"]\n\n            joint_pos = joints_pos[joint_idx]\n            joint_rot = joints_rot[joint_idx]\n            mesh_pos = joint_pos + joint_rot @ geom_pos\n            mesh_rot = joint_rot @ geom_rot\n\n            handle.position = mesh_pos\n            handle.wxyz = tf.SO3.from_matrix(mesh_rot).wxyz\n\n    def clear(self) -> None:\n        for handle in self.mesh_handles:\n            self.server.scene.remove_by_name(handle.name)\n        self.mesh_handles = []\n        self.mesh_items = []\n"
  },
  {
    "path": "kimodo/viz/gui.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"GUI element handles for the demo app.\"\"\"\n\nfrom dataclasses import dataclass\n\nimport viser\n\n\n@dataclass\nclass GuiElements:\n    gui_play_pause_button: viser.GuiInputHandle\n    gui_next_frame_button: viser.GuiInputHandle\n    gui_prev_frame_button: viser.GuiInputHandle\n    gui_generate_button: viser.GuiInputHandle\n    gui_model_fps: viser.GuiInputHandle[int]\n    gui_timeline: viser.GuiInputHandle[int]\n    gui_viz_skeleton_checkbox: viser.GuiInputHandle[bool]\n    gui_viz_foot_contacts_checkbox: viser.GuiInputHandle[bool]\n    gui_viz_skinned_mesh_checkbox: viser.GuiInputHandle[bool]\n    gui_viz_skinned_mesh_opacity_slider: viser.GuiInputHandle[float]\n    gui_camera_fov_slider: viser.GuiInputHandle[float]\n\n    # generation controls\n    gui_duration_slider: viser.GuiInputHandle[float]\n    gui_num_samples_slider: viser.GuiInputHandle[int]\n    gui_cfg_checkbox: viser.GuiCheckboxHandle\n    gui_cfg_text_weight_slider: viser.GuiInputHandle[float]\n    gui_cfg_constraint_weight_slider: viser.GuiInputHandle[float]\n    gui_diffusion_steps_slider: viser.GuiInputHandle[int]\n    gui_seed: viser.GuiInputHandle[int]\n    gui_postprocess_checkbox: viser.GuiCheckboxHandle\n    gui_root_margin: viser.GuiInputHandle[float]\n    gui_real_robot_rotations_checkbox: viser.GuiInputHandle[bool]\n    # appearance\n    gui_dark_mode_checkbox: viser.GuiCheckboxHandle\n\n    # which skinning method to use for SOMA\n    gui_use_soma_layer_checkbox: viser.GuiCheckboxHandle\n"
  },
  {
    "path": "kimodo/viz/playback.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Playback and motion editing: CharacterMotion.\"\"\"\n\nfrom typing import Callable, Literal, Optional\n\nimport numpy as np\nimport torch\n\nimport viser.transforms as tf\nfrom kimodo.skeleton import (\n    G1Skeleton34,\n    SOMASkeleton30,\n    SOMASkeleton77,\n    batch_rigid_transform,\n    global_rots_to_local_rots,\n)\nfrom kimodo.tools import to_numpy, to_torch\n\nfrom .g1_rig import (\n    _get_g1_joint_axis_indices,\n    _get_g1_joint_limits,\n    get_g1_joint_f2q_data,\n)\nfrom .scene import Character\n\n\nclass CharacterMotion:\n    def __init__(\n        self,\n        character: Character,\n        joints_pos: torch.Tensor,\n        joints_rot: torch.Tensor,\n        foot_contacts: Optional[torch.Tensor] = None,\n    ):\n        self.character = character\n        self.server = character.server\n        self.skeleton = character.skeleton\n        self.name = character.name\n\n        # [T, J, 3] global joint positions\n        self.joints_pos = joints_pos\n        # [T, J, 3, 3] global joint rotation matrices\n        self.joints_rot = joints_rot\n        assert joints_pos.shape[0] == joints_rot.shape[0]\n        # keep track of local rots as well for convenience during pose editing\n        self.joints_local_rot = global_rots_to_local_rots(joints_rot, self.skeleton)\n\n        self.length = joints_pos.shape[0]\n        self.cur_frame_idx = None\n\n        self.foot_contacts = foot_contacts\n        if foot_contacts is not None:\n            assert foot_contacts.shape[0] == self.length\n\n        self.precompute_mesh_info()\n\n        # gizmos for pose editing\n        self.root_translation_gizmo = None\n        self.updating_root_translation_gizmo = False\n        self.joint_gizmos = None\n        self.updating_joint_gizmos = False\n        self.gizmo_space: Literal[\"world\", \"local\"] = \"local\"\n        self._drag_start_world_rot: list = []\n        self._joint_gizmo_dragging: list[bool] = []\n\n    def precompute_mesh_info(self):\n        if self.character.skeleton_mesh is not None:\n            print(\"Caching skeleton mesh info...\")\n            self.character.skeleton_mesh.precompute_mesh_info(self.joints_pos)\n        if self.character.skinned_mesh is not None:\n            print(\"Caching skinning info...\")\n            self.character.precompute_skinning(self.joints_pos, self.joints_rot)\n\n    def set_frame(self, idx: int):\n        \"\"\"Sets the pose of the character to the given frame index.\"\"\"\n        idx = min(idx, self.length - 1)  # clamp to last frame\n        cur_foot_contacts = self.foot_contacts[idx] if self.foot_contacts is not None else None\n        self.character.set_pose(\n            self.joints_pos[idx],\n            self.joints_rot[idx],\n            frame_idx=idx,\n            foot_contacts=cur_foot_contacts,\n        )\n        self.cur_frame_idx = idx\n\n        # update gizmos if frame has changed due to playback\n        cur_root_pos = self.joints_pos[self.cur_frame_idx, self.skeleton.root_idx].clone()\n        cur_root_pos[1] = 0.0\n        if self.root_translation_gizmo is not None and not self.updating_root_translation_gizmo:\n            self.root_translation_gizmo.position = cur_root_pos.cpu().numpy()\n        if self.joint_gizmos is not None:\n            for i, joint_gizmo in enumerate(self.joint_gizmos):\n                # Do not push wxyz/position while this gizmo is being dragged;\n                # otherwise the client receives e.g. identity and the gizmo snaps back.\n                if not self.updating_joint_gizmos and not self._joint_gizmo_dragging[i]:\n                    joint_gizmo.position = self.joints_pos[self.cur_frame_idx, i].cpu().numpy()\n                    if self.gizmo_space == \"world\":\n                        joint_gizmo.wxyz = (1.0, 0.0, 0.0, 0.0)\n                    else:\n                        joint_gizmo.wxyz = tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx, i].cpu().numpy()).wxyz\n\n    def update_pose_at_frame(\n        self,\n        frame_idx: int,\n        joints_pos: Optional[torch.Tensor] = None,\n        joints_rot: Optional[torch.Tensor] = None,\n        joints_local_rot: Optional[torch.Tensor] = None,\n        foot_contacts: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"Overwrites one or more of the pose components at the given frame.\n\n        If only a subset of joints_pos, joints_rot, or joints_local_rot are provided, the other\n        components will be updated with FK.\n        \"\"\"\n        if joints_pos is not None:\n            joints_pos = to_torch(joints_pos, device=self.joints_pos.device, dtype=self.joints_pos.dtype)\n            self.joints_pos[frame_idx] = joints_pos\n            if joints_local_rot is None and joints_rot is None:\n                raise NotImplementedError(\"No IK to update joint rotations accordingly.\")\n        if joints_rot is not None:\n            joints_rot = to_torch(joints_rot, device=self.joints_rot.device, dtype=self.joints_rot.dtype)\n            self.joints_rot[frame_idx] = joints_rot\n            if joints_local_rot is None:\n                # update local rots from global rots\n                self.joints_local_rot[frame_idx] = global_rots_to_local_rots(joints_rot, self.skeleton)\n            if joints_pos is None:\n                # need to update with FK\n                new_posed_joints, _ = batch_rigid_transform(\n                    self.joints_local_rot[frame_idx : frame_idx + 1],\n                    self.skeleton.neutral_joints[None].to(self.joints_local_rot.device),\n                    self.skeleton.joint_parents.to(self.joints_local_rot.device),\n                    self.skeleton.root_idx,\n                )\n                new_posed_joints = (\n                    new_posed_joints[0]\n                    + self.joints_pos[frame_idx, self.skeleton.root_idx : self.skeleton.root_idx + 1]\n                    - self.skeleton.neutral_joints[[self.skeleton.root_idx]]\n                )\n                self.joints_pos[frame_idx] = new_posed_joints\n        if joints_local_rot is not None:\n            joints_local_rot = to_torch(joints_local_rot, device=self.joints_local_rot.device).to(\n                dtype=self.joints_local_rot.dtype\n            )\n            self.joints_local_rot[frame_idx] = joints_local_rot\n            if joints_rot is None or joints_pos is None:\n                # need to update with FK\n                new_posed_joints, new_global_rots = batch_rigid_transform(\n                    self.joints_local_rot[frame_idx : frame_idx + 1],\n                    self.skeleton.neutral_joints[None].to(self.joints_local_rot.device),\n                    self.skeleton.joint_parents.to(self.joints_local_rot.device),\n                    self.skeleton.root_idx,\n                )\n                new_posed_joints = (\n                    new_posed_joints[0]\n                    + self.joints_pos[frame_idx, self.skeleton.root_idx : self.skeleton.root_idx + 1]\n                    - self.skeleton.neutral_joints[[self.skeleton.root_idx]]\n                )\n                if joints_rot is None:\n                    self.joints_rot[frame_idx] = new_global_rots[0]\n                if joints_pos is None:\n                    self.joints_pos[frame_idx] = new_posed_joints\n        if foot_contacts is not None:\n            foot_contacts = to_torch(foot_contacts, device=self.foot_contacts.device).to(dtype=self.foot_contacts.dtype)\n            self.foot_contacts[frame_idx] = foot_contacts\n\n        if self.character.skeleton_mesh is not None:\n            self.character.skeleton_mesh.update_mesh_info_cache(self.joints_pos[frame_idx], frame_idx)\n        if self.character.skinned_mesh is not None:\n            self.character.update_skinning_cache(self.joints_pos[frame_idx], self.joints_rot[frame_idx], frame_idx)\n\n    def clear(self):\n        self.character.clear()\n\n    #\n    # Editing helpers\n    #\n    def get_current_projected_root_pos(self) -> np.ndarray:\n        \"\"\"Get the projected root position on the ground at the current frame.\"\"\"\n        root_pos = self.joints_pos[self.cur_frame_idx, self.skeleton.root_idx].clone()\n        root_pos[1] = 0.0\n        return to_numpy(root_pos)\n\n    def get_projected_root_pos(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray:\n        \"\"\"If requested frames are out of range, simply pads with the last frame to get expected\n        length.\"\"\"\n        if end_frame_idx is None:\n            expected_len = 1\n        else:\n            expected_len = end_frame_idx - start_frame_idx + 1\n        if start_frame_idx >= self.length:\n            start_frame_idx = self.length - 1\n        if end_frame_idx is None or expected_len == 1:\n            root_pos = self.joints_pos[start_frame_idx, self.skeleton.root_idx].clone()\n            root_pos[1] = 0.0\n            return to_numpy(root_pos)\n        else:\n            if end_frame_idx >= self.length:\n                end_frame_idx = self.length - 1\n            root_pos = self.joints_pos[start_frame_idx : end_frame_idx + 1, self.skeleton.root_idx].clone()\n            root_pos[:, 1] = 0.0\n            if root_pos.shape[0] < expected_len:\n                # pad with the last root position\n                root_pos = torch.cat(\n                    [\n                        root_pos,\n                        root_pos[-1:].repeat(expected_len - root_pos.shape[0], 1),\n                    ],\n                    dim=0,\n                )\n            return to_numpy(root_pos)\n\n    def set_projected_root_pos_path(\n        self,\n        root_pos_path: np.ndarray | torch.Tensor,\n        min_frame_idx: int = None,\n        max_frame_idx: int = None,\n    ):\n        \"\"\"Sets the projected root position path for the character motion. Can set only a subset of\n        the path by providing min_frame_idx and max_frame_idx. If not provided, will set the full\n        path.\n\n        Args:\n            root_pos_path: torch.Tensor, [T, 2] projected root positions\n            min_frame_idx: int, optional, minimum frame index to set the path at\n            max_frame_idx: int, optional, maximum frame index to set the path at\n        \"\"\"\n        if min_frame_idx is not None or max_frame_idx is not None:\n            assert (\n                min_frame_idx is not None and max_frame_idx is not None\n            ), \"min_frame_idx and max_frame_idx must be provided if setting path at specific frames\"\n            if min_frame_idx >= self.length:\n                # both are out of bounds\n                return\n            max_frame_idx = min(max_frame_idx, self.length - 1)\n            root_pos_path = root_pos_path[min_frame_idx : max_frame_idx + 1]\n        else:\n            assert root_pos_path.shape[0] == self.length\n            min_frame_idx = 0\n            max_frame_idx = self.length - 1\n\n        cur_joints_pos = self.joints_pos.clone()[min_frame_idx : max_frame_idx + 1]\n        root_pos_tensor = to_torch(root_pos_path, device=cur_joints_pos.device, dtype=cur_joints_pos.dtype)\n        diff = root_pos_tensor - cur_joints_pos[:, self.skeleton.root_idx, [0, 2]]\n        cur_joints_pos[:, :, [0, 2]] += diff.unsqueeze(1)\n        for frame_idx in range(min_frame_idx, max_frame_idx + 1):\n            rel_idx = frame_idx - min_frame_idx\n            self.update_pose_at_frame(\n                frame_idx,\n                joints_pos=cur_joints_pos[rel_idx],\n                joints_rot=self.joints_rot[frame_idx],\n                joints_local_rot=self.joints_local_rot[frame_idx],\n            )\n        # update immediately to show changes\n        self.set_frame(self.cur_frame_idx)\n\n    def get_joints_pos(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray:\n        \"\"\"If requested frames are out of range, simply pads with the last frame to get expected\n        length.\"\"\"\n        if end_frame_idx is None:\n            expected_len = 1\n        else:\n            expected_len = end_frame_idx - start_frame_idx + 1\n        if start_frame_idx >= self.length:\n            start_frame_idx = self.length - 1\n        if end_frame_idx is None or expected_len == 1:\n            return to_numpy(self.joints_pos[start_frame_idx].clone())\n        else:\n            if end_frame_idx >= self.length:\n                end_frame_idx = self.length - 1\n            return_joints_pos = self.joints_pos[start_frame_idx : end_frame_idx + 1].clone()\n            if return_joints_pos.shape[0] < expected_len:\n                # pad with the last pose\n                return_joints_pos = torch.cat(\n                    [\n                        return_joints_pos,\n                        return_joints_pos[-1:].repeat(expected_len - return_joints_pos.shape[0], 1, 1),\n                    ],\n                    dim=0,\n                )\n            return to_numpy(return_joints_pos)\n\n    def get_joints_rot(self, start_frame_idx: int, end_frame_idx: int = None) -> np.ndarray:\n        \"\"\"If requested frames are out of range, simply pads with the last frame to get expected\n        length.\"\"\"\n        if end_frame_idx is None:\n            expected_len = 1\n        else:\n            expected_len = end_frame_idx - start_frame_idx + 1\n        if start_frame_idx >= self.length:\n            start_frame_idx = self.length - 1\n        if end_frame_idx is None or expected_len == 1:\n            return to_numpy(self.joints_rot[start_frame_idx].clone())\n        else:\n            if end_frame_idx >= self.length:\n                end_frame_idx = self.length - 1\n            return_joints_rot = self.joints_rot[start_frame_idx : end_frame_idx + 1].clone()\n            if return_joints_rot.shape[0] < expected_len:\n                # pad with the last pose\n                return_joints_rot = torch.cat(\n                    [\n                        return_joints_rot,\n                        return_joints_rot[-1:].repeat(expected_len - return_joints_rot.shape[0], 1, 1, 1),\n                    ],\n                    dim=0,\n                )\n            return to_numpy(return_joints_rot)\n\n    def get_current_joints_pos(self) -> torch.Tensor:\n        return self.joints_pos[self.cur_frame_idx].clone()\n\n    def get_current_joints_rot(self) -> torch.Tensor:\n        return self.joints_rot[self.cur_frame_idx].clone()\n\n    def add_root_translation_gizmo(\n        self,\n        constraints: dict,\n        on_2d_root_drag_end: Optional[Callable[[], None]] = None,\n        on_drag_start: Optional[Callable[[], None]] = None,\n    ):\n        \"\"\"Create and initialize gizmo to control the root translation.\n\n        When the user drags the root 2D gizmo, path updates are skipped until release. Optional\n        on_2d_root_drag_end is called when the drag ends (e.g. to refresh dense path). on_drag_start\n        is called when the drag begins (e.g. to snapshot state for undo).\n        \"\"\"\n        # TODO: could also allow rotation around y-axis\n        self.root_translation_gizmo = self.server.scene.add_transform_controls(\n            f\"/{self.name}/gizmo_root_translation\",\n            scale=0.5,\n            line_width=2.5,\n            active_axes=(True, False, True),  # only allow translation on xz plane\n            disable_axes=False,\n            disable_sliders=False,\n            disable_rotations=True,\n            depth_test=False,  # render even when occluded\n        )\n        init_position = self.get_current_projected_root_pos()\n        self.root_translation_gizmo.position = init_position\n\n        @self.root_translation_gizmo.on_drag_start\n        def _(_):\n            if on_drag_start is not None:\n                on_drag_start()\n\n        @self.root_translation_gizmo.on_update\n        def _(_):\n            self.updating_root_translation_gizmo = True\n            # translate to gizmo position\n            new_root_pos = to_torch(\n                self.root_translation_gizmo.position,\n                device=self.joints_pos.device,\n            ).to(dtype=self.joints_pos.dtype)\n            cur_joints_pos = self.joints_pos[self.cur_frame_idx].clone()\n            root_diff = new_root_pos - cur_joints_pos[self.skeleton.root_idx]\n            root_diff[1] = 0.0  # don't change height\n            cur_joints_pos += root_diff[None]\n            self.update_pose_at_frame(\n                self.cur_frame_idx,\n                joints_pos=cur_joints_pos,\n                joints_rot=self.joints_rot[self.cur_frame_idx],\n                joints_local_rot=self.joints_local_rot[self.cur_frame_idx],\n            )\n\n            self.updating_root_translation_gizmo = False\n            # update immediately to show user changes\n            self.set_frame(self.cur_frame_idx)\n            # update the 2D waypoint constraints as well if there is one\n            if \"2D Root\" in constraints:\n                root_2d_contraints = constraints[\"2D Root\"]\n                # if there is a constraint at that frame, we want to update it\n                frame_idx = self.cur_frame_idx\n                if frame_idx in root_2d_contraints.keyframes:\n                    for keyframe_id in root_2d_contraints.frame2keyid[frame_idx]:\n                        # add will modify the existing constraint\n                        # update_path=False during drag to avoid lag; path refreshes on_drag_end\n                        root_2d_contraints.add_keyframe(\n                            keyframe_id,\n                            frame_idx,\n                            root_pos=new_root_pos,\n                            exists_ok=True,\n                            update_path=False,\n                        )\n            if \"Full-Body\" in constraints:\n                full_body_constraints = constraints[\"Full-Body\"]\n                # if there is a constraint at that frame, we want to update it\n                frame_idx = self.cur_frame_idx\n                if frame_idx in full_body_constraints.keyframes:\n                    current_dict = full_body_constraints.keyframes[frame_idx]\n                    for keyframe_id in full_body_constraints.frame2keyid[frame_idx]:\n                        # add will modify the existing constraint\n                        full_body_constraints.add_keyframe(\n                            keyframe_id,\n                            frame_idx,\n                            joints_pos=cur_joints_pos,\n                            joints_rot=current_dict[\"joints_rot\"],\n                            exists_ok=True,\n                        )\n            if \"End-Effectors\" in constraints:\n                end_effector_constraints = constraints[\"End-Effectors\"]\n                # if there is a constraint at that frame, we want to update it\n                frame_idx = self.cur_frame_idx\n                if frame_idx in end_effector_constraints.keyframes:\n                    current_dict = end_effector_constraints.keyframes[frame_idx]\n                    for keyframe_id, _ in end_effector_constraints.frame2keyid[frame_idx]:\n                        # add will modify the existing constraint\n                        end_effector_constraints.add_keyframe(\n                            keyframe_id,\n                            frame_idx,\n                            joints_pos=cur_joints_pos,\n                            joints_rot=current_dict[\"joints_rot\"],\n                            joint_names=current_dict[\"joint_names\"],\n                            end_effector_type=current_dict[\"end_effector_type\"],\n                            exists_ok=True,\n                        )\n\n        @self.root_translation_gizmo.on_drag_end\n        def _on_drag_end(_):\n            # Refresh path visualization and dense path after release.\n            if \"2D Root\" in constraints:\n                root_2d = constraints[\"2D Root\"]\n                if root_2d.line_segments is not None:\n                    root_2d.update_line_segments()\n            if on_2d_root_drag_end is not None:\n                on_2d_root_drag_end()\n\n    def add_joint_gizmos(\n        self,\n        constraints: dict,\n        space: Literal[\"world\", \"local\"] = \"local\",\n        on_drag_start: Optional[Callable[[], None]] = None,\n    ):\n        # Remove existing joint gizmos first so the client gets remove then add,\n        # avoiding in-place update that can briefly show duplicate gizmos.\n        if self.joint_gizmos is not None:\n            for joint_gizmo in self.joint_gizmos:\n                self.server.scene.remove_by_name(joint_gizmo.name)\n            self.joint_gizmos = None\n\n        self.joint_gizmos = []\n        self.gizmo_space = space\n        # For world mode: store joint world rotation at drag start to compose with\n        # PivotControls' cumulative-from-identity drag rotation.\n        self._drag_start_world_rot = [None] * self.skeleton.nbjoints\n        # Skip pushing wxyz/position in set_frame while a gizmo is being dragged,\n        # so the client does not receive \"snap back\" (e.g. identity for world mode).\n        self._joint_gizmo_dragging = [False] * self.skeleton.nbjoints\n\n        joint_axis_indices = None\n        joint_limits = None\n        joint_f2q_data = None\n        hidden_gizmo_joints = None\n        if isinstance(self.skeleton, G1Skeleton34):\n            joint_axis_indices = _get_g1_joint_axis_indices()\n            joint_limits = _get_g1_joint_limits()\n            joint_f2q_data = get_g1_joint_f2q_data(self.skeleton)\n            hidden_gizmo_joints = {\n                \"left_hand_roll_skel\",\n                \"right_hand_roll_skel\",\n                \"left_toe_base\",\n                \"right_toe_base\",\n            }\n        elif isinstance(self.skeleton, SOMASkeleton77):\n            skel30_names = {name for name, _ in SOMASkeleton30.bone_order_names_with_parents}\n            hidden_gizmo_joints = {name for name in self.skeleton.bone_order_names if name not in skel30_names}\n            hidden_gizmo_joints |= {\n                \"RightHandThumbEnd\",\n                \"RightHandMiddleEnd\",\n                \"LeftHandThumbEnd\",\n                \"LeftHandMiddleEnd\",\n                \"LeftEye\",\n                \"RightEye\",\n                \"Jaw\",\n            }\n        elif isinstance(self.skeleton, SOMASkeleton30):\n            hidden_gizmo_joints = {\n                \"RightHandThumbEnd\",\n                \"RightHandMiddleEnd\",\n                \"LeftHandThumbEnd\",\n                \"LeftHandMiddleEnd\",\n                \"LeftEye\",\n                \"RightEye\",\n                \"Jaw\",\n            }\n\n        if space == \"world\":\n            # World mode: gizmo rings stay scene-axis-aligned (identity).\n            joints_wxyzs = np.tile(\n                np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),\n                (self.skeleton.nbjoints, 1),\n            )\n        else:\n            # Local mode: gizmo shows joint world rotation so rings follow the joint.\n            joints_wxyzs = tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx].cpu().numpy()).wxyz\n        for joint_idx in range(self.skeleton.nbjoints):\n            disable_axes = True  # by default, only rotation controls\n            disable_sliders = True\n            if joint_idx == self.skeleton.root_idx:\n                disable_axes = False  # allow translation for root\n                disable_sliders = False\n            active_axes = (True, True, True)\n            if joint_axis_indices is not None:\n                joint_name = self.skeleton.bone_order_names[joint_idx]\n                axis_idx = joint_axis_indices.get(joint_name)\n                if axis_idx is not None:\n                    # PivotControls shows rotation handles when a plane is active.\n                    # To allow rotation about one axis, enable the other two axes.\n                    active_axes = (\n                        axis_idx != 0,\n                        axis_idx != 1,\n                        axis_idx != 2,\n                    )\n            joint_visible = True\n            if hidden_gizmo_joints is not None:\n                joint_name = self.skeleton.bone_order_names[joint_idx]\n                joint_visible = joint_name not in hidden_gizmo_joints\n            cur_joint_gizmo = self.server.scene.add_transform_controls(\n                f\"/{self.name}/gizmo_joint_{joint_idx}\",\n                scale=0.075,\n                line_width=4.0,\n                active_axes=active_axes,\n                disable_axes=disable_axes,\n                disable_sliders=disable_sliders,\n                disable_rotations=False,\n                depth_test=False,  # render even when occluded\n                position=self.joints_pos[self.cur_frame_idx, joint_idx].cpu().numpy(),\n                wxyz=joints_wxyzs[joint_idx],\n                visible=joint_visible,\n                space=space,\n            )\n            self.joint_gizmos.append(cur_joint_gizmo)\n\n            def set_callback_in_closure(i: int) -> None:\n                @cur_joint_gizmo.on_drag_start\n                def _on_drag_start(_) -> None:\n                    if on_drag_start is not None:\n                        on_drag_start()\n                    self._joint_gizmo_dragging[i] = True\n                    if self.gizmo_space == \"world\":\n                        self._drag_start_world_rot[i] = self.joints_rot[self.cur_frame_idx, i].clone().cpu().numpy()\n\n                @cur_joint_gizmo.on_drag_end\n                def _on_drag_end(_) -> None:\n                    self._joint_gizmo_dragging[i] = False\n                    # Force-sync so the client always receives the reset (viser setter skips on allclose).\n                    # Use self.joint_gizmos[i] (not cur_joint_gizmo) to avoid the\n                    # closure-in-loop bug: cur_joint_gizmo would point to the last handle.\n                    gizmo = self.joint_gizmos[i]\n                    gizmo.sync_position(self.joints_pos[self.cur_frame_idx, i].cpu().numpy())\n                    if self.gizmo_space == \"world\":\n                        gizmo.sync_wxyz((1.0, 0.0, 0.0, 0.0))\n                    else:\n                        gizmo.sync_wxyz(tf.SO3.from_matrix(self.joints_rot[self.cur_frame_idx, i].cpu().numpy()).wxyz)\n                    self.set_frame(self.cur_frame_idx)\n\n                @cur_joint_gizmo.on_update\n                def _(_) -> None:\n                    self.updating_joint_gizmos = True\n                    new_local_joint_rots = self.joints_local_rot[self.cur_frame_idx].clone()\n                    # Gizmo parent is identity; client sends rotation as wxyz.\n                    # World mode: wxyz is cumulative from identity, compose with\n                    # stored initial world rotation. Local mode: wxyz is new world rotation.\n                    gizmo_rot_mat = tf.SO3(self.joint_gizmos[i].wxyz).as_matrix()\n                    if self.gizmo_space == \"world\" and self._drag_start_world_rot[i] is not None:\n                        new_world_rot_mat = gizmo_rot_mat @ self._drag_start_world_rot[i]\n                    else:\n                        new_world_rot_mat = gizmo_rot_mat\n                    parent_idx = self.skeleton.joint_parents[i].item()\n                    if parent_idx >= 0:\n                        R_parent_world = self.joints_rot[self.cur_frame_idx, parent_idx].detach().cpu().numpy()\n                        new_local_rot_mat_np = (R_parent_world.T @ new_world_rot_mat).astype(np.float32)\n                    else:\n                        new_local_rot_mat_np = new_world_rot_mat.astype(np.float32)\n                    new_local_rot = tf.SO3.from_matrix(new_local_rot_mat_np)\n                    joint_name = self.skeleton.bone_order_names[i]\n                    if joint_f2q_data is not None and joint_name in joint_f2q_data:\n                        # G1 hinge: use offset (f2q) space so 1-DoF and limits match the robot.\n                        # R_f2q = offset_f2q @ R_local; angle_f2q = dot(axis_angle(R_f2q), axis_f2q);\n                        # MuJoCo q = angle_f2q - rest_dof; limits apply to q.\n                        f2q = joint_f2q_data[joint_name]\n                        offset_f2q = f2q[\"offset_f2q\"]\n                        axis_f2q = f2q[\"axis_f2q\"]\n                        rest_dof = f2q[\"rest_dof_axis_angle\"]\n                        R_local = new_local_rot_mat_np.astype(np.float64)\n                        R_f2q = offset_f2q @ R_local\n                        rotvec = tf.SO3.from_matrix(R_f2q).log()\n                        angle_f2q = float(np.dot(rotvec, axis_f2q))\n                        # Keep angle continuous relative to current pose.\n                        current_R_f2q = offset_f2q @ (\n                            self.joints_local_rot[self.cur_frame_idx, i].detach().cpu().numpy().astype(np.float64)\n                        )\n                        current_angle_f2q = float(np.dot(tf.SO3.from_matrix(current_R_f2q).log(), axis_f2q))\n                        two_pi = 2.0 * np.pi\n                        angle_f2q = angle_f2q + two_pi * np.round((current_angle_f2q - angle_f2q) / two_pi)\n                        q = angle_f2q - rest_dof\n                        if joint_limits is not None:\n                            joint_limit = joint_limits.get(joint_name)\n                            if joint_limit is not None:\n                                q = float(np.clip(q, joint_limit[0], joint_limit[1]))\n                        angle_f2q = q + rest_dof\n                        R_f2q_new = tf.SO3.exp(angle_f2q * axis_f2q).as_matrix()\n                        new_local_rot_mat_np = (offset_f2q.T @ R_f2q_new).astype(np.float32)\n                    elif joint_axis_indices is not None:\n                        axis_idx = joint_axis_indices.get(joint_name)\n                        if axis_idx is not None:\n                            rotvec = new_local_rot.log()\n                            axis = np.zeros(3, dtype=np.float64)\n                            axis[axis_idx] = 1.0\n                            angle = float(rotvec[axis_idx])\n                            # Keep angle continuous relative to current pose.\n                            current_rot = tf.SO3.from_matrix(\n                                self.joints_local_rot[self.cur_frame_idx, i].detach().cpu().numpy()\n                            )\n                            current_angle = float(current_rot.log()[axis_idx])\n                            two_pi = 2.0 * np.pi\n                            angle = angle + two_pi * np.round((current_angle - angle) / two_pi)\n                            if joint_limits is not None:\n                                joint_limit = joint_limits.get(joint_name)\n                                if joint_limit is not None:\n                                    angle = float(np.clip(angle, joint_limit[0], joint_limit[1]))\n                            new_local_rot_mat_np = tf.SO3.exp(angle * axis).as_matrix()\n                    new_local_rot_mat = torch.tensor(new_local_rot_mat_np).to(new_local_joint_rots.device)\n                    new_local_joint_rots[i] = new_local_rot_mat\n\n                    self.update_pose_at_frame(\n                        self.cur_frame_idx,\n                        joints_local_rot=new_local_joint_rots,\n                    )\n\n                    # handle root translation separately\n                    cur_joints_pos = self.joints_pos[self.cur_frame_idx].clone()\n                    if i == self.skeleton.root_idx:\n                        new_root_pos = to_torch(\n                            self.joint_gizmos[i].position,\n                            device=self.joints_pos.device,\n                        ).to(dtype=self.joints_pos.dtype)\n                        root_diff = new_root_pos - self.joints_pos[self.cur_frame_idx, i]\n                        if torch.norm(root_diff) > 1e-3:\n                            # the root translation has been changed\n                            # translate to gizmo position\n                            cur_joints_pos += root_diff[None]\n                            self.update_pose_at_frame(\n                                self.cur_frame_idx,\n                                joints_pos=cur_joints_pos,\n                                joints_rot=self.joints_rot[self.cur_frame_idx],\n                                joints_local_rot=self.joints_local_rot[self.cur_frame_idx],\n                            )\n\n                    # update immediately to show user changes. Keep updating_joint_gizmos\n                    # True so set_frame does not overwrite gizmo wxyz mid-drag.\n                    self.set_frame(self.cur_frame_idx)\n                    self.updating_joint_gizmos = False\n\n                    if i == self.skeleton.root_idx:\n                        # update the 2D waypoint constraints as well if there is one\n                        if \"2D Root\" in constraints:\n                            root_2d_contraints = constraints[\"2D Root\"]\n                            # if there is a constraint at that frame, we want to update it\n                            frame_idx = self.cur_frame_idx\n                            if frame_idx in root_2d_contraints.keyframes:\n                                new_root_pos[1] = 0.0  # force y to 0\n                                for keyframe_id in root_2d_contraints.frame2keyid[frame_idx]:\n                                    # add will modify the existing constraint\n                                    root_2d_contraints.add_keyframe(\n                                        keyframe_id,\n                                        frame_idx,\n                                        root_pos=new_root_pos,\n                                        exists_ok=True,\n                                        update_path=False,\n                                    )\n\n                    if \"Full-Body\" in constraints:\n                        full_body_constraints = constraints[\"Full-Body\"]\n                        # if there is a constraint at that frame, we want to update it\n                        frame_idx = self.cur_frame_idx\n                        if frame_idx in full_body_constraints.keyframes:\n                            for keyframe_id in full_body_constraints.frame2keyid[frame_idx]:\n                                # add will modify the existing constraint\n                                full_body_constraints.add_keyframe(\n                                    keyframe_id,\n                                    frame_idx,\n                                    joints_pos=self.joints_pos[frame_idx],\n                                    joints_rot=self.joints_rot[frame_idx],\n                                    exists_ok=True,\n                                )\n                    if \"End-Effectors\" in constraints:\n                        end_effector_constraints = constraints[\"End-Effectors\"]\n                        # if there is a constraint at that frame, we want to update it\n                        frame_idx = self.cur_frame_idx\n                        if frame_idx in end_effector_constraints.keyframes:\n                            current_dict = end_effector_constraints.keyframes[frame_idx]\n                            for keyframe_id, _ in end_effector_constraints.frame2keyid[frame_idx]:\n                                # add will modify the existing constraint\n                                end_effector_constraints.add_keyframe(\n                                    keyframe_id,\n                                    frame_idx,\n                                    joints_pos=self.joints_pos[frame_idx],\n                                    joints_rot=self.joints_rot[frame_idx],\n                                    joint_names=current_dict[\"joint_names\"],\n                                    end_effector_type=current_dict[\"end_effector_type\"],\n                                    exists_ok=True,\n                                )\n\n            set_callback_in_closure(joint_idx)\n\n    def clear_all_gizmos(self):\n        self.updating_root_translation_gizmo = True\n        self.updating_joint_gizmos = True\n        if self.root_translation_gizmo is not None:\n            self.server.scene.remove_by_name(self.root_translation_gizmo.name)\n            self.root_translation_gizmo = None\n        if self.joint_gizmos is not None:\n            for joint_gizmo in self.joint_gizmos:\n                self.server.scene.remove_by_name(joint_gizmo.name)\n            self.joint_gizmos = None\n        self._drag_start_world_rot = []\n        self._joint_gizmo_dragging = []\n        self.updating_root_translation_gizmo = False\n        self.updating_joint_gizmos = False\n"
  },
  {
    "path": "kimodo/viz/scene.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Viser scene entities: waypoints, skeleton mesh, and character.\"\"\"\n\nimport os\nimport traceback\nfrom pathlib import Path\nfrom typing import Optional, Tuple\n\nimport numpy as np\nimport torch\nimport trimesh\n\nimport viser\nimport viser.transforms as tf\nfrom kimodo.skeleton import (\n    G1Skeleton34,\n    SkeletonBase,\n    SMPLXSkeleton22,\n    SOMASkeleton30,\n    SOMASkeleton77,\n)\n\nfrom .coords import rotation_matrix_from_two_vec\nfrom .g1_rig import (\n    G1MeshRig,\n)\nfrom .smplx_skin import SMPLXSkin\nfrom .soma_skin import SOMASkin\n\n\nclass WaypointMesh:\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        position: np.ndarray,\n        heading: Optional[np.ndarray] = None,\n        color: Optional[Tuple[int, int, int]] = (255, 0, 0),\n    ):\n        self.server = server\n\n        sphere = trimesh.creation.icosphere(subdivisions=3, radius=0.025)\n        annulus = trimesh.creation.annulus(r_min=0.1, r_max=0.2, height=0.005)\n\n        z_to_y_up = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]])\n        annulus_vertices = annulus.vertices @ z_to_y_up\n\n        self.sphere = self.server.scene.add_mesh_simple(\n            name=f\"{name}/sphere\",\n            vertices=sphere.vertices,\n            faces=sphere.faces,\n            position=position,\n            color=color,\n        )\n        self.annulus = self.server.scene.add_mesh_simple(\n            name=f\"{name}/annulus\",\n            vertices=annulus_vertices,\n            faces=annulus.faces,\n            position=position,\n            color=color,\n        )\n\n        self.arrow_base = None\n        self.arrow_head = None\n        if heading is not None:\n            assert heading.shape == (2,), \"Heading must be a 2D vector\"\n            heading = 0.3 * (heading / np.linalg.norm(heading))\n            heading_3d = np.array([heading[0], 0, heading[1]])\n            arrow_base = trimesh.creation.cylinder(radius=0.01, height=0.3)\n            arrow_head = trimesh.creation.cone(radius=0.03, height=0.075)\n            arrow_base_vertices = arrow_base.vertices\n            arrow_head_vertices = arrow_head.vertices\n            self.arrow_base = self.server.scene.add_mesh_simple(\n                name=f\"{name}/arrow_base\",\n                vertices=arrow_base_vertices,\n                faces=arrow_base.faces,\n                position=position + (heading_3d / 2),\n                color=color,\n            )\n            self.arrow_head = self.server.scene.add_mesh_simple(\n                name=f\"{name}/arrow_head\",\n                vertices=arrow_head_vertices,\n                faces=arrow_head.faces,\n                position=position + heading_3d,\n                color=color,\n            )\n\n    def update_position(self, position: np.ndarray, heading: Optional[np.ndarray] = None):\n        self.sphere.position = position\n        self.annulus.position = position\n        if heading is not None:\n            assert heading.shape == (2,), \"Heading must be a 2D vector\"\n            heading = 0.3 * (heading / np.linalg.norm(heading))\n            heading_3d = np.array([heading[0], 0, heading[1]])\n            if self.arrow_base is not None:\n                self.arrow_base.position = position + (heading_3d / 2)\n            if self.arrow_head is not None:\n                self.arrow_head.position = position + heading_3d\n\n    def clear(self):\n        self.server.scene.remove_by_name(self.sphere.name)\n        self.server.scene.remove_by_name(self.annulus.name)\n        if self.arrow_base is not None:\n            self.server.scene.remove_by_name(self.arrow_base.name)\n        if self.arrow_head is not None:\n            self.server.scene.remove_by_name(self.arrow_head.name)\n\n    def set_visible(self, visible: bool) -> None:\n        self.sphere.visible = visible\n        self.annulus.visible = visible\n        if self.arrow_base is not None:\n            self.arrow_base.visible = visible\n        if self.arrow_head is not None:\n            self.arrow_head.visible = visible\n\n\nclass SkeletonMesh:\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer,\n        skeleton: SkeletonBase,\n        joint_color: Optional[Tuple[float, float, float] | np.ndarray] = (\n            255,\n            235,\n            0,\n        ),\n        bone_color: Optional[Tuple[float, float, float] | np.ndarray] = (\n            27,\n            106,\n            0,\n        ),\n        starting_joints_pos: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"\n        name: str, name of the skeleton mesh\n        server: viser.ViserServer, server to add the skeleton mesh to\n        skeleton: SkeletonBase, skeleton to visualize\n        joint_color: Optional[Tuple[float, float, float] | np.ndarray], color of the joints\n        bone_color: Optional[Tuple[float, float, float] | np.ndarray], color of the bones\n        starting_joints_pos: Optional[torch.Tensor], starting joint positions\n        \"\"\"\n        self.server = server\n        self.skeleton = skeleton\n        joint_mesh = trimesh.creation.icosphere(subdivisions=3, radius=0.02)\n        bone_mesh = trimesh.creation.cylinder(radius=0.01, height=1.0)\n\n        init_joints_pos = skeleton.neutral_joints.clone()\n        self.num_joints = init_joints_pos.shape[0]\n        num_bones = self.num_joints - 1\n        non_root_bones = [\n            joint_name\n            for joint_name, parent_name in self.skeleton.bone_order_names_with_parents\n            if parent_name is not None\n        ]\n        self.bone_to_idx = {bone_name: idx for idx, bone_name in enumerate(non_root_bones)}\n\n        # initialize meshes\n        init_joints_wxyzs = np.concatenate([np.ones((self.num_joints, 1)), np.zeros((self.num_joints, 3))], axis=1)\n        if isinstance(joint_color, tuple):\n            self.joint_colors = np.full((self.num_joints, 3), joint_color)\n        elif isinstance(joint_color, np.ndarray):\n            assert joint_color.shape == (\n                self.num_joints,\n                3,\n            ), \"Joint colors must be (J, 3)\"\n            self.joint_colors = joint_color\n        joint_scales = np.ones((self.num_joints, 3))\n        hand_roots = {\"LeftHand\", \"RightHand\"}\n        finger_joint_names = set(skeleton.left_hand_joint_names + skeleton.right_hand_joint_names) - hand_roots\n        for jname in finger_joint_names:\n            if jname in skeleton.bone_index:\n                joint_scales[skeleton.bone_index[jname]] = 0.6\n        self.joint_scales = joint_scales\n\n        self.joints_batched_mesh = server.scene.add_batched_meshes_simple(\n            f\"{name}/joints\",\n            vertices=joint_mesh.vertices,\n            faces=joint_mesh.faces,\n            batched_wxyzs=init_joints_wxyzs,\n            batched_positions=np.zeros((self.num_joints, 3)),\n            batched_scales=joint_scales,\n            batched_colors=self.joint_colors,\n        )\n        init_bones_wxyzs = np.concatenate([np.ones((num_bones, 1)), np.zeros((num_bones, 3))], axis=1)\n        if isinstance(bone_color, tuple):\n            bone_color = np.full((num_bones, 3), bone_color)\n        elif isinstance(bone_color, np.ndarray):\n            assert bone_color.shape == (num_bones, 3), \"Bone colors must be (J-1, 3)\"\n            bone_color = bone_color\n        self.bones_batched_mesh = server.scene.add_batched_meshes_simple(\n            f\"{name}/bones\",\n            vertices=bone_mesh.vertices,\n            faces=bone_mesh.faces,\n            batched_wxyzs=init_bones_wxyzs,\n            batched_positions=np.zeros((num_bones, 3)),\n            batched_scales=np.ones((num_bones, 3)),\n            batched_colors=bone_color,\n        )\n\n        self.mesh_info_cache = None\n\n        if starting_joints_pos is not None:\n            self.set_pose(starting_joints_pos)\n        else:\n            if isinstance(skeleton, SOMASkeleton77):\n                skel30 = SOMASkeleton30(load=True)\n                min_height = skel30.neutral_joints[:, 1].min().item()\n            else:\n                min_height = init_joints_pos[:, 1].min().item()\n            init_joints_pos[:, 1] -= min_height  # move to be on ground\n            self.set_pose(init_joints_pos)\n\n    def compute_single_pose(self, joints_pos: np.ndarray):\n        \"\"\"Compute the mesh for a single frame.\n\n        joints_pos: [J, 3] global joint positions.\n        \"\"\"\n        new_batched_positions = np.zeros((self.skeleton.nbjoints - 1, 3))\n        new_batched_wxyzs = np.zeros((self.skeleton.nbjoints - 1, 4))\n        new_batched_scales = np.ones((self.skeleton.nbjoints - 1, 3))\n        for joint_name, parent_name in self.skeleton.bone_order_names_with_parents:\n            if parent_name is None:\n                continue\n            joint_idx = self.skeleton.bone_index[joint_name]\n            parent_idx = self.skeleton.bone_index[parent_name]\n            joint_pos = joints_pos[joint_idx]\n            parent_pos = joints_pos[parent_idx]\n\n            bone_pos = (joint_pos + parent_pos) / 2.0\n            bone_scale = np.linalg.norm(joint_pos - parent_pos)\n            if bone_scale < 1e-8:\n                bone_wxyz = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)\n            else:\n                bone_dir = (joint_pos - parent_pos) / bone_scale\n                R = rotation_matrix_from_two_vec(np.array([0.0, 0.0, 1.0], dtype=np.float64), bone_dir)\n                bone_wxyz = tf.SO3.from_matrix(R).wxyz\n\n            bone_idx = self.bone_to_idx[joint_name]\n            new_batched_positions[bone_idx] = bone_pos\n            new_batched_wxyzs[bone_idx] = bone_wxyz\n            new_batched_scales[bone_idx] = np.array([1.0, 1.0, bone_scale], dtype=float)\n\n        return new_batched_positions, new_batched_wxyzs, new_batched_scales\n\n    def precompute_mesh_info(self, joints_pos: torch.Tensor):\n        \"\"\"Precompute the meshes for all frames at once.\n\n        joints_pos: [T, J, 3].\n        \"\"\"\n        joints_pos = joints_pos.cpu().numpy()\n        num_frames = joints_pos.shape[0]\n        self.mesh_info_cache = {\n            \"positions\": np.zeros((num_frames, self.skeleton.nbjoints - 1, 3)),\n            \"wxyzs\": np.zeros((num_frames, self.skeleton.nbjoints - 1, 4)),\n            \"scales\": np.ones((num_frames, self.skeleton.nbjoints - 1, 3)),\n        }\n        for i in range(num_frames):\n            new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose(joints_pos[i])\n            self.mesh_info_cache[\"positions\"][i] = new_batched_positions\n            self.mesh_info_cache[\"wxyzs\"][i] = new_batched_wxyzs\n            self.mesh_info_cache[\"scales\"][i] = new_batched_scales\n\n    def update_mesh_info_cache(self, joints_pos: torch.Tensor, frame_idx: int):\n        \"\"\"Update the mesh info cache for the given frame.\"\"\"\n        assert self.mesh_info_cache is not None\n        new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose(\n            joints_pos.cpu().numpy()\n        )\n        self.mesh_info_cache[\"positions\"][frame_idx] = new_batched_positions\n        self.mesh_info_cache[\"wxyzs\"][frame_idx] = new_batched_wxyzs\n        self.mesh_info_cache[\"scales\"][frame_idx] = new_batched_scales\n\n    def set_pose(\n        self,\n        joints_pos: torch.Tensor,\n        foot_contacts: Optional[torch.Tensor] = None,\n        frame_idx: Optional[int] = None,\n    ):\n        \"\"\"Set pose from [J, 3] global joint positions.\"\"\"\n        self.cur_joints_pos = joints_pos\n        joints_pos = joints_pos.cpu().numpy()\n\n        if self.mesh_info_cache is not None:\n            assert frame_idx is not None\n            new_batched_positions = self.mesh_info_cache[\"positions\"][frame_idx]\n            new_batched_wxyzs = self.mesh_info_cache[\"wxyzs\"][frame_idx]\n            new_batched_scales = self.mesh_info_cache[\"scales\"][frame_idx]\n        else:\n            new_batched_positions, new_batched_wxyzs, new_batched_scales = self.compute_single_pose(joints_pos)\n\n        self.bones_batched_mesh.batched_positions = new_batched_positions\n        self.bones_batched_mesh.batched_wxyzs = new_batched_wxyzs\n        self.bones_batched_mesh.batched_scales = new_batched_scales\n        self.joints_batched_mesh.batched_positions = joints_pos\n\n        if foot_contacts is not None:\n            cur_joint_colors = self.joint_colors.copy()\n            foot_contacts = foot_contacts.bool().cpu().numpy().astype(bool)\n            foot_joints = np.array(self.skeleton.foot_joint_idx, dtype=int)\n            contact_idx = foot_joints[foot_contacts]\n            cur_joint_colors[contact_idx] = (255, 0, 0)\n            self.joints_batched_mesh.batched_colors = cur_joint_colors\n        else:\n            self.joints_batched_mesh.batched_colors = self.joint_colors\n\n    def set_visibility(self, visible: bool):\n        self.joints_batched_mesh.visible = visible\n        self.bones_batched_mesh.visible = visible\n\n    def get_pose(self) -> np.ndarray:\n        return self.cur_joints_pos\n\n    def clear(self):\n        names = [mesh.name for mesh in [self.joints_batched_mesh, self.bones_batched_mesh]]\n        for name in names:\n            self.server.scene.remove_by_name(name)\n\n\nLIGHT_THEME = dict(\n    mesh=(152, 189, 255),\n)\n\nDARK_THEME = dict(\n    mesh=(100, 135, 195),\n)\n\nSKIN_CACHE = {}\n\n\nclass Character:\n    def __init__(\n        self,\n        name: str,\n        server: viser.ViserServer | viser.ClientHandle,\n        skeleton: SkeletonBase,\n        create_skeleton_mesh: bool = True,\n        create_skinned_mesh: bool = True,\n        visible_skeleton: bool = False,\n        visible_skinned_mesh: bool = True,\n        skinned_mesh_opacity: float = 1.0,\n        show_foot_contacts: bool = True,\n        dark_mode: bool = False,\n        mesh_mode: Optional[str] = None,\n        gui_use_soma_layer_checkbox: Optional[viser.GuiCheckboxHandle] = None,\n    ):\n        self.server = server\n        self.name = name\n        self.skeleton = skeleton\n        self.cur_joints_pos = None\n        self.cur_joints_rot = None\n        self.cur_foot_contacts = None\n\n        self.skeleton_mesh = None\n        self.show_foot_contacts = show_foot_contacts\n        if create_skeleton_mesh:\n            self.skeleton_mesh = SkeletonMesh(f\"/{name}/skeleton\", server, skeleton)\n            self.cur_joints_pos = self.skeleton_mesh.get_pose()\n            self.skeleton_mesh.set_visibility(visible_skeleton)\n\n        self.skinned_mesh = None\n        self.skin = None\n        self.mesh_mode = mesh_mode\n        self.g1_mesh_rig = None\n        if create_skinned_mesh:\n            if isinstance(self.skeleton, (SOMASkeleton30, SOMASkeleton77)) and mesh_mode in [\n                \"soma_skin\",\n                \"soma_layer_skin\",\n            ]:\n                if mesh_mode in SKIN_CACHE:\n                    # already okay\n                    pass\n                else:\n                    if mesh_mode == \"soma_layer_skin\":\n                        try:\n                            # try importing the lib\n                            from .soma_layer_skin import SOMASkin as SOMASkin_SOMA\n\n                            if mesh_mode not in SKIN_CACHE:\n                                SKIN_CACHE[mesh_mode] = SOMASkin_SOMA(self.skeleton)\n\n                        except (ModuleNotFoundError, FileNotFoundError) as e:\n                            if isinstance(e, ModuleNotFoundError):\n                                msg = \"SOMA layer skin is unavailable: the soma package is not installed.\"\n                            else:\n                                msg = \"SOMA layer skin is unavailable: SOMA asset files are missing.\"\n                            traceback.print_exc()\n                            if hasattr(self.server, \"add_notification\"):\n                                self.server.add_notification(\n                                    \"SOMA layer skin unavailable\",\n                                    msg,\n                                    auto_close_seconds=5.0,\n                                    with_close_button=True,\n                                )\n                            if gui_use_soma_layer_checkbox is not None:\n                                gui_use_soma_layer_checkbox.value = False\n                            mesh_mode = \"soma_skin\"\n\n                    # another if, in case mesh_mode changed\n                    if mesh_mode == \"soma_skin\" and mesh_mode not in SKIN_CACHE:\n                        SKIN_CACHE[mesh_mode] = SOMASkin(self.skeleton)\n\n                self.skin = SKIN_CACHE[mesh_mode]\n                self.skinned_mesh = server.scene.add_mesh_simple(\n                    f\"/{name}/simple_skinned\",\n                    vertices=self.skin.bind_vertices.cpu().numpy(),\n                    faces=self.skin.faces.cpu().numpy(),\n                    opacity=None,\n                    color=LIGHT_THEME[\"mesh\"] if not dark_mode else DARK_THEME[\"mesh\"],\n                    wireframe=False,\n                    visible=False,\n                )\n                self.skinned_verts_cache = None\n\n                bind_pos = self.skeleton.neutral_joints.clone()\n                if isinstance(self.skeleton, SOMASkeleton77):\n                    skel30 = SOMASkeleton30(load=True)\n                    min_height = skel30.neutral_joints[:, 1].min().item()\n                else:\n                    min_height = bind_pos[:, 1].min().item()\n                bind_pos[:, 1] -= min_height\n                bind_pos[:, 1] += 0.02\n                bind_rotmat = torch.eye(3, device=bind_pos.device).repeat(bind_pos.shape[0], 1, 1)\n                self.set_pose(bind_pos, bind_rotmat)\n                self.skinned_mesh.visible = True\n                self.set_skinned_mesh_visibility(visible_skinned_mesh)\n                self.set_skinned_mesh_opacity(skinned_mesh_opacity)\n            elif isinstance(self.skeleton, SMPLXSkeleton22) and mesh_mode == \"smplx_skin\":\n                if mesh_mode not in SKIN_CACHE:\n                    SKIN_CACHE[mesh_mode] = SMPLXSkin(self.skeleton)\n                self.skin = SKIN_CACHE[mesh_mode]\n                self.skinned_mesh = server.scene.add_mesh_simple(\n                    f\"/{name}/simple_skinned\",\n                    vertices=self.skin.bind_vertices.cpu().numpy(),\n                    faces=self.skin.faces.cpu().numpy(),\n                    opacity=None,\n                    color=LIGHT_THEME[\"mesh\"] if not dark_mode else DARK_THEME[\"mesh\"],\n                    wireframe=False,\n                    visible=False,\n                )\n                self.skinned_verts_cache = None\n\n                bind_pos = self.skeleton.neutral_joints.clone()\n                min_height = bind_pos[:, 1].min().item()\n                bind_pos[:, 1] -= min_height\n                bind_rotmat = torch.eye(3, device=bind_pos.device).repeat(bind_pos.shape[0], 1, 1)\n                self.set_pose(bind_pos, bind_rotmat)\n                self.skinned_mesh.visible = True\n                self.set_skinned_mesh_visibility(visible_skinned_mesh)\n                self.set_skinned_mesh_opacity(skinned_mesh_opacity)\n            elif isinstance(self.skeleton, G1Skeleton34) and mesh_mode == \"g1_stl\":\n                g1_mesh_dir = Path(self.skeleton.folder) / \"meshes/g1\"\n                if not os.path.exists(g1_mesh_dir):\n                    raise ValueError(f\"G1 mesh directory not found: {g1_mesh_dir}\")\n                self.g1_mesh_rig = G1MeshRig(\n                    name,\n                    server,\n                    self.skeleton,\n                    str(g1_mesh_dir),\n                    DARK_THEME[\"mesh\"] if dark_mode else LIGHT_THEME[\"mesh\"],\n                )\n                init_joints_rot = self.skeleton.rest_pose_local_rot.clone()\n                init_global_joint_rots, _, init_joints_pos = self.skeleton.fk(\n                    init_joints_rot,\n                    torch.zeros(3, device=init_joints_rot.device, dtype=init_joints_rot.dtype),\n                )\n                min_height = init_joints_pos[:, 1].min().item()\n                init_joints_pos[:, 1] -= min_height\n                self.set_pose(init_joints_pos, init_global_joint_rots)\n                self.set_skinned_mesh_visibility(visible_skinned_mesh)\n                self.set_skinned_mesh_opacity(skinned_mesh_opacity)\n            else:\n                raise ValueError(\n                    \"Unsupported mesh mode for skeleton type: \"\n                    f\"{type(self.skeleton).__name__} with mesh_mode={mesh_mode}\"\n                )\n\n    def change_theme(self, is_dark_mode):\n        color = DARK_THEME[\"mesh\"] if is_dark_mode else LIGHT_THEME[\"mesh\"]\n        if self.skinned_mesh is not None:\n            self.skinned_mesh.color = color\n        if self.g1_mesh_rig is not None:\n            self.g1_mesh_rig.set_color(color)\n\n    def set_skeleton_visibility(self, visible: bool):\n        if self.skeleton_mesh is not None:\n            self.skeleton_mesh.set_visibility(visible)\n\n    def set_show_foot_contacts(self, show: bool, frame_idx: Optional[int] = None):\n        self.show_foot_contacts = show\n        if self.skeleton_mesh is not None and self.cur_joints_pos is not None:\n            fc = self.cur_foot_contacts if show else None\n            self.skeleton_mesh.set_pose(self.cur_joints_pos, foot_contacts=fc, frame_idx=frame_idx)\n\n    def set_skinned_mesh_visibility(self, visible: bool):\n        if self.skinned_mesh is not None:\n            self.skinned_mesh.visible = visible\n        if self.g1_mesh_rig is not None:\n            self.g1_mesh_rig.set_visibility(visible)\n\n    def set_skinned_mesh_opacity(self, opacity: float):\n        if self.skinned_mesh is not None:\n            self.skinned_mesh.opacity = opacity\n        if self.g1_mesh_rig is not None:\n            self.g1_mesh_rig.set_opacity(opacity)\n\n    def set_skinned_mesh_wireframe(self, wireframe: bool):\n        if self.skinned_mesh is not None:\n            self.skinned_mesh.wireframe = wireframe\n        if self.g1_mesh_rig is not None:\n            self.g1_mesh_rig.set_wireframe(wireframe)\n\n    def precompute_skinning(self, joints_pos: torch.Tensor, joints_rot: torch.Tensor, chunk_size: int = 64):\n        \"\"\"Precompute skinning for all frames, processing in chunks to avoid OOM.\n\n        joints_pos: [T, J, 3], joints_rot: [T, J, 3, 3].\n\n        The LBS gather intermediate is ~V*W*48 bytes per frame (V=18k, W=8 for SOMA\n        gives ~7 MB/frame), so a chunk of 64 peaks around ~1 GB -- safe alongside\n        a loaded text encoder + diffusion model on a typical 24 GB GPU.\n        \"\"\"\n        assert self.skin is not None\n        T = joints_pos.shape[0]\n        with torch.no_grad():\n            if T <= chunk_size:\n                self.skinned_verts_cache = self.skin.skin(joints_rot, joints_pos, rot_is_global=True).cpu().numpy()\n            else:\n                chunks = []\n                for start in range(0, T, chunk_size):\n                    end = min(start + chunk_size, T)\n                    verts = self.skin.skin(joints_rot[start:end], joints_pos[start:end], rot_is_global=True).cpu().numpy()\n                    chunks.append(verts)\n                self.skinned_verts_cache = np.concatenate(chunks, axis=0)\n\n    def update_skinning_cache(self, joints_pos: torch.Tensor, joints_rot: torch.Tensor, frame_idx: int):\n        \"\"\"Update skinning cache for one frame.\"\"\"\n        if self.skinned_verts_cache is None:\n            return\n        with torch.no_grad():\n            new_skinned_verts = self.skin.skin(joints_rot[None], joints_pos[None], rot_is_global=True)[0].cpu().numpy()\n        self.skinned_verts_cache[frame_idx] = new_skinned_verts\n\n    def set_pose(\n        self,\n        joints_pos: torch.Tensor,\n        joints_rot: torch.Tensor,\n        foot_contacts: Optional[torch.Tensor] = None,\n        frame_idx: Optional[int] = None,\n    ):\n        if self.skeleton_mesh is not None:\n            self.cur_foot_contacts = foot_contacts\n            display_fc = foot_contacts if self.show_foot_contacts else None\n            self.skeleton_mesh.set_pose(joints_pos, foot_contacts=display_fc, frame_idx=frame_idx)\n\n        if self.skinned_mesh is not None:\n            if self.skinned_verts_cache is not None:\n                assert frame_idx is not None\n                skinned_verts = self.skinned_verts_cache[frame_idx]\n            else:\n                with torch.no_grad():\n                    skinned_verts = self.skin.skin(joints_rot[None], joints_pos[None], rot_is_global=True)[0].cpu().numpy()\n            self.skinned_mesh.vertices = skinned_verts\n        if self.g1_mesh_rig is not None:\n            joints_pos_np = joints_pos.detach().cpu().numpy()\n            joints_rot_np = joints_rot.detach().cpu().numpy()\n            self.g1_mesh_rig.set_pose(joints_pos_np, joints_rot_np)\n\n        self.cur_joints_pos = joints_pos\n        self.cur_joints_rot = joints_rot\n\n    def get_pose(self) -> torch.Tensor:\n        return self.cur_joints_pos, self.cur_joints_rot\n\n    def clear(self):\n        if self.skeleton_mesh is not None:\n            self.skeleton_mesh.clear()\n        if self.skinned_mesh is not None:\n            self.server.scene.remove_by_name(self.skinned_mesh.name)\n        if self.g1_mesh_rig is not None:\n            self.g1_mesh_rig.clear()\n"
  },
  {
    "path": "kimodo/viz/smplx_skin.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"SMPL-X skinning and joint mapping for visualization.\"\"\"\n\nimport os\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom kimodo.geometry import axis_angle_to_matrix\nfrom kimodo.skeleton import SMPLXSkeleton22, batch_rigid_transform\n\nSKIN_NAME = \"SMPLX_NEUTRAL.npz\"\nBETA_NAME = \"beta.npy\"\nMEAN_HANDS_NAME = \"mean_hands.npy\"\n\nSMPLX_BODY_JOINT_NAME_MAP = {\n    \"pelvis\": \"Pelvis\",\n    \"left_hip\": \"L_Hip\",\n    \"right_hip\": \"R_Hip\",\n    \"spine1\": \"Spine1\",\n    \"left_knee\": \"L_Knee\",\n    \"right_knee\": \"R_Knee\",\n    \"spine2\": \"Spine2\",\n    \"left_ankle\": \"L_Ankle\",\n    \"right_ankle\": \"R_Ankle\",\n    \"spine3\": \"Spine3\",\n    \"left_foot\": \"L_Foot\",\n    \"right_foot\": \"R_Foot\",\n    \"neck\": \"Neck\",\n    \"left_collar\": \"L_Collar\",\n    \"right_collar\": \"R_Collar\",\n    \"head\": \"Head\",\n    \"left_shoulder\": \"L_Shoulder\",\n    \"right_shoulder\": \"R_Shoulder\",\n    \"left_elbow\": \"L_Elbow\",\n    \"right_elbow\": \"R_Elbow\",\n    \"left_wrist\": \"L_Wrist\",\n    \"right_wrist\": \"R_Wrist\",\n}\n\n# SMPL-X hand pose order (15 joints per hand) matching SMPL-X index order.\nSMPLX_HAND_JOINT_ORDER = [\n    \"Index1\",\n    \"Index2\",\n    \"Index3\",\n    \"Middle1\",\n    \"Middle2\",\n    \"Middle3\",\n    \"Pinky1\",\n    \"Pinky2\",\n    \"Pinky3\",\n    \"Ring1\",\n    \"Ring2\",\n    \"Ring3\",\n    \"Thumb1\",\n    \"Thumb2\",\n    \"Thumb3\",\n]\n\nSMPLX_FACE_JOINT_NAMES = [\"Jaw\", \"L_Eye\", \"R_Eye\"]\n\n\nclass SMPLXSkin:\n    def __init__(\n        self,\n        skeleton,\n        use_mean_hands=True,\n    ):\n        skel_dir = Path(skeleton.folder)\n        skin_data_path = skel_dir / SKIN_NAME\n\n        if not skin_data_path.exists():\n            raise FileExistsError(\n                f\"You should download the {SKIN_NAME} from the smplx website, and put it there: {skin_data_path}\"\n            )\n\n        beta_path = skel_dir / BETA_NAME\n        mean_hands_path = skel_dir / MEAN_HANDS_NAME\n\n        self.skeleton = skeleton\n        assert isinstance(skeleton, SMPLXSkeleton22), \"SMPLXSkin only supports SMPLXSkeleton22\"\n        assert skeleton.neutral_joints is not None, \"SMPLXSkeleton22 must have neutral joints instantiated\"\n\n        device = skeleton.neutral_joints.device\n        with warnings.catch_warnings():\n            # Ignore legacy object-dtype warning emitted while unpickling old SMPL-X assets.\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\"dtype\\(\\): align should be passed as Python or NumPy boolean.*\",\n                category=Warning,\n                module=r\"numpy\\.lib\\._format_impl\",\n            )\n            # np.load on .npz is lazy; materialize all fields while filter is active.\n            with np.load(skin_data_path, allow_pickle=True) as skin_npz:\n                skin_data = {key: skin_npz[key] for key in skin_npz.files}\n\n        joint2num = skin_data[\"joint2num\"]\n        if isinstance(joint2num, np.ndarray):\n            joint2num = joint2num.item()\n        self.full_joint_count = int(skin_data[\"weights\"].shape[1])\n        kintree_table = np.array(skin_data[\"kintree_table\"], dtype=np.int64)\n        parents = kintree_table[0].copy()\n        parents[parents > 1_000_000_000] = -1\n        self.full_joint_parents = torch.tensor(parents, device=device, dtype=torch.long)\n        root_candidates = np.where(parents == -1)[0]\n        self.full_root_idx = int(root_candidates[0]) if root_candidates.size else 0\n        self.joint_regressor = torch.tensor(\n            np.array(skin_data[\"J_regressor\"], dtype=np.float32),\n            device=device,\n            dtype=torch.float,\n        )\n\n        rig_joint_names = []\n        rig_joint_indices = []\n        for joint_name in self.skeleton.bone_order_names:\n            mapped_name = SMPLX_BODY_JOINT_NAME_MAP.get(joint_name)\n            if mapped_name is None or mapped_name not in joint2num:\n                raise ValueError(f\"Missing SMPL-X joint mapping for '{joint_name}'\")\n            rig_joint_names.append(mapped_name)\n            rig_joint_indices.append(int(joint2num[mapped_name]))\n        self.body_joint_indices = np.array(rig_joint_indices, dtype=np.int64)\n\n        # Prepare mean hand pose rotations for joints not produced by the model.\n        if use_mean_hands and mean_hands_path is not None and os.path.exists(mean_hands_path):\n            mean_hands = np.array(np.load(mean_hands_path), dtype=np.float32)\n        else:\n            mean_hands = np.zeros(90, dtype=np.float32)\n        if mean_hands.shape[0] != 90:\n            raise ValueError(f\"Expected mean_hands shape (90,), got {mean_hands.shape}\")\n        mean_hands = mean_hands.reshape(30, 3)\n        mean_hands_rotmats = axis_angle_to_matrix(torch.tensor(mean_hands, device=device, dtype=torch.float))\n        left_hand_joint_names = [f\"L_{name}\" for name in SMPLX_HAND_JOINT_ORDER]\n        right_hand_joint_names = [f\"R_{name}\" for name in SMPLX_HAND_JOINT_ORDER]\n        left_indices = [joint2num[name] for name in left_hand_joint_names]\n        right_indices = [joint2num[name] for name in right_hand_joint_names]\n        self.hand_joint_indices = np.array(left_indices + right_indices, dtype=np.int64)\n        self.mean_hand_rotmats = mean_hands_rotmats\n        face_indices = [joint2num[name] for name in SMPLX_FACE_JOINT_NAMES if name in joint2num]\n        self.face_joint_indices = np.array(face_indices, dtype=np.int64)\n        self.mean_face_rotmats = torch.eye(3, device=device).repeat(len(self.face_joint_indices), 1, 1)\n\n        # bind_rig_transform: [J, 4, 4]\n        # bind_vertices: [V, 3]\n        # faces: [F, 3]\n        # lbs indices, lbs weights: [V, W] (W = number of joints)\n        v_template = np.array(skin_data[\"v_template\"], dtype=np.float32)\n        faces = np.array(skin_data[\"f\"], dtype=np.int64)\n        weights = np.array(skin_data[\"weights\"], dtype=np.float32)\n\n        shapedirs = np.array(skin_data[\"shapedirs\"], dtype=np.float32)\n        posedirs = np.array(skin_data[\"posedirs\"], dtype=np.float32)\n\n        if beta_path is not None and os.path.exists(beta_path):\n            betas = np.array(np.load(beta_path), dtype=np.float32)\n        else:\n            betas = np.zeros(300, dtype=np.float32)\n\n        num_shape_coeffs = shapedirs.shape[2]  # 400 = 300 + 100 (shape + expression)\n        if betas.shape[0] < num_shape_coeffs:\n            betas = np.pad(betas, (0, num_shape_coeffs - betas.shape[0]), mode=\"constant\")\n        elif betas.shape[0] > num_shape_coeffs:\n            betas = betas[:num_shape_coeffs]\n\n        v_shaped = v_template + np.tensordot(shapedirs, betas, axes=[2, 0])\n        self.v_shaped = torch.tensor(v_shaped, device=device, dtype=torch.float)\n        self.posedirs = torch.tensor(posedirs, device=device, dtype=torch.float)\n        self.joint_rest = torch.einsum(\"jv,vc->jc\", self.joint_regressor, self.v_shaped)\n\n        # Align SMPL-X body rest joints to the model skeleton rest pose.\n        body_rest = self.skeleton.neutral_joints.to(device=device, dtype=torch.float)\n        if body_rest.shape[0] == self.body_joint_indices.shape[0]:\n            # Treat mismatches as a warning and align to the skeleton pose anyway.\n            max_delta = (self.joint_rest[self.body_joint_indices] - body_rest).abs().max()\n            if max_delta > 1e-6:\n                print(\n                    \"Warning: SMPL-X rest pose mismatch (max_delta=\"\n                    f\"{max_delta:.2e}); aligning to skeleton neutral joints.\"\n                )\n            self.joint_rest[self.body_joint_indices] = body_rest\n\n        # Renormalize weights to avoid numerical issues.\n        weight_sums = weights.sum(axis=1, keepdims=True)\n        zero_mask = weight_sums[:, 0] < 1e-8\n        weights = weights / np.clip(weight_sums, 1e-8, None)\n        if np.any(zero_mask):\n            weights[zero_mask, :] = 0.0\n            weights[zero_mask, self.full_root_idx] = 1.0\n\n        joint_indices = np.arange(self.full_joint_count, dtype=np.int64)\n        lbs_indices = np.tile(joint_indices[None, :], (v_template.shape[0], 1))\n\n        bind_rig_np = np.zeros((self.full_joint_count, 4, 4), dtype=np.float32)\n        bind_rig_np[:, 3, 3] = 1.0\n        bind_rig_np[:, :3, :3] = np.eye(3, dtype=np.float32)\n        bind_rig_np[:, :3, 3] = self.joint_rest.detach().cpu().numpy()\n\n        self.bind_rig_transform = torch.from_numpy(bind_rig_np).to(device=device, dtype=torch.float)\n        bind_rig_inv_np = np.linalg.inv(bind_rig_np)\n        self.bind_rig_transform_inv = torch.from_numpy(bind_rig_inv_np).to(device=device, dtype=torch.float)\n        self.bind_vertices = torch.tensor(v_shaped, device=device, dtype=torch.float)\n        self.faces = torch.tensor(faces, device=device, dtype=torch.long)\n        self.lbs_indices = torch.tensor(lbs_indices, device=device, dtype=torch.long)\n        self.lbs_weights = torch.tensor(weights, device=device, dtype=torch.float)\n\n        # double check the rig matches expected skeleton order\n        for sname, rname in zip(self.skeleton.bone_order_names, rig_joint_names):\n            mapped_name = SMPLX_BODY_JOINT_NAME_MAP.get(sname)\n            if mapped_name != rname:\n                raise ValueError(f\"MISMATCH in skinning rig: expected='{mapped_name}' vs rig='{rname}'\")\n\n    def lbs(self, posed_transform, bind_vertices=None):\n        bind_rig_transform_inv = self.bind_rig_transform_inv\n        if bind_vertices is None:\n            bind_vertices = self.bind_vertices\n        lbs_weights = self.lbs_weights\n        # posed_transform: [B, F, J, 4, 4] or [B, J, 4, 4] or [J, 4, 4]\n        # unsqueeze to match posed_transform batch dims\n        batch_dims = posed_transform.shape[:-3]\n        if bind_vertices.dim() == 2:\n            for _ in batch_dims:\n                bind_vertices = bind_vertices.unsqueeze(0)\n        elif bind_vertices.dim() == 3:\n            if len(batch_dims) == 1:\n                if bind_vertices.shape[0] != batch_dims[0]:\n                    bind_vertices = bind_vertices.unsqueeze(0)\n            elif len(batch_dims) > 1:\n                for _ in range(len(batch_dims) - 1):\n                    bind_vertices = bind_vertices.unsqueeze(0)\n        for _ in batch_dims:\n            bind_rig_transform_inv = bind_rig_transform_inv.unsqueeze(0)\n            lbs_weights = lbs_weights.unsqueeze(0)\n        # bind_rig_transform_inv: [..., J, 4, 4]\n        # bind_vertices: [..., V, 3]\n        # lbs_weights: [..., V, W]\n\n        affine_mat = (posed_transform @ bind_rig_transform_inv)[..., :3, :]  # [..., J, 3, 4]\n        vs = (\n            affine_mat[..., self.lbs_indices, :, :]\n            @ torch.concat([bind_vertices, torch.ones_like(bind_vertices[..., 0:1])], dim=-1)[..., None, :, None]\n        )  # [..., V, W, 3, 1]\n        ws = lbs_weights[..., None, None]\n        resv = (vs * ws).sum(dim=-3).squeeze(-1)  # [..., V, 3]\n        return resv\n\n    def skin(self, joint_rotmat, joint_pos, rot_is_global=False):\n        \"\"\"\n        joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices\n        joint_pos: [T, J, 3] global joint positions\n        rot_is_global: bool, if True, joint_rotmat is global rotation matrices,\n        otherwise it is local rotation matrices and FK is performed internally\n        \"\"\"\n        nF, nJ = joint_pos.shape[:2]\n        device = joint_rotmat.device\n\n        # import ipdb; ipdb.set_trace()\n        if rot_is_global:\n            if joint_rotmat.shape[1] == self.full_joint_count:\n                local_rotmat_full = joint_rotmat.clone()\n                parents = self.full_joint_parents.to(device)\n                parent_rot_mats = local_rotmat_full[:, parents]\n                parent_rot_mats[:, self.full_root_idx] = torch.eye(3, device=device)\n                parent_rot_mats_inv = parent_rot_mats.transpose(2, 3)\n                local_rotmat_full = torch.einsum(\n                    \"T N m n, T N n o -> T N m o\",\n                    parent_rot_mats_inv,\n                    local_rotmat_full,\n                )\n            else:\n                local_rotmat = self.skeleton.global_rots_to_local_rots(joint_rotmat)\n        else:\n            local_rotmat = joint_rotmat\n\n        if rot_is_global and joint_rotmat.shape[1] == self.full_joint_count:\n            full_local = local_rotmat_full\n        else:\n            full_local = torch.eye(3, device=device).reshape(1, 1, 3, 3).repeat(nF, self.full_joint_count, 1, 1)\n            full_local[:, self.body_joint_indices] = local_rotmat\n        if self.mean_hand_rotmats is not None:\n            full_local[:, self.hand_joint_indices] = self.mean_hand_rotmats[None]\n        if self.mean_face_rotmats is not None:\n            full_local[:, self.face_joint_indices] = self.mean_face_rotmats[None]\n        pose_feature = (full_local[:, 1:] - torch.eye(3, device=device)[None, None]).reshape(nF, -1)\n\n        pose_offsets = torch.einsum(\"vcp,tp->tvc\", self.posedirs, pose_feature)\n        v_posed = self.v_shaped[None] + pose_offsets\n        joints_rest = self.joint_rest[None].repeat(nF, 1, 1)\n        posed_joints, global_joint_rots = batch_rigid_transform(\n            full_local,\n            joints_rest,\n            self.full_joint_parents.to(device),\n            self.full_root_idx,\n        )\n        # remove the skeleton offset of the root joint\n        root_trans = joint_pos[:, self.skeleton.root_idx] - self.skeleton.neutral_joints[0:1]\n        posed_joints = posed_joints + root_trans[:, None, :]\n\n        fk_transform = torch.eye(4, device=device)[None, None].repeat(nF, self.full_joint_count, 1, 1)\n        fk_transform[..., :3, :3] = global_joint_rots\n        fk_transform[..., :3, 3] = posed_joints\n\n        vertices = self.lbs(fk_transform, bind_vertices=v_posed)\n        return vertices\n"
  },
  {
    "path": "kimodo/viz/soma_layer_skin.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"SOMA layer-based skinning for visualization (SOMASkeleton30 / SOMASkeleton77).\"\"\"\n\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom soma import SomaLayer as SOMALayer\n\nfrom kimodo.assets import SOMA_ASSETS_ROOT\nfrom kimodo.skeleton import SOMASkeleton30, SOMASkeleton77, global_rots_to_local_rots\n\nSOMA_MHR_NEUTRAL_PATH = \"somaskel30/soma_base_fit_mhr_params.npz\"\n\n\nclass SOMASkin:\n    def __init__(\n        self,\n        skeleton,\n    ):\n        self.skeleton = skeleton\n\n        assert isinstance(\n            skeleton, (SOMASkeleton30, SOMASkeleton77)\n        ), \"SOMASkin currently only supports SOMASkeleton30 or SOMASkeleton77\"\n        assert skeleton.neutral_joints is not None, \"The skeleton must have neutral joints instantiated\"\n\n        device = skeleton.neutral_joints.device\n        device = \"cpu\"\n        self.device = device\n\n        self._soma_model = SOMALayer(\n            identity_model_type=\"mhr\",\n            device=device,\n        )\n        self.faces = self._soma_model.faces\n\n        neutral_mhr_path = Path(skeleton.folder).parent / SOMA_MHR_NEUTRAL_PATH\n        neutral_mhr = np.load(neutral_mhr_path)\n\n        # one time call to prepare the identity\n        self.soma_identity = torch.from_numpy(neutral_mhr[\"identity_params\"])\n        self.scale_params = torch.from_numpy(neutral_mhr[\"scale_params\"])\n        self._soma_model.prepare_identity(self.soma_identity.to(device), scale_params=self.scale_params.to(device))\n\n        # dummy output to get bind_vertices\n        transl = torch.zeros(1, 3, device=device)\n\n        self._full_skeleton = SOMASkeleton77()\n        self.skel_slice = self.skeleton.get_skel_slice(self._full_skeleton)\n\n        self.bind_vertices = self.soma_model_pose(\n            self._full_skeleton.relaxed_hands_rest_pose[None],\n            transl=transl,\n            pose2rot=False,\n        )[\"vertices\"][0]\n\n    def soma_model_pose(self, *args, **kwargs):\n        with torch.inference_mode():\n            return self._soma_model.pose(*args, **kwargs)\n\n    def skin(self, joint_rotmat, joint_pos, rot_is_global=False):\n        \"\"\"\n        joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices\n        joint_pos: [T, J, 3] global joint positions\n        rot_is_global: bool, if True, joint_rotmat is global rotation matrices, otherwise it is local rotation matrices and FK is performed internally\n        \"\"\"\n\n        nF, nJ = joint_pos.shape[:2]\n\n        if rot_is_global:\n            local_joint_rots_mats_subset = global_rots_to_local_rots(joint_rotmat, self.skeleton)\n        else:\n            local_joint_rots_mats_subset = joint_rotmat\n\n        if nJ != self._full_skeleton.nbjoints:\n            local_joint_rots_mats = self.skeleton.to_SOMASkeleton77(local_joint_rots_mats_subset)\n        else:\n            local_joint_rots_mats = local_joint_rots_mats_subset\n\n        # remove the skeleton offset of the root joint\n        transl = joint_pos[:, self.skeleton.root_idx] - self.skeleton.neutral_joints[0:1]\n\n        output = self.soma_model_pose(\n            local_joint_rots_mats.to(device=self.device, dtype=torch.float32),\n            transl=transl.to(device=self.device, dtype=torch.float32),\n            pose2rot=False,\n        )\n        return output[\"vertices\"]\n"
  },
  {
    "path": "kimodo/viz/soma_skin.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"SOMA skeleton skinning for visualization (SOMASkeleton30 / SOMASkeleton77).\"\"\"\n\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\nfrom kimodo.skeleton import (\n    SOMASkeleton30,\n    SOMASkeleton77,\n    batch_rigid_transform,\n    global_rots_to_local_rots,\n)\n\n# Skin for SOMASkeleton77\nSKEL_PATH = \"somaskel77\"\nSKIN_NAME = \"skin_standard.npz\"\n\n\nclass SOMASkin:\n    def __init__(self, skeleton):\n        skel_path = Path(skeleton.folder).parent / SKEL_PATH\n        skin_data_path = skel_path / SKIN_NAME\n\n        self.skeleton_input = skeleton\n        assert isinstance(\n            skeleton, (SOMASkeleton30, SOMASkeleton77)\n        ), \"SOMASkin currently only supports SOMASkeleton30 or SOMASkeleton77\"\n        assert skeleton.neutral_joints is not None, \"The skeleton must have neutral joints instantiated\"\n        device = skeleton.neutral_joints.device\n\n        # the skin is always the 77-joint skeleton\n        #   if user is using the 30-joint skeleton, we will pad it when skinning is called\n        self.skeleton_skin = SOMASkeleton77(skel_path).to(device)\n\n        # bind_rig_transform: [R, 4, 4]\n        # bind_vertices: [V, 3]\n        # faces: [F, 3]\n        # lbs indices, lbs weights: [V, W] (W = max (num joints vertice is related to), in our case W=5)\n        skin_data = np.load(skin_data_path)\n        bind_rig_np = np.array(skin_data[\"bind_rig_transform\"], dtype=np.float32)\n        self.bind_rig_transform = torch.from_numpy(bind_rig_np).to(device=device, dtype=torch.float)\n        # Precompute the inverse in numpy to avoid torch lazy evaluation issues\n        bind_rig_inv_np = np.linalg.inv(bind_rig_np)\n        self.bind_rig_transform_inv = torch.from_numpy(bind_rig_inv_np).to(device=device, dtype=torch.float)\n        self.bind_vertices = torch.tensor(skin_data[\"bind_vertices\"], device=device, dtype=torch.float)\n        self.faces = torch.tensor(skin_data[\"faces\"], device=device, dtype=torch.long)\n        self.lbs_indices = torch.tensor(skin_data[\"lbs_indices\"], device=device, dtype=torch.long)\n        self.lbs_weights = torch.tensor(skin_data[\"lbs_weights\"], device=device, dtype=torch.float)\n\n        # double check the rig matches expected skeleton\n        rig_joint_names = list(skin_data[\"rig_joint_names\"])  # list(str) : [R]\n        for sname, rname in zip(self.skeleton_skin.bone_order_names, rig_joint_names):\n            if sname != rname:\n                raise ValueError(f\"MISMATCH in skinnging rig: expected='{sname}' vs rig='{rname}'\")\n\n    def lbs(self, posed_transform):\n        bind_rig_transform_inv = self.bind_rig_transform_inv\n        bind_vertices = self.bind_vertices\n        lbs_weights = self.lbs_weights\n        # posed_transform: [B, F, J, 4, 4] or [B, J, 4, 4] or [J, 4, 4]\n        # unsqueeze to match posed_transform dim\n        for _ in range(posed_transform.dim() - 3):\n            bind_rig_transform_inv = bind_rig_transform_inv.unsqueeze(0)\n            bind_vertices = bind_vertices.unsqueeze(0)\n            lbs_weights = lbs_weights.unsqueeze(0)\n            # bind_rig_transform_inv: [..., R, 4, 4]\n            # bind_vertices: [..., V, 3]\n            # lbs_weights: [..., V, W]\n\n        affine_mat = (posed_transform @ bind_rig_transform_inv)[..., :3, :]  # [..., J, 3, 4]\n        vs = (\n            affine_mat[..., self.lbs_indices, :, :]\n            @ torch.concat([bind_vertices, torch.ones_like(bind_vertices[..., 0:1])], dim=-1)[..., None, :, None]\n        )  # [..., V, W, 3, 1]\n        ws = lbs_weights[..., None, None]\n        resv = (vs * ws).sum(dim=-3).squeeze(-1)  # [..., V, 3]\n        return resv\n\n    def skin(self, joint_rotmat, joint_pos, rot_is_global=False):\n        \"\"\"\n        joint_rotmat: [T, J, 3, 3] local or global joint rotation matrices\n        joint_pos: [T, J, 3] global joint positions\n        rot_is_global: bool, if True, joint_rotmat is global rotation matrices, otherwise it is local rotation matrices and FK is performed internally\n        \"\"\"\n        nF, nJ = joint_pos.shape[:2]\n        device = joint_rotmat.device\n\n        if nJ != self.skeleton_skin.nbjoints:\n            assert nJ == 30, \"SOMASkin currently only supports 30-joint or 77-joint skeletons\"\n\n            # make sure we have local joint rotations\n            if rot_is_global:\n                local_joint_rots_mats_subset = global_rots_to_local_rots(joint_rotmat, self.skeleton_input)\n            else:\n                local_joint_rots_mats_subset = joint_rotmat\n\n            local_joint_rots_mats = self.skeleton_input.to_SOMASkeleton77(local_joint_rots_mats_subset)\n\n            # FK to get the global joint pos and rot\n            neutral_joints_seq = self.skeleton_skin.neutral_joints[None].repeat((nF, 1, 1)).to(device)\n            new_joint_pos, joint_rotmat = batch_rigid_transform(\n                local_joint_rots_mats,\n                neutral_joints_seq,\n                self.skeleton_skin.joint_parents.to(device),\n                self.skeleton_skin.root_idx,\n            )\n            joint_pos = new_joint_pos + joint_pos[:, self.skeleton_input.root_idx : self.skeleton_input.root_idx + 1]\n            nJ = self.skeleton_skin.nbjoints\n            rot_is_global = True\n\n        # prepare full transformation matrices\n        fk_transform = torch.eye(4, device=device)[None, None].repeat(nF, nJ, 1, 1)\n        fk_transform[..., :3, 3] = joint_pos\n        if rot_is_global:\n            fk_transform[..., :3, :3] = joint_rotmat\n        else:\n            neutral_joints_seq = self.skeleton_skin.neutral_joints[None].repeat((nF, 1, 1)).to(device)\n            # FK to get the global rotations\n            _, global_joint_rotmat = batch_rigid_transform(\n                joint_rotmat,\n                neutral_joints_seq,\n                self.skeleton_skin.joint_parents.to(device),\n                self.skeleton_skin.root_idx,\n            )\n            fk_transform[..., :3, :3] = global_joint_rotmat\n\n        vertices = self.lbs(fk_transform)\n        return vertices\n"
  },
  {
    "path": "kimodo/viz/viser_utils.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\"\"\"Viser-based 3D viz: re-exports from viz submodules for backward compatibility.\"\"\"\n\nimport os\n\nfrom .constraint_ui import (\n    ConstraintSet,\n    EEJointsKeyframeSet,\n    FullbodyKeyframeSet,\n    RootKeyframe2DSet,\n    build_constraint_set_table_markdown,\n    update_interval,\n)\nfrom .gui import GuiElements\nfrom .playback import CharacterMotion\nfrom .scene import (\n    DARK_THEME,\n    LIGHT_THEME,\n    SKIN_CACHE,\n    Character,\n    SkeletonMesh,\n    WaypointMesh,\n)\n\n\ndef load_example_cases(examples_base_dir):\n    \"\"\"List subdirectories of examples_base_dir as a name -> path dict.\"\"\"\n    example_dirs = os.listdir(examples_base_dir)\n    example_names = sorted([d for d in example_dirs if os.path.isdir(os.path.join(examples_base_dir, d))])\n    return {name: os.path.join(examples_base_dir, name) for name in example_names}\n\n\n__all__ = [\n    \"Character\",\n    \"CharacterMotion\",\n    \"ConstraintSet\",\n    \"DARK_THEME\",\n    \"EEJointsKeyframeSet\",\n    \"FullbodyKeyframeSet\",\n    \"GuiElements\",\n    \"LIGHT_THEME\",\n    \"RootKeyframe2DSet\",\n    \"SKIN_CACHE\",\n    \"SkeletonMesh\",\n    \"WaypointMesh\",\n    \"build_constraint_set_table_markdown\",\n    \"load_example_cases\",\n    \"update_interval\",\n]\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"kimodo\"\nversion = \"1.0.0\"\ndescription = \"Kimodo motion generation model\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nlicense = {text = \"Apache-2.0\"}\ndependencies = [\n  \"hydra-core>=1.3\",\n  \"omegaconf>=2.3\",\n  \"numpy>=1.23\",\n  \"scipy>=1.10\",\n  \"transformers==5.1.0\",\n  \"urllib3>=2.6.3\",\n  \"boto3\",\n  \"peft>=0.18\",\n  \"einops>=0.7\",\n  \"tqdm>=4.0\",\n  \"packaging>=21.0\",\n  \"pydantic>=2.0\",\n  \"filelock>=3.20.3\",\n  \"gradio>=6.8.0\",\n  \"gradio_client>=1.0\",\n  \"trimesh>=3.21.7\",\n  \"scenepic>=1.1.0\",\n  \"pillow>=9.0\",\n  \"av>=16.1.0\",\n  \"bvhio\",\n]\n\n[project.optional-dependencies]\ndemo = [\n  \"viser @ git+https://github.com/nv-tlabs/kimodo-viser.git\",\n]\nsoma = [\n  \"py-soma-x @ git+https://github.com/NVlabs/SOMA-X.git\"\n]\nall = [\n  \"viser @ git+https://github.com/nv-tlabs/kimodo-viser.git\",\n  \"py-soma-x @ git+https://github.com/NVlabs/SOMA-X.git\"\n]\n\n[project.scripts]\nkimodo_gen = \"kimodo.scripts.generate:main\"\nkimodo_demo = \"kimodo.demo:main\"\nkimodo_textencoder = \"kimodo.scripts.run_text_encoder_server:main\"\nkimodo_convert = \"kimodo.scripts.motion_convert:main\"\n\n[tool.setuptools]\ninclude-package-data = true\nzip-safe = false\n\n[tool.setuptools.package-data]\nkimodo = [\"assets/**/*\"]\n\n[tool.flake8]\nmax-line-length = 120\n\n[tool.ruff]\nextend-select = [\"I001\"]  # Enable import sorting\nline-length = 120\n\n[tool.ruff.lint.isort]\nknown-first-party = [\"kimodo\"]\nknown-third-party = [\"torch\", \"numpy\", \"pytorch_lightning\", \"wandb\", \"tqdm\"]\nforce-sort-within-sections = false\n"
  },
  {
    "path": "setup.py",
    "content": "# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n\nimport os\nimport shutil\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom setuptools import Extension, find_packages, setup\nfrom setuptools.command.build_ext import build_ext\n\n\nclass CMakeExtension(Extension):\n    def __init__(self, name, sourcedir=\"\"):\n        super().__init__(name, sources=[])\n        self.sourcedir = os.path.abspath(sourcedir)\n\n\nclass CMakeBuild(build_ext):\n    def run(self):\n        try:\n            subprocess.check_output([\"cmake\", \"--version\"])\n        except OSError as exc:\n            raise RuntimeError(\"CMake must be installed to build this package\") from exc\n\n        for ext in self.extensions:\n            self.build_extension(ext)\n\n    def build_extension(self, ext):\n        extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}\",\n            f\"-DPYTHON_EXECUTABLE={sys.executable}\",\n        ]\n\n        cfg = \"Debug\" if self.debug else \"Release\"\n        build_args = [\"--config\", cfg]\n        cmake_args.append(f\"-DCMAKE_BUILD_TYPE={cfg}\")\n\n        use_mingw = False\n        mingw_bin = None\n\n        if sys.platform == \"win32\":\n            generator = os.environ.get(\"CMAKE_GENERATOR\", \"\")\n            if generator:\n                cmake_args = [\"-G\", generator] + cmake_args\n                if \"mingw\" in generator.lower():\n                    use_mingw = True\n                else:\n                    cmake_args.append(f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\")\n            else:\n                try:\n                    subprocess.check_output([\"g++\", \"--version\"], stderr=subprocess.STDOUT)\n                    use_mingw = True\n                    cmake_args = [\"-G\", \"MinGW Makefiles\"] + cmake_args\n                    build_args = []\n                except (OSError, subprocess.CalledProcessError):\n                    cmake_args.append(f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\")\n\n            if use_mingw:\n                gxx_path = shutil.which(\"g++\")\n                if gxx_path:\n                    mingw_bin = Path(gxx_path).parent\n        else:\n            build_args += [\"--\", \"-j4\"]\n\n        env = os.environ.copy()\n        env[\"CXXFLAGS\"] = f'{env.get(\"CXXFLAGS\", \"\")} -DVERSION_INFO=\\\\\"{self.distribution.get_version()}\\\\\"'\n\n        if not os.path.exists(self.build_temp):\n            os.makedirs(self.build_temp)\n\n        subprocess.check_call([\"cmake\", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env)\n        subprocess.check_call([\"cmake\", \"--build\", \".\"] + build_args, cwd=self.build_temp)\n\n        if use_mingw and mingw_bin is not None:\n            runtime_libs = [\n                \"libstdc++-6.dll\",\n                \"libgcc_s_seh-1.dll\",\n                \"libwinpthread-1.dll\",\n            ]\n            extdir_path = Path(extdir)\n            extdir_path.mkdir(parents=True, exist_ok=True)\n            for lib_name in runtime_libs:\n                src_path = mingw_bin / lib_name\n                if src_path.exists():\n                    shutil.copy2(src_path, extdir_path / lib_name)\n                else:\n                    self.announce(\n                        f\"Warning: Expected MinGW runtime DLL '{lib_name}' not found next to g++ (looked in {mingw_bin}). \"\n                        \"The built extension may fail to import if the DLL is not on PATH.\",\n                        level=3,\n                    )\n\n\nkimodo_packages = find_packages(include=[\"kimodo\", \"kimodo.*\"])\n\n# When set (e.g. in Docker), do not bundle motion_correction here; it is installed\n# separately (e.g. from docker_requirements.txt as ./MotionCorrection) non-editable.\nskip_motion_correction = os.environ.get(\"SKIP_MOTION_CORRECTION_IN_SETUP\", \"\").strip().lower() in (\"1\", \"true\", \"yes\")\n\nif skip_motion_correction:\n    packages = kimodo_packages\n    package_dir = {}\n    ext_modules = []\n    cmdclass = {}\nelse:\n    packages = kimodo_packages + [\"motion_correction\"]\n    package_dir = {\"motion_correction\": \"MotionCorrection/python/motion_correction\"}\n    ext_modules = [CMakeExtension(\"motion_correction._motion_correction\", \"MotionCorrection\")]\n    cmdclass = {\"build_ext\": CMakeBuild}\n\nsetup(\n    packages=packages,\n    package_dir=package_dir,\n    ext_modules=ext_modules,\n    cmdclass=cmdclass,\n)\n"
  }
]